Fix warmup steps calculation when gradient_accumulation is used

This commit is contained in:
limiteinductive 2024-01-25 11:36:58 +01:00 committed by Benjamin Trom
parent 12a5439fc4
commit 0ee2d5e075
3 changed files with 41 additions and 16 deletions

View file

@ -102,18 +102,18 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
class WarmupScheduler(LRScheduler): class WarmupScheduler(LRScheduler):
_step_count: int # defined by LRScheduler _step_count: int # defined by LRScheduler
def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_steps: int = 0) -> None: def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_scheduler_steps: int = 0) -> None:
self.warmup_steps = warmup_steps self.warmup_scheduler_steps = warmup_scheduler_steps
self.scheduler = scheduler self.scheduler = scheduler
super().__init__(optimizer=optimizer) super().__init__(optimizer=optimizer)
def get_lr(self) -> list[float] | float: # type: ignore def get_lr(self) -> list[float] | float: # type: ignore
if self._step_count < self.warmup_steps: if self._step_count <= self.warmup_scheduler_steps:
return [base_lr * self._step_count / self.warmup_steps for base_lr in self.base_lrs] return [base_lr * self._step_count / self.warmup_scheduler_steps for base_lr in self.base_lrs]
return self.scheduler.get_lr() return self.scheduler.get_lr()
def step(self, epoch: int | None = None) -> None: def step(self, epoch: int | None = None) -> None:
if self._step_count < self.warmup_steps: if self._step_count < self.warmup_scheduler_steps:
super().step() super().step()
else: else:
self.scheduler.step(epoch=epoch) self.scheduler.step(epoch=epoch)
@ -342,19 +342,19 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property @cached_property
def lr_scheduler(self) -> LRScheduler: def lr_scheduler(self) -> LRScheduler:
config = self.config.scheduler config = self.config.scheduler
step_size = self.clock.convert_time_unit_to_steps( scheduler_step_size = config.update_interval["number"]
number=config.update_interval["number"], unit=config.update_interval["unit"]
)
match config.scheduler_type: match config.scheduler_type:
case SchedulerType.CONSTANT_LR: case SchedulerType.CONSTANT_LR:
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0) lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
case SchedulerType.STEP_LR: case SchedulerType.STEP_LR:
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=step_size, gamma=config.gamma) lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma)
case SchedulerType.EXPONENTIAL_LR: case SchedulerType.EXPONENTIAL_LR:
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
case SchedulerType.COSINE_ANNEALING_LR: case SchedulerType.COSINE_ANNEALING_LR:
lr_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=step_size, eta_min=config.eta_min) lr_scheduler = CosineAnnealingLR(
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
)
case SchedulerType.REDUCE_LR_ON_PLATEAU: case SchedulerType.REDUCE_LR_ON_PLATEAU:
lr_scheduler = cast( lr_scheduler = cast(
LRScheduler, LRScheduler,
@ -372,12 +372,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR" assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.ONE_CYCLE_LR: case SchedulerType.ONE_CYCLE_LR:
lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=config.max_lr, total_steps=step_size) lr_scheduler = OneCycleLR(
optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size
)
case SchedulerType.MULTIPLICATIVE_LR: case SchedulerType.MULTIPLICATIVE_LR:
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR" assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS: case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=step_size) lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size)
case SchedulerType.CYCLIC_LR: case SchedulerType.CYCLIC_LR:
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr) lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
case SchedulerType.MULTI_STEP_LR: case SchedulerType.MULTI_STEP_LR:
@ -385,12 +387,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
case _: case _:
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}") raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
warmup_steps = self.clock.convert_time_unit_to_steps(number=config.warmup["number"], unit=config.warmup["unit"]) warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
if warmup_steps > 0: if warmup_scheduler_steps > 0:
lr_scheduler = WarmupScheduler( lr_scheduler = WarmupScheduler(
optimizer=self.optimizer, optimizer=self.optimizer,
scheduler=lr_scheduler, scheduler=lr_scheduler,
warmup_steps=warmup_steps, warmup_scheduler_steps=warmup_scheduler_steps,
) )
return lr_scheduler return lr_scheduler

View file

@ -29,4 +29,4 @@ save_interval = "10:epoch"
[wandb] [wandb]
mode = "disabled" mode = "disabled"
project = "mock_project" project = "mock_project"

View file

@ -6,6 +6,7 @@ from warnings import warn
import pytest import pytest
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim import SGD
from torch.utils.data import Dataset from torch.utils.data import Dataset
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
@ -14,6 +15,7 @@ from refiners.training_utils.config import BaseConfig, TimeUnit
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
Trainer, Trainer,
TrainingClock, TrainingClock,
WarmupScheduler,
count_learnable_parameters, count_learnable_parameters,
human_readable_number, human_readable_number,
) )
@ -183,3 +185,24 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch
assert mock_trainer.step_counter == mock_trainer.clock.step assert mock_trainer.step_counter == mock_trainer.clock.step
@pytest.fixture
def warmup_scheduler():
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, 1)
return WarmupScheduler(optimizer, scheduler, warmup_scheduler_steps=100)
def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups:
assert group["lr"] == 1e-3
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
for _ in range(102):
warmup_scheduler.step()
optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups:
assert group["lr"] == 0.1