From 5e7986ef085346ffbb985bb1ae7134b8a5d5092e Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 7 Mar 2024 14:22:04 +0000 Subject: [PATCH] adding more log messages in training_utils --- src/refiners/training_utils/clock.py | 6 ++++++ src/refiners/training_utils/common.py | 3 ++- src/refiners/training_utils/trainer.py | 3 +++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 8eb3936..49d70a4 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -170,12 +170,18 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): self.log(f"Epoch {trainer.clock.epoch} started.") def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log(f"Epoch {trainer.clock.epoch} ended.") trainer.clock.epoch += 1 trainer.clock.num_batches_processed = 0 def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + if self.num_minibatches_processed == 0: + self.log(f"Iteration {trainer.clock.iteration} started.") self.log(f"Step {trainer.clock.step} started.") + def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log(f"Step {trainer.clock.step} ended.") + def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: trainer.clock.step += 1 trainer.clock.num_batches_processed += 1 diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index eff12aa..b6fb131 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -38,7 +38,7 @@ def human_readable_number(number: int) -> str: def seed_everything(seed: int | None = None) -> None: if seed is None: seed = random.randint(0, 2**32 - 1) - logger.info(f"Using random seed: {seed}") + logger.info(f"Using random seed: {seed}") random.seed(a=seed) np.random.seed(seed=seed) manual_seed(seed=seed) @@ -67,6 +67,7 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C actual_seed = seed(*args) if callable(seed) else seed seed_everything(seed=actual_seed) result = func(*args, **kwargs) + logger.debug(f"Restoring previous seed state") random.setstate(random_state) np.random.set_state(numpy_state) torch.set_rng_state(torch_state) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 8653286..3875a9b 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -103,8 +103,11 @@ def register_model(): model = func(self, config) model = model.to(self.device, dtype=self.dtype) if config.requires_grad is not None: + logger.info(f"Setting requires_grad to {config.requires_grad} for model: {name}") model.requires_grad_(requires_grad=config.requires_grad) learnable_parameters = [param for param in model.parameters() if param.requires_grad] + numel = sum(param.numel() for param in learnable_parameters) + logger.info(f"Number of learnable parameters in {name}: {human_readable_number(numel)}") self.models[name] = ModelItem( name=name, config=config, model=model, learnable_parameters=learnable_parameters )