mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
add @register_model and @register_callback decorators
Refactor ClockTrainer to include Callback
This commit is contained in:
parent
f541badcb3
commit
d6546c9026
|
@ -4,8 +4,20 @@ from importlib.metadata import requires
|
||||||
|
|
||||||
from packaging.requirements import Requirement
|
from packaging.requirements import Requirement
|
||||||
|
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.clock import ClockConfig
|
||||||
|
from refiners.training_utils.config import (
|
||||||
|
BaseConfig,
|
||||||
|
ModelConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
Optimizers,
|
||||||
|
SchedulerConfig,
|
||||||
|
SchedulerType,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||||
|
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||||
|
|
||||||
refiners_requires = requires("refiners")
|
refiners_requires = requires("refiners")
|
||||||
assert refiners_requires is not None
|
assert refiners_requires is not None
|
||||||
|
@ -29,4 +41,18 @@ for dep in refiners_requires:
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"BaseConfig",
|
"BaseConfig",
|
||||||
|
"ModelConfig",
|
||||||
|
"register_callback",
|
||||||
|
"register_model",
|
||||||
|
"Callback",
|
||||||
|
"CallbackConfig",
|
||||||
|
"WandbMixin",
|
||||||
|
"WandbConfig",
|
||||||
|
"SchedulerConfig",
|
||||||
|
"OptimizerConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"ClockConfig",
|
||||||
|
"GradientClippingConfig",
|
||||||
|
"Optimizers",
|
||||||
|
"SchedulerType",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,43 +1,22 @@
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
from loguru import logger
|
from pydantic import BaseModel, ConfigDict
|
||||||
from torch import tensor
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
__all__ = [
|
T = TypeVar("T", bound="Trainer[BaseConfig, Any]")
|
||||||
"Callback",
|
|
||||||
"GradientNormClipping",
|
|
||||||
"GradientValueClipping",
|
|
||||||
"ClockCallback",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
class CallbackConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
Base configuration for a callback.
|
||||||
|
|
||||||
|
For your callback to be properly configured, you should inherit from this class and add your own configuration.
|
||||||
"""
|
"""
|
||||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
|
||||||
assert gradients, "The model has no gradients to clip."
|
|
||||||
clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
|
||||||
for gradient in gradients:
|
|
||||||
gradient.mul_(other=clip_coefficient) # type: ignore
|
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None:
|
|
||||||
"""
|
|
||||||
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
|
||||||
"""
|
|
||||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
|
||||||
assert gradients, "The model has no gradients to clip."
|
|
||||||
for gradient in gradients:
|
|
||||||
gradient.clamp_(min=-clip_value, max=clip_value)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
|
@ -97,71 +76,3 @@ class Callback(Generic[T]):
|
||||||
|
|
||||||
def on_checkpoint_save(self, trainer: T) -> None:
|
def on_checkpoint_save(self, trainer: T) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.clock.reset()
|
|
||||||
logger.info(
|
|
||||||
(
|
|
||||||
"Starting training for a total of: "
|
|
||||||
f"{trainer.clock.num_steps} steps, "
|
|
||||||
f"{trainer.clock.num_epochs} epochs, "
|
|
||||||
f"{trainer.clock.num_iterations} iterations."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
trainer.clock.start_timer()
|
|
||||||
|
|
||||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.clock.stop_timer()
|
|
||||||
logger.info(
|
|
||||||
(
|
|
||||||
"Training took: "
|
|
||||||
f"{trainer.clock.time_elapsed} seconds, "
|
|
||||||
f"{trainer.clock.iteration} iterations, "
|
|
||||||
f"{trainer.clock.epoch} epochs, "
|
|
||||||
f"{trainer.clock.step} steps."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
logger.info(f"Epoch {trainer.clock.epoch} started.")
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.clock.epoch += 1
|
|
||||||
trainer.clock.num_batches_processed = 0
|
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
logger.info(f"Step {trainer.clock.step} started.")
|
|
||||||
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.clock.step += 1
|
|
||||||
trainer.clock.num_batches_processed += 1
|
|
||||||
trainer.clock.num_minibatches_processed += 1
|
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
logger.info(f"Iteration {trainer.clock.iteration} ended.")
|
|
||||||
trainer.clock.iteration += 1
|
|
||||||
trainer.clock.num_minibatches_processed = 0
|
|
||||||
|
|
||||||
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
logger.info("Evaluation started.")
|
|
||||||
|
|
||||||
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
logger.info("Evaluation ended.")
|
|
||||||
|
|
||||||
|
|
||||||
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
clip_norm = trainer.config.training.clip_grad_norm
|
|
||||||
if clip_norm is not None:
|
|
||||||
clip_gradient_norm(
|
|
||||||
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
clip_value = trainer.config.training.clip_grad_value
|
|
||||||
if clip_value is not None:
|
|
||||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
|
||||||
|
|
193
src/refiners/training_utils/clock.py
Normal file
193
src/refiners/training_utils/clock.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
import time
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
|
from refiners.training_utils.common import TimeUnit, TimeValue
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class ClockConfig(CallbackConfig):
|
||||||
|
verbose: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_length: int,
|
||||||
|
batch_size: int,
|
||||||
|
training_duration: TimeValue,
|
||||||
|
gradient_accumulation: TimeValue,
|
||||||
|
evaluation_interval: TimeValue,
|
||||||
|
lr_scheduler_interval: TimeValue,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.dataset_length = dataset_length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.training_duration = training_duration
|
||||||
|
self.gradient_accumulation = gradient_accumulation
|
||||||
|
self.evaluation_interval = evaluation_interval
|
||||||
|
self.lr_scheduler_interval = lr_scheduler_interval
|
||||||
|
self.verbose = verbose
|
||||||
|
self.num_batches_per_epoch = dataset_length // batch_size
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.step = 0
|
||||||
|
self.epoch = 0
|
||||||
|
self.iteration = 0
|
||||||
|
self.num_batches_processed = 0
|
||||||
|
self.num_minibatches_processed = 0
|
||||||
|
self.loss: Tensor | None = None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||||
|
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
||||||
|
return {
|
||||||
|
TimeUnit.STEP: 1,
|
||||||
|
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||||
|
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||||
|
return number * self.unit_to_steps[unit]
|
||||||
|
|
||||||
|
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||||
|
return steps // self.unit_to_steps[unit]
|
||||||
|
|
||||||
|
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||||
|
number, unit = time_value["number"], time_value["unit"]
|
||||||
|
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
||||||
|
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_epochs(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_iterations(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_steps(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_step_per_iteration(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_step_per_evaluation(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.step = 0
|
||||||
|
self.epoch = 0
|
||||||
|
self.iteration = 0
|
||||||
|
self.num_batches_processed = 0
|
||||||
|
self.num_minibatches_processed = 0
|
||||||
|
|
||||||
|
def start_timer(self) -> None:
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def stop_timer(self) -> None:
|
||||||
|
self.end_time = time.time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_elapsed(self) -> int:
|
||||||
|
assert self.start_time is not None, "Timer has not been started yet."
|
||||||
|
return int(time.time() - self.start_time)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def evaluation_interval_steps(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def lr_scheduler_interval_steps(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_optimizer_step(self) -> bool:
|
||||||
|
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_lr_scheduler_step(self) -> bool:
|
||||||
|
return self.step % self.lr_scheduler_interval_steps == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
return self.step >= self.num_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_evaluation_step(self) -> bool:
|
||||||
|
return self.step % self.evaluation_interval_steps == 0
|
||||||
|
|
||||||
|
def log(self, message: str, /) -> None:
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.reset()
|
||||||
|
self.log(
|
||||||
|
(
|
||||||
|
"Starting training for a total of: "
|
||||||
|
f"{trainer.clock.num_steps} steps, "
|
||||||
|
f"{trainer.clock.num_epochs} epochs, "
|
||||||
|
f"{trainer.clock.num_iterations} iterations."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
trainer.clock.start_timer()
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.stop_timer()
|
||||||
|
self.log(
|
||||||
|
(
|
||||||
|
"Training took: "
|
||||||
|
f"{trainer.clock.time_elapsed} seconds, "
|
||||||
|
f"{trainer.clock.iteration} iterations, "
|
||||||
|
f"{trainer.clock.epoch} epochs, "
|
||||||
|
f"{trainer.clock.step} steps."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log(f"Epoch {trainer.clock.epoch} started.")
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.epoch += 1
|
||||||
|
trainer.clock.num_batches_processed = 0
|
||||||
|
|
||||||
|
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log(f"Step {trainer.clock.step} started.")
|
||||||
|
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.step += 1
|
||||||
|
trainer.clock.num_batches_processed += 1
|
||||||
|
trainer.clock.num_minibatches_processed += 1
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log(f"Iteration {trainer.clock.iteration} ended.")
|
||||||
|
trainer.clock.iteration += 1
|
||||||
|
trainer.clock.num_minibatches_processed = 0
|
||||||
|
|
||||||
|
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log("Evaluation started.")
|
||||||
|
|
||||||
|
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.log("Evaluation ended.")
|
103
src/refiners/training_utils/common.py
Normal file
103
src/refiners/training_utils/common.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
import random
|
||||||
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from torch import Tensor, cuda, nn
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
|
||||||
|
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
|
||||||
|
"""
|
||||||
|
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
||||||
|
"""
|
||||||
|
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to compute the norm."
|
||||||
|
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
||||||
|
return total_norm # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
|
||||||
|
return sum(p.numel() for p in parameters if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def human_readable_number(number: int) -> str:
|
||||||
|
float_number = float(number)
|
||||||
|
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||||
|
if abs(float_number) < 1000:
|
||||||
|
return f"{float_number:.1f}{unit}"
|
||||||
|
float_number /= 1000
|
||||||
|
return f"{float_number:.1f}E"
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int | None = None) -> None:
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
logger.info(f"Using random seed: {seed}")
|
||||||
|
random.seed(a=seed)
|
||||||
|
np.random.seed(seed=seed)
|
||||||
|
manual_seed(seed=seed)
|
||||||
|
cuda.manual_seed_all(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
|
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
||||||
|
"""
|
||||||
|
Decorator for setting a random seed within the scope of a function.
|
||||||
|
|
||||||
|
This decorator sets the random seed for Python's built-in `random` module,
|
||||||
|
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
||||||
|
function is executed, it restores the state of the random number generators
|
||||||
|
to what it was before the function was called. This is useful for ensuring
|
||||||
|
reproducibility for specific parts of the code without affecting randomness
|
||||||
|
elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
@wraps(func)
|
||||||
|
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
random_state = random.getstate()
|
||||||
|
numpy_state = np.random.get_state()
|
||||||
|
torch_state = torch.get_rng_state()
|
||||||
|
cuda_torch_state = cuda.get_rng_state()
|
||||||
|
actual_seed = seed(*args) if callable(seed) else seed
|
||||||
|
seed_everything(seed=actual_seed)
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
random.setstate(random_state)
|
||||||
|
np.random.set_state(numpy_state)
|
||||||
|
torch.set_rng_state(torch_state)
|
||||||
|
cuda.set_rng_state(cuda_torch_state)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return inner_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class TimeUnit(Enum):
|
||||||
|
STEP = "step"
|
||||||
|
EPOCH = "epoch"
|
||||||
|
ITERATION = "iteration"
|
||||||
|
DEFAULT = "step"
|
||||||
|
|
||||||
|
|
||||||
|
class TimeValue(TypedDict):
|
||||||
|
number: int
|
||||||
|
unit: TimeUnit
|
||||||
|
|
||||||
|
|
||||||
|
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
||||||
|
match value:
|
||||||
|
case str(value_str):
|
||||||
|
number, unit = value_str.split(sep=":")
|
||||||
|
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
||||||
|
case int(number):
|
||||||
|
return {"number": number, "unit": TimeUnit.DEFAULT}
|
||||||
|
case {"number": int(number), "unit": str(unit)}:
|
||||||
|
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported value format: {value}")
|
|
@ -9,50 +9,16 @@ from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, ConfigDict, validator
|
from pydantic import BaseModel, ConfigDict, validator
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
||||||
|
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||||
|
|
||||||
# PyTorch optimizer parameters type
|
# PyTorch optimizer parameters type
|
||||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||||
# See https://github.com/pytorch/pytorch/pull/111114
|
# See https://github.com/pytorch/pytorch/pull/111114
|
||||||
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"parse_number_unit_field",
|
|
||||||
"TimeUnit",
|
|
||||||
"TimeValue",
|
|
||||||
"TrainingConfig",
|
|
||||||
"OptimizerConfig",
|
|
||||||
"Optimizers",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TimeUnit(Enum):
|
|
||||||
STEP = "step"
|
|
||||||
EPOCH = "epoch"
|
|
||||||
ITERATION = "iteration"
|
|
||||||
DEFAULT = "step"
|
|
||||||
|
|
||||||
|
|
||||||
class TimeValue(TypedDict):
|
|
||||||
number: int
|
|
||||||
unit: TimeUnit
|
|
||||||
|
|
||||||
|
|
||||||
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
|
||||||
match value:
|
|
||||||
case str(value_str):
|
|
||||||
number, unit = value_str.split(sep=":")
|
|
||||||
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
|
||||||
case int(number):
|
|
||||||
return {"number": number, "unit": TimeUnit.DEFAULT}
|
|
||||||
case {"number": int(number), "unit": str(unit)}:
|
|
||||||
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported value format: {value}")
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
@ -61,8 +27,6 @@ class TrainingConfig(BaseModel):
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
||||||
clip_grad_norm: float | None = None
|
|
||||||
clip_grad_value: float | None = None
|
|
||||||
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
evaluation_seed: int = 0
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
|
@ -195,29 +159,6 @@ class ModelConfig(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
class GyroDropoutConfig(BaseModel):
|
|
||||||
total_subnetworks: int = 512
|
|
||||||
concurrent_subnetworks: int = 64
|
|
||||||
iters_per_epoch: int = 512
|
|
||||||
num_features_threshold: float = 5e5
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
|
|
||||||
|
|
||||||
class DropoutConfig(BaseModel):
|
|
||||||
dropout_probability: float = 0.0
|
|
||||||
gyro_dropout: GyroDropoutConfig | None = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
|
|
||||||
def apply_dropout(self, model: fl.Chain) -> None:
|
|
||||||
if self.dropout_probability > 0.0:
|
|
||||||
if self.gyro_dropout is not None:
|
|
||||||
apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump())
|
|
||||||
else:
|
|
||||||
apply_dropout(module=model, probability=self.dropout_probability)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseConfig")
|
T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -226,7 +167,8 @@ class BaseConfig(BaseModel):
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
scheduler: SchedulerConfig
|
scheduler: SchedulerConfig
|
||||||
dropout: DropoutConfig
|
clock: ClockConfig = ClockConfig()
|
||||||
|
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
|
@ -1,200 +0,0 @@
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar
|
|
||||||
|
|
||||||
from torch import Tensor, cat, rand, randint
|
|
||||||
from torch.nn import Dropout as TorchDropout
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
|
||||||
from refiners.training_utils.callback import Callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from refiners.training_utils.config import BaseConfig
|
|
||||||
from refiners.training_utils.trainer import Trainer
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Dropout", "GyroDropout", "DropoutCallback"]
|
|
||||||
|
|
||||||
|
|
||||||
class Dropout(TorchDropout, fl.Module):
|
|
||||||
def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
|
|
||||||
super().__init__(p=probability, inplace=inplace)
|
|
||||||
|
|
||||||
|
|
||||||
class GyroDropout(fl.Module):
|
|
||||||
"""
|
|
||||||
GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training.
|
|
||||||
It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training.
|
|
||||||
This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional
|
|
||||||
dropout.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
total_subnetworks:
|
|
||||||
The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks
|
|
||||||
that are precomputed and stored.
|
|
||||||
|
|
||||||
concurrent_subnetworks:
|
|
||||||
The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of
|
|
||||||
masks from the precomputed set is used to dropout different portions of the input.
|
|
||||||
|
|
||||||
dropout_probability: float, optional (default=0.5)
|
|
||||||
The probability that an element will be zeroed by the dropout.
|
|
||||||
|
|
||||||
iters_per_epoch:
|
|
||||||
Number of iterations per epoch, used to determine how often the masks should be updated.
|
|
||||||
|
|
||||||
num_features_threshold:
|
|
||||||
If the number of features in the input is greater than this threshold, dropout is skipped. This is because
|
|
||||||
gyro dropout mask size vram usage is proportional to the number of features in the input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
total_subnetworks: int,
|
|
||||||
concurrent_subnetworks: int,
|
|
||||||
dropout_probability: float = 0.5,
|
|
||||||
iters_per_epoch: int = 1,
|
|
||||||
num_features_threshold: float = 5e5,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
assert (
|
|
||||||
iters_per_epoch >= total_subnetworks
|
|
||||||
), "The number of iterations per epoch must be greater than the number of masks"
|
|
||||||
self.dropout_probability = dropout_probability
|
|
||||||
self.iters_per_epoch = iters_per_epoch
|
|
||||||
self.total_subnetworks = total_subnetworks
|
|
||||||
self.concurrent_subnetworks = concurrent_subnetworks
|
|
||||||
self.scale = 1 / (1 - self.dropout_probability)
|
|
||||||
self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks
|
|
||||||
self.preselected_masks: Tensor | None = None
|
|
||||||
self.dropout_mask = None
|
|
||||||
self.training_step = 0
|
|
||||||
self.num_features_threshold = num_features_threshold
|
|
||||||
self.skip_high_num_features = False
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if not self.training:
|
|
||||||
return x
|
|
||||||
if self.skip_high_num_features:
|
|
||||||
return self.basic_dropout(x)
|
|
||||||
if self.training_step == 0:
|
|
||||||
num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1]
|
|
||||||
if num_features > self.num_features_threshold:
|
|
||||||
self.skip_high_num_features = True
|
|
||||||
self.basic_dropout = Dropout(probability=self.dropout_probability)
|
|
||||||
return self.basic_dropout(x)
|
|
||||||
self.init_masks(x=x)
|
|
||||||
|
|
||||||
if self.training_step % self.mask_update_interval == 0:
|
|
||||||
self.update_dropout_mask(x=x)
|
|
||||||
|
|
||||||
self.training_step += 1
|
|
||||||
|
|
||||||
return x * self.dropout_mask * self.scale
|
|
||||||
|
|
||||||
def init_masks(self, x: Tensor) -> None:
|
|
||||||
if x.dim() == 2:
|
|
||||||
self.preselected_masks = (
|
|
||||||
rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability
|
|
||||||
)
|
|
||||||
if x.dim() == 3:
|
|
||||||
self.preselected_masks = (
|
|
||||||
rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions"
|
|
||||||
self.preselected_masks = self.preselected_masks.float()
|
|
||||||
|
|
||||||
def update_dropout_mask(self, x: Tensor) -> None:
|
|
||||||
assert self.preselected_masks is not None
|
|
||||||
indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device)
|
|
||||||
selected_masks = self.preselected_masks[indices]
|
|
||||||
|
|
||||||
repeat_factor = x.shape[0] // self.concurrent_subnetworks
|
|
||||||
remaining = x.shape[0] % self.concurrent_subnetworks
|
|
||||||
repeated_masks = [selected_masks] * repeat_factor
|
|
||||||
if remaining > 0:
|
|
||||||
repeated_masks.append(selected_masks[:remaining])
|
|
||||||
final_masks = cat(tensors=repeated_masks, dim=0)
|
|
||||||
|
|
||||||
if x.dim() == 2:
|
|
||||||
self.dropout_mask = final_masks
|
|
||||||
if x.dim() == 3:
|
|
||||||
self.dropout_mask = final_masks.expand_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
|
||||||
def __init__(self, target: fl.Linear, probability: float = 0.5):
|
|
||||||
with self.setup_adapter(target):
|
|
||||||
super().__init__(target, Dropout(probability=probability))
|
|
||||||
|
|
||||||
|
|
||||||
class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
target: fl.Linear,
|
|
||||||
probability: float = 0.5,
|
|
||||||
total_subnetworks: int = 512,
|
|
||||||
concurrent_subnetworks: int = 64,
|
|
||||||
iters_per_epoch: int = 512,
|
|
||||||
num_features_threshold: float = 5e5,
|
|
||||||
) -> None:
|
|
||||||
self.probability = probability
|
|
||||||
self.total_subnetworks = total_subnetworks
|
|
||||||
self.concurrent_subnetworks = concurrent_subnetworks
|
|
||||||
self.iters_per_epoch = iters_per_epoch
|
|
||||||
|
|
||||||
with self.setup_adapter(target):
|
|
||||||
super().__init__(
|
|
||||||
target,
|
|
||||||
GyroDropout(
|
|
||||||
total_subnetworks=total_subnetworks,
|
|
||||||
concurrent_subnetworks=concurrent_subnetworks,
|
|
||||||
dropout_probability=probability,
|
|
||||||
iters_per_epoch=iters_per_epoch,
|
|
||||||
num_features_threshold=num_features_threshold,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
|
|
||||||
for linear, parent in module.walk(fl.Linear):
|
|
||||||
if not linear.weight.requires_grad:
|
|
||||||
continue
|
|
||||||
assert not (
|
|
||||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
|
||||||
), f"{linear} already has a dropout layer"
|
|
||||||
DropoutAdapter(target=linear, probability=probability).inject(parent)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_gyro_dropout(
|
|
||||||
module: fl.Chain,
|
|
||||||
probability: float = 0.5,
|
|
||||||
total_subnetworks: int = 32,
|
|
||||||
concurrent_subnetworks: int = 16,
|
|
||||||
iters_per_epoch: int = 32,
|
|
||||||
) -> None:
|
|
||||||
for linear, parent in module.walk(fl.Linear):
|
|
||||||
if not linear.weight.requires_grad:
|
|
||||||
continue
|
|
||||||
assert not (
|
|
||||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
|
||||||
), f"{linear} already has a dropout layer"
|
|
||||||
GyroDropoutAdapter(
|
|
||||||
target=linear,
|
|
||||||
probability=probability,
|
|
||||||
total_subnetworks=total_subnetworks,
|
|
||||||
concurrent_subnetworks=concurrent_subnetworks,
|
|
||||||
iters_per_epoch=iters_per_epoch,
|
|
||||||
).inject(parent)
|
|
||||||
|
|
||||||
|
|
||||||
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
|
|
||||||
|
|
||||||
|
|
||||||
class DropoutCallback(Callback["Trainer[ConfigType, Any]"]):
|
|
||||||
def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None:
|
|
||||||
dropout_config = trainer.config.dropout
|
|
||||||
chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)]
|
|
||||||
for model in chain_models:
|
|
||||||
dropout_config.apply_dropout(model=model)
|
|
52
src/refiners/training_utils/gradient_clipping.py
Normal file
52
src/refiners/training_utils/gradient_clipping.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
||||||
|
"""
|
||||||
|
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to clip."
|
||||||
|
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
||||||
|
for gradient in gradients:
|
||||||
|
gradient.mul_(other=clip_coefficient) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
|
||||||
|
"""
|
||||||
|
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
||||||
|
"""
|
||||||
|
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to clip."
|
||||||
|
for gradient in gradients:
|
||||||
|
gradient.clamp_(min=-clip_value, max=clip_value)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientClippingConfig(CallbackConfig):
|
||||||
|
clip_grad_norm: float | None = None
|
||||||
|
clip_grad_value: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def __init__(self, config: GradientClippingConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
clip_norm = self.config.clip_grad_norm
|
||||||
|
if clip_norm is not None:
|
||||||
|
clip_gradient_norm(
|
||||||
|
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_value = self.config.clip_grad_value
|
||||||
|
if clip_value is not None:
|
||||||
|
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
|
@ -1,15 +1,12 @@
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
from torch.autograd import backward
|
from torch.autograd import backward
|
||||||
from torch.nn import Parameter
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import (
|
from torch.optim.lr_scheduler import (
|
||||||
CosineAnnealingLR,
|
CosineAnnealingLR,
|
||||||
|
@ -27,73 +24,20 @@ from torch.optim.lr_scheduler import (
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.training_utils.callback import (
|
from refiners.training_utils.callback import (
|
||||||
Callback,
|
Callback,
|
||||||
ClockCallback,
|
CallbackConfig,
|
||||||
GradientNormClipping,
|
|
||||||
GradientValueClipping,
|
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
|
from refiners.training_utils.clock import ClockConfig, TrainingClock
|
||||||
from refiners.training_utils.dropout import DropoutCallback
|
from refiners.training_utils.common import (
|
||||||
|
compute_grad_norm,
|
||||||
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
count_learnable_parameters,
|
||||||
|
human_readable_number,
|
||||||
|
scoped_seed,
|
||||||
def count_learnable_parameters(parameters: Iterable[Parameter]) -> int:
|
)
|
||||||
return sum(p.numel() for p in parameters if p.requires_grad)
|
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType
|
||||||
|
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
||||||
|
|
||||||
def human_readable_number(number: int) -> str:
|
|
||||||
float_number = float(number)
|
|
||||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
|
||||||
if abs(float_number) < 1000:
|
|
||||||
return f"{float_number:.1f}{unit}"
|
|
||||||
float_number /= 1000
|
|
||||||
return f"{float_number:.1f}E"
|
|
||||||
|
|
||||||
|
|
||||||
def seed_everything(seed: int | None = None) -> None:
|
|
||||||
if seed is None:
|
|
||||||
seed = random.randint(0, 2**32 - 1)
|
|
||||||
logger.info(f"Using random seed: {seed}")
|
|
||||||
random.seed(a=seed)
|
|
||||||
np.random.seed(seed=seed)
|
|
||||||
manual_seed(seed=seed)
|
|
||||||
cuda.manual_seed_all(seed=seed)
|
|
||||||
|
|
||||||
|
|
||||||
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
|
||||||
"""
|
|
||||||
Decorator for setting a random seed within the scope of a function.
|
|
||||||
|
|
||||||
This decorator sets the random seed for Python's built-in `random` module,
|
|
||||||
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
|
||||||
function is executed, it restores the state of the random number generators
|
|
||||||
to what it was before the function was called. This is useful for ensuring
|
|
||||||
reproducibility for specific parts of the code without affecting randomness
|
|
||||||
elsewhere.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
||||||
@wraps(func)
|
|
||||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
random_state = random.getstate()
|
|
||||||
numpy_state = np.random.get_state()
|
|
||||||
torch_state = get_rng_state()
|
|
||||||
cuda_torch_state = cuda.get_rng_state()
|
|
||||||
actual_seed = seed(*args) if callable(seed) else seed
|
|
||||||
seed_everything(seed=actual_seed)
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
random.setstate(random_state)
|
|
||||||
np.random.set_state(numpy_state)
|
|
||||||
set_rng_state(torch_state)
|
|
||||||
cuda.set_rng_state(cuda_torch_state)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return inner_wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupScheduler(LRScheduler):
|
class WarmupScheduler(LRScheduler):
|
||||||
|
@ -117,135 +61,6 @@ class WarmupScheduler(LRScheduler):
|
||||||
self._step_count += 1
|
self._step_count += 1
|
||||||
|
|
||||||
|
|
||||||
class TrainingClock:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataset_length: int,
|
|
||||||
batch_size: int,
|
|
||||||
training_duration: TimeValue,
|
|
||||||
gradient_accumulation: TimeValue,
|
|
||||||
evaluation_interval: TimeValue,
|
|
||||||
lr_scheduler_interval: TimeValue,
|
|
||||||
) -> None:
|
|
||||||
self.dataset_length = dataset_length
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.training_duration = training_duration
|
|
||||||
self.gradient_accumulation = gradient_accumulation
|
|
||||||
self.evaluation_interval = evaluation_interval
|
|
||||||
self.lr_scheduler_interval = lr_scheduler_interval
|
|
||||||
self.num_batches_per_epoch = dataset_length // batch_size
|
|
||||||
self.start_time = None
|
|
||||||
self.end_time = None
|
|
||||||
self.step = 0
|
|
||||||
self.epoch = 0
|
|
||||||
self.iteration = 0
|
|
||||||
self.num_batches_processed = 0
|
|
||||||
self.num_minibatches_processed = 0
|
|
||||||
self.loss: Tensor | None = None
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
|
||||||
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
|
||||||
return {
|
|
||||||
TimeUnit.STEP: 1,
|
|
||||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
|
||||||
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
|
||||||
}
|
|
||||||
|
|
||||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
|
||||||
return number * self.unit_to_steps[unit]
|
|
||||||
|
|
||||||
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
|
||||||
return steps // self.unit_to_steps[unit]
|
|
||||||
|
|
||||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
|
||||||
number, unit = time_value["number"], time_value["unit"]
|
|
||||||
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
|
||||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def num_epochs(self) -> int:
|
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def num_iterations(self) -> int:
|
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def num_steps(self) -> int:
|
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def num_step_per_iteration(self) -> int:
|
|
||||||
return self.convert_time_unit_to_steps(
|
|
||||||
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def num_step_per_evaluation(self) -> int:
|
|
||||||
return self.convert_time_unit_to_steps(
|
|
||||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.start_time = None
|
|
||||||
self.end_time = None
|
|
||||||
self.step = 0
|
|
||||||
self.epoch = 0
|
|
||||||
self.iteration = 0
|
|
||||||
self.num_batches_processed = 0
|
|
||||||
self.num_minibatches_processed = 0
|
|
||||||
|
|
||||||
def start_timer(self) -> None:
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def stop_timer(self) -> None:
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def time_elapsed(self) -> int:
|
|
||||||
assert self.start_time is not None, "Timer has not been started yet."
|
|
||||||
return int(time.time() - self.start_time)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def evaluation_interval_steps(self) -> int:
|
|
||||||
return self.convert_time_unit_to_steps(
|
|
||||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def lr_scheduler_interval_steps(self) -> int:
|
|
||||||
return self.convert_time_unit_to_steps(
|
|
||||||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_optimizer_step(self) -> bool:
|
|
||||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_lr_scheduler_step(self) -> bool:
|
|
||||||
return self.step % self.lr_scheduler_interval_steps == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def done(self) -> bool:
|
|
||||||
return self.step >= self.num_steps
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_evaluation_step(self) -> bool:
|
|
||||||
return self.step % self.evaluation_interval_steps == 0
|
|
||||||
|
|
||||||
|
|
||||||
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
|
||||||
"""
|
|
||||||
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
|
||||||
"""
|
|
||||||
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
|
||||||
assert gradients, "The model has no gradients to compute the norm."
|
|
||||||
total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
|
||||||
return total_norm # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
Batch = TypeVar("Batch")
|
Batch = TypeVar("Batch")
|
||||||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||||
|
|
||||||
|
@ -267,44 +82,91 @@ class _Dataset(Dataset[Batch]):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
class Trainer(Generic[ConfigType, Batch], ABC):
|
@dataclass
|
||||||
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
class ModelItem:
|
||||||
self.config = config
|
name: str
|
||||||
self.clock = TrainingClock(
|
config: ModelConfig
|
||||||
dataset_length=self.dataset_length,
|
model: fl.Module
|
||||||
batch_size=config.training.batch_size,
|
learnable_parameters: list[nn.Parameter]
|
||||||
training_duration=config.training.duration,
|
|
||||||
evaluation_interval=config.training.evaluation_interval,
|
|
||||||
gradient_accumulation=config.training.gradient_accumulation,
|
ModelRegistry = dict[str, ModelItem]
|
||||||
lr_scheduler_interval=config.scheduler.update_interval,
|
ModuleT = TypeVar("ModuleT", bound=fl.Module)
|
||||||
|
|
||||||
|
|
||||||
|
def register_model():
|
||||||
|
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
|
||||||
|
name = func.__name__
|
||||||
|
model = func(self, config)
|
||||||
|
model = model.to(self.device, dtype=self.dtype)
|
||||||
|
if config.requires_grad is not None:
|
||||||
|
model.requires_grad_(requires_grad=config.requires_grad)
|
||||||
|
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
|
||||||
|
self.models[name] = ModelItem(
|
||||||
|
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||||
)
|
)
|
||||||
self.callbacks = callbacks or []
|
setattr(self, name, self.models[name].model)
|
||||||
self.callbacks += self.default_callbacks()
|
return func(self, config)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
CallbackRegistry = dict[str, Callback[Any]]
|
||||||
|
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
|
||||||
|
|
||||||
|
|
||||||
|
def register_callback():
|
||||||
|
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
|
||||||
|
name = func.__name__
|
||||||
|
callback = func(self, config)
|
||||||
|
self.callbacks[name] = callback
|
||||||
|
setattr(self, name, callback)
|
||||||
|
return func(self, config)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
def __init__(self, config: ConfigType) -> None:
|
||||||
|
self._models: ModelRegistry = {}
|
||||||
|
self._callbacks: CallbackRegistry = {}
|
||||||
|
self.config = config
|
||||||
|
self._load_callbacks()
|
||||||
self._call_callbacks(event_name="on_init_begin")
|
self._call_callbacks(event_name="on_init_begin")
|
||||||
self.load_models()
|
self._load_models()
|
||||||
self.prepare_models()
|
|
||||||
self._call_callbacks(event_name="on_init_end")
|
self._call_callbacks(event_name="on_init_end")
|
||||||
|
|
||||||
def default_callbacks(self) -> list[Callback[Any]]:
|
@register_callback()
|
||||||
callbacks: list[Callback[Any]] = [
|
def clock(self, config: ClockConfig) -> TrainingClock:
|
||||||
ClockCallback(),
|
return TrainingClock(
|
||||||
GradientValueClipping(),
|
dataset_length=self.dataset_length,
|
||||||
GradientNormClipping(),
|
batch_size=self.config.training.batch_size,
|
||||||
DropoutCallback(),
|
training_duration=self.config.training.duration,
|
||||||
]
|
evaluation_interval=self.config.training.evaluation_interval,
|
||||||
|
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||||
|
lr_scheduler_interval=self.config.scheduler.update_interval,
|
||||||
|
verbose=config.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
# look for any Callback that might be a property of the Trainer
|
@register_callback()
|
||||||
for attr_name in dir(self):
|
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
|
||||||
if "__" in attr_name:
|
return GradientClipping(config)
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
@property
|
||||||
attr = getattr(self, attr_name)
|
def models(self) -> ModelRegistry:
|
||||||
except AssertionError:
|
return self._models
|
||||||
continue
|
|
||||||
if isinstance(attr, Callback):
|
@property
|
||||||
callbacks.append(cast(Callback[Any], attr))
|
def callbacks(self) -> CallbackRegistry:
|
||||||
return callbacks
|
return self._callbacks
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
|
@ -320,38 +182,31 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> list[Parameter]:
|
def learnable_parameters(self) -> list[nn.Parameter]:
|
||||||
"""Returns a list of all parameters in all models"""
|
|
||||||
return [param for model in self.models.values() for param in model.parameters()]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def learnable_parameters(self) -> list[Parameter]:
|
|
||||||
"""Returns a list of learnable parameters in all models"""
|
"""Returns a list of learnable parameters in all models"""
|
||||||
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
return [param for item in self.models.values() for param in item.learnable_parameters]
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def optimizer_parameters(self) -> list[dict[str, Any]]:
|
def optimizer_parameters(self) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Returns a list of `dict`-s containing the params and optimizer options for each model.
|
Returns a list of `dict`-s containing the params and optimizer options for each model.
|
||||||
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
|
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
|
||||||
"""
|
"""
|
||||||
params: list[dict[str, Any]] = []
|
params: list[dict[str, Any]] = []
|
||||||
for model_name, model in self.models.items():
|
for item in self.models.values():
|
||||||
model_params = [param for param in model.parameters() if param.requires_grad]
|
config = item.config
|
||||||
model_config = self.config.models[model_name]
|
|
||||||
model_optim_conf: dict[str, Any] = {}
|
model_optim_conf: dict[str, Any] = {}
|
||||||
|
|
||||||
if model_config.learning_rate is not None:
|
if config.learning_rate is not None:
|
||||||
model_optim_conf["lr"] = model_config.learning_rate
|
model_optim_conf["lr"] = config.learning_rate
|
||||||
if model_config.weight_decay is not None:
|
if config.weight_decay is not None:
|
||||||
model_optim_conf["weight_decay"] = model_config.learning_rate
|
model_optim_conf["weight_decay"] = config.learning_rate
|
||||||
if model_config.betas is not None:
|
if config.betas is not None:
|
||||||
model_optim_conf["betas"] = model_config.learning_rate
|
model_optim_conf["betas"] = config.learning_rate
|
||||||
if model_config.eps is not None:
|
if config.eps is not None:
|
||||||
model_optim_conf["eps"] = model_config.learning_rate
|
model_optim_conf["eps"] = config.learning_rate
|
||||||
|
|
||||||
for param in model_params:
|
params.append({"params": item.learnable_parameters, **model_optim_conf})
|
||||||
params.append({"params": param, **model_optim_conf})
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
@ -363,17 +218,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@property
|
@property
|
||||||
def gradients(self) -> list[Tensor]:
|
def gradients(self) -> list[Tensor]:
|
||||||
"""Returns a list of detached gradients for all learnable parameters in all models"""
|
"""Returns a list of detached gradients for all learnable parameters in all models"""
|
||||||
return [
|
return [param.grad.detach() for param in self.learnable_parameters if param.grad is not None]
|
||||||
param.grad.detach()
|
|
||||||
for model in self.models.values()
|
|
||||||
for param in model.parameters()
|
|
||||||
if param.grad is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_gradient_norm(self) -> float:
|
def total_gradient_norm(self) -> float:
|
||||||
"""Returns the total gradient norm for all learnable parameters in all models"""
|
"""Returns the total gradient norm for all learnable parameters in all models"""
|
||||||
return compute_grad_norm(parameters=self.parameters)
|
return compute_grad_norm(parameters=self.learnable_parameters)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def optimizer(self) -> Optimizer:
|
def optimizer(self) -> Optimizer:
|
||||||
|
@ -441,38 +291,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def models(self) -> dict[str, fl.Module]:
|
|
||||||
return self.load_models()
|
|
||||||
|
|
||||||
def set_models_to_train_mode(self) -> None:
|
|
||||||
for model in self.models.values():
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
def set_models_to_eval_mode(self) -> None:
|
|
||||||
for model in self.models.values():
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
def prepare_model(self, model_name: str) -> None:
|
|
||||||
model = self.models[model_name]
|
|
||||||
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
|
||||||
model.load_from_safetensors(tensors_path=checkpoint)
|
|
||||||
else:
|
|
||||||
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
|
||||||
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
|
|
||||||
model.requires_grad_(requires_grad=requires_grad)
|
|
||||||
model.to(self.device)
|
|
||||||
model.zero_grad()
|
|
||||||
|
|
||||||
def prepare_models(self) -> None:
|
|
||||||
assert self.models, "No models found."
|
|
||||||
for model_name in self.models:
|
|
||||||
self.prepare_model(model_name=model_name)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_item(self, index: int) -> Batch:
|
def get_item(self, index: int) -> Batch:
|
||||||
"""
|
"""
|
||||||
|
@ -563,7 +381,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@scoped_seed(seed=get_training_seed)
|
@scoped_seed(seed=get_training_seed)
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
"""Train the model."""
|
"""Train the model."""
|
||||||
self.set_models_to_train_mode()
|
self.set_models_to_mode("train")
|
||||||
self._call_callbacks(event_name="on_train_begin")
|
self._call_callbacks(event_name="on_train_begin")
|
||||||
assert self.learnable_parameters, "There are no learnable parameters in the models."
|
assert self.learnable_parameters, "There are no learnable parameters in the models."
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
|
@ -581,12 +399,43 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@scoped_seed(seed=get_evaluation_seed)
|
@scoped_seed(seed=get_evaluation_seed)
|
||||||
def evaluate(self) -> None:
|
def evaluate(self) -> None:
|
||||||
"""Evaluate the model."""
|
"""Evaluate the model."""
|
||||||
self.set_models_to_eval_mode()
|
self.set_models_to_mode(mode="eval")
|
||||||
self._call_callbacks(event_name="on_evaluate_begin")
|
self._call_callbacks(event_name="on_evaluate_begin")
|
||||||
self.compute_evaluation()
|
self.compute_evaluation()
|
||||||
self._call_callbacks(event_name="on_evaluate_end")
|
self._call_callbacks(event_name="on_evaluate_end")
|
||||||
self.set_models_to_train_mode()
|
self.set_models_to_mode(mode="train")
|
||||||
|
|
||||||
|
def set_models_to_mode(self, mode: Literal["train", "eval"]) -> None:
|
||||||
|
for item in self.models.values():
|
||||||
|
if mode == "train":
|
||||||
|
item.model.train()
|
||||||
|
elif mode == "eval":
|
||||||
|
item.model.eval()
|
||||||
|
|
||||||
def _call_callbacks(self, event_name: str) -> None:
|
def _call_callbacks(self, event_name: str) -> None:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks.values():
|
||||||
getattr(callback, event_name)(self)
|
getattr(callback, event_name)(self)
|
||||||
|
|
||||||
|
def _load_callbacks(self) -> None:
|
||||||
|
for name, config in self.config:
|
||||||
|
if not isinstance(config, CallbackConfig):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
registered_callback = getattr(self, name)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Callback {name} is in the config but not registered in the Trainer. Create a method with the @register_callback decorator."
|
||||||
|
)
|
||||||
|
assert callable(registered_callback)
|
||||||
|
registered_callback(config)
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
for name, config in self.config.models.items():
|
||||||
|
try:
|
||||||
|
registered_model = getattr(self, name)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {name} is in the config but not registered in the Trainer. Create a method with the @register_model decorator."
|
||||||
|
)
|
||||||
|
assert callable(registered_model)
|
||||||
|
registered_model(config)
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from functools import cached_property
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.trainer import Trainer, register_callback
|
||||||
|
|
||||||
number = float | int
|
number = float | int
|
||||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||||
|
@ -64,7 +62,7 @@ class WandbLogger:
|
||||||
return self.wandb_run.name or "" # type: ignore
|
return self.wandb_run.name or "" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class WandbConfig(BaseModel):
|
class WandbConfig(CallbackConfig):
|
||||||
"""
|
"""
|
||||||
Wandb configuration.
|
Wandb configuration.
|
||||||
|
|
||||||
|
@ -87,18 +85,16 @@ class WandbConfig(BaseModel):
|
||||||
anonymous: Literal["never", "allow", "must"] | None = None
|
anonymous: Literal["never", "allow", "must"] | None = None
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
|
|
||||||
|
|
||||||
AnyTrainer = Trainer[BaseConfig, Any]
|
AnyTrainer = Trainer[BaseConfig, Any]
|
||||||
|
|
||||||
|
|
||||||
class WandbCallback(Callback["TrainerWithWandb"]):
|
class WandbCallback(Callback["TrainerWithWandb"]):
|
||||||
epoch_losses: list[float]
|
def __init__(self, config: WandbConfig, /, trainer_config: dict[str, Any]) -> None:
|
||||||
iteration_losses: list[float]
|
self.config = config
|
||||||
|
self.epoch_losses: list[float] = []
|
||||||
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
|
self.iteration_losses: list[float] = []
|
||||||
trainer.load_wandb()
|
self.logger = WandbLogger({**config.model_dump(), "config": trainer_config})
|
||||||
|
|
||||||
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
|
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||||
self.epoch_losses = []
|
self.epoch_losses = []
|
||||||
|
@ -131,19 +127,13 @@ class WandbMixin(ABC):
|
||||||
config: Any
|
config: Any
|
||||||
wandb_logger: WandbLogger
|
wandb_logger: WandbLogger
|
||||||
|
|
||||||
def load_wandb(self) -> None:
|
@register_callback()
|
||||||
wandb_config = getattr(self.config, "wandb", None)
|
def wandb(self, config: WandbConfig) -> WandbCallback:
|
||||||
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
|
return WandbCallback(config, trainer_config=self.config.model_dump())
|
||||||
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
|
|
||||||
self.wandb_logger = WandbLogger(init_config=init_config)
|
|
||||||
|
|
||||||
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
|
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
|
||||||
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
|
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
|
||||||
self.wandb_logger.log(data=data, step=self.clock.step)
|
self.wandb.logger.log(data=data, step=self.clock.step)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def wandb_callback(self) -> WandbCallback:
|
|
||||||
return WandbCallback()
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
|
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
[models.mock_model]
|
[models.mock_model]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
|
|
||||||
|
[clock]
|
||||||
|
verbose = false
|
||||||
|
|
||||||
|
[gradient_clipping]
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
|
@ -9,7 +14,6 @@ device = "cpu"
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
clip_grad_norm = 1.0
|
|
||||||
evaluation_interval = "5:epoch"
|
evaluation_interval = "5:epoch"
|
||||||
evaluation_seed = 1
|
evaluation_seed = 1
|
||||||
|
|
||||||
|
@ -21,6 +25,3 @@ learning_rate = 1
|
||||||
scheduler_type = "ConstantLR"
|
scheduler_type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:step"
|
||||||
warmup = "20:step"
|
warmup = "20:step"
|
||||||
|
|
||||||
[dropout]
|
|
||||||
dropout_probability = 0.0
|
|
||||||
|
|
|
@ -5,12 +5,17 @@ learning_rate = 1e-5
|
||||||
[models.mock_model2]
|
[models.mock_model2]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
|
|
||||||
|
[clock]
|
||||||
|
verbose = false
|
||||||
|
|
||||||
|
[gradient_clipping]
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
clip_grad_norm = 1.0
|
|
||||||
evaluation_interval = "5:epoch"
|
evaluation_interval = "5:epoch"
|
||||||
evaluation_seed = 1
|
evaluation_seed = 1
|
||||||
|
|
||||||
|
@ -21,6 +26,3 @@ learning_rate = 1
|
||||||
[scheduler]
|
[scheduler]
|
||||||
scheduler_type = "ConstantLR"
|
scheduler_type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:step"
|
||||||
|
|
||||||
[dropout]
|
|
||||||
dropout_probability = 0.0
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
@ -10,13 +9,15 @@ from torch.optim import SGD
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
from refiners.training_utils.config import BaseConfig, TimeUnit
|
from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number
|
||||||
|
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingClock,
|
TrainingClock,
|
||||||
WarmupScheduler,
|
WarmupScheduler,
|
||||||
count_learnable_parameters,
|
count_learnable_parameters,
|
||||||
human_readable_number,
|
human_readable_number,
|
||||||
|
register_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,13 +56,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
targets=torch.cat([b.targets for b in batch]),
|
targets=torch.cat([b.targets for b in batch]),
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@register_model()
|
||||||
def mock_model(self) -> MockModel:
|
def mock_model(self, config: ModelConfig) -> MockModel:
|
||||||
return MockModel()
|
return MockModel()
|
||||||
|
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
|
||||||
return {"mock_model": self.mock_model}
|
|
||||||
|
|
||||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||||
self.step_counter += 1
|
self.step_counter += 1
|
||||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||||
|
@ -217,17 +215,14 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
|
|
||||||
|
|
||||||
class MockTrainerWith2Models(MockTrainer):
|
class MockTrainerWith2Models(MockTrainer):
|
||||||
@cached_property
|
@register_model()
|
||||||
def mock_model1(self) -> MockModel:
|
def mock_model1(self, config: ModelConfig) -> MockModel:
|
||||||
return MockModel()
|
return MockModel()
|
||||||
|
|
||||||
@cached_property
|
@register_model()
|
||||||
def mock_model2(self) -> MockModel:
|
def mock_model2(self, config: ModelConfig) -> MockModel:
|
||||||
return MockModel()
|
return MockModel()
|
||||||
|
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
|
||||||
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
|
|
||||||
|
|
||||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||||
self.step_counter += 1
|
self.step_counter += 1
|
||||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||||
|
@ -246,7 +241,5 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M
|
||||||
|
|
||||||
|
|
||||||
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
||||||
assert (
|
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
|
||||||
len(mock_trainer_2_models.optimizer.param_groups) == 12
|
|
||||||
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
|
|
||||||
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
||||||
|
|
Loading…
Reference in a new issue