remove EventConfig

This is a partial rollback of commit 5dde281
This commit is contained in:
limiteinductive 2024-04-24 16:25:55 +00:00 committed by Benjamin Trom
parent 7aff743019
commit b7bb8bba80
4 changed files with 36 additions and 114 deletions

View file

@ -1,16 +1,6 @@
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict
from refiners.training_utils.common import (
Epoch,
Iteration,
Step,
TimeValue,
TimeValueInput,
parse_number_unit_field,
scoped_seed,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.training_utils.trainer import Trainer from refiners.training_utils.trainer import Trainer
@ -18,65 +8,6 @@ if TYPE_CHECKING:
T = TypeVar("T", bound="Trainer[Any, Any]") T = TypeVar("T", bound="Trainer[Any, Any]")
class StepEventConfig(BaseModel):
"""
Base configuration for an event that is triggered at every step.
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
restored after the event.
- `interval`: Interval at which the event should be triggered. The interval is defined by either a `Step` object,
an `Iteration` object, or an `Epoch` object.
"""
model_config = ConfigDict(extra="forbid")
seed: int | None = None
interval: Step | Iteration | Epoch = Step(1)
@field_validator("interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class IterationEventConfig(BaseModel):
"""
Base configuration for an event that is triggered only once per iteration.
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
restored after the event.
- `interval`: Interval at which the event should be triggered. The interval is defined by an `Iteration` object or
a `Epoch` object.
"""
model_config = ConfigDict(extra="forbid")
seed: int | None = None
interval: Iteration | Epoch = Iteration(1)
@field_validator("interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class EpochEventConfig(BaseModel):
"""
Base configuration for an event that is triggered only once per epoch.
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
restored after the event.
- `interval`: Interval at which the event should be triggered. The interval is defined by a `Epoch` object.
"""
model_config = ConfigDict(extra="forbid")
seed: int | None = None
interval: Epoch = Epoch(1)
@field_validator("interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
EventConfig = StepEventConfig | IterationEventConfig | EpochEventConfig
class CallbackConfig(BaseModel): class CallbackConfig(BaseModel):
""" """
Base configuration for a callback. Base configuration for a callback.
@ -85,37 +16,9 @@ class CallbackConfig(BaseModel):
""" """
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
on_epoch_begin: EpochEventConfig = EpochEventConfig()
on_epoch_end: EpochEventConfig = EpochEventConfig()
on_batch_begin: StepEventConfig = StepEventConfig()
on_batch_end: StepEventConfig = StepEventConfig()
on_backward_begin: StepEventConfig = StepEventConfig()
on_backward_end: StepEventConfig = StepEventConfig()
on_optimizer_step_begin: IterationEventConfig = IterationEventConfig()
on_optimizer_step_end: IterationEventConfig = IterationEventConfig()
on_compute_loss_begin: StepEventConfig = StepEventConfig()
on_compute_loss_end: StepEventConfig = StepEventConfig()
on_evaluate_begin: IterationEventConfig = IterationEventConfig()
on_evaluate_end: IterationEventConfig = IterationEventConfig()
on_lr_scheduler_step_begin: IterationEventConfig = IterationEventConfig()
on_lr_scheduler_step_end: IterationEventConfig = IterationEventConfig()
class Callback(Generic[T]): class Callback(Generic[T]):
def run_event(self, trainer: T, callback_name: str, event_name: str) -> None:
if not hasattr(self, event_name):
return
callback_config = getattr(trainer.config, callback_name)
# For event that run once, there is no configuration to check, e.g. on_train_begin
if not hasattr(callback_config, event_name):
getattr(self, event_name)(trainer)
return
event_config = cast(EventConfig, getattr(callback_config, event_name))
if not trainer.clock.is_due(event_config.interval):
return
with scoped_seed(event_config.seed):
getattr(self, event_name)(trainer)
def on_init_begin(self, trainer: T) -> None: ... def on_init_begin(self, trainer: T) -> None: ...
def on_init_end(self, trainer: T) -> None: ... def on_init_end(self, trainer: T) -> None: ...

View file

@ -424,8 +424,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
item.model.eval() item.model.eval()
def _call_callbacks(self, event_name: str) -> None: def _call_callbacks(self, event_name: str) -> None:
for name, callback in self.callbacks.items(): for callback in self.callbacks.values():
callback.run_event(trainer=self, callback_name=name, event_name=event_name) getattr(callback, event_name)(self)
def _load_callbacks(self) -> None: def _load_callbacks(self) -> None:
for name, config in self.config: for name, config in self.config:

View file

@ -1,11 +1,9 @@
[mock_callback.on_optimizer_step_begin] [mock_callback]
interval = "2:iteration" on_batch_end_interval = "3:step"
seed = 42 on_batch_end_seed = 42
on_optimizer_step_interval = "2:iteration"
[mock_callback.on_batch_end]
interval = "3:step"
[mock_model] [mock_model]
requires_grad = true requires_grad = true

View file

@ -6,6 +6,7 @@ from typing import cast
import pytest import pytest
import torch import torch
from pydantic import field_validator
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim import SGD from torch.optim import SGD
@ -16,8 +17,12 @@ from refiners.training_utils.common import (
Epoch, Epoch,
Iteration, Iteration,
Step, Step,
TimeValue,
TimeValueInput,
count_learnable_parameters, count_learnable_parameters,
human_readable_number, human_readable_number,
parse_number_unit_field,
scoped_seed,
) )
from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
@ -41,9 +46,19 @@ class MockModelConfig(ModelConfig):
use_activation: bool use_activation: bool
class MockCallbackConfig(CallbackConfig):
on_batch_end_interval: Step | Iteration | Epoch
on_batch_end_seed: int
on_optimizer_step_interval: Iteration | Epoch
@field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class MockConfig(BaseConfig): class MockConfig(BaseConfig):
mock_model: MockModelConfig mock_model: MockModelConfig
mock_callback: CallbackConfig mock_callback: MockCallbackConfig
class MockModel(fl.Chain): class MockModel(fl.Chain):
@ -60,7 +75,8 @@ class MockModel(fl.Chain):
class MockCallback(Callback["MockTrainer"]): class MockCallback(Callback["MockTrainer"]):
def __init__(self) -> None: def __init__(self, config: MockCallbackConfig) -> None:
self.config = config
self.optimizer_step_count = 0 self.optimizer_step_count = 0
self.batch_end_count = 0 self.batch_end_count = 0
self.optimizer_step_random_int: int | None = None self.optimizer_step_random_int: int | None = None
@ -70,11 +86,16 @@ class MockCallback(Callback["MockTrainer"]):
pass pass
def on_optimizer_step_begin(self, trainer: "MockTrainer") -> None: def on_optimizer_step_begin(self, trainer: "MockTrainer") -> None:
if not trainer.clock.is_due(self.config.on_optimizer_step_interval):
return
self.optimizer_step_count += 1 self.optimizer_step_count += 1
self.optimizer_step_random_int = random.randint(0, 100) self.optimizer_step_random_int = random.randint(0, 100)
def on_batch_end(self, trainer: "MockTrainer") -> None: def on_batch_end(self, trainer: "MockTrainer") -> None:
if not trainer.clock.is_due(self.config.on_batch_end_interval):
return
self.batch_end_count += 1 self.batch_end_count += 1
with scoped_seed(self.config.on_batch_end_seed):
self.batch_end_random_int = random.randint(0, 100) self.batch_end_random_int = random.randint(0, 100)
@ -96,8 +117,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
) )
@register_callback() @register_callback()
def mock_callback(self, config: CallbackConfig) -> MockCallback: def mock_callback(self, config: MockCallbackConfig) -> MockCallback:
return MockCallback() return MockCallback(config)
@register_model() @register_model()
def mock_model(self, config: MockModelConfig) -> MockModel: def mock_model(self, config: MockModelConfig) -> MockModel:
@ -264,8 +285,8 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3 assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3
# Check that the random seed was set # Check that the random seed was set
assert mock_trainer.mock_callback.optimizer_step_random_int == 81 assert mock_trainer.mock_callback.optimizer_step_random_int == 93
assert mock_trainer.mock_callback.batch_end_random_int == 72 assert mock_trainer.mock_callback.batch_end_random_int == 81
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: