From b7bb8bba8045ca5ffa9ff04531bfc915f260b3dd Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 24 Apr 2024 16:25:55 +0000 Subject: [PATCH] remove EventConfig This is a partial rollback of commit 5dde281 --- src/refiners/training_utils/callback.py | 101 +----------------------- src/refiners/training_utils/trainer.py | 4 +- tests/training_utils/mock_config.toml | 10 +-- tests/training_utils/test_trainer.py | 35 ++++++-- 4 files changed, 36 insertions(+), 114 deletions(-) diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index bcca3fd..b5471a2 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -1,16 +1,6 @@ -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from pydantic import BaseModel, ConfigDict, field_validator - -from refiners.training_utils.common import ( - Epoch, - Iteration, - Step, - TimeValue, - TimeValueInput, - parse_number_unit_field, - scoped_seed, -) +from pydantic import BaseModel, ConfigDict if TYPE_CHECKING: from refiners.training_utils.trainer import Trainer @@ -18,65 +8,6 @@ if TYPE_CHECKING: T = TypeVar("T", bound="Trainer[Any, Any]") -class StepEventConfig(BaseModel): - """ - Base configuration for an event that is triggered at every step. - - - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and - restored after the event. - - `interval`: Interval at which the event should be triggered. The interval is defined by either a `Step` object, - an `Iteration` object, or an `Epoch` object. - """ - - model_config = ConfigDict(extra="forbid") - seed: int | None = None - interval: Step | Iteration | Epoch = Step(1) - - @field_validator("interval", mode="before") - def parse_field(cls, value: TimeValueInput) -> TimeValue: - return parse_number_unit_field(value) - - -class IterationEventConfig(BaseModel): - """ - Base configuration for an event that is triggered only once per iteration. - - - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and - restored after the event. - - `interval`: Interval at which the event should be triggered. The interval is defined by an `Iteration` object or - a `Epoch` object. - """ - - model_config = ConfigDict(extra="forbid") - seed: int | None = None - interval: Iteration | Epoch = Iteration(1) - - @field_validator("interval", mode="before") - def parse_field(cls, value: TimeValueInput) -> TimeValue: - return parse_number_unit_field(value) - - -class EpochEventConfig(BaseModel): - """ - Base configuration for an event that is triggered only once per epoch. - - - `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and - restored after the event. - - `interval`: Interval at which the event should be triggered. The interval is defined by a `Epoch` object. - """ - - model_config = ConfigDict(extra="forbid") - seed: int | None = None - interval: Epoch = Epoch(1) - - @field_validator("interval", mode="before") - def parse_field(cls, value: TimeValueInput) -> TimeValue: - return parse_number_unit_field(value) - - -EventConfig = StepEventConfig | IterationEventConfig | EpochEventConfig - - class CallbackConfig(BaseModel): """ Base configuration for a callback. @@ -85,37 +16,9 @@ class CallbackConfig(BaseModel): """ model_config = ConfigDict(extra="forbid") - on_epoch_begin: EpochEventConfig = EpochEventConfig() - on_epoch_end: EpochEventConfig = EpochEventConfig() - on_batch_begin: StepEventConfig = StepEventConfig() - on_batch_end: StepEventConfig = StepEventConfig() - on_backward_begin: StepEventConfig = StepEventConfig() - on_backward_end: StepEventConfig = StepEventConfig() - on_optimizer_step_begin: IterationEventConfig = IterationEventConfig() - on_optimizer_step_end: IterationEventConfig = IterationEventConfig() - on_compute_loss_begin: StepEventConfig = StepEventConfig() - on_compute_loss_end: StepEventConfig = StepEventConfig() - on_evaluate_begin: IterationEventConfig = IterationEventConfig() - on_evaluate_end: IterationEventConfig = IterationEventConfig() - on_lr_scheduler_step_begin: IterationEventConfig = IterationEventConfig() - on_lr_scheduler_step_end: IterationEventConfig = IterationEventConfig() class Callback(Generic[T]): - def run_event(self, trainer: T, callback_name: str, event_name: str) -> None: - if not hasattr(self, event_name): - return - callback_config = getattr(trainer.config, callback_name) - # For event that run once, there is no configuration to check, e.g. on_train_begin - if not hasattr(callback_config, event_name): - getattr(self, event_name)(trainer) - return - event_config = cast(EventConfig, getattr(callback_config, event_name)) - if not trainer.clock.is_due(event_config.interval): - return - with scoped_seed(event_config.seed): - getattr(self, event_name)(trainer) - def on_init_begin(self, trainer: T) -> None: ... def on_init_end(self, trainer: T) -> None: ... diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 392779a..10cbf48 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -424,8 +424,8 @@ class Trainer(Generic[ConfigType, Batch], ABC): item.model.eval() def _call_callbacks(self, event_name: str) -> None: - for name, callback in self.callbacks.items(): - callback.run_event(trainer=self, callback_name=name, event_name=event_name) + for callback in self.callbacks.values(): + getattr(callback, event_name)(self) def _load_callbacks(self) -> None: for name, config in self.config: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index f48ae9c..a42a746 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -1,11 +1,9 @@ -[mock_callback.on_optimizer_step_begin] -interval = "2:iteration" -seed = 42 +[mock_callback] +on_batch_end_interval = "3:step" +on_batch_end_seed = 42 +on_optimizer_step_interval = "2:iteration" -[mock_callback.on_batch_end] -interval = "3:step" - [mock_model] requires_grad = true diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 0349c4a..fd62312 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -6,6 +6,7 @@ from typing import cast import pytest import torch +from pydantic import field_validator from torch import Tensor, nn from torch.optim import SGD @@ -16,8 +17,12 @@ from refiners.training_utils.common import ( Epoch, Iteration, Step, + TimeValue, + TimeValueInput, count_learnable_parameters, human_readable_number, + parse_number_unit_field, + scoped_seed, ) from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.trainer import ( @@ -41,9 +46,19 @@ class MockModelConfig(ModelConfig): use_activation: bool +class MockCallbackConfig(CallbackConfig): + on_batch_end_interval: Step | Iteration | Epoch + on_batch_end_seed: int + on_optimizer_step_interval: Iteration | Epoch + + @field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before") + def parse_field(cls, value: TimeValueInput) -> TimeValue: + return parse_number_unit_field(value) + + class MockConfig(BaseConfig): mock_model: MockModelConfig - mock_callback: CallbackConfig + mock_callback: MockCallbackConfig class MockModel(fl.Chain): @@ -60,7 +75,8 @@ class MockModel(fl.Chain): class MockCallback(Callback["MockTrainer"]): - def __init__(self) -> None: + def __init__(self, config: MockCallbackConfig) -> None: + self.config = config self.optimizer_step_count = 0 self.batch_end_count = 0 self.optimizer_step_random_int: int | None = None @@ -70,12 +86,17 @@ class MockCallback(Callback["MockTrainer"]): pass def on_optimizer_step_begin(self, trainer: "MockTrainer") -> None: + if not trainer.clock.is_due(self.config.on_optimizer_step_interval): + return self.optimizer_step_count += 1 self.optimizer_step_random_int = random.randint(0, 100) def on_batch_end(self, trainer: "MockTrainer") -> None: + if not trainer.clock.is_due(self.config.on_batch_end_interval): + return self.batch_end_count += 1 - self.batch_end_random_int = random.randint(0, 100) + with scoped_seed(self.config.on_batch_end_seed): + self.batch_end_random_int = random.randint(0, 100) class MockTrainer(Trainer[MockConfig, MockBatch]): @@ -96,8 +117,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): ) @register_callback() - def mock_callback(self, config: CallbackConfig) -> MockCallback: - return MockCallback() + def mock_callback(self, config: MockCallbackConfig) -> MockCallback: + return MockCallback(config) @register_model() def mock_model(self, config: MockModelConfig) -> MockModel: @@ -264,8 +285,8 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None: assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3 # Check that the random seed was set - assert mock_trainer.mock_callback.optimizer_step_random_int == 81 - assert mock_trainer.mock_callback.batch_end_random_int == 72 + assert mock_trainer.mock_callback.optimizer_step_random_int == 93 + assert mock_trainer.mock_callback.batch_end_random_int == 81 def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: