mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove dataset length (bis)
This commit is contained in:
parent
1db0845db2
commit
b497b27cd3
|
@ -21,24 +21,18 @@ class ClockConfig(CallbackConfig):
|
|||
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_length: int,
|
||||
batch_size: int,
|
||||
training_duration: TimeValue,
|
||||
gradient_accumulation: TimeValue,
|
||||
gradient_accumulation: int,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
assert batch_size > 0, "Batch size must be greater than 0."
|
||||
assert (
|
||||
dataset_length >= batch_size
|
||||
), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})."
|
||||
self.dataset_length = dataset_length
|
||||
self.batch_size = batch_size
|
||||
self.training_duration = training_duration
|
||||
self.gradient_accumulation = gradient_accumulation
|
||||
self.lr_scheduler_interval = lr_scheduler_interval
|
||||
self.verbose = verbose
|
||||
self.num_batches_per_epoch = dataset_length // batch_size
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.step = 0
|
||||
|
@ -48,43 +42,16 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
self.num_minibatches_processed = 0
|
||||
self.loss: Tensor | None = None
|
||||
|
||||
@cached_property
|
||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||
iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1
|
||||
return {
|
||||
Step: 1,
|
||||
Epoch: self.num_batches_per_epoch,
|
||||
Iteration: self.gradient_accumulation.number * iteration_factor,
|
||||
}
|
||||
|
||||
def convert_time_value_to_steps(self, time_value: TimeValue) -> int:
|
||||
return time_value.number * self.unit_to_steps[time_value.unit]
|
||||
|
||||
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||
return steps // self.unit_to_steps[unit]
|
||||
|
||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||
steps = self.convert_time_value_to_steps(time_value=time_value)
|
||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||
|
||||
@cached_property
|
||||
def num_epochs(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch)
|
||||
|
||||
@cached_property
|
||||
def num_iterations(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration)
|
||||
|
||||
@cached_property
|
||||
def num_steps(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=Step)
|
||||
|
||||
@cached_property
|
||||
def num_step_per_iteration(self) -> int:
|
||||
return self.convert_time_value_to_steps(self.gradient_accumulation)
|
||||
|
||||
def is_due(self, interval: TimeValue) -> bool:
|
||||
return self.step % self.convert_time_value_to_steps(interval) == 0
|
||||
match interval:
|
||||
case Step(number):
|
||||
return self.step % number == 0
|
||||
case Iteration(number):
|
||||
return self.iteration % number == 0
|
||||
case Epoch(number):
|
||||
return self.epoch % number == 0
|
||||
case _:
|
||||
raise ValueError(f"Unsupported TimeValue: {interval}")
|
||||
|
||||
def reset(self) -> None:
|
||||
self.start_time = None
|
||||
|
@ -108,11 +75,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
|
||||
@property
|
||||
def is_optimizer_step(self) -> bool:
|
||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||
return self.num_minibatches_processed == self.gradient_accumulation
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.step >= self.num_steps
|
||||
match self.training_duration:
|
||||
case Step(number):
|
||||
return self.step >= number
|
||||
case Iteration(number):
|
||||
return self.iteration >= number
|
||||
case Epoch(number):
|
||||
return self.epoch >= number
|
||||
case _:
|
||||
raise ValueError(f"Unsupported TimeValue: {self.training_duration}")
|
||||
|
||||
def log(self, message: str, /) -> None:
|
||||
if self.verbose:
|
||||
|
@ -120,14 +95,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
|
||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.reset()
|
||||
self.log(
|
||||
(
|
||||
"Starting training for a total of: "
|
||||
f"{trainer.clock.num_steps} steps, "
|
||||
f"{trainer.clock.num_epochs} epochs, "
|
||||
f"{trainer.clock.num_iterations} iterations."
|
||||
)
|
||||
)
|
||||
trainer.clock.start_timer()
|
||||
|
||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
|
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
|
|||
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||
seed: int = 0
|
||||
batch_size: int = 1
|
||||
gradient_accumulation: Step | Epoch = Step(1)
|
||||
gradient_accumulation: int = 1
|
||||
gradient_clipping_max_norm: float | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("duration", "gradient_accumulation", mode="before")
|
||||
@field_validator("duration", mode="before")
|
||||
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from refiners.training_utils.callback import (
|
|||
)
|
||||
from refiners.training_utils.clock import ClockConfig, TrainingClock
|
||||
from refiners.training_utils.common import (
|
||||
Step,
|
||||
compute_grad_norm,
|
||||
count_learnable_parameters,
|
||||
human_readable_number,
|
||||
|
@ -150,7 +151,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
@register_callback()
|
||||
def clock(self, config: ClockConfig) -> TrainingClock:
|
||||
return TrainingClock(
|
||||
dataset_length=self.dataset_length,
|
||||
batch_size=self.config.training.batch_size,
|
||||
training_duration=self.config.training.duration,
|
||||
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||
|
@ -279,7 +279,11 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
case _:
|
||||
raise ValueError(f"Unknown scheduler type: {config.type}")
|
||||
|
||||
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit)
|
||||
warmup_scheduler_steps = (
|
||||
config.warmup.number
|
||||
if isinstance(config.warmup, Step)
|
||||
else config.warmup.number * self.clock.gradient_accumulation
|
||||
)
|
||||
if warmup_scheduler_steps > 0:
|
||||
lr_scheduler = WarmupScheduler(
|
||||
optimizer=self.optimizer,
|
||||
|
@ -346,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
def backward(self) -> None:
|
||||
"""Backward pass on the loss."""
|
||||
self._call_callbacks(event_name="on_backward_begin")
|
||||
scaled_loss = self.loss / self.clock.num_step_per_iteration
|
||||
scaled_loss = self.loss / self.config.training.gradient_accumulation
|
||||
backward(tensors=scaled_loss)
|
||||
self._call_callbacks(event_name="on_backward_end")
|
||||
if self.clock.is_optimizer_step:
|
||||
|
|
|
@ -17,7 +17,7 @@ seed = 0
|
|||
device = "cpu"
|
||||
dtype = "float32"
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
gradient_accumulation = 4
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
|
|
|
@ -12,7 +12,7 @@ verbose = false
|
|||
duration = "100:epoch"
|
||||
seed = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
gradient_accumulation = 4
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
|
|
|
@ -204,10 +204,9 @@ def test_human_readable_number() -> None:
|
|||
@pytest.fixture
|
||||
def training_clock() -> TrainingClock:
|
||||
return TrainingClock(
|
||||
dataset_length=100,
|
||||
batch_size=10,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=Epoch(1),
|
||||
gradient_accumulation=1,
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
@ -215,10 +214,9 @@ def training_clock() -> TrainingClock:
|
|||
def test_small_dataset_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
dataset_length=3,
|
||||
batch_size=10,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=Epoch(1),
|
||||
gradient_accumulation=1,
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
@ -226,35 +224,13 @@ def test_small_dataset_error():
|
|||
def test_zero_batch_size_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
dataset_length=3,
|
||||
batch_size=0,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=Epoch(1),
|
||||
gradient_accumulation=1,
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
||||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10
|
||||
assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20
|
||||
assert training_clock.convert_time_value_to_steps(Step(1)) == 1
|
||||
assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10
|
||||
|
||||
|
||||
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1
|
||||
assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2
|
||||
assert training_clock.convert_steps_to_time_unit(1, Step) == 1
|
||||
assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1
|
||||
|
||||
|
||||
def test_clock_properties(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.num_batches_per_epoch == 10
|
||||
assert training_clock.num_epochs == 5
|
||||
assert training_clock.num_iterations == 5
|
||||
assert training_clock.num_steps == 50
|
||||
|
||||
|
||||
def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||
training_clock.start_timer()
|
||||
assert training_clock.start_time is not None
|
||||
|
@ -275,17 +251,12 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
|||
clock = mock_trainer.clock
|
||||
config = mock_trainer.config
|
||||
|
||||
assert clock.num_step_per_iteration == config.training.gradient_accumulation.number
|
||||
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
||||
|
||||
assert mock_trainer.step_counter == 0
|
||||
assert clock.epoch == 0
|
||||
|
||||
mock_trainer.train()
|
||||
|
||||
assert clock.epoch == config.training.duration.number
|
||||
assert clock.step == config.training.duration.number * clock.num_batches_per_epoch
|
||||
|
||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue