mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Fix warmup steps calculation when gradient_accumulation is used
This commit is contained in:
parent
12a5439fc4
commit
0ee2d5e075
|
@ -102,18 +102,18 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
|
||||||
class WarmupScheduler(LRScheduler):
|
class WarmupScheduler(LRScheduler):
|
||||||
_step_count: int # defined by LRScheduler
|
_step_count: int # defined by LRScheduler
|
||||||
|
|
||||||
def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_steps: int = 0) -> None:
|
def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_scheduler_steps: int = 0) -> None:
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_scheduler_steps = warmup_scheduler_steps
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
super().__init__(optimizer=optimizer)
|
super().__init__(optimizer=optimizer)
|
||||||
|
|
||||||
def get_lr(self) -> list[float] | float: # type: ignore
|
def get_lr(self) -> list[float] | float: # type: ignore
|
||||||
if self._step_count < self.warmup_steps:
|
if self._step_count <= self.warmup_scheduler_steps:
|
||||||
return [base_lr * self._step_count / self.warmup_steps for base_lr in self.base_lrs]
|
return [base_lr * self._step_count / self.warmup_scheduler_steps for base_lr in self.base_lrs]
|
||||||
return self.scheduler.get_lr()
|
return self.scheduler.get_lr()
|
||||||
|
|
||||||
def step(self, epoch: int | None = None) -> None:
|
def step(self, epoch: int | None = None) -> None:
|
||||||
if self._step_count < self.warmup_steps:
|
if self._step_count < self.warmup_scheduler_steps:
|
||||||
super().step()
|
super().step()
|
||||||
else:
|
else:
|
||||||
self.scheduler.step(epoch=epoch)
|
self.scheduler.step(epoch=epoch)
|
||||||
|
@ -342,19 +342,19 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@cached_property
|
@cached_property
|
||||||
def lr_scheduler(self) -> LRScheduler:
|
def lr_scheduler(self) -> LRScheduler:
|
||||||
config = self.config.scheduler
|
config = self.config.scheduler
|
||||||
step_size = self.clock.convert_time_unit_to_steps(
|
scheduler_step_size = config.update_interval["number"]
|
||||||
number=config.update_interval["number"], unit=config.update_interval["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
match config.scheduler_type:
|
match config.scheduler_type:
|
||||||
case SchedulerType.CONSTANT_LR:
|
case SchedulerType.CONSTANT_LR:
|
||||||
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
|
||||||
case SchedulerType.STEP_LR:
|
case SchedulerType.STEP_LR:
|
||||||
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=step_size, gamma=config.gamma)
|
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma)
|
||||||
case SchedulerType.EXPONENTIAL_LR:
|
case SchedulerType.EXPONENTIAL_LR:
|
||||||
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
||||||
case SchedulerType.COSINE_ANNEALING_LR:
|
case SchedulerType.COSINE_ANNEALING_LR:
|
||||||
lr_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=step_size, eta_min=config.eta_min)
|
lr_scheduler = CosineAnnealingLR(
|
||||||
|
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
|
||||||
|
)
|
||||||
case SchedulerType.REDUCE_LR_ON_PLATEAU:
|
case SchedulerType.REDUCE_LR_ON_PLATEAU:
|
||||||
lr_scheduler = cast(
|
lr_scheduler = cast(
|
||||||
LRScheduler,
|
LRScheduler,
|
||||||
|
@ -372,12 +372,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
|
||||||
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
case SchedulerType.ONE_CYCLE_LR:
|
case SchedulerType.ONE_CYCLE_LR:
|
||||||
lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=config.max_lr, total_steps=step_size)
|
lr_scheduler = OneCycleLR(
|
||||||
|
optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size
|
||||||
|
)
|
||||||
case SchedulerType.MULTIPLICATIVE_LR:
|
case SchedulerType.MULTIPLICATIVE_LR:
|
||||||
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
|
||||||
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
|
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
|
||||||
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=step_size)
|
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size)
|
||||||
case SchedulerType.CYCLIC_LR:
|
case SchedulerType.CYCLIC_LR:
|
||||||
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
|
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
|
||||||
case SchedulerType.MULTI_STEP_LR:
|
case SchedulerType.MULTI_STEP_LR:
|
||||||
|
@ -385,12 +387,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
|
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
|
||||||
|
|
||||||
warmup_steps = self.clock.convert_time_unit_to_steps(number=config.warmup["number"], unit=config.warmup["unit"])
|
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
|
||||||
if warmup_steps > 0:
|
if warmup_scheduler_steps > 0:
|
||||||
lr_scheduler = WarmupScheduler(
|
lr_scheduler = WarmupScheduler(
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
scheduler=lr_scheduler,
|
scheduler=lr_scheduler,
|
||||||
warmup_steps=warmup_steps,
|
warmup_scheduler_steps=warmup_scheduler_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
|
@ -6,6 +6,7 @@ from warnings import warn
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from torch.optim import SGD
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
|
@ -14,6 +15,7 @@ from refiners.training_utils.config import BaseConfig, TimeUnit
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingClock,
|
TrainingClock,
|
||||||
|
WarmupScheduler,
|
||||||
count_learnable_parameters,
|
count_learnable_parameters,
|
||||||
human_readable_number,
|
human_readable_number,
|
||||||
)
|
)
|
||||||
|
@ -183,3 +185,24 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch
|
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch
|
||||||
|
|
||||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def warmup_scheduler():
|
||||||
|
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, 1)
|
||||||
|
return WarmupScheduler(optimizer, scheduler, warmup_scheduler_steps=100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
|
optimizer = warmup_scheduler.optimizer
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
assert group["lr"] == 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
|
for _ in range(102):
|
||||||
|
warmup_scheduler.step()
|
||||||
|
optimizer = warmup_scheduler.optimizer
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
assert group["lr"] == 0.1
|
||||||
|
|
Loading…
Reference in a new issue