From 7992258dd21cd61ff0201e985f3b2dbc3d553478 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Mon, 11 Dec 2023 14:33:06 +0100 Subject: [PATCH] add before/after init callback to trainer --- src/refiners/training_utils/callback.py | 6 ++++++ src/refiners/training_utils/trainer.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 7f9bdef..f1471ec 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -43,6 +43,12 @@ T = TypeVar("T") class Callback(Generic[T]): + def on_init_begin(self, trainer: T) -> None: + ... + + def on_init_end(self, trainer: T) -> None: + ... + def on_train_begin(self, trainer: T) -> None: ... diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 4ccb014..87276d8 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -278,10 +278,12 @@ class Trainer(Generic[ConfigType, Batch]): ) self.callbacks = callbacks or [] self.callbacks += self.default_callbacks() + self._call_callbacks(event_name="on_init_begin") self.load_wandb() self.load_models() self.prepare_models() self.prepare_checkpointing() + self._call_callbacks(event_name="on_init_end") def default_callbacks(self) -> list[Callback[Any]]: return [