adding more log messages in training_utils

This commit is contained in:
Laurent 2024-03-07 14:22:04 +00:00 committed by Laureηt
parent be2368cf20
commit 5e7986ef08
3 changed files with 11 additions and 1 deletions

View file

@ -170,12 +170,18 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.log(f"Epoch {trainer.clock.epoch} started.") self.log(f"Epoch {trainer.clock.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} ended.")
trainer.clock.epoch += 1 trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0 trainer.clock.num_batches_processed = 0
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0:
self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.") self.log(f"Step {trainer.clock.step} started.")
def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.step += 1 trainer.clock.step += 1
trainer.clock.num_batches_processed += 1 trainer.clock.num_batches_processed += 1

View file

@ -67,6 +67,7 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
actual_seed = seed(*args) if callable(seed) else seed actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed) seed_everything(seed=actual_seed)
result = func(*args, **kwargs) result = func(*args, **kwargs)
logger.debug(f"Restoring previous seed state")
random.setstate(random_state) random.setstate(random_state)
np.random.set_state(numpy_state) np.random.set_state(numpy_state)
torch.set_rng_state(torch_state) torch.set_rng_state(torch_state)

View file

@ -103,8 +103,11 @@ def register_model():
model = func(self, config) model = func(self, config)
model = model.to(self.device, dtype=self.dtype) model = model.to(self.device, dtype=self.dtype)
if config.requires_grad is not None: if config.requires_grad is not None:
logger.info(f"Setting requires_grad to {config.requires_grad} for model: {name}")
model.requires_grad_(requires_grad=config.requires_grad) model.requires_grad_(requires_grad=config.requires_grad)
learnable_parameters = [param for param in model.parameters() if param.requires_grad] learnable_parameters = [param for param in model.parameters() if param.requires_grad]
numel = sum(param.numel() for param in learnable_parameters)
logger.info(f"Number of learnable parameters in {name}: {human_readable_number(numel)}")
self.models[name] = ModelItem( self.models[name] = ModelItem(
name=name, config=config, model=model, learnable_parameters=learnable_parameters name=name, config=config, model=model, learnable_parameters=learnable_parameters
) )