mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove dataset length
This commit is contained in:
parent
b497b27cd3
commit
de8334b6fc
|
@ -1,9 +1,8 @@
|
||||||
import time
|
import time
|
||||||
from functools import cached_property
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
@ -23,7 +22,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
training_duration: TimeValue,
|
training_duration: TimeValue,
|
||||||
gradient_accumulation: int,
|
gradient_accumulation: Step,
|
||||||
lr_scheduler_interval: TimeValue,
|
lr_scheduler_interval: TimeValue,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -75,7 +74,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_optimizer_step(self) -> bool:
|
def is_optimizer_step(self) -> bool:
|
||||||
return self.num_minibatches_processed == self.gradient_accumulation
|
return self.num_minibatches_processed == self.gradient_accumulation.number
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
|
@ -94,41 +93,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
trainer.clock.reset()
|
self.log(f"Starting training for {self.training_duration}.")
|
||||||
trainer.clock.start_timer()
|
self.reset()
|
||||||
|
self.start_timer()
|
||||||
|
|
||||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
trainer.clock.stop_timer()
|
self.stop_timer()
|
||||||
self.log(
|
self.log(
|
||||||
(
|
(
|
||||||
"Training took: "
|
"Training took: "
|
||||||
f"{trainer.clock.time_elapsed} seconds, "
|
f"{self.time_elapsed} seconds, "
|
||||||
f"{trainer.clock.iteration} iterations, "
|
f"{self.iteration} iterations, "
|
||||||
f"{trainer.clock.epoch} epochs, "
|
f"{self.epoch} epochs, "
|
||||||
f"{trainer.clock.step} steps."
|
f"{self.step} steps."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Epoch {trainer.clock.epoch} started.")
|
self.log(f"Epoch {self.epoch} started.")
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Epoch {trainer.clock.epoch} ended.")
|
self.log(f"Epoch {self.epoch} ended.")
|
||||||
trainer.clock.epoch += 1
|
self.epoch += 1
|
||||||
trainer.clock.num_batches_processed = 0
|
self.num_batches_processed = 0
|
||||||
|
|
||||||
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
if self.num_minibatches_processed == 0:
|
if self.num_minibatches_processed == 0:
|
||||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
self.log(f"Iteration {self.iteration} started.")
|
||||||
self.log(f"Step {trainer.clock.step} started.")
|
self.log(f"Step {self.step} started.")
|
||||||
|
|
||||||
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Step {trainer.clock.step} ended.")
|
self.log(f"Step {self.step} ended.")
|
||||||
trainer.clock.step += 1
|
self.step += 1
|
||||||
trainer.clock.num_batches_processed += 1
|
self.num_batches_processed += 1
|
||||||
trainer.clock.num_minibatches_processed += 1
|
self.num_minibatches_processed += 1
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Iteration {trainer.clock.iteration} ended.")
|
self.log(f"Iteration {self.iteration} ended.")
|
||||||
trainer.clock.iteration += 1
|
self.iteration += 1
|
||||||
trainer.clock.num_minibatches_processed = 0
|
self.num_minibatches_processed = 0
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
|
||||||
|
|
||||||
# PyTorch optimizer parameters type
|
# PyTorch optimizer parameters type
|
||||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||||
|
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
|
||||||
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: int = 1
|
gradient_accumulation: Step = Step(1)
|
||||||
gradient_clipping_max_norm: float | None = None
|
gradient_clipping_max_norm: float | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@field_validator("duration", mode="before")
|
@field_validator("duration", "gradient_accumulation", mode="before")
|
||||||
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
|
@ -282,7 +282,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
warmup_scheduler_steps = (
|
warmup_scheduler_steps = (
|
||||||
config.warmup.number
|
config.warmup.number
|
||||||
if isinstance(config.warmup, Step)
|
if isinstance(config.warmup, Step)
|
||||||
else config.warmup.number * self.clock.gradient_accumulation
|
else config.warmup.number * self.clock.gradient_accumulation.number
|
||||||
)
|
)
|
||||||
if warmup_scheduler_steps > 0:
|
if warmup_scheduler_steps > 0:
|
||||||
lr_scheduler = WarmupScheduler(
|
lr_scheduler = WarmupScheduler(
|
||||||
|
@ -350,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
def backward(self) -> None:
|
def backward(self) -> None:
|
||||||
"""Backward pass on the loss."""
|
"""Backward pass on the loss."""
|
||||||
self._call_callbacks(event_name="on_backward_begin")
|
self._call_callbacks(event_name="on_backward_begin")
|
||||||
scaled_loss = self.loss / self.config.training.gradient_accumulation
|
scaled_loss = self.loss / self.config.training.gradient_accumulation.number
|
||||||
backward(tensors=scaled_loss)
|
backward(tensors=scaled_loss)
|
||||||
self._call_callbacks(event_name="on_backward_end")
|
self._call_callbacks(event_name="on_backward_end")
|
||||||
if self.clock.is_optimizer_step:
|
if self.clock.is_optimizer_step:
|
||||||
|
|
|
@ -17,7 +17,7 @@ seed = 0
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
gradient_accumulation = 4
|
gradient_accumulation = "4:step"
|
||||||
gradient_clipping_max_norm = 1.0
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
|
|
|
@ -12,7 +12,7 @@ verbose = false
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
gradient_accumulation = 4
|
gradient_accumulation = "4:step"
|
||||||
gradient_clipping_max_norm = 1.0
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
|
|
|
@ -206,17 +206,7 @@ def training_clock() -> TrainingClock:
|
||||||
return TrainingClock(
|
return TrainingClock(
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
training_duration=Epoch(5),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=1,
|
gradient_accumulation=Step(1),
|
||||||
lr_scheduler_interval=Epoch(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_small_dataset_error():
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
TrainingClock(
|
|
||||||
batch_size=10,
|
|
||||||
training_duration=Epoch(5),
|
|
||||||
gradient_accumulation=1,
|
|
||||||
lr_scheduler_interval=Epoch(1),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -226,7 +216,7 @@ def test_zero_batch_size_error():
|
||||||
TrainingClock(
|
TrainingClock(
|
||||||
batch_size=0,
|
batch_size=0,
|
||||||
training_duration=Epoch(5),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=1,
|
gradient_accumulation=Step(1),
|
||||||
lr_scheduler_interval=Epoch(1),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue