mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
rename Scheduler -> LRScheduler
This commit is contained in:
parent
684303230d
commit
432e32f94f
|
@ -8,11 +8,11 @@ from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.config import (
|
from refiners.training_utils.config import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
|
LRSchedulerConfig,
|
||||||
|
LRSchedulerType,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
OptimizerConfig,
|
OptimizerConfig,
|
||||||
Optimizers,
|
Optimizers,
|
||||||
SchedulerConfig,
|
|
||||||
SchedulerType,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||||
|
@ -48,11 +48,11 @@ __all__ = [
|
||||||
"CallbackConfig",
|
"CallbackConfig",
|
||||||
"WandbMixin",
|
"WandbMixin",
|
||||||
"WandbConfig",
|
"WandbConfig",
|
||||||
"SchedulerConfig",
|
"LRSchedulerConfig",
|
||||||
"OptimizerConfig",
|
"OptimizerConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"ClockConfig",
|
"ClockConfig",
|
||||||
"GradientClippingConfig",
|
"GradientClippingConfig",
|
||||||
"Optimizers",
|
"Optimizers",
|
||||||
"SchedulerType",
|
"LRSchedulerType",
|
||||||
]
|
]
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Optimizers(str, Enum):
|
||||||
Prodigy = "Prodigy"
|
Prodigy = "Prodigy"
|
||||||
|
|
||||||
|
|
||||||
class SchedulerType(str, Enum):
|
class LRSchedulerType(str, Enum):
|
||||||
STEP_LR = "StepLR"
|
STEP_LR = "StepLR"
|
||||||
EXPONENTIAL_LR = "ExponentialLR"
|
EXPONENTIAL_LR = "ExponentialLR"
|
||||||
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
|
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
|
||||||
|
@ -61,8 +61,8 @@ class SchedulerType(str, Enum):
|
||||||
DEFAULT = "ConstantLR"
|
DEFAULT = "ConstantLR"
|
||||||
|
|
||||||
|
|
||||||
class SchedulerConfig(BaseModel):
|
class LRSchedulerConfig(BaseModel):
|
||||||
scheduler_type: SchedulerType = SchedulerType.DEFAULT
|
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
||||||
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
|
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
|
||||||
gamma: float = 0.1
|
gamma: float = 0.1
|
||||||
|
@ -165,7 +165,7 @@ T = TypeVar("T", bound="BaseConfig")
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
scheduler: SchedulerConfig
|
lr_scheduler: LRSchedulerConfig
|
||||||
clock: ClockConfig = ClockConfig()
|
clock: ClockConfig = ClockConfig()
|
||||||
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ from refiners.training_utils.common import (
|
||||||
human_readable_number,
|
human_readable_number,
|
||||||
scoped_seed,
|
scoped_seed,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType
|
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
|
||||||
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
training_duration=self.config.training.duration,
|
training_duration=self.config.training.duration,
|
||||||
evaluation_interval=self.config.training.evaluation_interval,
|
evaluation_interval=self.config.training.evaluation_interval,
|
||||||
gradient_accumulation=self.config.training.gradient_accumulation,
|
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||||
lr_scheduler_interval=self.config.scheduler.update_interval,
|
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
|
||||||
verbose=config.verbose,
|
verbose=config.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -237,21 +237,21 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def lr_scheduler(self) -> LRScheduler:
|
def lr_scheduler(self) -> LRScheduler:
|
||||||
config = self.config.scheduler
|
config = self.config.lr_scheduler
|
||||||
scheduler_step_size = config.update_interval["number"]
|
scheduler_step_size = config.update_interval["number"]
|
||||||
|
|
||||||
match config.scheduler_type:
|
match config.type:
|
||||||
case SchedulerType.CONSTANT_LR:
|
case LRSchedulerType.CONSTANT_LR:
|
||||||
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
|
||||||
case SchedulerType.STEP_LR:
|
case LRSchedulerType.STEP_LR:
|
||||||
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma)
|
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=scheduler_step_size, gamma=config.gamma)
|
||||||
case SchedulerType.EXPONENTIAL_LR:
|
case LRSchedulerType.EXPONENTIAL_LR:
|
||||||
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
||||||
case SchedulerType.COSINE_ANNEALING_LR:
|
case LRSchedulerType.COSINE_ANNEALING_LR:
|
||||||
lr_scheduler = CosineAnnealingLR(
|
lr_scheduler = CosineAnnealingLR(
|
||||||
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
|
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
|
||||||
)
|
)
|
||||||
case SchedulerType.REDUCE_LR_ON_PLATEAU:
|
case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
|
||||||
lr_scheduler = cast(
|
lr_scheduler = cast(
|
||||||
LRScheduler,
|
LRScheduler,
|
||||||
ReduceLROnPlateau(
|
ReduceLROnPlateau(
|
||||||
|
@ -264,24 +264,24 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
min_lr=config.min_lr,
|
min_lr=config.min_lr,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
case SchedulerType.LAMBDA_LR:
|
case LRSchedulerType.LAMBDA_LR:
|
||||||
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
|
||||||
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
case SchedulerType.ONE_CYCLE_LR:
|
case LRSchedulerType.ONE_CYCLE_LR:
|
||||||
lr_scheduler = OneCycleLR(
|
lr_scheduler = OneCycleLR(
|
||||||
optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size
|
optimizer=self.optimizer, max_lr=config.max_lr, total_steps=scheduler_step_size
|
||||||
)
|
)
|
||||||
case SchedulerType.MULTIPLICATIVE_LR:
|
case LRSchedulerType.MULTIPLICATIVE_LR:
|
||||||
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
|
||||||
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
|
case LRSchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
|
||||||
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size)
|
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=scheduler_step_size)
|
||||||
case SchedulerType.CYCLIC_LR:
|
case LRSchedulerType.CYCLIC_LR:
|
||||||
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
|
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
|
||||||
case SchedulerType.MULTI_STEP_LR:
|
case LRSchedulerType.MULTI_STEP_LR:
|
||||||
lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma)
|
lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
|
raise ValueError(f"Unknown scheduler type: {config.type}")
|
||||||
|
|
||||||
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
|
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
|
||||||
if warmup_scheduler_steps > 0:
|
if warmup_scheduler_steps > 0:
|
||||||
|
|
|
@ -22,7 +22,7 @@ evaluation_seed = 1
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
learning_rate = 1
|
learning_rate = 1
|
||||||
|
|
||||||
[scheduler]
|
[lr_scheduler]
|
||||||
scheduler_type = "ConstantLR"
|
type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:step"
|
||||||
warmup = "20:step"
|
warmup = "20:step"
|
||||||
|
|
|
@ -23,6 +23,6 @@ evaluation_seed = 1
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
learning_rate = 1
|
learning_rate = 1
|
||||||
|
|
||||||
[scheduler]
|
[lr_scheduler]
|
||||||
scheduler_type = "ConstantLR"
|
type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:step"
|
||||||
|
|
Loading…
Reference in a new issue