rename Scheduler -> LRScheduler

This commit is contained in:
limiteinductive 2024-02-15 09:48:12 +00:00 committed by Benjamin Trom
parent 684303230d
commit 432e32f94f
5 changed files with 28 additions and 28 deletions

View file

@ -8,11 +8,11 @@ from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.config import ( from refiners.training_utils.config import (
BaseConfig, BaseConfig,
LRSchedulerConfig,
LRSchedulerType,
ModelConfig, ModelConfig,
OptimizerConfig, OptimizerConfig,
Optimizers, Optimizers,
SchedulerConfig,
SchedulerType,
TrainingConfig, TrainingConfig,
) )
from refiners.training_utils.gradient_clipping import GradientClippingConfig from refiners.training_utils.gradient_clipping import GradientClippingConfig
@ -48,11 +48,11 @@ __all__ = [
"CallbackConfig", "CallbackConfig",
"WandbMixin", "WandbMixin",
"WandbConfig", "WandbConfig",
"SchedulerConfig", "LRSchedulerConfig",
"OptimizerConfig", "OptimizerConfig",
"TrainingConfig", "TrainingConfig",
"ClockConfig", "ClockConfig",
"GradientClippingConfig", "GradientClippingConfig",
"Optimizers", "Optimizers",
"SchedulerType", "LRSchedulerType",
] ]

View file

@ -46,7 +46,7 @@ class Optimizers(str, Enum):
Prodigy = "Prodigy" Prodigy = "Prodigy"
class SchedulerType(str, Enum): class LRSchedulerType(str, Enum):
STEP_LR = "StepLR" STEP_LR = "StepLR"
EXPONENTIAL_LR = "ExponentialLR" EXPONENTIAL_LR = "ExponentialLR"
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau" REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
@ -61,8 +61,8 @@ class SchedulerType(str, Enum):
DEFAULT = "ConstantLR" DEFAULT = "ConstantLR"
class SchedulerConfig(BaseModel): class LRSchedulerConfig(BaseModel):
scheduler_type: SchedulerType = SchedulerType.DEFAULT type: LRSchedulerType = LRSchedulerType.DEFAULT
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION} warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
gamma: float = 0.1 gamma: float = 0.1
@ -165,7 +165,7 @@ T = TypeVar("T", bound="BaseConfig")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
training: TrainingConfig training: TrainingConfig
optimizer: OptimizerConfig optimizer: OptimizerConfig
scheduler: SchedulerConfig lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig() clock: ClockConfig = ClockConfig()
gradient_clipping: GradientClippingConfig = GradientClippingConfig() gradient_clipping: GradientClippingConfig = GradientClippingConfig()

View file

@ -36,7 +36,7 @@ from refiners.training_utils.common import (
human_readable_number, human_readable_number,
scoped_seed, scoped_seed,
) )
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
@ -154,7 +154,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
training_duration=self.config.training.duration, training_duration=self.config.training.duration,
evaluation_interval=self.config.training.evaluation_interval, evaluation_interval=self.config.training.evaluation_interval,
gradient_accumulation=self.config.training.gradient_accumulation, gradient_accumulation=self.config.training.gradient_accumulation,
lr_scheduler_interval=self.config.scheduler.update_interval, lr_scheduler_interval=self.config.lr_scheduler.update_interval,
verbose=config.verbose, verbose=config.verbose,
) )
@ -237,21 +237,21 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property @cached_property
def lr_scheduler(self) -> LRScheduler: def lr_scheduler(self) -> LRScheduler:
config = self.config.scheduler config = self.config.lr_scheduler
scheduler_step_size = config.update_interval["number"] scheduler_step_size = config.update_interval["number"]
match config.scheduler_type: match config.type:
case SchedulerType.CONSTANT_LR: case LRSchedulerType.CONSTANT_LR:
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0) lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
case SchedulerType.STEP_LR: case LRSchedulerType.STEP_LR:
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma) lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma)
case SchedulerType.EXPONENTIAL_LR: case LRSchedulerType.EXPONENTIAL_LR:
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
case SchedulerType.COSINE_ANNEALING_LR: case LRSchedulerType.COSINE_ANNEALING_LR:
lr_scheduler = CosineAnnealingLR( lr_scheduler = CosineAnnealingLR(
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
) )
case SchedulerType.REDUCE_LR_ON_PLATEAU: case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
lr_scheduler = cast( lr_scheduler = cast(
LRScheduler, LRScheduler,
ReduceLROnPlateau( ReduceLROnPlateau(
@ -264,24 +264,24 @@ class Trainer(Generic[ConfigType, Batch], ABC):
min_lr=config.min_lr, min_lr=config.min_lr,
), ),
) )
case SchedulerType.LAMBDA_LR: case LRSchedulerType.LAMBDA_LR:
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR" assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.ONE_CYCLE_LR: case LRSchedulerType.ONE_CYCLE_LR:
lr_scheduler = OneCycleLR( lr_scheduler = OneCycleLR(
optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size
) )
case SchedulerType.MULTIPLICATIVE_LR: case LRSchedulerType.MULTIPLICATIVE_LR:
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR" assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS: case LRSchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size) lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size)
case SchedulerType.CYCLIC_LR: case LRSchedulerType.CYCLIC_LR:
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr) lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
case SchedulerType.MULTI_STEP_LR: case LRSchedulerType.MULTI_STEP_LR:
lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma) lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma)
case _: case _:
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}") raise ValueError(f"Unknown scheduler type: {config.type}")
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"]) warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
if warmup_scheduler_steps > 0: if warmup_scheduler_steps > 0:

View file

@ -22,7 +22,7 @@ evaluation_seed = 1
optimizer = "SGD" optimizer = "SGD"
learning_rate = 1 learning_rate = 1
[scheduler] [lr_scheduler]
scheduler_type = "ConstantLR" type = "ConstantLR"
update_interval = "1:step" update_interval = "1:step"
warmup = "20:step" warmup = "20:step"

View file

@ -23,6 +23,6 @@ evaluation_seed = 1
optimizer = "SGD" optimizer = "SGD"
learning_rate = 1 learning_rate = 1
[scheduler] [lr_scheduler]
scheduler_type = "ConstantLR" type = "ConstantLR"
update_interval = "1:step" update_interval = "1:step"