mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
remove dataset length
This commit is contained in:
parent
b497b27cd3
commit
de8334b6fc
|
@ -1,9 +1,8 @@
|
|||
import time
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
|
@ -23,7 +22,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
self,
|
||||
batch_size: int,
|
||||
training_duration: TimeValue,
|
||||
gradient_accumulation: int,
|
||||
gradient_accumulation: Step,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
|
@ -75,7 +74,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
|
||||
@property
|
||||
def is_optimizer_step(self) -> bool:
|
||||
return self.num_minibatches_processed == self.gradient_accumulation
|
||||
return self.num_minibatches_processed == self.gradient_accumulation.number
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
|
@ -94,41 +93,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
logger.info(message)
|
||||
|
||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.reset()
|
||||
trainer.clock.start_timer()
|
||||
self.log(f"Starting training for {self.training_duration}.")
|
||||
self.reset()
|
||||
self.start_timer()
|
||||
|
||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.stop_timer()
|
||||
self.stop_timer()
|
||||
self.log(
|
||||
(
|
||||
"Training took: "
|
||||
f"{trainer.clock.time_elapsed} seconds, "
|
||||
f"{trainer.clock.iteration} iterations, "
|
||||
f"{trainer.clock.epoch} epochs, "
|
||||
f"{trainer.clock.step} steps."
|
||||
f"{self.time_elapsed} seconds, "
|
||||
f"{self.iteration} iterations, "
|
||||
f"{self.epoch} epochs, "
|
||||
f"{self.step} steps."
|
||||
)
|
||||
)
|
||||
|
||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Epoch {trainer.clock.epoch} started.")
|
||||
self.log(f"Epoch {self.epoch} started.")
|
||||
|
||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Epoch {trainer.clock.epoch} ended.")
|
||||
trainer.clock.epoch += 1
|
||||
trainer.clock.num_batches_processed = 0
|
||||
self.log(f"Epoch {self.epoch} ended.")
|
||||
self.epoch += 1
|
||||
self.num_batches_processed = 0
|
||||
|
||||
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
if self.num_minibatches_processed == 0:
|
||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||
self.log(f"Step {trainer.clock.step} started.")
|
||||
self.log(f"Iteration {self.iteration} started.")
|
||||
self.log(f"Step {self.step} started.")
|
||||
|
||||
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Step {trainer.clock.step} ended.")
|
||||
trainer.clock.step += 1
|
||||
trainer.clock.num_batches_processed += 1
|
||||
trainer.clock.num_minibatches_processed += 1
|
||||
self.log(f"Step {self.step} ended.")
|
||||
self.step += 1
|
||||
self.num_batches_processed += 1
|
||||
self.num_minibatches_processed += 1
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Iteration {trainer.clock.iteration} ended.")
|
||||
trainer.clock.iteration += 1
|
||||
trainer.clock.num_minibatches_processed = 0
|
||||
self.log(f"Iteration {self.iteration} ended.")
|
||||
self.iteration += 1
|
||||
self.num_minibatches_processed = 0
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
|
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
|
|||
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||
seed: int = 0
|
||||
batch_size: int = 1
|
||||
gradient_accumulation: int = 1
|
||||
gradient_accumulation: Step = Step(1)
|
||||
gradient_clipping_max_norm: float | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("duration", mode="before")
|
||||
@field_validator("duration", "gradient_accumulation", mode="before")
|
||||
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
|
|
@ -282,7 +282,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
warmup_scheduler_steps = (
|
||||
config.warmup.number
|
||||
if isinstance(config.warmup, Step)
|
||||
else config.warmup.number * self.clock.gradient_accumulation
|
||||
else config.warmup.number * self.clock.gradient_accumulation.number
|
||||
)
|
||||
if warmup_scheduler_steps > 0:
|
||||
lr_scheduler = WarmupScheduler(
|
||||
|
@ -350,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
def backward(self) -> None:
|
||||
"""Backward pass on the loss."""
|
||||
self._call_callbacks(event_name="on_backward_begin")
|
||||
scaled_loss = self.loss / self.config.training.gradient_accumulation
|
||||
scaled_loss = self.loss / self.config.training.gradient_accumulation.number
|
||||
backward(tensors=scaled_loss)
|
||||
self._call_callbacks(event_name="on_backward_end")
|
||||
if self.clock.is_optimizer_step:
|
||||
|
|
|
@ -17,7 +17,7 @@ seed = 0
|
|||
device = "cpu"
|
||||
dtype = "float32"
|
||||
batch_size = 4
|
||||
gradient_accumulation = 4
|
||||
gradient_accumulation = "4:step"
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
|
|
|
@ -12,7 +12,7 @@ verbose = false
|
|||
duration = "100:epoch"
|
||||
seed = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = 4
|
||||
gradient_accumulation = "4:step"
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
|
|
|
@ -206,27 +206,17 @@ def training_clock() -> TrainingClock:
|
|||
return TrainingClock(
|
||||
batch_size=10,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=1,
|
||||
gradient_accumulation=Step(1),
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
||||
def test_small_dataset_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
batch_size=10,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=1,
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
||||
def test_zero_batch_size_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
batch_size=0,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=1,
|
||||
gradient_accumulation=Step(1),
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue