deprecate evaluation

This commit is contained in:
limiteinductive 2024-04-24 16:36:10 +00:00 committed by Benjamin Trom
parent 061d44888f
commit 44760ac19f
6 changed files with 1 additions and 52 deletions

View file

@ -25,7 +25,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: TimeValue,
evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue,
verbose: bool = True,
) -> None:
@ -37,7 +36,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval
self.verbose = verbose
self.num_batches_per_epoch = dataset_length // batch_size
@ -85,10 +83,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def num_step_per_iteration(self) -> int:
return self.convert_time_value_to_steps(self.gradient_accumulation)
@cached_property
def num_step_per_evaluation(self) -> int:
return self.convert_time_value_to_steps(self.evaluation_interval)
def is_due(self, interval: TimeValue) -> bool:
return self.step % self.convert_time_value_to_steps(interval) == 0
@ -171,9 +165,3 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.log(f"Iteration {trainer.clock.iteration} ended.")
trainer.clock.iteration += 1
trainer.clock.num_minibatches_processed = 0
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log("Evaluation started.")
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log("Evaluation ended.")

View file

@ -26,13 +26,11 @@ class TrainingConfig(BaseModel):
seed: int = 0
batch_size: int = 1
gradient_accumulation: Step | Epoch = Step(1)
evaluation_interval: Iteration | Epoch = Iteration(1)
gradient_clipping_max_norm: float | None = None
evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid")
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
@field_validator("duration", "gradient_accumulation", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)

View file

@ -24,7 +24,6 @@ from torch.optim.lr_scheduler import (
from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import no_grad
from refiners.training_utils.callback import (
Callback,
CallbackConfig,
@ -154,7 +153,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
dataset_length=self.dataset_length,
batch_size=self.config.training.batch_size,
training_duration=self.config.training.duration,
evaluation_interval=self.config.training.evaluation_interval,
gradient_accumulation=self.config.training.gradient_accumulation,
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
verbose=config.verbose,
@ -345,9 +343,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@abstractmethod
def compute_loss(self, batch: Batch) -> Tensor: ...
def compute_evaluation(self) -> None:
pass
def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
@ -365,8 +360,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
self.lr_scheduler.step()
self._call_callbacks(event_name="on_lr_scheduler_step_end")
if self.clock.is_due(self.config.training.evaluation_interval):
self.evaluate()
def step(self, batch: Batch) -> None:
"""Perform a single training step."""
@ -395,27 +388,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
self.set_models_to_mode("train")
self._call_callbacks(event_name="on_train_begin")
assert self.learnable_parameters, "There are no learnable parameters in the models."
self.evaluate()
while not self.clock.done:
self._call_callbacks(event_name="on_epoch_begin")
self.epoch()
self._call_callbacks(event_name="on_epoch_end")
self._call_callbacks(event_name="on_train_end")
@staticmethod
def get_evaluation_seed(instance: "Trainer[BaseConfig, Any]") -> int:
return instance.config.training.evaluation_seed
@no_grad()
@scoped_seed(seed=get_evaluation_seed)
def evaluate(self) -> None:
"""Evaluate the model."""
self.set_models_to_mode(mode="eval")
self._call_callbacks(event_name="on_evaluate_begin")
self.compute_evaluation()
self._call_callbacks(event_name="on_evaluate_end")
self.set_models_to_mode(mode="train")
def set_models_to_mode(self, mode: Literal["train", "eval"]) -> None:
for item in self.models.values():
if mode == "train":

View file

@ -4,7 +4,6 @@ on_batch_end_seed = 42
on_optimizer_step_interval = "2:iteration"
[mock_model]
requires_grad = true
use_activation = true
@ -19,8 +18,6 @@ device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = "4:step"
evaluation_interval = "5:epoch"
evaluation_seed = 1
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -13,8 +13,6 @@ duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
evaluation_interval = "5:epoch"
evaluation_seed = 1
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -186,7 +186,6 @@ def training_clock() -> TrainingClock:
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
evaluation_interval=Epoch(1),
lr_scheduler_interval=Epoch(1),
)
@ -198,7 +197,6 @@ def test_small_dataset_error():
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
evaluation_interval=Epoch(1),
lr_scheduler_interval=Epoch(1),
)
@ -210,7 +208,6 @@ def test_zero_batch_size_error():
batch_size=0,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
evaluation_interval=Epoch(1),
lr_scheduler_interval=Epoch(1),
)
@ -244,13 +241,6 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
assert training_clock.time_elapsed >= 0
def test_state_based_properties(training_clock: TrainingClock) -> None:
training_clock.step = 5 # Halfway through the first epoch
assert not training_clock.is_due(training_clock.evaluation_interval) # Assuming evaluation every epoch
training_clock.step = 10 # End of the first epoch
assert training_clock.is_due(training_clock.evaluation_interval)
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
assert mock_trainer.config == mock_config
assert isinstance(mock_trainer, MockTrainer)