remove dataset length

This commit is contained in:
limiteinductive 2024-04-24 16:50:27 +00:00 committed by Benjamin Trom
parent b497b27cd3
commit de8334b6fc
6 changed files with 33 additions and 43 deletions

View file

@ -1,9 +1,8 @@
import time
from functools import cached_property
from typing import TYPE_CHECKING, Any
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
@ -23,7 +22,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: int,
gradient_accumulation: Step,
lr_scheduler_interval: TimeValue,
verbose: bool = True,
) -> None:
@ -75,7 +74,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.gradient_accumulation
return self.num_minibatches_processed == self.gradient_accumulation.number
@property
def done(self) -> bool:
@ -94,41 +93,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
logger.info(message)
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset()
trainer.clock.start_timer()
self.log(f"Starting training for {self.training_duration}.")
self.reset()
self.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.stop_timer()
self.stop_timer()
self.log(
(
"Training took: "
f"{trainer.clock.time_elapsed} seconds, "
f"{trainer.clock.iteration} iterations, "
f"{trainer.clock.epoch} epochs, "
f"{trainer.clock.step} steps."
f"{self.time_elapsed} seconds, "
f"{self.iteration} iterations, "
f"{self.epoch} epochs, "
f"{self.step} steps."
)
)
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} started.")
self.log(f"Epoch {self.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} ended.")
trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0
self.log(f"Epoch {self.epoch} ended.")
self.epoch += 1
self.num_batches_processed = 0
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0:
self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.")
self.log(f"Iteration {self.iteration} started.")
self.log(f"Step {self.step} started.")
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.")
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1
self.log(f"Step {self.step} ended.")
self.step += 1
self.num_batches_processed += 1
self.num_minibatches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Iteration {trainer.clock.iteration} ended.")
trainer.clock.iteration += 1
trainer.clock.num_minibatches_processed = 0
self.log(f"Iteration {self.iteration} ended.")
self.iteration += 1
self.num_minibatches_processed = 0

View file

@ -11,7 +11,7 @@ from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
seed: int = 0
batch_size: int = 1
gradient_accumulation: int = 1
gradient_accumulation: Step = Step(1)
gradient_clipping_max_norm: float | None = None
model_config = ConfigDict(extra="forbid")
@field_validator("duration", mode="before")
@field_validator("duration", "gradient_accumulation", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)

View file

@ -282,7 +282,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
warmup_scheduler_steps = (
config.warmup.number
if isinstance(config.warmup, Step)
else config.warmup.number * self.clock.gradient_accumulation
else config.warmup.number * self.clock.gradient_accumulation.number
)
if warmup_scheduler_steps > 0:
lr_scheduler = WarmupScheduler(
@ -350,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.config.training.gradient_accumulation
scaled_loss = self.loss / self.config.training.gradient_accumulation.number
backward(tensors=scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:

View file

@ -17,7 +17,7 @@ seed = 0
device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = 4
gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -12,7 +12,7 @@ verbose = false
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = 4
gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -206,27 +206,17 @@ def training_clock() -> TrainingClock:
return TrainingClock(
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=1,
gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),
)
def test_small_dataset_error():
with pytest.raises(AssertionError):
TrainingClock(
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=1,
lr_scheduler_interval=Epoch(1),
)
def test_zero_batch_size_error():
with pytest.raises(AssertionError):
TrainingClock(
batch_size=0,
training_duration=Epoch(5),
gradient_accumulation=1,
gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),
)