remove dataset length

This commit is contained in:
limiteinductive 2024-04-24 16:50:27 +00:00 committed by Benjamin Trom
parent b497b27cd3
commit de8334b6fc
6 changed files with 33 additions and 43 deletions

View file

@ -1,9 +1,8 @@
import time import time
from functools import cached_property
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
@ -23,7 +22,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self, self,
batch_size: int, batch_size: int,
training_duration: TimeValue, training_duration: TimeValue,
gradient_accumulation: int, gradient_accumulation: Step,
lr_scheduler_interval: TimeValue, lr_scheduler_interval: TimeValue,
verbose: bool = True, verbose: bool = True,
) -> None: ) -> None:
@ -75,7 +74,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@property @property
def is_optimizer_step(self) -> bool: def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.gradient_accumulation return self.num_minibatches_processed == self.gradient_accumulation.number
@property @property
def done(self) -> bool: def done(self) -> bool:
@ -94,41 +93,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
logger.info(message) logger.info(message)
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset() self.log(f"Starting training for {self.training_duration}.")
trainer.clock.start_timer() self.reset()
self.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.stop_timer() self.stop_timer()
self.log( self.log(
( (
"Training took: " "Training took: "
f"{trainer.clock.time_elapsed} seconds, " f"{self.time_elapsed} seconds, "
f"{trainer.clock.iteration} iterations, " f"{self.iteration} iterations, "
f"{trainer.clock.epoch} epochs, " f"{self.epoch} epochs, "
f"{trainer.clock.step} steps." f"{self.step} steps."
) )
) )
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} started.") self.log(f"Epoch {self.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} ended.") self.log(f"Epoch {self.epoch} ended.")
trainer.clock.epoch += 1 self.epoch += 1
trainer.clock.num_batches_processed = 0 self.num_batches_processed = 0
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0: if self.num_minibatches_processed == 0:
self.log(f"Iteration {trainer.clock.iteration} started.") self.log(f"Iteration {self.iteration} started.")
self.log(f"Step {trainer.clock.step} started.") self.log(f"Step {self.step} started.")
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.") self.log(f"Step {self.step} ended.")
trainer.clock.step += 1 self.step += 1
trainer.clock.num_batches_processed += 1 self.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1 self.num_minibatches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Iteration {trainer.clock.iteration} ended.") self.log(f"Iteration {self.iteration} ended.")
trainer.clock.iteration += 1 self.iteration += 1
trainer.clock.num_minibatches_processed = 0 self.num_minibatches_processed = 0

View file

@ -11,7 +11,7 @@ from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
# PyTorch optimizer parameters type # PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
seed: int = 0 seed: int = 0
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: int = 1 gradient_accumulation: Step = Step(1)
gradient_clipping_max_norm: float | None = None gradient_clipping_max_norm: float | None = None
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@field_validator("duration", mode="before") @field_validator("duration", "gradient_accumulation", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue: def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)

View file

@ -282,7 +282,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
warmup_scheduler_steps = ( warmup_scheduler_steps = (
config.warmup.number config.warmup.number
if isinstance(config.warmup, Step) if isinstance(config.warmup, Step)
else config.warmup.number * self.clock.gradient_accumulation else config.warmup.number * self.clock.gradient_accumulation.number
) )
if warmup_scheduler_steps > 0: if warmup_scheduler_steps > 0:
lr_scheduler = WarmupScheduler( lr_scheduler = WarmupScheduler(
@ -350,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def backward(self) -> None: def backward(self) -> None:
"""Backward pass on the loss.""" """Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin") self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.config.training.gradient_accumulation scaled_loss = self.loss / self.config.training.gradient_accumulation.number
backward(tensors=scaled_loss) backward(tensors=scaled_loss)
self._call_callbacks(event_name="on_backward_end") self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step: if self.clock.is_optimizer_step:

View file

@ -17,7 +17,7 @@ seed = 0
device = "cpu" device = "cpu"
dtype = "float32" dtype = "float32"
batch_size = 4 batch_size = 4
gradient_accumulation = 4 gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0 gradient_clipping_max_norm = 1.0
[optimizer] [optimizer]

View file

@ -12,7 +12,7 @@ verbose = false
duration = "100:epoch" duration = "100:epoch"
seed = 0 seed = 0
batch_size = 4 batch_size = 4
gradient_accumulation = 4 gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0 gradient_clipping_max_norm = 1.0
[optimizer] [optimizer]

View file

@ -206,17 +206,7 @@ def training_clock() -> TrainingClock:
return TrainingClock( return TrainingClock(
batch_size=10, batch_size=10,
training_duration=Epoch(5), training_duration=Epoch(5),
gradient_accumulation=1, gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),
)
def test_small_dataset_error():
with pytest.raises(AssertionError):
TrainingClock(
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=1,
lr_scheduler_interval=Epoch(1), lr_scheduler_interval=Epoch(1),
) )
@ -226,7 +216,7 @@ def test_zero_batch_size_error():
TrainingClock( TrainingClock(
batch_size=0, batch_size=0,
training_duration=Epoch(5), training_duration=Epoch(5),
gradient_accumulation=1, gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1), lr_scheduler_interval=Epoch(1),
) )