mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add before/after init callback to trainer
This commit is contained in:
parent
42a0fc4aa0
commit
7992258dd2
|
@ -43,6 +43,12 @@ T = TypeVar("T")
|
|||
|
||||
|
||||
class Callback(Generic[T]):
|
||||
def on_init_begin(self, trainer: T) -> None:
|
||||
...
|
||||
|
||||
def on_init_end(self, trainer: T) -> None:
|
||||
...
|
||||
|
||||
def on_train_begin(self, trainer: T) -> None:
|
||||
...
|
||||
|
||||
|
|
|
@ -278,10 +278,12 @@ class Trainer(Generic[ConfigType, Batch]):
|
|||
)
|
||||
self.callbacks = callbacks or []
|
||||
self.callbacks += self.default_callbacks()
|
||||
self._call_callbacks(event_name="on_init_begin")
|
||||
self.load_wandb()
|
||||
self.load_models()
|
||||
self.prepare_models()
|
||||
self.prepare_checkpointing()
|
||||
self._call_callbacks(event_name="on_init_end")
|
||||
|
||||
def default_callbacks(self) -> list[Callback[Any]]:
|
||||
return [
|
||||
|
|
Loading…
Reference in a new issue