mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
adding more log messages in training_utils
This commit is contained in:
parent
be2368cf20
commit
5e7986ef08
|
@ -170,12 +170,18 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
self.log(f"Epoch {trainer.clock.epoch} started.")
|
||||
|
||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Epoch {trainer.clock.epoch} ended.")
|
||||
trainer.clock.epoch += 1
|
||||
trainer.clock.num_batches_processed = 0
|
||||
|
||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
if self.num_minibatches_processed == 0:
|
||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||
self.log(f"Step {trainer.clock.step} started.")
|
||||
|
||||
def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Step {trainer.clock.step} ended.")
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.step += 1
|
||||
trainer.clock.num_batches_processed += 1
|
||||
|
|
|
@ -38,7 +38,7 @@ def human_readable_number(number: int) -> str:
|
|||
def seed_everything(seed: int | None = None) -> None:
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Using random seed: {seed}")
|
||||
logger.info(f"Using random seed: {seed}")
|
||||
random.seed(a=seed)
|
||||
np.random.seed(seed=seed)
|
||||
manual_seed(seed=seed)
|
||||
|
@ -67,6 +67,7 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
|
|||
actual_seed = seed(*args) if callable(seed) else seed
|
||||
seed_everything(seed=actual_seed)
|
||||
result = func(*args, **kwargs)
|
||||
logger.debug(f"Restoring previous seed state")
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(numpy_state)
|
||||
torch.set_rng_state(torch_state)
|
||||
|
|
|
@ -103,8 +103,11 @@ def register_model():
|
|||
model = func(self, config)
|
||||
model = model.to(self.device, dtype=self.dtype)
|
||||
if config.requires_grad is not None:
|
||||
logger.info(f"Setting requires_grad to {config.requires_grad} for model: {name}")
|
||||
model.requires_grad_(requires_grad=config.requires_grad)
|
||||
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
|
||||
numel = sum(param.numel() for param in learnable_parameters)
|
||||
logger.info(f"Number of learnable parameters in {name}: {human_readable_number(numel)}")
|
||||
self.models[name] = ModelItem(
|
||||
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue