remove dataset length (bis)

This commit is contained in:
limiteinductive 2024-04-24 16:50:27 +00:00 committed by Benjamin Trom
parent 1db0845db2
commit b497b27cd3
6 changed files with 35 additions and 93 deletions

View file

@ -21,24 +21,18 @@ class ClockConfig(CallbackConfig):
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def __init__(
self,
dataset_length: int,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: TimeValue,
gradient_accumulation: int,
lr_scheduler_interval: TimeValue,
verbose: bool = True,
) -> None:
assert batch_size > 0, "Batch size must be greater than 0."
assert (
dataset_length >= batch_size
), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})."
self.dataset_length = dataset_length
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.lr_scheduler_interval = lr_scheduler_interval
self.verbose = verbose
self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None
self.end_time = None
self.step = 0
@ -48,43 +42,16 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.num_minibatches_processed = 0
self.loss: Tensor | None = None
@cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]:
iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1
return {
Step: 1,
Epoch: self.num_batches_per_epoch,
Iteration: self.gradient_accumulation.number * iteration_factor,
}
def convert_time_value_to_steps(self, time_value: TimeValue) -> int:
return time_value.number * self.unit_to_steps[time_value.unit]
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
steps = self.convert_time_value_to_steps(time_value=time_value)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@cached_property
def num_epochs(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch)
@cached_property
def num_iterations(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration)
@cached_property
def num_steps(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=Step)
@cached_property
def num_step_per_iteration(self) -> int:
return self.convert_time_value_to_steps(self.gradient_accumulation)
def is_due(self, interval: TimeValue) -> bool:
return self.step % self.convert_time_value_to_steps(interval) == 0
match interval:
case Step(number):
return self.step % number == 0
case Iteration(number):
return self.iteration % number == 0
case Epoch(number):
return self.epoch % number == 0
case _:
raise ValueError(f"Unsupported TimeValue: {interval}")
def reset(self) -> None:
self.start_time = None
@ -108,11 +75,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration
return self.num_minibatches_processed == self.gradient_accumulation
@property
def done(self) -> bool:
return self.step >= self.num_steps
match self.training_duration:
case Step(number):
return self.step >= number
case Iteration(number):
return self.iteration >= number
case Epoch(number):
return self.epoch >= number
case _:
raise ValueError(f"Unsupported TimeValue: {self.training_duration}")
def log(self, message: str, /) -> None:
if self.verbose:
@ -120,14 +95,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset()
self.log(
(
"Starting training for a total of: "
f"{trainer.clock.num_steps} steps, "
f"{trainer.clock.num_epochs} epochs, "
f"{trainer.clock.num_iterations} iterations."
)
)
trainer.clock.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:

View file

@ -11,7 +11,7 @@ from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field
# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -25,12 +25,12 @@ class TrainingConfig(BaseModel):
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
seed: int = 0
batch_size: int = 1
gradient_accumulation: Step | Epoch = Step(1)
gradient_accumulation: int = 1
gradient_clipping_max_norm: float | None = None
model_config = ConfigDict(extra="forbid")
@field_validator("duration", "gradient_accumulation", mode="before")
@field_validator("duration", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)

View file

@ -30,6 +30,7 @@ from refiners.training_utils.callback import (
)
from refiners.training_utils.clock import ClockConfig, TrainingClock
from refiners.training_utils.common import (
Step,
compute_grad_norm,
count_learnable_parameters,
human_readable_number,
@ -150,7 +151,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@register_callback()
def clock(self, config: ClockConfig) -> TrainingClock:
return TrainingClock(
dataset_length=self.dataset_length,
batch_size=self.config.training.batch_size,
training_duration=self.config.training.duration,
gradient_accumulation=self.config.training.gradient_accumulation,
@ -279,7 +279,11 @@ class Trainer(Generic[ConfigType, Batch], ABC):
case _:
raise ValueError(f"Unknown scheduler type: {config.type}")
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit)
warmup_scheduler_steps = (
config.warmup.number
if isinstance(config.warmup, Step)
else config.warmup.number * self.clock.gradient_accumulation
)
if warmup_scheduler_steps > 0:
lr_scheduler = WarmupScheduler(
optimizer=self.optimizer,
@ -346,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.clock.num_step_per_iteration
scaled_loss = self.loss / self.config.training.gradient_accumulation
backward(tensors=scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:

View file

@ -17,7 +17,7 @@ seed = 0
device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = "4:step"
gradient_accumulation = 4
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -12,7 +12,7 @@ verbose = false
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
gradient_accumulation = 4
gradient_clipping_max_norm = 1.0
[optimizer]

View file

@ -204,10 +204,9 @@ def test_human_readable_number() -> None:
@pytest.fixture
def training_clock() -> TrainingClock:
return TrainingClock(
dataset_length=100,
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
gradient_accumulation=1,
lr_scheduler_interval=Epoch(1),
)
@ -215,10 +214,9 @@ def training_clock() -> TrainingClock:
def test_small_dataset_error():
with pytest.raises(AssertionError):
TrainingClock(
dataset_length=3,
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
gradient_accumulation=1,
lr_scheduler_interval=Epoch(1),
)
@ -226,35 +224,13 @@ def test_small_dataset_error():
def test_zero_batch_size_error():
with pytest.raises(AssertionError):
TrainingClock(
dataset_length=3,
batch_size=0,
training_duration=Epoch(5),
gradient_accumulation=Epoch(1),
gradient_accumulation=1,
lr_scheduler_interval=Epoch(1),
)
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10
assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20
assert training_clock.convert_time_value_to_steps(Step(1)) == 1
assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1
assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2
assert training_clock.convert_steps_to_time_unit(1, Step) == 1
assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1
def test_clock_properties(training_clock: TrainingClock) -> None:
assert training_clock.num_batches_per_epoch == 10
assert training_clock.num_epochs == 5
assert training_clock.num_iterations == 5
assert training_clock.num_steps == 50
def test_timer_functionality(training_clock: TrainingClock) -> None:
training_clock.start_timer()
assert training_clock.start_time is not None
@ -275,17 +251,12 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
clock = mock_trainer.clock
config = mock_trainer.config
assert clock.num_step_per_iteration == config.training.gradient_accumulation.number
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
assert mock_trainer.step_counter == 0
assert clock.epoch == 0
mock_trainer.train()
assert clock.epoch == config.training.duration.number
assert clock.step == config.training.duration.number * clock.num_batches_per_epoch
assert mock_trainer.step_counter == mock_trainer.clock.step