From d6546c90268d00483f87bc66d9f79fc5315523c7 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Mon, 12 Feb 2024 08:28:41 +0000 Subject: [PATCH] add @register_model and @register_callback decorators Refactor ClockTrainer to include Callback --- src/refiners/training_utils/__init__.py | 30 +- src/refiners/training_utils/callback.py | 105 +---- src/refiners/training_utils/clock.py | 193 ++++++++ src/refiners/training_utils/common.py | 103 +++++ src/refiners/training_utils/config.py | 68 +-- src/refiners/training_utils/dropout.py | 200 -------- .../training_utils/gradient_clipping.py | 52 +++ src/refiners/training_utils/trainer.py | 437 ++++++------------ src/refiners/training_utils/wandb.py | 34 +- tests/training_utils/mock_config.toml | 9 +- .../training_utils/mock_config_2_models.toml | 10 +- tests/training_utils/test_trainer.py | 27 +- 12 files changed, 565 insertions(+), 703 deletions(-) create mode 100644 src/refiners/training_utils/clock.py create mode 100644 src/refiners/training_utils/common.py delete mode 100644 src/refiners/training_utils/dropout.py create mode 100644 src/refiners/training_utils/gradient_clipping.py diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 617a8b4..8850b34 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -4,8 +4,20 @@ from importlib.metadata import requires from packaging.requirements import Requirement -from refiners.training_utils.config import BaseConfig -from refiners.training_utils.trainer import Trainer +from refiners.training_utils.callback import Callback, CallbackConfig +from refiners.training_utils.clock import ClockConfig +from refiners.training_utils.config import ( + BaseConfig, + ModelConfig, + OptimizerConfig, + Optimizers, + SchedulerConfig, + SchedulerType, + TrainingConfig, +) +from refiners.training_utils.gradient_clipping import GradientClippingConfig +from refiners.training_utils.trainer import Trainer, register_callback, register_model +from refiners.training_utils.wandb import WandbConfig, WandbMixin refiners_requires = requires("refiners") assert refiners_requires is not None @@ -29,4 +41,18 @@ for dep in refiners_requires: __all__ = [ "Trainer", "BaseConfig", + "ModelConfig", + "register_callback", + "register_model", + "Callback", + "CallbackConfig", + "WandbMixin", + "WandbConfig", + "SchedulerConfig", + "OptimizerConfig", + "TrainingConfig", + "ClockConfig", + "GradientClippingConfig", + "Optimizers", + "SchedulerType", ] diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 2d7c0b2..c6275b0 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -1,43 +1,22 @@ -from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from loguru import logger -from torch import tensor -from torch.nn import Parameter +from pydantic import BaseModel, ConfigDict if TYPE_CHECKING: from refiners.training_utils.config import BaseConfig from refiners.training_utils.trainer import Trainer -__all__ = [ - "Callback", - "GradientNormClipping", - "GradientValueClipping", - "ClockCallback", -] +T = TypeVar("T", bound="Trainer[BaseConfig, Any]") -def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None: +class CallbackConfig(BaseModel): """ - Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`. + Base configuration for a callback. + + For your callback to be properly configured, you should inherit from this class and add your own configuration. """ - gradients = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to clip." - clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1) - for gradient in gradients: - gradient.mul_(other=clip_coefficient) # type: ignore - -def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None: - """ - Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`. - """ - gradients = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to clip." - for gradient in gradients: - gradient.clamp_(min=-clip_value, max=clip_value) - - -T = TypeVar("T") + model_config = ConfigDict(extra="forbid") class Callback(Generic[T]): @@ -97,71 +76,3 @@ class Callback(Generic[T]): def on_checkpoint_save(self, trainer: T) -> None: ... - - -class ClockCallback(Callback["Trainer[BaseConfig, Any]"]): - def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.reset() - logger.info( - ( - "Starting training for a total of: " - f"{trainer.clock.num_steps} steps, " - f"{trainer.clock.num_epochs} epochs, " - f"{trainer.clock.num_iterations} iterations." - ) - ) - trainer.clock.start_timer() - - def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.stop_timer() - logger.info( - ( - "Training took: " - f"{trainer.clock.time_elapsed} seconds, " - f"{trainer.clock.iteration} iterations, " - f"{trainer.clock.epoch} epochs, " - f"{trainer.clock.step} steps." - ) - ) - - def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - logger.info(f"Epoch {trainer.clock.epoch} started.") - - def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.epoch += 1 - trainer.clock.num_batches_processed = 0 - - def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - logger.info(f"Step {trainer.clock.step} started.") - - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.step += 1 - trainer.clock.num_batches_processed += 1 - trainer.clock.num_minibatches_processed += 1 - - def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - logger.info(f"Iteration {trainer.clock.iteration} ended.") - trainer.clock.iteration += 1 - trainer.clock.num_minibatches_processed = 0 - - def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - logger.info("Evaluation started.") - - def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - logger.info("Evaluation ended.") - - -class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - clip_norm = trainer.config.training.clip_grad_norm - if clip_norm is not None: - clip_gradient_norm( - parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm - ) - - -class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]): - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - clip_value = trainer.config.training.clip_grad_value - if clip_value is not None: - clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py new file mode 100644 index 0000000..8eb3936 --- /dev/null +++ b/src/refiners/training_utils/clock.py @@ -0,0 +1,193 @@ +import time +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from refiners.training_utils.callback import Callback, CallbackConfig +from refiners.training_utils.common import TimeUnit, TimeValue + +if TYPE_CHECKING: + from refiners.training_utils.config import BaseConfig + from refiners.training_utils.trainer import Trainer + + +from loguru import logger +from torch import Tensor + + +class ClockConfig(CallbackConfig): + verbose: bool = True + + +class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): + def __init__( + self, + dataset_length: int, + batch_size: int, + training_duration: TimeValue, + gradient_accumulation: TimeValue, + evaluation_interval: TimeValue, + lr_scheduler_interval: TimeValue, + verbose: bool = True, + ) -> None: + self.dataset_length = dataset_length + self.batch_size = batch_size + self.training_duration = training_duration + self.gradient_accumulation = gradient_accumulation + self.evaluation_interval = evaluation_interval + self.lr_scheduler_interval = lr_scheduler_interval + self.verbose = verbose + self.num_batches_per_epoch = dataset_length // batch_size + self.start_time = None + self.end_time = None + self.step = 0 + self.epoch = 0 + self.iteration = 0 + self.num_batches_processed = 0 + self.num_minibatches_processed = 0 + self.loss: Tensor | None = None + + @cached_property + def unit_to_steps(self) -> dict[TimeUnit, int]: + iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1 + return { + TimeUnit.STEP: 1, + TimeUnit.EPOCH: self.num_batches_per_epoch, + TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor, + } + + def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: + return number * self.unit_to_steps[unit] + + def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int: + return steps // self.unit_to_steps[unit] + + def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: + number, unit = time_value["number"], time_value["unit"] + steps = self.convert_time_unit_to_steps(number=number, unit=unit) + return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) + + @cached_property + def num_epochs(self) -> int: + return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH) + + @cached_property + def num_iterations(self) -> int: + return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION) + + @cached_property + def num_steps(self) -> int: + return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP) + + @cached_property + def num_step_per_iteration(self) -> int: + return self.convert_time_unit_to_steps( + number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"] + ) + + @cached_property + def num_step_per_evaluation(self) -> int: + return self.convert_time_unit_to_steps( + number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] + ) + + def reset(self) -> None: + self.start_time = None + self.end_time = None + self.step = 0 + self.epoch = 0 + self.iteration = 0 + self.num_batches_processed = 0 + self.num_minibatches_processed = 0 + + def start_timer(self) -> None: + self.start_time = time.time() + + def stop_timer(self) -> None: + self.end_time = time.time() + + @property + def time_elapsed(self) -> int: + assert self.start_time is not None, "Timer has not been started yet." + return int(time.time() - self.start_time) + + @cached_property + def evaluation_interval_steps(self) -> int: + return self.convert_time_unit_to_steps( + number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] + ) + + @cached_property + def lr_scheduler_interval_steps(self) -> int: + return self.convert_time_unit_to_steps( + number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] + ) + + @property + def is_optimizer_step(self) -> bool: + return self.num_minibatches_processed == self.num_step_per_iteration + + @property + def is_lr_scheduler_step(self) -> bool: + return self.step % self.lr_scheduler_interval_steps == 0 + + @property + def done(self) -> bool: + return self.step >= self.num_steps + + @property + def is_evaluation_step(self) -> bool: + return self.step % self.evaluation_interval_steps == 0 + + def log(self, message: str, /) -> None: + if self.verbose: + logger.info(message) + + def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.reset() + self.log( + ( + "Starting training for a total of: " + f"{trainer.clock.num_steps} steps, " + f"{trainer.clock.num_epochs} epochs, " + f"{trainer.clock.num_iterations} iterations." + ) + ) + trainer.clock.start_timer() + + def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.stop_timer() + self.log( + ( + "Training took: " + f"{trainer.clock.time_elapsed} seconds, " + f"{trainer.clock.iteration} iterations, " + f"{trainer.clock.epoch} epochs, " + f"{trainer.clock.step} steps." + ) + ) + + def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log(f"Epoch {trainer.clock.epoch} started.") + + def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.epoch += 1 + trainer.clock.num_batches_processed = 0 + + def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log(f"Step {trainer.clock.step} started.") + + def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.step += 1 + trainer.clock.num_batches_processed += 1 + trainer.clock.num_minibatches_processed += 1 + + def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log(f"Iteration {trainer.clock.iteration} ended.") + trainer.clock.iteration += 1 + trainer.clock.num_minibatches_processed = 0 + + def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log("Evaluation started.") + + def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.log("Evaluation ended.") diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py new file mode 100644 index 0000000..eff12aa --- /dev/null +++ b/src/refiners/training_utils/common.py @@ -0,0 +1,103 @@ +import random +from enum import Enum +from functools import wraps +from typing import Any, Callable, Iterable + +import numpy as np +import torch +from loguru import logger +from torch import Tensor, cuda, nn +from typing_extensions import TypedDict + +from refiners.fluxion.utils import manual_seed + + +def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float: + """ + Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value. + """ + gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None] + assert gradients, "The model has no gradients to compute the norm." + total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore + return total_norm # type: ignore + + +def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int: + return sum(p.numel() for p in parameters if p.requires_grad) + + +def human_readable_number(number: int) -> str: + float_number = float(number) + for unit in ["", "K", "M", "G", "T", "P"]: + if abs(float_number) < 1000: + return f"{float_number:.1f}{unit}" + float_number /= 1000 + return f"{float_number:.1f}E" + + +def seed_everything(seed: int | None = None) -> None: + if seed is None: + seed = random.randint(0, 2**32 - 1) + logger.info(f"Using random seed: {seed}") + random.seed(a=seed) + np.random.seed(seed=seed) + manual_seed(seed=seed) + cuda.manual_seed_all(seed=seed) + + +def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]: + """ + Decorator for setting a random seed within the scope of a function. + + This decorator sets the random seed for Python's built-in `random` module, + `numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the + function is executed, it restores the state of the random number generators + to what it was before the function was called. This is useful for ensuring + reproducibility for specific parts of the code without affecting randomness + elsewhere. + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def inner_wrapper(*args: Any, **kwargs: Any) -> Any: + random_state = random.getstate() + numpy_state = np.random.get_state() + torch_state = torch.get_rng_state() + cuda_torch_state = cuda.get_rng_state() + actual_seed = seed(*args) if callable(seed) else seed + seed_everything(seed=actual_seed) + result = func(*args, **kwargs) + random.setstate(random_state) + np.random.set_state(numpy_state) + torch.set_rng_state(torch_state) + cuda.set_rng_state(cuda_torch_state) + return result + + return inner_wrapper + + return decorator + + +class TimeUnit(Enum): + STEP = "step" + EPOCH = "epoch" + ITERATION = "iteration" + DEFAULT = "step" + + +class TimeValue(TypedDict): + number: int + unit: TimeUnit + + +def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue: + match value: + case str(value_str): + number, unit = value_str.split(sep=":") + return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())} + case int(number): + return {"number": number, "unit": TimeUnit.DEFAULT} + case {"number": int(number), "unit": str(unit)}: + return {"number": number, "unit": TimeUnit(value=unit.lower())} + case _: + raise ValueError(f"Unsupported value format: {value}") diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index f333f6d..9a4c9a5 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -9,50 +9,16 @@ from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, ConfigDict, validator from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer -from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version -import refiners.fluxion.layers as fl -from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout +from refiners.training_utils.clock import ClockConfig +from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field +from refiners.training_utils.gradient_clipping import GradientClippingConfig # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced # See https://github.com/pytorch/pytorch/pull/111114 ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] -__all__ = [ - "parse_number_unit_field", - "TimeUnit", - "TimeValue", - "TrainingConfig", - "OptimizerConfig", - "Optimizers", -] - - -class TimeUnit(Enum): - STEP = "step" - EPOCH = "epoch" - ITERATION = "iteration" - DEFAULT = "step" - - -class TimeValue(TypedDict): - number: int - unit: TimeUnit - - -def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue: - match value: - case str(value_str): - number, unit = value_str.split(sep=":") - return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())} - case int(number): - return {"number": number, "unit": TimeUnit.DEFAULT} - case {"number": int(number), "unit": str(unit)}: - return {"number": number, "unit": TimeUnit(value=unit.lower())} - case _: - raise ValueError(f"Unsupported value format: {value}") - class TrainingConfig(BaseModel): device: str = "cpu" @@ -61,8 +27,6 @@ class TrainingConfig(BaseModel): seed: int = 0 batch_size: int = 1 gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} - clip_grad_norm: float | None = None - clip_grad_value: float | None = None evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} evaluation_seed: int = 0 @@ -195,29 +159,6 @@ class ModelConfig(BaseModel): model_config = ConfigDict(extra="forbid") -class GyroDropoutConfig(BaseModel): - total_subnetworks: int = 512 - concurrent_subnetworks: int = 64 - iters_per_epoch: int = 512 - num_features_threshold: float = 5e5 - - model_config = ConfigDict(extra="forbid") - - -class DropoutConfig(BaseModel): - dropout_probability: float = 0.0 - gyro_dropout: GyroDropoutConfig | None = None - - model_config = ConfigDict(extra="forbid") - - def apply_dropout(self, model: fl.Chain) -> None: - if self.dropout_probability > 0.0: - if self.gyro_dropout is not None: - apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump()) - else: - apply_dropout(module=model, probability=self.dropout_probability) - - T = TypeVar("T", bound="BaseConfig") @@ -226,7 +167,8 @@ class BaseConfig(BaseModel): training: TrainingConfig optimizer: OptimizerConfig scheduler: SchedulerConfig - dropout: DropoutConfig + clock: ClockConfig = ClockConfig() + gradient_clipping: GradientClippingConfig = GradientClippingConfig() model_config = ConfigDict(extra="forbid") diff --git a/src/refiners/training_utils/dropout.py b/src/refiners/training_utils/dropout.py deleted file mode 100644 index 37c188e..0000000 --- a/src/refiners/training_utils/dropout.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import TYPE_CHECKING, Any, TypeVar - -from torch import Tensor, cat, rand, randint -from torch.nn import Dropout as TorchDropout - -import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.adapter import Adapter -from refiners.training_utils.callback import Callback - -if TYPE_CHECKING: - from refiners.training_utils.config import BaseConfig - from refiners.training_utils.trainer import Trainer - - -__all__ = ["Dropout", "GyroDropout", "DropoutCallback"] - - -class Dropout(TorchDropout, fl.Module): - def __init__(self, probability: float = 0.5, inplace: bool = False) -> None: - super().__init__(p=probability, inplace=inplace) - - -class GyroDropout(fl.Module): - """ - GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training. - It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training. - This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional - dropout. - - Parameters: - ----------- - total_subnetworks: - The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks - that are precomputed and stored. - - concurrent_subnetworks: - The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of - masks from the precomputed set is used to dropout different portions of the input. - - dropout_probability: float, optional (default=0.5) - The probability that an element will be zeroed by the dropout. - - iters_per_epoch: - Number of iterations per epoch, used to determine how often the masks should be updated. - - num_features_threshold: - If the number of features in the input is greater than this threshold, dropout is skipped. This is because - gyro dropout mask size vram usage is proportional to the number of features in the input. - """ - - def __init__( - self, - total_subnetworks: int, - concurrent_subnetworks: int, - dropout_probability: float = 0.5, - iters_per_epoch: int = 1, - num_features_threshold: float = 5e5, - ) -> None: - super().__init__() - assert ( - iters_per_epoch >= total_subnetworks - ), "The number of iterations per epoch must be greater than the number of masks" - self.dropout_probability = dropout_probability - self.iters_per_epoch = iters_per_epoch - self.total_subnetworks = total_subnetworks - self.concurrent_subnetworks = concurrent_subnetworks - self.scale = 1 / (1 - self.dropout_probability) - self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks - self.preselected_masks: Tensor | None = None - self.dropout_mask = None - self.training_step = 0 - self.num_features_threshold = num_features_threshold - self.skip_high_num_features = False - - def forward(self, x: Tensor) -> Tensor: - if not self.training: - return x - if self.skip_high_num_features: - return self.basic_dropout(x) - if self.training_step == 0: - num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1] - if num_features > self.num_features_threshold: - self.skip_high_num_features = True - self.basic_dropout = Dropout(probability=self.dropout_probability) - return self.basic_dropout(x) - self.init_masks(x=x) - - if self.training_step % self.mask_update_interval == 0: - self.update_dropout_mask(x=x) - - self.training_step += 1 - - return x * self.dropout_mask * self.scale - - def init_masks(self, x: Tensor) -> None: - if x.dim() == 2: - self.preselected_masks = ( - rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability - ) - if x.dim() == 3: - self.preselected_masks = ( - rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability - ) - - assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions" - self.preselected_masks = self.preselected_masks.float() - - def update_dropout_mask(self, x: Tensor) -> None: - assert self.preselected_masks is not None - indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device) - selected_masks = self.preselected_masks[indices] - - repeat_factor = x.shape[0] // self.concurrent_subnetworks - remaining = x.shape[0] % self.concurrent_subnetworks - repeated_masks = [selected_masks] * repeat_factor - if remaining > 0: - repeated_masks.append(selected_masks[:remaining]) - final_masks = cat(tensors=repeated_masks, dim=0) - - if x.dim() == 2: - self.dropout_mask = final_masks - if x.dim() == 3: - self.dropout_mask = final_masks.expand_as(x) - - -class DropoutAdapter(fl.Chain, Adapter[fl.Linear]): - def __init__(self, target: fl.Linear, probability: float = 0.5): - with self.setup_adapter(target): - super().__init__(target, Dropout(probability=probability)) - - -class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]): - def __init__( - self, - target: fl.Linear, - probability: float = 0.5, - total_subnetworks: int = 512, - concurrent_subnetworks: int = 64, - iters_per_epoch: int = 512, - num_features_threshold: float = 5e5, - ) -> None: - self.probability = probability - self.total_subnetworks = total_subnetworks - self.concurrent_subnetworks = concurrent_subnetworks - self.iters_per_epoch = iters_per_epoch - - with self.setup_adapter(target): - super().__init__( - target, - GyroDropout( - total_subnetworks=total_subnetworks, - concurrent_subnetworks=concurrent_subnetworks, - dropout_probability=probability, - iters_per_epoch=iters_per_epoch, - num_features_threshold=num_features_threshold, - ), - ) - - -def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None: - for linear, parent in module.walk(fl.Linear): - if not linear.weight.requires_grad: - continue - assert not ( - isinstance(parent, Dropout) or isinstance(parent, GyroDropout) - ), f"{linear} already has a dropout layer" - DropoutAdapter(target=linear, probability=probability).inject(parent) - - -def apply_gyro_dropout( - module: fl.Chain, - probability: float = 0.5, - total_subnetworks: int = 32, - concurrent_subnetworks: int = 16, - iters_per_epoch: int = 32, -) -> None: - for linear, parent in module.walk(fl.Linear): - if not linear.weight.requires_grad: - continue - assert not ( - isinstance(parent, Dropout) or isinstance(parent, GyroDropout) - ), f"{linear} already has a dropout layer" - GyroDropoutAdapter( - target=linear, - probability=probability, - total_subnetworks=total_subnetworks, - concurrent_subnetworks=concurrent_subnetworks, - iters_per_epoch=iters_per_epoch, - ).inject(parent) - - -ConfigType = TypeVar("ConfigType", bound="BaseConfig") - - -class DropoutCallback(Callback["Trainer[ConfigType, Any]"]): - def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None: - dropout_config = trainer.config.dropout - chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)] - for model in chain_models: - dropout_config.apply_dropout(model=model) diff --git a/src/refiners/training_utils/gradient_clipping.py b/src/refiners/training_utils/gradient_clipping.py new file mode 100644 index 0000000..28701c8 --- /dev/null +++ b/src/refiners/training_utils/gradient_clipping.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING, Any, Iterable + +import torch +from torch import nn + +from refiners.training_utils.callback import Callback, CallbackConfig + +if TYPE_CHECKING: + from refiners.training_utils.config import BaseConfig + from refiners.training_utils.trainer import Trainer + + +def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None: + """ + Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`. + """ + gradients = [p.grad.detach() for p in parameters if p.grad is not None] + assert gradients, "The model has no gradients to clip." + clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1) + for gradient in gradients: + gradient.mul_(other=clip_coefficient) # type: ignore + + +def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None: + """ + Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`. + """ + gradients = [p.grad.detach() for p in parameters if p.grad is not None] + assert gradients, "The model has no gradients to clip." + for gradient in gradients: + gradient.clamp_(min=-clip_value, max=clip_value) + + +class GradientClippingConfig(CallbackConfig): + clip_grad_norm: float | None = None + clip_grad_value: float | None = None + + +class GradientClipping(Callback["Trainer[BaseConfig, Any]"]): + def __init__(self, config: GradientClippingConfig) -> None: + self.config = config + + def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + clip_norm = self.config.clip_grad_norm + if clip_norm is not None: + clip_gradient_norm( + parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm + ) + + clip_value = self.config.clip_grad_value + if clip_value is not None: + clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 433d10b..cb3d2b6 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,15 +1,12 @@ -import random -import time from abc import ABC, abstractmethod, abstractproperty +from dataclasses import dataclass from functools import cached_property, wraps -from typing import Any, Callable, Generic, Iterable, TypeVar, cast +from typing import Any, Callable, Generic, Literal, TypeVar, cast -import numpy as np import torch from loguru import logger -from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack +from torch import Tensor, device as Device, dtype as DType, nn from torch.autograd import backward -from torch.nn import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, @@ -27,73 +24,20 @@ from torch.optim.lr_scheduler import ( from torch.utils.data import DataLoader, Dataset from refiners.fluxion import layers as fl -from refiners.fluxion.utils import manual_seed, no_grad +from refiners.fluxion.utils import no_grad from refiners.training_utils.callback import ( Callback, - ClockCallback, - GradientNormClipping, - GradientValueClipping, + CallbackConfig, ) -from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue -from refiners.training_utils.dropout import DropoutCallback - -__all__ = ["seed_everything", "scoped_seed", "Trainer"] - - -def count_learnable_parameters(parameters: Iterable[Parameter]) -> int: - return sum(p.numel() for p in parameters if p.requires_grad) - - -def human_readable_number(number: int) -> str: - float_number = float(number) - for unit in ["", "K", "M", "G", "T", "P"]: - if abs(float_number) < 1000: - return f"{float_number:.1f}{unit}" - float_number /= 1000 - return f"{float_number:.1f}E" - - -def seed_everything(seed: int | None = None) -> None: - if seed is None: - seed = random.randint(0, 2**32 - 1) - logger.info(f"Using random seed: {seed}") - random.seed(a=seed) - np.random.seed(seed=seed) - manual_seed(seed=seed) - cuda.manual_seed_all(seed=seed) - - -def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]: - """ - Decorator for setting a random seed within the scope of a function. - - This decorator sets the random seed for Python's built-in `random` module, - `numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the - function is executed, it restores the state of the random number generators - to what it was before the function was called. This is useful for ensuring - reproducibility for specific parts of the code without affecting randomness - elsewhere. - """ - - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - def inner_wrapper(*args: Any, **kwargs: Any) -> Any: - random_state = random.getstate() - numpy_state = np.random.get_state() - torch_state = get_rng_state() - cuda_torch_state = cuda.get_rng_state() - actual_seed = seed(*args) if callable(seed) else seed - seed_everything(seed=actual_seed) - result = func(*args, **kwargs) - random.setstate(random_state) - np.random.set_state(numpy_state) - set_rng_state(torch_state) - cuda.set_rng_state(cuda_torch_state) - return result - - return inner_wrapper - - return decorator +from refiners.training_utils.clock import ClockConfig, TrainingClock +from refiners.training_utils.common import ( + compute_grad_norm, + count_learnable_parameters, + human_readable_number, + scoped_seed, +) +from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType +from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig class WarmupScheduler(LRScheduler): @@ -117,135 +61,6 @@ class WarmupScheduler(LRScheduler): self._step_count += 1 -class TrainingClock: - def __init__( - self, - dataset_length: int, - batch_size: int, - training_duration: TimeValue, - gradient_accumulation: TimeValue, - evaluation_interval: TimeValue, - lr_scheduler_interval: TimeValue, - ) -> None: - self.dataset_length = dataset_length - self.batch_size = batch_size - self.training_duration = training_duration - self.gradient_accumulation = gradient_accumulation - self.evaluation_interval = evaluation_interval - self.lr_scheduler_interval = lr_scheduler_interval - self.num_batches_per_epoch = dataset_length // batch_size - self.start_time = None - self.end_time = None - self.step = 0 - self.epoch = 0 - self.iteration = 0 - self.num_batches_processed = 0 - self.num_minibatches_processed = 0 - self.loss: Tensor | None = None - - @cached_property - def unit_to_steps(self) -> dict[TimeUnit, int]: - iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1 - return { - TimeUnit.STEP: 1, - TimeUnit.EPOCH: self.num_batches_per_epoch, - TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor, - } - - def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: - return number * self.unit_to_steps[unit] - - def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int: - return steps // self.unit_to_steps[unit] - - def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: - number, unit = time_value["number"], time_value["unit"] - steps = self.convert_time_unit_to_steps(number=number, unit=unit) - return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) - - @cached_property - def num_epochs(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH) - - @cached_property - def num_iterations(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION) - - @cached_property - def num_steps(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP) - - @cached_property - def num_step_per_iteration(self) -> int: - return self.convert_time_unit_to_steps( - number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"] - ) - - @cached_property - def num_step_per_evaluation(self) -> int: - return self.convert_time_unit_to_steps( - number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] - ) - - def reset(self) -> None: - self.start_time = None - self.end_time = None - self.step = 0 - self.epoch = 0 - self.iteration = 0 - self.num_batches_processed = 0 - self.num_minibatches_processed = 0 - - def start_timer(self) -> None: - self.start_time = time.time() - - def stop_timer(self) -> None: - self.end_time = time.time() - - @property - def time_elapsed(self) -> int: - assert self.start_time is not None, "Timer has not been started yet." - return int(time.time() - self.start_time) - - @cached_property - def evaluation_interval_steps(self) -> int: - return self.convert_time_unit_to_steps( - number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] - ) - - @cached_property - def lr_scheduler_interval_steps(self) -> int: - return self.convert_time_unit_to_steps( - number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] - ) - - @property - def is_optimizer_step(self) -> bool: - return self.num_minibatches_processed == self.num_step_per_iteration - - @property - def is_lr_scheduler_step(self) -> bool: - return self.step % self.lr_scheduler_interval_steps == 0 - - @property - def done(self) -> bool: - return self.step >= self.num_steps - - @property - def is_evaluation_step(self) -> bool: - return self.step % self.evaluation_interval_steps == 0 - - -def compute_grad_norm(parameters: Iterable[Parameter]) -> float: - """ - Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value. - """ - gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to compute the norm." - total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore - return total_norm # type: ignore - - Batch = TypeVar("Batch") ConfigType = TypeVar("ConfigType", bound=BaseConfig) @@ -267,44 +82,91 @@ class _Dataset(Dataset[Batch]): return self.length +@dataclass +class ModelItem: + name: str + config: ModelConfig + model: fl.Module + learnable_parameters: list[nn.Parameter] + + +ModelRegistry = dict[str, ModelItem] +ModuleT = TypeVar("ModuleT", bound=fl.Module) + + +def register_model(): + def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT: + @wraps(func) + def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module: + name = func.__name__ + model = func(self, config) + model = model.to(self.device, dtype=self.dtype) + if config.requires_grad is not None: + model.requires_grad_(requires_grad=config.requires_grad) + learnable_parameters = [param for param in model.parameters() if param.requires_grad] + self.models[name] = ModelItem( + name=name, config=config, model=model, learnable_parameters=learnable_parameters + ) + setattr(self, name, self.models[name].model) + return func(self, config) + + return wrapper # type: ignore + + return decorator + + +CallbackRegistry = dict[str, Callback[Any]] +CallbackT = TypeVar("CallbackT", bound=Callback[Any]) + + +def register_callback(): + def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT: + @wraps(func) + def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT: + name = func.__name__ + callback = func(self, config) + self.callbacks[name] = callback + setattr(self, name, callback) + return func(self, config) + + return wrapper # type: ignore + + return decorator + + class Trainer(Generic[ConfigType, Batch], ABC): - def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: + def __init__(self, config: ConfigType) -> None: + self._models: ModelRegistry = {} + self._callbacks: CallbackRegistry = {} self.config = config - self.clock = TrainingClock( - dataset_length=self.dataset_length, - batch_size=config.training.batch_size, - training_duration=config.training.duration, - evaluation_interval=config.training.evaluation_interval, - gradient_accumulation=config.training.gradient_accumulation, - lr_scheduler_interval=config.scheduler.update_interval, - ) - self.callbacks = callbacks or [] - self.callbacks += self.default_callbacks() + self._load_callbacks() self._call_callbacks(event_name="on_init_begin") - self.load_models() - self.prepare_models() + self._load_models() self._call_callbacks(event_name="on_init_end") - def default_callbacks(self) -> list[Callback[Any]]: - callbacks: list[Callback[Any]] = [ - ClockCallback(), - GradientValueClipping(), - GradientNormClipping(), - DropoutCallback(), - ] + @register_callback() + def clock(self, config: ClockConfig) -> TrainingClock: + return TrainingClock( + dataset_length=self.dataset_length, + batch_size=self.config.training.batch_size, + training_duration=self.config.training.duration, + evaluation_interval=self.config.training.evaluation_interval, + gradient_accumulation=self.config.training.gradient_accumulation, + lr_scheduler_interval=self.config.scheduler.update_interval, + verbose=config.verbose, + ) - # look for any Callback that might be a property of the Trainer - for attr_name in dir(self): - if "__" in attr_name: - continue + @register_callback() + def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping: + return GradientClipping(config) - try: - attr = getattr(self, attr_name) - except AssertionError: - continue - if isinstance(attr, Callback): - callbacks.append(cast(Callback[Any], attr)) - return callbacks + @property + def models(self) -> ModelRegistry: + return self._models + + @property + def callbacks(self) -> CallbackRegistry: + return self._callbacks @cached_property def device(self) -> Device: @@ -320,38 +182,31 @@ class Trainer(Generic[ConfigType, Batch], ABC): return dtype @property - def parameters(self) -> list[Parameter]: - """Returns a list of all parameters in all models""" - return [param for model in self.models.values() for param in model.parameters()] - - @property - def learnable_parameters(self) -> list[Parameter]: + def learnable_parameters(self) -> list[nn.Parameter]: """Returns a list of learnable parameters in all models""" - return [param for model in self.models.values() for param in model.parameters() if param.requires_grad] + return [param for item in self.models.values() for param in item.learnable_parameters] - @property + @cached_property def optimizer_parameters(self) -> list[dict[str, Any]]: """ Returns a list of `dict`-s containing the params and optimizer options for each model. See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details """ params: list[dict[str, Any]] = [] - for model_name, model in self.models.items(): - model_params = [param for param in model.parameters() if param.requires_grad] - model_config = self.config.models[model_name] + for item in self.models.values(): + config = item.config model_optim_conf: dict[str, Any] = {} - if model_config.learning_rate is not None: - model_optim_conf["lr"] = model_config.learning_rate - if model_config.weight_decay is not None: - model_optim_conf["weight_decay"] = model_config.learning_rate - if model_config.betas is not None: - model_optim_conf["betas"] = model_config.learning_rate - if model_config.eps is not None: - model_optim_conf["eps"] = model_config.learning_rate + if config.learning_rate is not None: + model_optim_conf["lr"] = config.learning_rate + if config.weight_decay is not None: + model_optim_conf["weight_decay"] = config.learning_rate + if config.betas is not None: + model_optim_conf["betas"] = config.learning_rate + if config.eps is not None: + model_optim_conf["eps"] = config.learning_rate - for param in model_params: - params.append({"params": param, **model_optim_conf}) + params.append({"params": item.learnable_parameters, **model_optim_conf}) return params @@ -363,17 +218,12 @@ class Trainer(Generic[ConfigType, Batch], ABC): @property def gradients(self) -> list[Tensor]: """Returns a list of detached gradients for all learnable parameters in all models""" - return [ - param.grad.detach() - for model in self.models.values() - for param in model.parameters() - if param.grad is not None - ] + return [param.grad.detach() for param in self.learnable_parameters if param.grad is not None] @property def total_gradient_norm(self) -> float: """Returns the total gradient norm for all learnable parameters in all models""" - return compute_grad_norm(parameters=self.parameters) + return compute_grad_norm(parameters=self.learnable_parameters) @cached_property def optimizer(self) -> Optimizer: @@ -441,38 +291,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): return lr_scheduler - @cached_property - def models(self) -> dict[str, fl.Module]: - return self.load_models() - - def set_models_to_train_mode(self) -> None: - for model in self.models.values(): - model.train() - - def set_models_to_eval_mode(self) -> None: - for model in self.models.values(): - model.eval() - - def prepare_model(self, model_name: str) -> None: - model = self.models[model_name] - if (checkpoint := self.config.models[model_name].checkpoint) is not None: - model.load_from_safetensors(tensors_path=checkpoint) - else: - logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") - if (requires_grad := self.config.models[model_name].requires_grad) is not None: - model.requires_grad_(requires_grad=requires_grad) - model.to(self.device) - model.zero_grad() - - def prepare_models(self) -> None: - assert self.models, "No models found." - for model_name in self.models: - self.prepare_model(model_name=model_name) - - @abstractmethod - def load_models(self) -> dict[str, fl.Module]: - ... - @abstractmethod def get_item(self, index: int) -> Batch: """ @@ -563,7 +381,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): @scoped_seed(seed=get_training_seed) def train(self) -> None: """Train the model.""" - self.set_models_to_train_mode() + self.set_models_to_mode("train") self._call_callbacks(event_name="on_train_begin") assert self.learnable_parameters, "There are no learnable parameters in the models." self.evaluate() @@ -581,12 +399,43 @@ class Trainer(Generic[ConfigType, Batch], ABC): @scoped_seed(seed=get_evaluation_seed) def evaluate(self) -> None: """Evaluate the model.""" - self.set_models_to_eval_mode() + self.set_models_to_mode(mode="eval") self._call_callbacks(event_name="on_evaluate_begin") self.compute_evaluation() self._call_callbacks(event_name="on_evaluate_end") - self.set_models_to_train_mode() + self.set_models_to_mode(mode="train") + + def set_models_to_mode(self, mode: Literal["train", "eval"]) -> None: + for item in self.models.values(): + if mode == "train": + item.model.train() + elif mode == "eval": + item.model.eval() def _call_callbacks(self, event_name: str) -> None: - for callback in self.callbacks: + for callback in self.callbacks.values(): getattr(callback, event_name)(self) + + def _load_callbacks(self) -> None: + for name, config in self.config: + if not isinstance(config, CallbackConfig): + continue + try: + registered_callback = getattr(self, name) + except AttributeError: + raise ValueError( + f"Callback {name} is in the config but not registered in the Trainer. Create a method with the @register_callback decorator." + ) + assert callable(registered_callback) + registered_callback(config) + + def _load_models(self) -> None: + for name, config in self.config.models.items(): + try: + registered_model = getattr(self, name) + except AttributeError: + raise ValueError( + f"Model {name} is in the config but not registered in the Trainer. Create a method with the @register_model decorator." + ) + assert callable(registered_model) + registered_model(config) diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 6422803..d4683d6 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -1,15 +1,13 @@ from abc import ABC -from functools import cached_property from pathlib import Path from typing import Any, Literal import wandb from PIL import Image -from pydantic import BaseModel, ConfigDict -from refiners.training_utils.callback import Callback +from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.config import BaseConfig -from refiners.training_utils.trainer import Trainer +from refiners.training_utils.trainer import Trainer, register_callback number = float | int WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]] @@ -64,7 +62,7 @@ class WandbLogger: return self.wandb_run.name or "" # type: ignore -class WandbConfig(BaseModel): +class WandbConfig(CallbackConfig): """ Wandb configuration. @@ -87,18 +85,16 @@ class WandbConfig(BaseModel): anonymous: Literal["never", "allow", "must"] | None = None id: str | None = None - model_config = ConfigDict(extra="forbid") - AnyTrainer = Trainer[BaseConfig, Any] class WandbCallback(Callback["TrainerWithWandb"]): - epoch_losses: list[float] - iteration_losses: list[float] - - def on_init_begin(self, trainer: "TrainerWithWandb") -> None: - trainer.load_wandb() + def __init__(self, config: WandbConfig, /, trainer_config: dict[str, Any]) -> None: + self.config = config + self.epoch_losses: list[float] = [] + self.iteration_losses: list[float] = [] + self.logger = WandbLogger({**config.model_dump(), "config": trainer_config}) def on_train_begin(self, trainer: "TrainerWithWandb") -> None: self.epoch_losses = [] @@ -131,19 +127,13 @@ class WandbMixin(ABC): config: Any wandb_logger: WandbLogger - def load_wandb(self) -> None: - wandb_config = getattr(self.config, "wandb", None) - assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set" - init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()} - self.wandb_logger = WandbLogger(init_config=init_config) + @register_callback() + def wandb(self, config: WandbConfig) -> WandbCallback: + return WandbCallback(config, trainer_config=self.config.model_dump()) def wandb_log(self, data: dict[str, WandbLoggable]) -> None: assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer" - self.wandb_logger.log(data=data, step=self.clock.step) - - @cached_property - def wandb_callback(self) -> WandbCallback: - return WandbCallback() + self.wandb.logger.log(data=data, step=self.clock.step) class TrainerWithWandb(AnyTrainer, WandbMixin, ABC): diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 900d97b..06652b5 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -1,6 +1,11 @@ [models.mock_model] requires_grad = true +[clock] +verbose = false + +[gradient_clipping] +clip_grad_norm = 1.0 [training] duration = "100:epoch" @@ -9,7 +14,6 @@ device = "cpu" dtype = "float32" batch_size = 4 gradient_accumulation = "4:step" -clip_grad_norm = 1.0 evaluation_interval = "5:epoch" evaluation_seed = 1 @@ -21,6 +25,3 @@ learning_rate = 1 scheduler_type = "ConstantLR" update_interval = "1:step" warmup = "20:step" - -[dropout] -dropout_probability = 0.0 diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 9808d5c..e279464 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -5,12 +5,17 @@ learning_rate = 1e-5 [models.mock_model2] requires_grad = true +[clock] +verbose = false + +[gradient_clipping] +clip_grad_norm = 1.0 + [training] duration = "100:epoch" seed = 0 batch_size = 4 gradient_accumulation = "4:step" -clip_grad_norm = 1.0 evaluation_interval = "5:epoch" evaluation_seed = 1 @@ -21,6 +26,3 @@ learning_rate = 1 [scheduler] scheduler_type = "ConstantLR" update_interval = "1:step" - -[dropout] -dropout_probability = 0.0 diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 62a6041..c1f8b08 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import cached_property from pathlib import Path from typing import cast @@ -10,13 +9,15 @@ from torch.optim import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm -from refiners.training_utils.config import BaseConfig, TimeUnit +from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number +from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.trainer import ( Trainer, TrainingClock, WarmupScheduler, count_learnable_parameters, human_readable_number, + register_model, ) @@ -55,13 +56,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): targets=torch.cat([b.targets for b in batch]), ) - @cached_property - def mock_model(self) -> MockModel: + @register_model() + def mock_model(self, config: ModelConfig) -> MockModel: return MockModel() - def load_models(self) -> dict[str, fl.Module]: - return {"mock_model": self.mock_model} - def compute_loss(self, batch: MockBatch) -> Tensor: self.step_counter += 1 inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device) @@ -217,17 +215,14 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: class MockTrainerWith2Models(MockTrainer): - @cached_property - def mock_model1(self) -> MockModel: + @register_model() + def mock_model1(self, config: ModelConfig) -> MockModel: return MockModel() - @cached_property - def mock_model2(self) -> MockModel: + @register_model() + def mock_model2(self, config: ModelConfig) -> MockModel: return MockModel() - def load_models(self) -> dict[str, fl.Module]: - return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2} - def compute_loss(self, batch: MockBatch) -> Tensor: self.step_counter += 1 inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device) @@ -246,7 +241,5 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None: - assert ( - len(mock_trainer_2_models.optimizer.param_groups) == 12 - ) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models] + assert len(mock_trainer_2_models.optimizer.param_groups) == 2 assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5