mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add @register_model and @register_callback decorators
Refactor ClockTrainer to include Callback
This commit is contained in:
parent
f541badcb3
commit
d6546c9026
|
@ -4,8 +4,20 @@ from importlib.metadata import requires
|
|||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.config import (
|
||||
BaseConfig,
|
||||
ModelConfig,
|
||||
OptimizerConfig,
|
||||
Optimizers,
|
||||
SchedulerConfig,
|
||||
SchedulerType,
|
||||
TrainingConfig,
|
||||
)
|
||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||
|
||||
refiners_requires = requires("refiners")
|
||||
assert refiners_requires is not None
|
||||
|
@ -29,4 +41,18 @@ for dep in refiners_requires:
|
|||
__all__ = [
|
||||
"Trainer",
|
||||
"BaseConfig",
|
||||
"ModelConfig",
|
||||
"register_callback",
|
||||
"register_model",
|
||||
"Callback",
|
||||
"CallbackConfig",
|
||||
"WandbMixin",
|
||||
"WandbConfig",
|
||||
"SchedulerConfig",
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"ClockConfig",
|
||||
"GradientClippingConfig",
|
||||
"Optimizers",
|
||||
"SchedulerType",
|
||||
]
|
||||
|
|
|
@ -1,43 +1,22 @@
|
|||
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
from torch import tensor
|
||||
from torch.nn import Parameter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
__all__ = [
|
||||
"Callback",
|
||||
"GradientNormClipping",
|
||||
"GradientValueClipping",
|
||||
"ClockCallback",
|
||||
]
|
||||
T = TypeVar("T", bound="Trainer[BaseConfig, Any]")
|
||||
|
||||
|
||||
def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
||||
class CallbackConfig(BaseModel):
|
||||
"""
|
||||
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
||||
Base configuration for a callback.
|
||||
|
||||
For your callback to be properly configured, you should inherit from this class and add your own configuration.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
||||
for gradient in gradients:
|
||||
gradient.mul_(other=clip_coefficient) # type: ignore
|
||||
|
||||
|
||||
def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None:
|
||||
"""
|
||||
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
for gradient in gradients:
|
||||
gradient.clamp_(min=-clip_value, max=clip_value)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class Callback(Generic[T]):
|
||||
|
@ -97,71 +76,3 @@ class Callback(Generic[T]):
|
|||
|
||||
def on_checkpoint_save(self, trainer: T) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.reset()
|
||||
logger.info(
|
||||
(
|
||||
"Starting training for a total of: "
|
||||
f"{trainer.clock.num_steps} steps, "
|
||||
f"{trainer.clock.num_epochs} epochs, "
|
||||
f"{trainer.clock.num_iterations} iterations."
|
||||
)
|
||||
)
|
||||
trainer.clock.start_timer()
|
||||
|
||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.stop_timer()
|
||||
logger.info(
|
||||
(
|
||||
"Training took: "
|
||||
f"{trainer.clock.time_elapsed} seconds, "
|
||||
f"{trainer.clock.iteration} iterations, "
|
||||
f"{trainer.clock.epoch} epochs, "
|
||||
f"{trainer.clock.step} steps."
|
||||
)
|
||||
)
|
||||
|
||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
logger.info(f"Epoch {trainer.clock.epoch} started.")
|
||||
|
||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.epoch += 1
|
||||
trainer.clock.num_batches_processed = 0
|
||||
|
||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
logger.info(f"Step {trainer.clock.step} started.")
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.step += 1
|
||||
trainer.clock.num_batches_processed += 1
|
||||
trainer.clock.num_minibatches_processed += 1
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
logger.info(f"Iteration {trainer.clock.iteration} ended.")
|
||||
trainer.clock.iteration += 1
|
||||
trainer.clock.num_minibatches_processed = 0
|
||||
|
||||
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
logger.info("Evaluation started.")
|
||||
|
||||
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
logger.info("Evaluation ended.")
|
||||
|
||||
|
||||
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
clip_norm = trainer.config.training.clip_grad_norm
|
||||
if clip_norm is not None:
|
||||
clip_gradient_norm(
|
||||
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
||||
)
|
||||
|
||||
|
||||
class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
clip_value = trainer.config.training.clip_grad_value
|
||||
if clip_value is not None:
|
||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
||||
|
|
193
src/refiners/training_utils/clock.py
Normal file
193
src/refiners/training_utils/clock.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
import time
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.common import TimeUnit, TimeValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ClockConfig(CallbackConfig):
|
||||
verbose: bool = True
|
||||
|
||||
|
||||
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_length: int,
|
||||
batch_size: int,
|
||||
training_duration: TimeValue,
|
||||
gradient_accumulation: TimeValue,
|
||||
evaluation_interval: TimeValue,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
self.dataset_length = dataset_length
|
||||
self.batch_size = batch_size
|
||||
self.training_duration = training_duration
|
||||
self.gradient_accumulation = gradient_accumulation
|
||||
self.evaluation_interval = evaluation_interval
|
||||
self.lr_scheduler_interval = lr_scheduler_interval
|
||||
self.verbose = verbose
|
||||
self.num_batches_per_epoch = dataset_length // batch_size
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
self.num_batches_processed = 0
|
||||
self.num_minibatches_processed = 0
|
||||
self.loss: Tensor | None = None
|
||||
|
||||
@cached_property
|
||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
||||
return {
|
||||
TimeUnit.STEP: 1,
|
||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
||||
}
|
||||
|
||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||
return number * self.unit_to_steps[unit]
|
||||
|
||||
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||
return steps // self.unit_to_steps[unit]
|
||||
|
||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||
number, unit = time_value["number"], time_value["unit"]
|
||||
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||
|
||||
@cached_property
|
||||
def num_epochs(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
||||
|
||||
@cached_property
|
||||
def num_iterations(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
||||
|
||||
@cached_property
|
||||
def num_steps(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
||||
|
||||
@cached_property
|
||||
def num_step_per_iteration(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def num_step_per_evaluation(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
self.num_batches_processed = 0
|
||||
self.num_minibatches_processed = 0
|
||||
|
||||
def start_timer(self) -> None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def stop_timer(self) -> None:
|
||||
self.end_time = time.time()
|
||||
|
||||
@property
|
||||
def time_elapsed(self) -> int:
|
||||
assert self.start_time is not None, "Timer has not been started yet."
|
||||
return int(time.time() - self.start_time)
|
||||
|
||||
@cached_property
|
||||
def evaluation_interval_steps(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def lr_scheduler_interval_steps(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||
)
|
||||
|
||||
@property
|
||||
def is_optimizer_step(self) -> bool:
|
||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||
|
||||
@property
|
||||
def is_lr_scheduler_step(self) -> bool:
|
||||
return self.step % self.lr_scheduler_interval_steps == 0
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.step >= self.num_steps
|
||||
|
||||
@property
|
||||
def is_evaluation_step(self) -> bool:
|
||||
return self.step % self.evaluation_interval_steps == 0
|
||||
|
||||
def log(self, message: str, /) -> None:
|
||||
if self.verbose:
|
||||
logger.info(message)
|
||||
|
||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.reset()
|
||||
self.log(
|
||||
(
|
||||
"Starting training for a total of: "
|
||||
f"{trainer.clock.num_steps} steps, "
|
||||
f"{trainer.clock.num_epochs} epochs, "
|
||||
f"{trainer.clock.num_iterations} iterations."
|
||||
)
|
||||
)
|
||||
trainer.clock.start_timer()
|
||||
|
||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.stop_timer()
|
||||
self.log(
|
||||
(
|
||||
"Training took: "
|
||||
f"{trainer.clock.time_elapsed} seconds, "
|
||||
f"{trainer.clock.iteration} iterations, "
|
||||
f"{trainer.clock.epoch} epochs, "
|
||||
f"{trainer.clock.step} steps."
|
||||
)
|
||||
)
|
||||
|
||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Epoch {trainer.clock.epoch} started.")
|
||||
|
||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.epoch += 1
|
||||
trainer.clock.num_batches_processed = 0
|
||||
|
||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Step {trainer.clock.step} started.")
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.clock.step += 1
|
||||
trainer.clock.num_batches_processed += 1
|
||||
trainer.clock.num_minibatches_processed += 1
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Iteration {trainer.clock.iteration} ended.")
|
||||
trainer.clock.iteration += 1
|
||||
trainer.clock.num_minibatches_processed = 0
|
||||
|
||||
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log("Evaluation started.")
|
||||
|
||||
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log("Evaluation ended.")
|
103
src/refiners/training_utils/common.py
Normal file
103
src/refiners/training_utils/common.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import random
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import Tensor, cuda, nn
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
|
||||
|
||||
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
|
||||
"""
|
||||
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
||||
"""
|
||||
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to compute the norm."
|
||||
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
||||
return total_norm # type: ignore
|
||||
|
||||
|
||||
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
|
||||
return sum(p.numel() for p in parameters if p.requires_grad)
|
||||
|
||||
|
||||
def human_readable_number(number: int) -> str:
|
||||
float_number = float(number)
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if abs(float_number) < 1000:
|
||||
return f"{float_number:.1f}{unit}"
|
||||
float_number /= 1000
|
||||
return f"{float_number:.1f}E"
|
||||
|
||||
|
||||
def seed_everything(seed: int | None = None) -> None:
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Using random seed: {seed}")
|
||||
random.seed(a=seed)
|
||||
np.random.seed(seed=seed)
|
||||
manual_seed(seed=seed)
|
||||
cuda.manual_seed_all(seed=seed)
|
||||
|
||||
|
||||
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
||||
"""
|
||||
Decorator for setting a random seed within the scope of a function.
|
||||
|
||||
This decorator sets the random seed for Python's built-in `random` module,
|
||||
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
||||
function is executed, it restores the state of the random number generators
|
||||
to what it was before the function was called. This is useful for ensuring
|
||||
reproducibility for specific parts of the code without affecting randomness
|
||||
elsewhere.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
random_state = random.getstate()
|
||||
numpy_state = np.random.get_state()
|
||||
torch_state = torch.get_rng_state()
|
||||
cuda_torch_state = cuda.get_rng_state()
|
||||
actual_seed = seed(*args) if callable(seed) else seed
|
||||
seed_everything(seed=actual_seed)
|
||||
result = func(*args, **kwargs)
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(numpy_state)
|
||||
torch.set_rng_state(torch_state)
|
||||
cuda.set_rng_state(cuda_torch_state)
|
||||
return result
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TimeUnit(Enum):
|
||||
STEP = "step"
|
||||
EPOCH = "epoch"
|
||||
ITERATION = "iteration"
|
||||
DEFAULT = "step"
|
||||
|
||||
|
||||
class TimeValue(TypedDict):
|
||||
number: int
|
||||
unit: TimeUnit
|
||||
|
||||
|
||||
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
||||
match value:
|
||||
case str(value_str):
|
||||
number, unit = value_str.split(sep=":")
|
||||
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
||||
case int(number):
|
||||
return {"number": number, "unit": TimeUnit.DEFAULT}
|
||||
case {"number": int(number), "unit": str(unit)}:
|
||||
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value format: {value}")
|
|
@ -9,50 +9,16 @@ from prodigyopt import Prodigy # type: ignore
|
|||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
# See https://github.com/pytorch/pytorch/pull/111114
|
||||
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
__all__ = [
|
||||
"parse_number_unit_field",
|
||||
"TimeUnit",
|
||||
"TimeValue",
|
||||
"TrainingConfig",
|
||||
"OptimizerConfig",
|
||||
"Optimizers",
|
||||
]
|
||||
|
||||
|
||||
class TimeUnit(Enum):
|
||||
STEP = "step"
|
||||
EPOCH = "epoch"
|
||||
ITERATION = "iteration"
|
||||
DEFAULT = "step"
|
||||
|
||||
|
||||
class TimeValue(TypedDict):
|
||||
number: int
|
||||
unit: TimeUnit
|
||||
|
||||
|
||||
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
||||
match value:
|
||||
case str(value_str):
|
||||
number, unit = value_str.split(sep=":")
|
||||
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
||||
case int(number):
|
||||
return {"number": number, "unit": TimeUnit.DEFAULT}
|
||||
case {"number": int(number), "unit": str(unit)}:
|
||||
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported value format: {value}")
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
device: str = "cpu"
|
||||
|
@ -61,8 +27,6 @@ class TrainingConfig(BaseModel):
|
|||
seed: int = 0
|
||||
batch_size: int = 1
|
||||
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
||||
clip_grad_norm: float | None = None
|
||||
clip_grad_value: float | None = None
|
||||
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||
evaluation_seed: int = 0
|
||||
|
||||
|
@ -195,29 +159,6 @@ class ModelConfig(BaseModel):
|
|||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class GyroDropoutConfig(BaseModel):
|
||||
total_subnetworks: int = 512
|
||||
concurrent_subnetworks: int = 64
|
||||
iters_per_epoch: int = 512
|
||||
num_features_threshold: float = 5e5
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DropoutConfig(BaseModel):
|
||||
dropout_probability: float = 0.0
|
||||
gyro_dropout: GyroDropoutConfig | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def apply_dropout(self, model: fl.Chain) -> None:
|
||||
if self.dropout_probability > 0.0:
|
||||
if self.gyro_dropout is not None:
|
||||
apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump())
|
||||
else:
|
||||
apply_dropout(module=model, probability=self.dropout_probability)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseConfig")
|
||||
|
||||
|
||||
|
@ -226,7 +167,8 @@ class BaseConfig(BaseModel):
|
|||
training: TrainingConfig
|
||||
optimizer: OptimizerConfig
|
||||
scheduler: SchedulerConfig
|
||||
dropout: DropoutConfig
|
||||
clock: ClockConfig = ClockConfig()
|
||||
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
@ -1,200 +0,0 @@
|
|||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from torch import Tensor, cat, rand, randint
|
||||
from torch.nn import Dropout as TorchDropout
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.training_utils.callback import Callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
|
||||
__all__ = ["Dropout", "GyroDropout", "DropoutCallback"]
|
||||
|
||||
|
||||
class Dropout(TorchDropout, fl.Module):
|
||||
def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
|
||||
super().__init__(p=probability, inplace=inplace)
|
||||
|
||||
|
||||
class GyroDropout(fl.Module):
|
||||
"""
|
||||
GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training.
|
||||
It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training.
|
||||
This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional
|
||||
dropout.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
total_subnetworks:
|
||||
The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks
|
||||
that are precomputed and stored.
|
||||
|
||||
concurrent_subnetworks:
|
||||
The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of
|
||||
masks from the precomputed set is used to dropout different portions of the input.
|
||||
|
||||
dropout_probability: float, optional (default=0.5)
|
||||
The probability that an element will be zeroed by the dropout.
|
||||
|
||||
iters_per_epoch:
|
||||
Number of iterations per epoch, used to determine how often the masks should be updated.
|
||||
|
||||
num_features_threshold:
|
||||
If the number of features in the input is greater than this threshold, dropout is skipped. This is because
|
||||
gyro dropout mask size vram usage is proportional to the number of features in the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total_subnetworks: int,
|
||||
concurrent_subnetworks: int,
|
||||
dropout_probability: float = 0.5,
|
||||
iters_per_epoch: int = 1,
|
||||
num_features_threshold: float = 5e5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
iters_per_epoch >= total_subnetworks
|
||||
), "The number of iterations per epoch must be greater than the number of masks"
|
||||
self.dropout_probability = dropout_probability
|
||||
self.iters_per_epoch = iters_per_epoch
|
||||
self.total_subnetworks = total_subnetworks
|
||||
self.concurrent_subnetworks = concurrent_subnetworks
|
||||
self.scale = 1 / (1 - self.dropout_probability)
|
||||
self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks
|
||||
self.preselected_masks: Tensor | None = None
|
||||
self.dropout_mask = None
|
||||
self.training_step = 0
|
||||
self.num_features_threshold = num_features_threshold
|
||||
self.skip_high_num_features = False
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
return x
|
||||
if self.skip_high_num_features:
|
||||
return self.basic_dropout(x)
|
||||
if self.training_step == 0:
|
||||
num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1]
|
||||
if num_features > self.num_features_threshold:
|
||||
self.skip_high_num_features = True
|
||||
self.basic_dropout = Dropout(probability=self.dropout_probability)
|
||||
return self.basic_dropout(x)
|
||||
self.init_masks(x=x)
|
||||
|
||||
if self.training_step % self.mask_update_interval == 0:
|
||||
self.update_dropout_mask(x=x)
|
||||
|
||||
self.training_step += 1
|
||||
|
||||
return x * self.dropout_mask * self.scale
|
||||
|
||||
def init_masks(self, x: Tensor) -> None:
|
||||
if x.dim() == 2:
|
||||
self.preselected_masks = (
|
||||
rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability
|
||||
)
|
||||
if x.dim() == 3:
|
||||
self.preselected_masks = (
|
||||
rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability
|
||||
)
|
||||
|
||||
assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions"
|
||||
self.preselected_masks = self.preselected_masks.float()
|
||||
|
||||
def update_dropout_mask(self, x: Tensor) -> None:
|
||||
assert self.preselected_masks is not None
|
||||
indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device)
|
||||
selected_masks = self.preselected_masks[indices]
|
||||
|
||||
repeat_factor = x.shape[0] // self.concurrent_subnetworks
|
||||
remaining = x.shape[0] % self.concurrent_subnetworks
|
||||
repeated_masks = [selected_masks] * repeat_factor
|
||||
if remaining > 0:
|
||||
repeated_masks.append(selected_masks[:remaining])
|
||||
final_masks = cat(tensors=repeated_masks, dim=0)
|
||||
|
||||
if x.dim() == 2:
|
||||
self.dropout_mask = final_masks
|
||||
if x.dim() == 3:
|
||||
self.dropout_mask = final_masks.expand_as(x)
|
||||
|
||||
|
||||
class DropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
||||
def __init__(self, target: fl.Linear, probability: float = 0.5):
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target, Dropout(probability=probability))
|
||||
|
||||
|
||||
class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Linear,
|
||||
probability: float = 0.5,
|
||||
total_subnetworks: int = 512,
|
||||
concurrent_subnetworks: int = 64,
|
||||
iters_per_epoch: int = 512,
|
||||
num_features_threshold: float = 5e5,
|
||||
) -> None:
|
||||
self.probability = probability
|
||||
self.total_subnetworks = total_subnetworks
|
||||
self.concurrent_subnetworks = concurrent_subnetworks
|
||||
self.iters_per_epoch = iters_per_epoch
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
GyroDropout(
|
||||
total_subnetworks=total_subnetworks,
|
||||
concurrent_subnetworks=concurrent_subnetworks,
|
||||
dropout_probability=probability,
|
||||
iters_per_epoch=iters_per_epoch,
|
||||
num_features_threshold=num_features_threshold,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
|
||||
for linear, parent in module.walk(fl.Linear):
|
||||
if not linear.weight.requires_grad:
|
||||
continue
|
||||
assert not (
|
||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||
), f"{linear} already has a dropout layer"
|
||||
DropoutAdapter(target=linear, probability=probability).inject(parent)
|
||||
|
||||
|
||||
def apply_gyro_dropout(
|
||||
module: fl.Chain,
|
||||
probability: float = 0.5,
|
||||
total_subnetworks: int = 32,
|
||||
concurrent_subnetworks: int = 16,
|
||||
iters_per_epoch: int = 32,
|
||||
) -> None:
|
||||
for linear, parent in module.walk(fl.Linear):
|
||||
if not linear.weight.requires_grad:
|
||||
continue
|
||||
assert not (
|
||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||
), f"{linear} already has a dropout layer"
|
||||
GyroDropoutAdapter(
|
||||
target=linear,
|
||||
probability=probability,
|
||||
total_subnetworks=total_subnetworks,
|
||||
concurrent_subnetworks=concurrent_subnetworks,
|
||||
iters_per_epoch=iters_per_epoch,
|
||||
).inject(parent)
|
||||
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
|
||||
|
||||
|
||||
class DropoutCallback(Callback["Trainer[ConfigType, Any]"]):
|
||||
def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None:
|
||||
dropout_config = trainer.config.dropout
|
||||
chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)]
|
||||
for model in chain_models:
|
||||
dropout_config.apply_dropout(model=model)
|
52
src/refiners/training_utils/gradient_clipping.py
Normal file
52
src/refiners/training_utils/gradient_clipping.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
|
||||
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
||||
"""
|
||||
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
||||
for gradient in gradients:
|
||||
gradient.mul_(other=clip_coefficient) # type: ignore
|
||||
|
||||
|
||||
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
|
||||
"""
|
||||
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
for gradient in gradients:
|
||||
gradient.clamp_(min=-clip_value, max=clip_value)
|
||||
|
||||
|
||||
class GradientClippingConfig(CallbackConfig):
|
||||
clip_grad_norm: float | None = None
|
||||
clip_grad_value: float | None = None
|
||||
|
||||
|
||||
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def __init__(self, config: GradientClippingConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
clip_norm = self.config.clip_grad_norm
|
||||
if clip_norm is not None:
|
||||
clip_gradient_norm(
|
||||
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
||||
)
|
||||
|
||||
clip_value = self.config.clip_grad_value
|
||||
if clip_value is not None:
|
||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
|
@ -1,15 +1,12 @@
|
|||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, wraps
|
||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
|
||||
from torch import Tensor, device as Device, dtype as DType, nn
|
||||
from torch.autograd import backward
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import (
|
||||
CosineAnnealingLR,
|
||||
|
@ -27,73 +24,20 @@ from torch.optim.lr_scheduler import (
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.training_utils.callback import (
|
||||
Callback,
|
||||
ClockCallback,
|
||||
GradientNormClipping,
|
||||
GradientValueClipping,
|
||||
CallbackConfig,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
|
||||
from refiners.training_utils.dropout import DropoutCallback
|
||||
|
||||
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
||||
|
||||
|
||||
def count_learnable_parameters(parameters: Iterable[Parameter]) -> int:
|
||||
return sum(p.numel() for p in parameters if p.requires_grad)
|
||||
|
||||
|
||||
def human_readable_number(number: int) -> str:
|
||||
float_number = float(number)
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if abs(float_number) < 1000:
|
||||
return f"{float_number:.1f}{unit}"
|
||||
float_number /= 1000
|
||||
return f"{float_number:.1f}E"
|
||||
|
||||
|
||||
def seed_everything(seed: int | None = None) -> None:
|
||||
if seed is None:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Using random seed: {seed}")
|
||||
random.seed(a=seed)
|
||||
np.random.seed(seed=seed)
|
||||
manual_seed(seed=seed)
|
||||
cuda.manual_seed_all(seed=seed)
|
||||
|
||||
|
||||
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
||||
"""
|
||||
Decorator for setting a random seed within the scope of a function.
|
||||
|
||||
This decorator sets the random seed for Python's built-in `random` module,
|
||||
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
||||
function is executed, it restores the state of the random number generators
|
||||
to what it was before the function was called. This is useful for ensuring
|
||||
reproducibility for specific parts of the code without affecting randomness
|
||||
elsewhere.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
random_state = random.getstate()
|
||||
numpy_state = np.random.get_state()
|
||||
torch_state = get_rng_state()
|
||||
cuda_torch_state = cuda.get_rng_state()
|
||||
actual_seed = seed(*args) if callable(seed) else seed
|
||||
seed_everything(seed=actual_seed)
|
||||
result = func(*args, **kwargs)
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(numpy_state)
|
||||
set_rng_state(torch_state)
|
||||
cuda.set_rng_state(cuda_torch_state)
|
||||
return result
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return decorator
|
||||
from refiners.training_utils.clock import ClockConfig, TrainingClock
|
||||
from refiners.training_utils.common import (
|
||||
compute_grad_norm,
|
||||
count_learnable_parameters,
|
||||
human_readable_number,
|
||||
scoped_seed,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType
|
||||
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
||||
|
||||
|
||||
class WarmupScheduler(LRScheduler):
|
||||
|
@ -117,135 +61,6 @@ class WarmupScheduler(LRScheduler):
|
|||
self._step_count += 1
|
||||
|
||||
|
||||
class TrainingClock:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_length: int,
|
||||
batch_size: int,
|
||||
training_duration: TimeValue,
|
||||
gradient_accumulation: TimeValue,
|
||||
evaluation_interval: TimeValue,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
) -> None:
|
||||
self.dataset_length = dataset_length
|
||||
self.batch_size = batch_size
|
||||
self.training_duration = training_duration
|
||||
self.gradient_accumulation = gradient_accumulation
|
||||
self.evaluation_interval = evaluation_interval
|
||||
self.lr_scheduler_interval = lr_scheduler_interval
|
||||
self.num_batches_per_epoch = dataset_length // batch_size
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
self.num_batches_processed = 0
|
||||
self.num_minibatches_processed = 0
|
||||
self.loss: Tensor | None = None
|
||||
|
||||
@cached_property
|
||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
||||
return {
|
||||
TimeUnit.STEP: 1,
|
||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
||||
}
|
||||
|
||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||
return number * self.unit_to_steps[unit]
|
||||
|
||||
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||
return steps // self.unit_to_steps[unit]
|
||||
|
||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||
number, unit = time_value["number"], time_value["unit"]
|
||||
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||
|
||||
@cached_property
|
||||
def num_epochs(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
||||
|
||||
@cached_property
|
||||
def num_iterations(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
||||
|
||||
@cached_property
|
||||
def num_steps(self) -> int:
|
||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
||||
|
||||
@cached_property
|
||||
def num_step_per_iteration(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def num_step_per_evaluation(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
self.num_batches_processed = 0
|
||||
self.num_minibatches_processed = 0
|
||||
|
||||
def start_timer(self) -> None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def stop_timer(self) -> None:
|
||||
self.end_time = time.time()
|
||||
|
||||
@property
|
||||
def time_elapsed(self) -> int:
|
||||
assert self.start_time is not None, "Timer has not been started yet."
|
||||
return int(time.time() - self.start_time)
|
||||
|
||||
@cached_property
|
||||
def evaluation_interval_steps(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def lr_scheduler_interval_steps(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||
)
|
||||
|
||||
@property
|
||||
def is_optimizer_step(self) -> bool:
|
||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||
|
||||
@property
|
||||
def is_lr_scheduler_step(self) -> bool:
|
||||
return self.step % self.lr_scheduler_interval_steps == 0
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.step >= self.num_steps
|
||||
|
||||
@property
|
||||
def is_evaluation_step(self) -> bool:
|
||||
return self.step % self.evaluation_interval_steps == 0
|
||||
|
||||
|
||||
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
||||
"""
|
||||
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
||||
"""
|
||||
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to compute the norm."
|
||||
total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
||||
return total_norm # type: ignore
|
||||
|
||||
|
||||
Batch = TypeVar("Batch")
|
||||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||
|
||||
|
@ -267,44 +82,91 @@ class _Dataset(Dataset[Batch]):
|
|||
return self.length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelItem:
|
||||
name: str
|
||||
config: ModelConfig
|
||||
model: fl.Module
|
||||
learnable_parameters: list[nn.Parameter]
|
||||
|
||||
|
||||
ModelRegistry = dict[str, ModelItem]
|
||||
ModuleT = TypeVar("ModuleT", bound=fl.Module)
|
||||
|
||||
|
||||
def register_model():
|
||||
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
|
||||
@wraps(func)
|
||||
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
|
||||
name = func.__name__
|
||||
model = func(self, config)
|
||||
model = model.to(self.device, dtype=self.dtype)
|
||||
if config.requires_grad is not None:
|
||||
model.requires_grad_(requires_grad=config.requires_grad)
|
||||
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
|
||||
self.models[name] = ModelItem(
|
||||
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||
)
|
||||
setattr(self, name, self.models[name].model)
|
||||
return func(self, config)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
CallbackRegistry = dict[str, Callback[Any]]
|
||||
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
|
||||
|
||||
|
||||
def register_callback():
|
||||
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
|
||||
@wraps(func)
|
||||
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
|
||||
name = func.__name__
|
||||
callback = func(self, config)
|
||||
self.callbacks[name] = callback
|
||||
setattr(self, name, callback)
|
||||
return func(self, config)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Trainer(Generic[ConfigType, Batch], ABC):
|
||||
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
||||
def __init__(self, config: ConfigType) -> None:
|
||||
self._models: ModelRegistry = {}
|
||||
self._callbacks: CallbackRegistry = {}
|
||||
self.config = config
|
||||
self.clock = TrainingClock(
|
||||
dataset_length=self.dataset_length,
|
||||
batch_size=config.training.batch_size,
|
||||
training_duration=config.training.duration,
|
||||
evaluation_interval=config.training.evaluation_interval,
|
||||
gradient_accumulation=config.training.gradient_accumulation,
|
||||
lr_scheduler_interval=config.scheduler.update_interval,
|
||||
)
|
||||
self.callbacks = callbacks or []
|
||||
self.callbacks += self.default_callbacks()
|
||||
self._load_callbacks()
|
||||
self._call_callbacks(event_name="on_init_begin")
|
||||
self.load_models()
|
||||
self.prepare_models()
|
||||
self._load_models()
|
||||
self._call_callbacks(event_name="on_init_end")
|
||||
|
||||
def default_callbacks(self) -> list[Callback[Any]]:
|
||||
callbacks: list[Callback[Any]] = [
|
||||
ClockCallback(),
|
||||
GradientValueClipping(),
|
||||
GradientNormClipping(),
|
||||
DropoutCallback(),
|
||||
]
|
||||
@register_callback()
|
||||
def clock(self, config: ClockConfig) -> TrainingClock:
|
||||
return TrainingClock(
|
||||
dataset_length=self.dataset_length,
|
||||
batch_size=self.config.training.batch_size,
|
||||
training_duration=self.config.training.duration,
|
||||
evaluation_interval=self.config.training.evaluation_interval,
|
||||
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||
lr_scheduler_interval=self.config.scheduler.update_interval,
|
||||
verbose=config.verbose,
|
||||
)
|
||||
|
||||
# look for any Callback that might be a property of the Trainer
|
||||
for attr_name in dir(self):
|
||||
if "__" in attr_name:
|
||||
continue
|
||||
@register_callback()
|
||||
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
|
||||
return GradientClipping(config)
|
||||
|
||||
try:
|
||||
attr = getattr(self, attr_name)
|
||||
except AssertionError:
|
||||
continue
|
||||
if isinstance(attr, Callback):
|
||||
callbacks.append(cast(Callback[Any], attr))
|
||||
return callbacks
|
||||
@property
|
||||
def models(self) -> ModelRegistry:
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def callbacks(self) -> CallbackRegistry:
|
||||
return self._callbacks
|
||||
|
||||
@cached_property
|
||||
def device(self) -> Device:
|
||||
|
@ -320,38 +182,31 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
return dtype
|
||||
|
||||
@property
|
||||
def parameters(self) -> list[Parameter]:
|
||||
"""Returns a list of all parameters in all models"""
|
||||
return [param for model in self.models.values() for param in model.parameters()]
|
||||
|
||||
@property
|
||||
def learnable_parameters(self) -> list[Parameter]:
|
||||
def learnable_parameters(self) -> list[nn.Parameter]:
|
||||
"""Returns a list of learnable parameters in all models"""
|
||||
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
||||
return [param for item in self.models.values() for param in item.learnable_parameters]
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def optimizer_parameters(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Returns a list of `dict`-s containing the params and optimizer options for each model.
|
||||
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
|
||||
"""
|
||||
params: list[dict[str, Any]] = []
|
||||
for model_name, model in self.models.items():
|
||||
model_params = [param for param in model.parameters() if param.requires_grad]
|
||||
model_config = self.config.models[model_name]
|
||||
for item in self.models.values():
|
||||
config = item.config
|
||||
model_optim_conf: dict[str, Any] = {}
|
||||
|
||||
if model_config.learning_rate is not None:
|
||||
model_optim_conf["lr"] = model_config.learning_rate
|
||||
if model_config.weight_decay is not None:
|
||||
model_optim_conf["weight_decay"] = model_config.learning_rate
|
||||
if model_config.betas is not None:
|
||||
model_optim_conf["betas"] = model_config.learning_rate
|
||||
if model_config.eps is not None:
|
||||
model_optim_conf["eps"] = model_config.learning_rate
|
||||
if config.learning_rate is not None:
|
||||
model_optim_conf["lr"] = config.learning_rate
|
||||
if config.weight_decay is not None:
|
||||
model_optim_conf["weight_decay"] = config.learning_rate
|
||||
if config.betas is not None:
|
||||
model_optim_conf["betas"] = config.learning_rate
|
||||
if config.eps is not None:
|
||||
model_optim_conf["eps"] = config.learning_rate
|
||||
|
||||
for param in model_params:
|
||||
params.append({"params": param, **model_optim_conf})
|
||||
params.append({"params": item.learnable_parameters, **model_optim_conf})
|
||||
|
||||
return params
|
||||
|
||||
|
@ -363,17 +218,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
@property
|
||||
def gradients(self) -> list[Tensor]:
|
||||
"""Returns a list of detached gradients for all learnable parameters in all models"""
|
||||
return [
|
||||
param.grad.detach()
|
||||
for model in self.models.values()
|
||||
for param in model.parameters()
|
||||
if param.grad is not None
|
||||
]
|
||||
return [param.grad.detach() for param in self.learnable_parameters if param.grad is not None]
|
||||
|
||||
@property
|
||||
def total_gradient_norm(self) -> float:
|
||||
"""Returns the total gradient norm for all learnable parameters in all models"""
|
||||
return compute_grad_norm(parameters=self.parameters)
|
||||
return compute_grad_norm(parameters=self.learnable_parameters)
|
||||
|
||||
@cached_property
|
||||
def optimizer(self) -> Optimizer:
|
||||
|
@ -441,38 +291,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
|
||||
return lr_scheduler
|
||||
|
||||
@cached_property
|
||||
def models(self) -> dict[str, fl.Module]:
|
||||
return self.load_models()
|
||||
|
||||
def set_models_to_train_mode(self) -> None:
|
||||
for model in self.models.values():
|
||||
model.train()
|
||||
|
||||
def set_models_to_eval_mode(self) -> None:
|
||||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
def prepare_model(self, model_name: str) -> None:
|
||||
model = self.models[model_name]
|
||||
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
||||
model.load_from_safetensors(tensors_path=checkpoint)
|
||||
else:
|
||||
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
||||
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
|
||||
model.requires_grad_(requires_grad=requires_grad)
|
||||
model.to(self.device)
|
||||
model.zero_grad()
|
||||
|
||||
def prepare_models(self) -> None:
|
||||
assert self.models, "No models found."
|
||||
for model_name in self.models:
|
||||
self.prepare_model(model_name=model_name)
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_item(self, index: int) -> Batch:
|
||||
"""
|
||||
|
@ -563,7 +381,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
@scoped_seed(seed=get_training_seed)
|
||||
def train(self) -> None:
|
||||
"""Train the model."""
|
||||
self.set_models_to_train_mode()
|
||||
self.set_models_to_mode("train")
|
||||
self._call_callbacks(event_name="on_train_begin")
|
||||
assert self.learnable_parameters, "There are no learnable parameters in the models."
|
||||
self.evaluate()
|
||||
|
@ -581,12 +399,43 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
@scoped_seed(seed=get_evaluation_seed)
|
||||
def evaluate(self) -> None:
|
||||
"""Evaluate the model."""
|
||||
self.set_models_to_eval_mode()
|
||||
self.set_models_to_mode(mode="eval")
|
||||
self._call_callbacks(event_name="on_evaluate_begin")
|
||||
self.compute_evaluation()
|
||||
self._call_callbacks(event_name="on_evaluate_end")
|
||||
self.set_models_to_train_mode()
|
||||
self.set_models_to_mode(mode="train")
|
||||
|
||||
def set_models_to_mode(self, mode: Literal["train", "eval"]) -> None:
|
||||
for item in self.models.values():
|
||||
if mode == "train":
|
||||
item.model.train()
|
||||
elif mode == "eval":
|
||||
item.model.eval()
|
||||
|
||||
def _call_callbacks(self, event_name: str) -> None:
|
||||
for callback in self.callbacks:
|
||||
for callback in self.callbacks.values():
|
||||
getattr(callback, event_name)(self)
|
||||
|
||||
def _load_callbacks(self) -> None:
|
||||
for name, config in self.config:
|
||||
if not isinstance(config, CallbackConfig):
|
||||
continue
|
||||
try:
|
||||
registered_callback = getattr(self, name)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
f"Callback {name} is in the config but not registered in the Trainer. Create a method with the @register_callback decorator."
|
||||
)
|
||||
assert callable(registered_callback)
|
||||
registered_callback(config)
|
||||
|
||||
def _load_models(self) -> None:
|
||||
for name, config in self.config.models.items():
|
||||
try:
|
||||
registered_model = getattr(self, name)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
f"Model {name} is in the config but not registered in the Trainer. Create a method with the @register_model decorator."
|
||||
)
|
||||
assert callable(registered_model)
|
||||
registered_model(config)
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
from abc import ABC
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import wandb
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
from refiners.training_utils.trainer import Trainer, register_callback
|
||||
|
||||
number = float | int
|
||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||
|
@ -64,7 +62,7 @@ class WandbLogger:
|
|||
return self.wandb_run.name or "" # type: ignore
|
||||
|
||||
|
||||
class WandbConfig(BaseModel):
|
||||
class WandbConfig(CallbackConfig):
|
||||
"""
|
||||
Wandb configuration.
|
||||
|
||||
|
@ -87,18 +85,16 @@ class WandbConfig(BaseModel):
|
|||
anonymous: Literal["never", "allow", "must"] | None = None
|
||||
id: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
AnyTrainer = Trainer[BaseConfig, Any]
|
||||
|
||||
|
||||
class WandbCallback(Callback["TrainerWithWandb"]):
|
||||
epoch_losses: list[float]
|
||||
iteration_losses: list[float]
|
||||
|
||||
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.load_wandb()
|
||||
def __init__(self, config: WandbConfig, /, trainer_config: dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.epoch_losses: list[float] = []
|
||||
self.iteration_losses: list[float] = []
|
||||
self.logger = WandbLogger({**config.model_dump(), "config": trainer_config})
|
||||
|
||||
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||
self.epoch_losses = []
|
||||
|
@ -131,19 +127,13 @@ class WandbMixin(ABC):
|
|||
config: Any
|
||||
wandb_logger: WandbLogger
|
||||
|
||||
def load_wandb(self) -> None:
|
||||
wandb_config = getattr(self.config, "wandb", None)
|
||||
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
|
||||
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
|
||||
self.wandb_logger = WandbLogger(init_config=init_config)
|
||||
@register_callback()
|
||||
def wandb(self, config: WandbConfig) -> WandbCallback:
|
||||
return WandbCallback(config, trainer_config=self.config.model_dump())
|
||||
|
||||
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
|
||||
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
|
||||
self.wandb_logger.log(data=data, step=self.clock.step)
|
||||
|
||||
@cached_property
|
||||
def wandb_callback(self) -> WandbCallback:
|
||||
return WandbCallback()
|
||||
self.wandb.logger.log(data=data, step=self.clock.step)
|
||||
|
||||
|
||||
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
[models.mock_model]
|
||||
requires_grad = true
|
||||
|
||||
[clock]
|
||||
verbose = false
|
||||
|
||||
[gradient_clipping]
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
|
@ -9,7 +14,6 @@ device = "cpu"
|
|||
dtype = "float32"
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
|
||||
|
@ -21,6 +25,3 @@ learning_rate = 1
|
|||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
warmup = "20:step"
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0.0
|
||||
|
|
|
@ -5,12 +5,17 @@ learning_rate = 1e-5
|
|||
[models.mock_model2]
|
||||
requires_grad = true
|
||||
|
||||
[clock]
|
||||
verbose = false
|
||||
|
||||
[gradient_clipping]
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
|
||||
|
@ -21,6 +26,3 @@ learning_rate = 1
|
|||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0.0
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
|
@ -10,13 +9,15 @@ from torch.optim import SGD
|
|||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import norm
|
||||
from refiners.training_utils.config import BaseConfig, TimeUnit
|
||||
from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number
|
||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||
from refiners.training_utils.trainer import (
|
||||
Trainer,
|
||||
TrainingClock,
|
||||
WarmupScheduler,
|
||||
count_learnable_parameters,
|
||||
human_readable_number,
|
||||
register_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -55,13 +56,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
targets=torch.cat([b.targets for b in batch]),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def mock_model(self) -> MockModel:
|
||||
@register_model()
|
||||
def mock_model(self, config: ModelConfig) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
return {"mock_model": self.mock_model}
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
self.step_counter += 1
|
||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||
|
@ -217,17 +215,14 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
|||
|
||||
|
||||
class MockTrainerWith2Models(MockTrainer):
|
||||
@cached_property
|
||||
def mock_model1(self) -> MockModel:
|
||||
@register_model()
|
||||
def mock_model1(self, config: ModelConfig) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
@cached_property
|
||||
def mock_model2(self) -> MockModel:
|
||||
@register_model()
|
||||
def mock_model2(self, config: ModelConfig) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
self.step_counter += 1
|
||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||
|
@ -246,7 +241,5 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M
|
|||
|
||||
|
||||
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
||||
assert (
|
||||
len(mock_trainer_2_models.optimizer.param_groups) == 12
|
||||
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
|
||||
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
|
||||
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
||||
|
|
Loading…
Reference in a new issue