diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 8850b34..8913ea1 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -8,11 +8,11 @@ from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.clock import ClockConfig from refiners.training_utils.config import ( BaseConfig, + LRSchedulerConfig, + LRSchedulerType, ModelConfig, OptimizerConfig, Optimizers, - SchedulerConfig, - SchedulerType, TrainingConfig, ) from refiners.training_utils.gradient_clipping import GradientClippingConfig @@ -48,11 +48,11 @@ __all__ = [ "CallbackConfig", "WandbMixin", "WandbConfig", - "SchedulerConfig", + "LRSchedulerConfig", "OptimizerConfig", "TrainingConfig", "ClockConfig", "GradientClippingConfig", "Optimizers", - "SchedulerType", + "LRSchedulerType", ] diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 91e41e2..4098ea0 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -46,7 +46,7 @@ class Optimizers(str, Enum): Prodigy = "Prodigy" -class SchedulerType(str, Enum): +class LRSchedulerType(str, Enum): STEP_LR = "StepLR" EXPONENTIAL_LR = "ExponentialLR" REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau" @@ -61,8 +61,8 @@ class SchedulerType(str, Enum): DEFAULT = "ConstantLR" -class SchedulerConfig(BaseModel): - scheduler_type: SchedulerType = SchedulerType.DEFAULT +class LRSchedulerConfig(BaseModel): + type: LRSchedulerType = LRSchedulerType.DEFAULT update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION} gamma: float = 0.1 @@ -165,7 +165,7 @@ T = TypeVar("T", bound="BaseConfig") class BaseConfig(BaseModel): training: TrainingConfig optimizer: OptimizerConfig - scheduler: SchedulerConfig + lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() gradient_clipping: GradientClippingConfig = GradientClippingConfig() diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index eb6c25a..c68a08e 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -36,7 +36,7 @@ from refiners.training_utils.common import ( human_readable_number, scoped_seed, ) -from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType +from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig @@ -154,7 +154,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): training_duration=self.config.training.duration, evaluation_interval=self.config.training.evaluation_interval, gradient_accumulation=self.config.training.gradient_accumulation, - lr_scheduler_interval=self.config.scheduler.update_interval, + lr_scheduler_interval=self.config.lr_scheduler.update_interval, verbose=config.verbose, ) @@ -237,21 +237,21 @@ class Trainer(Generic[ConfigType, Batch], ABC): @cached_property def lr_scheduler(self) -> LRScheduler: - config = self.config.scheduler + config = self.config.lr_scheduler scheduler_step_size = config.update_interval["number"] - match config.scheduler_type: - case SchedulerType.CONSTANT_LR: + match config.type: + case LRSchedulerType.CONSTANT_LR: lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0) - case SchedulerType.STEP_LR: + case LRSchedulerType.STEP_LR: lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma) - case SchedulerType.EXPONENTIAL_LR: + case LRSchedulerType.EXPONENTIAL_LR: lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) - case SchedulerType.COSINE_ANNEALING_LR: + case LRSchedulerType.COSINE_ANNEALING_LR: lr_scheduler = CosineAnnealingLR( optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min ) - case SchedulerType.REDUCE_LR_ON_PLATEAU: + case LRSchedulerType.REDUCE_LR_ON_PLATEAU: lr_scheduler = cast( LRScheduler, ReduceLROnPlateau( @@ -264,24 +264,24 @@ class Trainer(Generic[ConfigType, Batch], ABC): min_lr=config.min_lr, ), ) - case SchedulerType.LAMBDA_LR: + case LRSchedulerType.LAMBDA_LR: assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR" lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) - case SchedulerType.ONE_CYCLE_LR: + case LRSchedulerType.ONE_CYCLE_LR: lr_scheduler = OneCycleLR( optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size ) - case SchedulerType.MULTIPLICATIVE_LR: + case LRSchedulerType.MULTIPLICATIVE_LR: assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR" lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda) - case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS: + case LRSchedulerType.COSINE_ANNEALING_WARM_RESTARTS: lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size) - case SchedulerType.CYCLIC_LR: + case LRSchedulerType.CYCLIC_LR: lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr) - case SchedulerType.MULTI_STEP_LR: + case LRSchedulerType.MULTI_STEP_LR: lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma) case _: - raise ValueError(f"Unknown scheduler type: {config.scheduler_type}") + raise ValueError(f"Unknown scheduler type: {config.type}") warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"]) if warmup_scheduler_steps > 0: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index a5017a9..bfc8b2d 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -22,7 +22,7 @@ evaluation_seed = 1 optimizer = "SGD" learning_rate = 1 -[scheduler] -scheduler_type = "ConstantLR" +[lr_scheduler] +type = "ConstantLR" update_interval = "1:step" warmup = "20:step" diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 474551f..9980a6f 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -23,6 +23,6 @@ evaluation_seed = 1 optimizer = "SGD" learning_rate = 1 -[scheduler] -scheduler_type = "ConstantLR" +[lr_scheduler] +type = "ConstantLR" update_interval = "1:step"