add before/after init callback to trainer

This commit is contained in:
limiteinductive 2023-12-11 14:33:06 +01:00 committed by Benjamin Trom
parent 42a0fc4aa0
commit 7992258dd2
2 changed files with 8 additions and 0 deletions

View file

@ -43,6 +43,12 @@ T = TypeVar("T")
class Callback(Generic[T]): class Callback(Generic[T]):
def on_init_begin(self, trainer: T) -> None:
...
def on_init_end(self, trainer: T) -> None:
...
def on_train_begin(self, trainer: T) -> None: def on_train_begin(self, trainer: T) -> None:
... ...

View file

@ -278,10 +278,12 @@ class Trainer(Generic[ConfigType, Batch]):
) )
self.callbacks = callbacks or [] self.callbacks = callbacks or []
self.callbacks += self.default_callbacks() self.callbacks += self.default_callbacks()
self._call_callbacks(event_name="on_init_begin")
self.load_wandb() self.load_wandb()
self.load_models() self.load_models()
self.prepare_models() self.prepare_models()
self.prepare_checkpointing() self.prepare_checkpointing()
self._call_callbacks(event_name="on_init_end")
def default_callbacks(self) -> list[Callback[Any]]: def default_callbacks(self) -> list[Callback[Any]]:
return [ return [