diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 644fd83..bcca3fd 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -1,12 +1,80 @@ -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, field_validator + +from refiners.training_utils.common import ( + Epoch, + Iteration, + Step, + TimeValue, + TimeValueInput, + parse_number_unit_field, + scoped_seed, +) if TYPE_CHECKING: - from refiners.training_utils.config import BaseConfig from refiners.training_utils.trainer import Trainer -T = TypeVar("T", bound="Trainer[BaseConfig, Any]") +T = TypeVar("T", bound="Trainer[Any, Any]") + + +class StepEventConfig(BaseModel): + """ + Base configuration for an event that is triggered at every step. + + - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and + restored after the event. + - `interval`: Interval at which the event should be triggered. The interval is defined by either a `Step` object, + an `Iteration` object, or an `Epoch` object. + """ + + model_config = ConfigDict(extra="forbid") + seed: int | None = None + interval: Step | Iteration | Epoch = Step(1) + + @field_validator("interval", mode="before") + def parse_field(cls, value: TimeValueInput) -> TimeValue: + return parse_number_unit_field(value) + + +class IterationEventConfig(BaseModel): + """ + Base configuration for an event that is triggered only once per iteration. + + - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and + restored after the event. + - `interval`: Interval at which the event should be triggered. The interval is defined by an `Iteration` object or + a `Epoch` object. + """ + + model_config = ConfigDict(extra="forbid") + seed: int | None = None + interval: Iteration | Epoch = Iteration(1) + + @field_validator("interval", mode="before") + def parse_field(cls, value: TimeValueInput) -> TimeValue: + return parse_number_unit_field(value) + + +class EpochEventConfig(BaseModel): + """ + Base configuration for an event that is triggered only once per epoch. + + - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and + restored after the event. + - `interval`: Interval at which the event should be triggered. The interval is defined by a `Epoch` object. + """ + + model_config = ConfigDict(extra="forbid") + seed: int | None = None + interval: Epoch = Epoch(1) + + @field_validator("interval", mode="before") + def parse_field(cls, value: TimeValueInput) -> TimeValue: + return parse_number_unit_field(value) + + +EventConfig = StepEventConfig | IterationEventConfig | EpochEventConfig class CallbackConfig(BaseModel): @@ -17,9 +85,37 @@ class CallbackConfig(BaseModel): """ model_config = ConfigDict(extra="forbid") + on_epoch_begin: EpochEventConfig = EpochEventConfig() + on_epoch_end: EpochEventConfig = EpochEventConfig() + on_batch_begin: StepEventConfig = StepEventConfig() + on_batch_end: StepEventConfig = StepEventConfig() + on_backward_begin: StepEventConfig = StepEventConfig() + on_backward_end: StepEventConfig = StepEventConfig() + on_optimizer_step_begin: IterationEventConfig = IterationEventConfig() + on_optimizer_step_end: IterationEventConfig = IterationEventConfig() + on_compute_loss_begin: StepEventConfig = StepEventConfig() + on_compute_loss_end: StepEventConfig = StepEventConfig() + on_evaluate_begin: IterationEventConfig = IterationEventConfig() + on_evaluate_end: IterationEventConfig = IterationEventConfig() + on_lr_scheduler_step_begin: IterationEventConfig = IterationEventConfig() + on_lr_scheduler_step_end: IterationEventConfig = IterationEventConfig() class Callback(Generic[T]): + def run_event(self, trainer: T, callback_name: str, event_name: str) -> None: + if not hasattr(self, event_name): + return + callback_config = getattr(trainer.config, callback_name) + # For event that run once, there is no configuration to check, e.g. on_train_begin + if not hasattr(callback_config, event_name): + getattr(self, event_name)(trainer) + return + event_config = cast(EventConfig, getattr(callback_config, event_name)) + if not trainer.clock.is_due(event_config.interval): + return + with scoped_seed(event_config.seed): + getattr(self, event_name)(trainer) + def on_init_begin(self, trainer: T) -> None: ... def on_init_end(self, trainer: T) -> None: ... diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index ba0f3f3..559ee1b 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -89,6 +89,9 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): def num_step_per_evaluation(self) -> int: return self.convert_time_value_to_steps(self.evaluation_interval) + def is_due(self, interval: TimeValue) -> bool: + return self.step % self.convert_time_value_to_steps(interval) == 0 + def reset(self) -> None: self.start_time = None self.end_time = None @@ -109,30 +112,14 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): assert self.start_time is not None, "Timer has not been started yet." return int(time.time() - self.start_time) - @cached_property - def evaluation_interval_steps(self) -> int: - return self.convert_time_value_to_steps(self.evaluation_interval) - - @cached_property - def lr_scheduler_interval_steps(self) -> int: - return self.convert_time_value_to_steps(self.lr_scheduler_interval) - @property def is_optimizer_step(self) -> bool: return self.num_minibatches_processed == self.num_step_per_iteration - @property - def is_lr_scheduler_step(self) -> bool: - return self.step % self.lr_scheduler_interval_steps == 0 - @property def done(self) -> bool: return self.step >= self.num_steps - @property - def is_evaluation_step(self) -> bool: - return self.step % self.evaluation_interval_steps == 0 - def log(self, message: str, /) -> None: if self.verbose: logger.info(message) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 0eb2335..392779a 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -361,11 +361,11 @@ class Trainer(Generic[ConfigType, Batch], ABC): self.optimizer.step() self.optimizer.zero_grad() self._call_callbacks(event_name="on_optimizer_step_end") - if self.clock.is_lr_scheduler_step: + if self.clock.is_due(self.config.lr_scheduler.update_interval): self._call_callbacks(event_name="on_lr_scheduler_step_begin") self.lr_scheduler.step() self._call_callbacks(event_name="on_lr_scheduler_step_end") - if self.clock.is_evaluation_step: + if self.clock.is_due(self.config.training.evaluation_interval): self.evaluate() def step(self, batch: Batch) -> None: @@ -424,8 +424,8 @@ class Trainer(Generic[ConfigType, Batch], ABC): item.model.eval() def _call_callbacks(self, event_name: str) -> None: - for callback in self.callbacks.values(): - getattr(callback, event_name)(self) + for name, callback in self.callbacks.items(): + callback.run_event(trainer=self, callback_name=name, event_name=event_name) def _load_callbacks(self) -> None: for name, config in self.config: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 0b48702..f48ae9c 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -1,3 +1,12 @@ +[mock_callback.on_optimizer_step_begin] +interval = "2:iteration" +seed = 42 + + +[mock_callback.on_batch_end] +interval = "3:step" + + [mock_model] requires_grad = true use_activation = true diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 3c7f45c..0349c4a 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -1,3 +1,4 @@ +import random import warnings from dataclasses import dataclass from pathlib import Path @@ -10,6 +11,7 @@ from torch.optim import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm +from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.common import ( Epoch, Iteration, @@ -24,6 +26,7 @@ from refiners.training_utils.trainer import ( WarmupScheduler, count_learnable_parameters, human_readable_number, + register_callback, register_model, ) @@ -40,6 +43,7 @@ class MockModelConfig(ModelConfig): class MockConfig(BaseConfig): mock_model: MockModelConfig + mock_callback: CallbackConfig class MockModel(fl.Chain): @@ -55,6 +59,25 @@ class MockModel(fl.Chain): self.insert(3, fl.SiLU()) +class MockCallback(Callback["MockTrainer"]): + def __init__(self) -> None: + self.optimizer_step_count = 0 + self.batch_end_count = 0 + self.optimizer_step_random_int: int | None = None + self.batch_end_random_int: int | None = None + + def on_init_begin(self, trainer: "MockTrainer") -> None: + pass + + def on_optimizer_step_begin(self, trainer: "MockTrainer") -> None: + self.optimizer_step_count += 1 + self.optimizer_step_random_int = random.randint(0, 100) + + def on_batch_end(self, trainer: "MockTrainer") -> None: + self.batch_end_count += 1 + self.batch_end_random_int = random.randint(0, 100) + + class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 model_registration_counter: int = 0 @@ -72,6 +95,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): targets=torch.cat([b.targets for b in batch]), ) + @register_callback() + def mock_callback(self, config: CallbackConfig) -> MockCallback: + return MockCallback() + @register_model() def mock_model(self, config: MockModelConfig) -> MockModel: model = MockModel() @@ -198,9 +225,9 @@ def test_timer_functionality(training_clock: TrainingClock) -> None: def test_state_based_properties(training_clock: TrainingClock) -> None: training_clock.step = 5 # Halfway through the first epoch - assert not training_clock.is_evaluation_step # Assuming evaluation every epoch + assert not training_clock.is_due(training_clock.evaluation_interval) # Assuming evaluation every epoch training_clock.step = 10 # End of the first epoch - assert training_clock.is_evaluation_step + assert training_clock.is_due(training_clock.evaluation_interval) def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: @@ -219,7 +246,7 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size assert mock_trainer.step_counter == 0 - assert mock_trainer.clock.epoch == 0 + assert clock.epoch == 0 mock_trainer.train() @@ -229,6 +256,18 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: assert mock_trainer.step_counter == mock_trainer.clock.step +def test_callback_registration(mock_trainer: MockTrainer) -> None: + mock_trainer.train() + + # Check that the callback skips every other iteration + assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3 + + # Check that the random seed was set + assert mock_trainer.mock_callback.optimizer_step_random_int == 81 + assert mock_trainer.mock_callback.batch_end_random_int == 72 + + def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: clock = mock_trainer_short.clock config = mock_trainer_short.config