diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 40cad58..9c2ee17 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -102,18 +102,18 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C class WarmupScheduler(LRScheduler): _step_count: int # defined by LRScheduler - def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_steps: int = 0) -> None: - self.warmup_steps = warmup_steps + def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_scheduler_steps: int = 0) -> None: + self.warmup_scheduler_steps = warmup_scheduler_steps self.scheduler = scheduler super().__init__(optimizer=optimizer) def get_lr(self) -> list[float] | float: # type: ignore - if self._step_count < self.warmup_steps: - return [base_lr * self._step_count / self.warmup_steps for base_lr in self.base_lrs] + if self._step_count <= self.warmup_scheduler_steps: + return [base_lr * self._step_count / self.warmup_scheduler_steps for base_lr in self.base_lrs] return self.scheduler.get_lr() def step(self, epoch: int | None = None) -> None: - if self._step_count < self.warmup_steps: + if self._step_count < self.warmup_scheduler_steps: super().step() else: self.scheduler.step(epoch=epoch) @@ -342,19 +342,19 @@ class Trainer(Generic[ConfigType, Batch], ABC): @cached_property def lr_scheduler(self) -> LRScheduler: config = self.config.scheduler - step_size = self.clock.convert_time_unit_to_steps( - number=config.update_interval["number"], unit=config.update_interval["unit"] - ) + scheduler_step_size = config.update_interval["number"] match config.scheduler_type: case SchedulerType.CONSTANT_LR: lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0) case SchedulerType.STEP_LR: - lr_scheduler = StepLR(optimizer=self.optimizer, step_size=step_size, gamma=config.gamma) + lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma) case SchedulerType.EXPONENTIAL_LR: lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) case SchedulerType.COSINE_ANNEALING_LR: - lr_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=step_size, eta_min=config.eta_min) + lr_scheduler = CosineAnnealingLR( + optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min + ) case SchedulerType.REDUCE_LR_ON_PLATEAU: lr_scheduler = cast( LRScheduler, @@ -372,12 +372,14 @@ class Trainer(Generic[ConfigType, Batch], ABC): assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR" lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) case SchedulerType.ONE_CYCLE_LR: - lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=config.max_lr, total_steps=step_size) + lr_scheduler = OneCycleLR( + optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size + ) case SchedulerType.MULTIPLICATIVE_LR: assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR" lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS: - lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=step_size) + lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size) case SchedulerType.CYCLIC_LR: lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr) case SchedulerType.MULTI_STEP_LR: @@ -385,12 +387,12 @@ class Trainer(Generic[ConfigType, Batch], ABC): case _: raise ValueError(f"Unknown scheduler type: {config.scheduler_type}") - warmup_steps = self.clock.convert_time_unit_to_steps(number=config.warmup["number"], unit=config.warmup["unit"]) - if warmup_steps > 0: + warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"]) + if warmup_scheduler_steps > 0: lr_scheduler = WarmupScheduler( optimizer=self.optimizer, scheduler=lr_scheduler, - warmup_steps=warmup_steps, + warmup_scheduler_steps=warmup_scheduler_steps, ) return lr_scheduler diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 6064f49..20c2f3a 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -29,4 +29,4 @@ save_interval = "10:epoch" [wandb] mode = "disabled" -project = "mock_project" \ No newline at end of file +project = "mock_project" diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 38e45d1..3d7be73 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -6,6 +6,7 @@ from warnings import warn import pytest import torch from torch import Tensor, nn +from torch.optim import SGD from torch.utils.data import Dataset from refiners.fluxion import layers as fl @@ -14,6 +15,7 @@ from refiners.training_utils.config import BaseConfig, TimeUnit from refiners.training_utils.trainer import ( Trainer, TrainingClock, + WarmupScheduler, count_learnable_parameters, human_readable_number, ) @@ -183,3 +185,24 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch assert mock_trainer.step_counter == mock_trainer.clock.step + + +@pytest.fixture +def warmup_scheduler(): + optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, 1) + return WarmupScheduler(optimizer, scheduler, warmup_scheduler_steps=100) + + +def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None: + optimizer = warmup_scheduler.optimizer + for group in optimizer.param_groups: + assert group["lr"] == 1e-3 + + +def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: + for _ in range(102): + warmup_scheduler.step() + optimizer = warmup_scheduler.optimizer + for group in optimizer.param_groups: + assert group["lr"] == 0.1