mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Implement EventConfig
This commit is contained in:
parent
07985694ed
commit
5dde281ada
|
@ -1,12 +1,80 @@
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
from refiners.training_utils.common import (
|
||||||
|
Epoch,
|
||||||
|
Iteration,
|
||||||
|
Step,
|
||||||
|
TimeValue,
|
||||||
|
TimeValueInput,
|
||||||
|
parse_number_unit_field,
|
||||||
|
scoped_seed,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.training_utils.config import BaseConfig
|
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
T = TypeVar("T", bound="Trainer[BaseConfig, Any]")
|
T = TypeVar("T", bound="Trainer[Any, Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class StepEventConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Base configuration for an event that is triggered at every step.
|
||||||
|
|
||||||
|
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
|
||||||
|
restored after the event.
|
||||||
|
- `interval`: Interval at which the event should be triggered. The interval is defined by either a `Step` object,
|
||||||
|
an `Iteration` object, or an `Epoch` object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
seed: int | None = None
|
||||||
|
interval: Step | Iteration | Epoch = Step(1)
|
||||||
|
|
||||||
|
@field_validator("interval", mode="before")
|
||||||
|
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
class IterationEventConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Base configuration for an event that is triggered only once per iteration.
|
||||||
|
|
||||||
|
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
|
||||||
|
restored after the event.
|
||||||
|
- `interval`: Interval at which the event should be triggered. The interval is defined by an `Iteration` object or
|
||||||
|
a `Epoch` object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
seed: int | None = None
|
||||||
|
interval: Iteration | Epoch = Iteration(1)
|
||||||
|
|
||||||
|
@field_validator("interval", mode="before")
|
||||||
|
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
class EpochEventConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Base configuration for an event that is triggered only once per epoch.
|
||||||
|
|
||||||
|
- `seed`: Seed to use for the event. If `None`, the seed will not be set. The random state will be saved and
|
||||||
|
restored after the event.
|
||||||
|
- `interval`: Interval at which the event should be triggered. The interval is defined by a `Epoch` object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
seed: int | None = None
|
||||||
|
interval: Epoch = Epoch(1)
|
||||||
|
|
||||||
|
@field_validator("interval", mode="before")
|
||||||
|
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
EventConfig = StepEventConfig | IterationEventConfig | EpochEventConfig
|
||||||
|
|
||||||
|
|
||||||
class CallbackConfig(BaseModel):
|
class CallbackConfig(BaseModel):
|
||||||
|
@ -17,9 +85,37 @@ class CallbackConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
on_epoch_begin: EpochEventConfig = EpochEventConfig()
|
||||||
|
on_epoch_end: EpochEventConfig = EpochEventConfig()
|
||||||
|
on_batch_begin: StepEventConfig = StepEventConfig()
|
||||||
|
on_batch_end: StepEventConfig = StepEventConfig()
|
||||||
|
on_backward_begin: StepEventConfig = StepEventConfig()
|
||||||
|
on_backward_end: StepEventConfig = StepEventConfig()
|
||||||
|
on_optimizer_step_begin: IterationEventConfig = IterationEventConfig()
|
||||||
|
on_optimizer_step_end: IterationEventConfig = IterationEventConfig()
|
||||||
|
on_compute_loss_begin: StepEventConfig = StepEventConfig()
|
||||||
|
on_compute_loss_end: StepEventConfig = StepEventConfig()
|
||||||
|
on_evaluate_begin: IterationEventConfig = IterationEventConfig()
|
||||||
|
on_evaluate_end: IterationEventConfig = IterationEventConfig()
|
||||||
|
on_lr_scheduler_step_begin: IterationEventConfig = IterationEventConfig()
|
||||||
|
on_lr_scheduler_step_end: IterationEventConfig = IterationEventConfig()
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
|
def run_event(self, trainer: T, callback_name: str, event_name: str) -> None:
|
||||||
|
if not hasattr(self, event_name):
|
||||||
|
return
|
||||||
|
callback_config = getattr(trainer.config, callback_name)
|
||||||
|
# For event that run once, there is no configuration to check, e.g. on_train_begin
|
||||||
|
if not hasattr(callback_config, event_name):
|
||||||
|
getattr(self, event_name)(trainer)
|
||||||
|
return
|
||||||
|
event_config = cast(EventConfig, getattr(callback_config, event_name))
|
||||||
|
if not trainer.clock.is_due(event_config.interval):
|
||||||
|
return
|
||||||
|
with scoped_seed(event_config.seed):
|
||||||
|
getattr(self, event_name)(trainer)
|
||||||
|
|
||||||
def on_init_begin(self, trainer: T) -> None: ...
|
def on_init_begin(self, trainer: T) -> None: ...
|
||||||
|
|
||||||
def on_init_end(self, trainer: T) -> None: ...
|
def on_init_end(self, trainer: T) -> None: ...
|
||||||
|
|
|
@ -89,6 +89,9 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
def num_step_per_evaluation(self) -> int:
|
def num_step_per_evaluation(self) -> int:
|
||||||
return self.convert_time_value_to_steps(self.evaluation_interval)
|
return self.convert_time_value_to_steps(self.evaluation_interval)
|
||||||
|
|
||||||
|
def is_due(self, interval: TimeValue) -> bool:
|
||||||
|
return self.step % self.convert_time_value_to_steps(interval) == 0
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.start_time = None
|
self.start_time = None
|
||||||
self.end_time = None
|
self.end_time = None
|
||||||
|
@ -109,30 +112,14 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
assert self.start_time is not None, "Timer has not been started yet."
|
assert self.start_time is not None, "Timer has not been started yet."
|
||||||
return int(time.time() - self.start_time)
|
return int(time.time() - self.start_time)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def evaluation_interval_steps(self) -> int:
|
|
||||||
return self.convert_time_value_to_steps(self.evaluation_interval)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def lr_scheduler_interval_steps(self) -> int:
|
|
||||||
return self.convert_time_value_to_steps(self.lr_scheduler_interval)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_optimizer_step(self) -> bool:
|
def is_optimizer_step(self) -> bool:
|
||||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||||
|
|
||||||
@property
|
|
||||||
def is_lr_scheduler_step(self) -> bool:
|
|
||||||
return self.step % self.lr_scheduler_interval_steps == 0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
return self.step >= self.num_steps
|
return self.step >= self.num_steps
|
||||||
|
|
||||||
@property
|
|
||||||
def is_evaluation_step(self) -> bool:
|
|
||||||
return self.step % self.evaluation_interval_steps == 0
|
|
||||||
|
|
||||||
def log(self, message: str, /) -> None:
|
def log(self, message: str, /) -> None:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
|
@ -361,11 +361,11 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self._call_callbacks(event_name="on_optimizer_step_end")
|
self._call_callbacks(event_name="on_optimizer_step_end")
|
||||||
if self.clock.is_lr_scheduler_step:
|
if self.clock.is_due(self.config.lr_scheduler.update_interval):
|
||||||
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
|
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
||||||
if self.clock.is_evaluation_step:
|
if self.clock.is_due(self.config.training.evaluation_interval):
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
|
|
||||||
def step(self, batch: Batch) -> None:
|
def step(self, batch: Batch) -> None:
|
||||||
|
@ -424,8 +424,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
item.model.eval()
|
item.model.eval()
|
||||||
|
|
||||||
def _call_callbacks(self, event_name: str) -> None:
|
def _call_callbacks(self, event_name: str) -> None:
|
||||||
for callback in self.callbacks.values():
|
for name, callback in self.callbacks.items():
|
||||||
getattr(callback, event_name)(self)
|
callback.run_event(trainer=self, callback_name=name, event_name=event_name)
|
||||||
|
|
||||||
def _load_callbacks(self) -> None:
|
def _load_callbacks(self) -> None:
|
||||||
for name, config in self.config:
|
for name, config in self.config:
|
||||||
|
|
|
@ -1,3 +1,12 @@
|
||||||
|
[mock_callback.on_optimizer_step_begin]
|
||||||
|
interval = "2:iteration"
|
||||||
|
seed = 42
|
||||||
|
|
||||||
|
|
||||||
|
[mock_callback.on_batch_end]
|
||||||
|
interval = "3:step"
|
||||||
|
|
||||||
|
|
||||||
[mock_model]
|
[mock_model]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
use_activation = true
|
use_activation = true
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -10,6 +11,7 @@ from torch.optim import SGD
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.common import (
|
from refiners.training_utils.common import (
|
||||||
Epoch,
|
Epoch,
|
||||||
Iteration,
|
Iteration,
|
||||||
|
@ -24,6 +26,7 @@ from refiners.training_utils.trainer import (
|
||||||
WarmupScheduler,
|
WarmupScheduler,
|
||||||
count_learnable_parameters,
|
count_learnable_parameters,
|
||||||
human_readable_number,
|
human_readable_number,
|
||||||
|
register_callback,
|
||||||
register_model,
|
register_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,6 +43,7 @@ class MockModelConfig(ModelConfig):
|
||||||
|
|
||||||
class MockConfig(BaseConfig):
|
class MockConfig(BaseConfig):
|
||||||
mock_model: MockModelConfig
|
mock_model: MockModelConfig
|
||||||
|
mock_callback: CallbackConfig
|
||||||
|
|
||||||
|
|
||||||
class MockModel(fl.Chain):
|
class MockModel(fl.Chain):
|
||||||
|
@ -55,6 +59,25 @@ class MockModel(fl.Chain):
|
||||||
self.insert(3, fl.SiLU())
|
self.insert(3, fl.SiLU())
|
||||||
|
|
||||||
|
|
||||||
|
class MockCallback(Callback["MockTrainer"]):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.optimizer_step_count = 0
|
||||||
|
self.batch_end_count = 0
|
||||||
|
self.optimizer_step_random_int: int | None = None
|
||||||
|
self.batch_end_random_int: int | None = None
|
||||||
|
|
||||||
|
def on_init_begin(self, trainer: "MockTrainer") -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_optimizer_step_begin(self, trainer: "MockTrainer") -> None:
|
||||||
|
self.optimizer_step_count += 1
|
||||||
|
self.optimizer_step_random_int = random.randint(0, 100)
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: "MockTrainer") -> None:
|
||||||
|
self.batch_end_count += 1
|
||||||
|
self.batch_end_random_int = random.randint(0, 100)
|
||||||
|
|
||||||
|
|
||||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
step_counter: int = 0
|
step_counter: int = 0
|
||||||
model_registration_counter: int = 0
|
model_registration_counter: int = 0
|
||||||
|
@ -72,6 +95,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
targets=torch.cat([b.targets for b in batch]),
|
targets=torch.cat([b.targets for b in batch]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@register_callback()
|
||||||
|
def mock_callback(self, config: CallbackConfig) -> MockCallback:
|
||||||
|
return MockCallback()
|
||||||
|
|
||||||
@register_model()
|
@register_model()
|
||||||
def mock_model(self, config: MockModelConfig) -> MockModel:
|
def mock_model(self, config: MockModelConfig) -> MockModel:
|
||||||
model = MockModel()
|
model = MockModel()
|
||||||
|
@ -198,9 +225,9 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||||
|
|
||||||
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
||||||
training_clock.step = 5 # Halfway through the first epoch
|
training_clock.step = 5 # Halfway through the first epoch
|
||||||
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
|
assert not training_clock.is_due(training_clock.evaluation_interval) # Assuming evaluation every epoch
|
||||||
training_clock.step = 10 # End of the first epoch
|
training_clock.step = 10 # End of the first epoch
|
||||||
assert training_clock.is_evaluation_step
|
assert training_clock.is_due(training_clock.evaluation_interval)
|
||||||
|
|
||||||
|
|
||||||
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
||||||
|
@ -219,7 +246,7 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
||||||
|
|
||||||
assert mock_trainer.step_counter == 0
|
assert mock_trainer.step_counter == 0
|
||||||
assert mock_trainer.clock.epoch == 0
|
assert clock.epoch == 0
|
||||||
|
|
||||||
mock_trainer.train()
|
mock_trainer.train()
|
||||||
|
|
||||||
|
@ -229,6 +256,18 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
||||||
|
mock_trainer.train()
|
||||||
|
|
||||||
|
# Check that the callback skips every other iteration
|
||||||
|
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
||||||
|
assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3
|
||||||
|
|
||||||
|
# Check that the random seed was set
|
||||||
|
assert mock_trainer.mock_callback.optimizer_step_random_int == 81
|
||||||
|
assert mock_trainer.mock_callback.batch_end_random_int == 72
|
||||||
|
|
||||||
|
|
||||||
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||||
clock = mock_trainer_short.clock
|
clock = mock_trainer_short.clock
|
||||||
config = mock_trainer_short.config
|
config = mock_trainer_short.config
|
||||||
|
|
Loading…
Reference in a new issue