mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
adding more log messages in training_utils
This commit is contained in:
parent
be2368cf20
commit
5e7986ef08
|
@ -170,12 +170,18 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
self.log(f"Epoch {trainer.clock.epoch} started.")
|
self.log(f"Epoch {trainer.clock.epoch} started.")
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log(f"Epoch {trainer.clock.epoch} ended.")
|
||||||
trainer.clock.epoch += 1
|
trainer.clock.epoch += 1
|
||||||
trainer.clock.num_batches_processed = 0
|
trainer.clock.num_batches_processed = 0
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
if self.num_minibatches_processed == 0:
|
||||||
|
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||||
self.log(f"Step {trainer.clock.step} started.")
|
self.log(f"Step {trainer.clock.step} started.")
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log(f"Step {trainer.clock.step} ended.")
|
||||||
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
trainer.clock.step += 1
|
trainer.clock.step += 1
|
||||||
trainer.clock.num_batches_processed += 1
|
trainer.clock.num_batches_processed += 1
|
||||||
|
|
|
@ -67,6 +67,7 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
|
||||||
actual_seed = seed(*args) if callable(seed) else seed
|
actual_seed = seed(*args) if callable(seed) else seed
|
||||||
seed_everything(seed=actual_seed)
|
seed_everything(seed=actual_seed)
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
|
logger.debug(f"Restoring previous seed state")
|
||||||
random.setstate(random_state)
|
random.setstate(random_state)
|
||||||
np.random.set_state(numpy_state)
|
np.random.set_state(numpy_state)
|
||||||
torch.set_rng_state(torch_state)
|
torch.set_rng_state(torch_state)
|
||||||
|
|
|
@ -103,8 +103,11 @@ def register_model():
|
||||||
model = func(self, config)
|
model = func(self, config)
|
||||||
model = model.to(self.device, dtype=self.dtype)
|
model = model.to(self.device, dtype=self.dtype)
|
||||||
if config.requires_grad is not None:
|
if config.requires_grad is not None:
|
||||||
|
logger.info(f"Setting requires_grad to {config.requires_grad} for model: {name}")
|
||||||
model.requires_grad_(requires_grad=config.requires_grad)
|
model.requires_grad_(requires_grad=config.requires_grad)
|
||||||
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
|
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
|
||||||
|
numel = sum(param.numel() for param in learnable_parameters)
|
||||||
|
logger.info(f"Number of learnable parameters in {name}: {human_readable_number(numel)}")
|
||||||
self.models[name] = ModelItem(
|
self.models[name] = ModelItem(
|
||||||
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue