mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
add before/after init callback to trainer
This commit is contained in:
parent
42a0fc4aa0
commit
7992258dd2
|
@ -43,6 +43,12 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
|
def on_init_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_init_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_train_begin(self, trainer: T) -> None:
|
def on_train_begin(self, trainer: T) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -278,10 +278,12 @@ class Trainer(Generic[ConfigType, Batch]):
|
||||||
)
|
)
|
||||||
self.callbacks = callbacks or []
|
self.callbacks = callbacks or []
|
||||||
self.callbacks += self.default_callbacks()
|
self.callbacks += self.default_callbacks()
|
||||||
|
self._call_callbacks(event_name="on_init_begin")
|
||||||
self.load_wandb()
|
self.load_wandb()
|
||||||
self.load_models()
|
self.load_models()
|
||||||
self.prepare_models()
|
self.prepare_models()
|
||||||
self.prepare_checkpointing()
|
self.prepare_checkpointing()
|
||||||
|
self._call_callbacks(event_name="on_init_end")
|
||||||
|
|
||||||
def default_callbacks(self) -> list[Callback[Any]]:
|
def default_callbacks(self) -> list[Callback[Any]]:
|
||||||
return [
|
return [
|
||||||
|
|
Loading…
Reference in a new issue