add @register_model and @register_callback decorators

Refactor ClockTrainer to include Callback
This commit is contained in:
limiteinductive 2024-02-12 08:28:41 +00:00 committed by Benjamin Trom
parent f541badcb3
commit d6546c9026
12 changed files with 565 additions and 703 deletions

View file

@ -4,8 +4,20 @@ from importlib.metadata import requires
from packaging.requirements import Requirement
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.config import (
BaseConfig,
ModelConfig,
OptimizerConfig,
Optimizers,
SchedulerConfig,
SchedulerType,
TrainingConfig,
)
from refiners.training_utils.gradient_clipping import GradientClippingConfig
from refiners.training_utils.trainer import Trainer, register_callback, register_model
from refiners.training_utils.wandb import WandbConfig, WandbMixin
refiners_requires = requires("refiners")
assert refiners_requires is not None
@ -29,4 +41,18 @@ for dep in refiners_requires:
__all__ = [
"Trainer",
"BaseConfig",
"ModelConfig",
"register_callback",
"register_model",
"Callback",
"CallbackConfig",
"WandbMixin",
"WandbConfig",
"SchedulerConfig",
"OptimizerConfig",
"TrainingConfig",
"ClockConfig",
"GradientClippingConfig",
"Optimizers",
"SchedulerType",
]

View file

@ -1,43 +1,22 @@
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from loguru import logger
from torch import tensor
from torch.nn import Parameter
from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
__all__ = [
"Callback",
"GradientNormClipping",
"GradientValueClipping",
"ClockCallback",
]
T = TypeVar("T", bound="Trainer[BaseConfig, Any]")
def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
class CallbackConfig(BaseModel):
"""
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
Base configuration for a callback.
For your callback to be properly configured, you should inherit from this class and add your own configuration.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
for gradient in gradients:
gradient.mul_(other=clip_coefficient) # type: ignore
def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None:
"""
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
for gradient in gradients:
gradient.clamp_(min=-clip_value, max=clip_value)
T = TypeVar("T")
model_config = ConfigDict(extra="forbid")
class Callback(Generic[T]):
@ -97,71 +76,3 @@ class Callback(Generic[T]):
def on_checkpoint_save(self, trainer: T) -> None:
...
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset()
logger.info(
(
"Starting training for a total of: "
f"{trainer.clock.num_steps} steps, "
f"{trainer.clock.num_epochs} epochs, "
f"{trainer.clock.num_iterations} iterations."
)
)
trainer.clock.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.stop_timer()
logger.info(
(
"Training took: "
f"{trainer.clock.time_elapsed} seconds, "
f"{trainer.clock.iteration} iterations, "
f"{trainer.clock.epoch} epochs, "
f"{trainer.clock.step} steps."
)
)
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Epoch {trainer.clock.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Step {trainer.clock.step} started.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Iteration {trainer.clock.iteration} ended.")
trainer.clock.iteration += 1
trainer.clock.num_minibatches_processed = 0
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info("Evaluation started.")
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info("Evaluation ended.")
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = trainer.config.training.clip_grad_norm
if clip_norm is not None:
clip_gradient_norm(
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
)
class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_value = trainer.config.training.clip_grad_value
if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)

View file

@ -0,0 +1,193 @@
import time
from functools import cached_property
from typing import TYPE_CHECKING, Any
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.common import TimeUnit, TimeValue
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
from loguru import logger
from torch import Tensor
class ClockConfig(CallbackConfig):
verbose: bool = True
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def __init__(
self,
dataset_length: int,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: TimeValue,
evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue,
verbose: bool = True,
) -> None:
self.dataset_length = dataset_length
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval
self.verbose = verbose
self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
self.loss: Tensor | None = None
@cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]:
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
return {
TimeUnit.STEP: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch,
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
}
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
return number * self.unit_to_steps[unit]
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
number, unit = time_value["number"], time_value["unit"]
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@cached_property
def num_epochs(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
@cached_property
def num_iterations(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
@cached_property
def num_steps(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
@cached_property
def num_step_per_iteration(self) -> int:
return self.convert_time_unit_to_steps(
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
)
@cached_property
def num_step_per_evaluation(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
def reset(self) -> None:
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
def start_timer(self) -> None:
self.start_time = time.time()
def stop_timer(self) -> None:
self.end_time = time.time()
@property
def time_elapsed(self) -> int:
assert self.start_time is not None, "Timer has not been started yet."
return int(time.time() - self.start_time)
@cached_property
def evaluation_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
@cached_property
def lr_scheduler_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
)
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration
@property
def is_lr_scheduler_step(self) -> bool:
return self.step % self.lr_scheduler_interval_steps == 0
@property
def done(self) -> bool:
return self.step >= self.num_steps
@property
def is_evaluation_step(self) -> bool:
return self.step % self.evaluation_interval_steps == 0
def log(self, message: str, /) -> None:
if self.verbose:
logger.info(message)
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset()
self.log(
(
"Starting training for a total of: "
f"{trainer.clock.num_steps} steps, "
f"{trainer.clock.num_epochs} epochs, "
f"{trainer.clock.num_iterations} iterations."
)
)
trainer.clock.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.stop_timer()
self.log(
(
"Training took: "
f"{trainer.clock.time_elapsed} seconds, "
f"{trainer.clock.iteration} iterations, "
f"{trainer.clock.epoch} epochs, "
f"{trainer.clock.step} steps."
)
)
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Epoch {trainer.clock.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} started.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Iteration {trainer.clock.iteration} ended.")
trainer.clock.iteration += 1
trainer.clock.num_minibatches_processed = 0
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log("Evaluation started.")
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log("Evaluation ended.")

View file

@ -0,0 +1,103 @@
import random
from enum import Enum
from functools import wraps
from typing import Any, Callable, Iterable
import numpy as np
import torch
from loguru import logger
from torch import Tensor, cuda, nn
from typing_extensions import TypedDict
from refiners.fluxion.utils import manual_seed
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
"""
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
"""
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to compute the norm."
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
return total_norm # type: ignore
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
return sum(p.numel() for p in parameters if p.requires_grad)
def human_readable_number(number: int) -> str:
float_number = float(number)
for unit in ["", "K", "M", "G", "T", "P"]:
if abs(float_number) < 1000:
return f"{float_number:.1f}{unit}"
float_number /= 1000
return f"{float_number:.1f}E"
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
manual_seed(seed=seed)
cuda.manual_seed_all(seed=seed)
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
"""
Decorator for setting a random seed within the scope of a function.
This decorator sets the random seed for Python's built-in `random` module,
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
function is executed, it restores the state of the random number generators
to what it was before the function was called. This is useful for ensuring
reproducibility for specific parts of the code without affecting randomness
elsewhere.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
random_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = torch.get_rng_state()
cuda_torch_state = cuda.get_rng_state()
actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed)
result = func(*args, **kwargs)
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)
cuda.set_rng_state(cuda_torch_state)
return result
return inner_wrapper
return decorator
class TimeUnit(Enum):
STEP = "step"
EPOCH = "epoch"
ITERATION = "iteration"
DEFAULT = "step"
class TimeValue(TypedDict):
number: int
unit: TimeUnit
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
match value:
case str(value_str):
number, unit = value_str.split(sep=":")
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
case int(number):
return {"number": number, "unit": TimeUnit.DEFAULT}
case {"number": int(number), "unit": str(unit)}:
return {"number": number, "unit": TimeUnit(value=unit.lower())}
case _:
raise ValueError(f"Unsupported value format: {value}")

View file

@ -9,50 +9,16 @@ from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, ConfigDict, validator
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
import refiners.fluxion.layers as fl
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
from refiners.training_utils.gradient_clipping import GradientClippingConfig
# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
# See https://github.com/pytorch/pytorch/pull/111114
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
__all__ = [
"parse_number_unit_field",
"TimeUnit",
"TimeValue",
"TrainingConfig",
"OptimizerConfig",
"Optimizers",
]
class TimeUnit(Enum):
STEP = "step"
EPOCH = "epoch"
ITERATION = "iteration"
DEFAULT = "step"
class TimeValue(TypedDict):
number: int
unit: TimeUnit
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
match value:
case str(value_str):
number, unit = value_str.split(sep=":")
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
case int(number):
return {"number": number, "unit": TimeUnit.DEFAULT}
case {"number": int(number), "unit": str(unit)}:
return {"number": number, "unit": TimeUnit(value=unit.lower())}
case _:
raise ValueError(f"Unsupported value format: {value}")
class TrainingConfig(BaseModel):
device: str = "cpu"
@ -61,8 +27,6 @@ class TrainingConfig(BaseModel):
seed: int = 0
batch_size: int = 1
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
clip_grad_norm: float | None = None
clip_grad_value: float | None = None
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
evaluation_seed: int = 0
@ -195,29 +159,6 @@ class ModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
class GyroDropoutConfig(BaseModel):
total_subnetworks: int = 512
concurrent_subnetworks: int = 64
iters_per_epoch: int = 512
num_features_threshold: float = 5e5
model_config = ConfigDict(extra="forbid")
class DropoutConfig(BaseModel):
dropout_probability: float = 0.0
gyro_dropout: GyroDropoutConfig | None = None
model_config = ConfigDict(extra="forbid")
def apply_dropout(self, model: fl.Chain) -> None:
if self.dropout_probability > 0.0:
if self.gyro_dropout is not None:
apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump())
else:
apply_dropout(module=model, probability=self.dropout_probability)
T = TypeVar("T", bound="BaseConfig")
@ -226,7 +167,8 @@ class BaseConfig(BaseModel):
training: TrainingConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
dropout: DropoutConfig
clock: ClockConfig = ClockConfig()
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
model_config = ConfigDict(extra="forbid")

View file

@ -1,200 +0,0 @@
from typing import TYPE_CHECKING, Any, TypeVar
from torch import Tensor, cat, rand, randint
from torch.nn import Dropout as TorchDropout
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.training_utils.callback import Callback
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
__all__ = ["Dropout", "GyroDropout", "DropoutCallback"]
class Dropout(TorchDropout, fl.Module):
def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
super().__init__(p=probability, inplace=inplace)
class GyroDropout(fl.Module):
"""
GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training.
It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training.
This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional
dropout.
Parameters:
-----------
total_subnetworks:
The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks
that are precomputed and stored.
concurrent_subnetworks:
The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of
masks from the precomputed set is used to dropout different portions of the input.
dropout_probability: float, optional (default=0.5)
The probability that an element will be zeroed by the dropout.
iters_per_epoch:
Number of iterations per epoch, used to determine how often the masks should be updated.
num_features_threshold:
If the number of features in the input is greater than this threshold, dropout is skipped. This is because
gyro dropout mask size vram usage is proportional to the number of features in the input.
"""
def __init__(
self,
total_subnetworks: int,
concurrent_subnetworks: int,
dropout_probability: float = 0.5,
iters_per_epoch: int = 1,
num_features_threshold: float = 5e5,
) -> None:
super().__init__()
assert (
iters_per_epoch >= total_subnetworks
), "The number of iterations per epoch must be greater than the number of masks"
self.dropout_probability = dropout_probability
self.iters_per_epoch = iters_per_epoch
self.total_subnetworks = total_subnetworks
self.concurrent_subnetworks = concurrent_subnetworks
self.scale = 1 / (1 - self.dropout_probability)
self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks
self.preselected_masks: Tensor | None = None
self.dropout_mask = None
self.training_step = 0
self.num_features_threshold = num_features_threshold
self.skip_high_num_features = False
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
if self.skip_high_num_features:
return self.basic_dropout(x)
if self.training_step == 0:
num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1]
if num_features > self.num_features_threshold:
self.skip_high_num_features = True
self.basic_dropout = Dropout(probability=self.dropout_probability)
return self.basic_dropout(x)
self.init_masks(x=x)
if self.training_step % self.mask_update_interval == 0:
self.update_dropout_mask(x=x)
self.training_step += 1
return x * self.dropout_mask * self.scale
def init_masks(self, x: Tensor) -> None:
if x.dim() == 2:
self.preselected_masks = (
rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability
)
if x.dim() == 3:
self.preselected_masks = (
rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability
)
assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions"
self.preselected_masks = self.preselected_masks.float()
def update_dropout_mask(self, x: Tensor) -> None:
assert self.preselected_masks is not None
indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device)
selected_masks = self.preselected_masks[indices]
repeat_factor = x.shape[0] // self.concurrent_subnetworks
remaining = x.shape[0] % self.concurrent_subnetworks
repeated_masks = [selected_masks] * repeat_factor
if remaining > 0:
repeated_masks.append(selected_masks[:remaining])
final_masks = cat(tensors=repeated_masks, dim=0)
if x.dim() == 2:
self.dropout_mask = final_masks
if x.dim() == 3:
self.dropout_mask = final_masks.expand_as(x)
class DropoutAdapter(fl.Chain, Adapter[fl.Linear]):
def __init__(self, target: fl.Linear, probability: float = 0.5):
with self.setup_adapter(target):
super().__init__(target, Dropout(probability=probability))
class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]):
def __init__(
self,
target: fl.Linear,
probability: float = 0.5,
total_subnetworks: int = 512,
concurrent_subnetworks: int = 64,
iters_per_epoch: int = 512,
num_features_threshold: float = 5e5,
) -> None:
self.probability = probability
self.total_subnetworks = total_subnetworks
self.concurrent_subnetworks = concurrent_subnetworks
self.iters_per_epoch = iters_per_epoch
with self.setup_adapter(target):
super().__init__(
target,
GyroDropout(
total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks,
dropout_probability=probability,
iters_per_epoch=iters_per_epoch,
num_features_threshold=num_features_threshold,
),
)
def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
for linear, parent in module.walk(fl.Linear):
if not linear.weight.requires_grad:
continue
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
DropoutAdapter(target=linear, probability=probability).inject(parent)
def apply_gyro_dropout(
module: fl.Chain,
probability: float = 0.5,
total_subnetworks: int = 32,
concurrent_subnetworks: int = 16,
iters_per_epoch: int = 32,
) -> None:
for linear, parent in module.walk(fl.Linear):
if not linear.weight.requires_grad:
continue
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
GyroDropoutAdapter(
target=linear,
probability=probability,
total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks,
iters_per_epoch=iters_per_epoch,
).inject(parent)
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
class DropoutCallback(Callback["Trainer[ConfigType, Any]"]):
def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None:
dropout_config = trainer.config.dropout
chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)]
for model in chain_models:
dropout_config.apply_dropout(model=model)

View file

@ -0,0 +1,52 @@
from typing import TYPE_CHECKING, Any, Iterable
import torch
from torch import nn
from refiners.training_utils.callback import Callback, CallbackConfig
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
"""
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
for gradient in gradients:
gradient.mul_(other=clip_coefficient) # type: ignore
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
"""
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
for gradient in gradients:
gradient.clamp_(min=-clip_value, max=clip_value)
class GradientClippingConfig(CallbackConfig):
clip_grad_norm: float | None = None
clip_grad_value: float | None = None
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
def __init__(self, config: GradientClippingConfig) -> None:
self.config = config
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = self.config.clip_grad_norm
if clip_norm is not None:
clip_gradient_norm(
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
)
clip_value = self.config.clip_grad_value
if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)

View file

@ -1,15 +1,12 @@
import random
import time
from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass
from functools import cached_property, wraps
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
from typing import Any, Callable, Generic, Literal, TypeVar, cast
import numpy as np
import torch
from loguru import logger
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
from torch import Tensor, device as Device, dtype as DType, nn
from torch.autograd import backward
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
@ -27,73 +24,20 @@ from torch.optim.lr_scheduler import (
from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.fluxion.utils import no_grad
from refiners.training_utils.callback import (
Callback,
ClockCallback,
GradientNormClipping,
GradientValueClipping,
CallbackConfig,
)
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
from refiners.training_utils.dropout import DropoutCallback
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
def count_learnable_parameters(parameters: Iterable[Parameter]) -> int:
return sum(p.numel() for p in parameters if p.requires_grad)
def human_readable_number(number: int) -> str:
float_number = float(number)
for unit in ["", "K", "M", "G", "T", "P"]:
if abs(float_number) < 1000:
return f"{float_number:.1f}{unit}"
float_number /= 1000
return f"{float_number:.1f}E"
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
manual_seed(seed=seed)
cuda.manual_seed_all(seed=seed)
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
"""
Decorator for setting a random seed within the scope of a function.
This decorator sets the random seed for Python's built-in `random` module,
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
function is executed, it restores the state of the random number generators
to what it was before the function was called. This is useful for ensuring
reproducibility for specific parts of the code without affecting randomness
elsewhere.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
random_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = get_rng_state()
cuda_torch_state = cuda.get_rng_state()
actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed)
result = func(*args, **kwargs)
random.setstate(random_state)
np.random.set_state(numpy_state)
set_rng_state(torch_state)
cuda.set_rng_state(cuda_torch_state)
return result
return inner_wrapper
return decorator
from refiners.training_utils.clock import ClockConfig, TrainingClock
from refiners.training_utils.common import (
compute_grad_norm,
count_learnable_parameters,
human_readable_number,
scoped_seed,
)
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
class WarmupScheduler(LRScheduler):
@ -117,135 +61,6 @@ class WarmupScheduler(LRScheduler):
self._step_count += 1
class TrainingClock:
def __init__(
self,
dataset_length: int,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: TimeValue,
evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue,
) -> None:
self.dataset_length = dataset_length
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval
self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
self.loss: Tensor | None = None
@cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]:
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
return {
TimeUnit.STEP: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch,
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
}
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
return number * self.unit_to_steps[unit]
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
number, unit = time_value["number"], time_value["unit"]
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@cached_property
def num_epochs(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
@cached_property
def num_iterations(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
@cached_property
def num_steps(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
@cached_property
def num_step_per_iteration(self) -> int:
return self.convert_time_unit_to_steps(
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
)
@cached_property
def num_step_per_evaluation(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
def reset(self) -> None:
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
def start_timer(self) -> None:
self.start_time = time.time()
def stop_timer(self) -> None:
self.end_time = time.time()
@property
def time_elapsed(self) -> int:
assert self.start_time is not None, "Timer has not been started yet."
return int(time.time() - self.start_time)
@cached_property
def evaluation_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
@cached_property
def lr_scheduler_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
)
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration
@property
def is_lr_scheduler_step(self) -> bool:
return self.step % self.lr_scheduler_interval_steps == 0
@property
def done(self) -> bool:
return self.step >= self.num_steps
@property
def is_evaluation_step(self) -> bool:
return self.step % self.evaluation_interval_steps == 0
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
"""
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
"""
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to compute the norm."
total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
return total_norm # type: ignore
Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
@ -267,44 +82,91 @@ class _Dataset(Dataset[Batch]):
return self.length
@dataclass
class ModelItem:
name: str
config: ModelConfig
model: fl.Module
learnable_parameters: list[nn.Parameter]
ModelRegistry = dict[str, ModelItem]
ModuleT = TypeVar("ModuleT", bound=fl.Module)
def register_model():
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
@wraps(func)
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
name = func.__name__
model = func(self, config)
model = model.to(self.device, dtype=self.dtype)
if config.requires_grad is not None:
model.requires_grad_(requires_grad=config.requires_grad)
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
self.models[name] = ModelItem(
name=name, config=config, model=model, learnable_parameters=learnable_parameters
)
setattr(self, name, self.models[name].model)
return func(self, config)
return wrapper # type: ignore
return decorator
CallbackRegistry = dict[str, Callback[Any]]
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
def register_callback():
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
@wraps(func)
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
name = func.__name__
callback = func(self, config)
self.callbacks[name] = callback
setattr(self, name, callback)
return func(self, config)
return wrapper # type: ignore
return decorator
class Trainer(Generic[ConfigType, Batch], ABC):
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
def __init__(self, config: ConfigType) -> None:
self._models: ModelRegistry = {}
self._callbacks: CallbackRegistry = {}
self.config = config
self.clock = TrainingClock(
dataset_length=self.dataset_length,
batch_size=config.training.batch_size,
training_duration=config.training.duration,
evaluation_interval=config.training.evaluation_interval,
gradient_accumulation=config.training.gradient_accumulation,
lr_scheduler_interval=config.scheduler.update_interval,
)
self.callbacks = callbacks or []
self.callbacks += self.default_callbacks()
self._load_callbacks()
self._call_callbacks(event_name="on_init_begin")
self.load_models()
self.prepare_models()
self._load_models()
self._call_callbacks(event_name="on_init_end")
def default_callbacks(self) -> list[Callback[Any]]:
callbacks: list[Callback[Any]] = [
ClockCallback(),
GradientValueClipping(),
GradientNormClipping(),
DropoutCallback(),
]
@register_callback()
def clock(self, config: ClockConfig) -> TrainingClock:
return TrainingClock(
dataset_length=self.dataset_length,
batch_size=self.config.training.batch_size,
training_duration=self.config.training.duration,
evaluation_interval=self.config.training.evaluation_interval,
gradient_accumulation=self.config.training.gradient_accumulation,
lr_scheduler_interval=self.config.scheduler.update_interval,
verbose=config.verbose,
)
# look for any Callback that might be a property of the Trainer
for attr_name in dir(self):
if "__" in attr_name:
continue
@register_callback()
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
return GradientClipping(config)
try:
attr = getattr(self, attr_name)
except AssertionError:
continue
if isinstance(attr, Callback):
callbacks.append(cast(Callback[Any], attr))
return callbacks
@property
def models(self) -> ModelRegistry:
return self._models
@property
def callbacks(self) -> CallbackRegistry:
return self._callbacks
@cached_property
def device(self) -> Device:
@ -320,38 +182,31 @@ class Trainer(Generic[ConfigType, Batch], ABC):
return dtype
@property
def parameters(self) -> list[Parameter]:
"""Returns a list of all parameters in all models"""
return [param for model in self.models.values() for param in model.parameters()]
@property
def learnable_parameters(self) -> list[Parameter]:
def learnable_parameters(self) -> list[nn.Parameter]:
"""Returns a list of learnable parameters in all models"""
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
return [param for item in self.models.values() for param in item.learnable_parameters]
@property
@cached_property
def optimizer_parameters(self) -> list[dict[str, Any]]:
"""
Returns a list of `dict`-s containing the params and optimizer options for each model.
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
"""
params: list[dict[str, Any]] = []
for model_name, model in self.models.items():
model_params = [param for param in model.parameters() if param.requires_grad]
model_config = self.config.models[model_name]
for item in self.models.values():
config = item.config
model_optim_conf: dict[str, Any] = {}
if model_config.learning_rate is not None:
model_optim_conf["lr"] = model_config.learning_rate
if model_config.weight_decay is not None:
model_optim_conf["weight_decay"] = model_config.learning_rate
if model_config.betas is not None:
model_optim_conf["betas"] = model_config.learning_rate
if model_config.eps is not None:
model_optim_conf["eps"] = model_config.learning_rate
if config.learning_rate is not None:
model_optim_conf["lr"] = config.learning_rate
if config.weight_decay is not None:
model_optim_conf["weight_decay"] = config.learning_rate
if config.betas is not None:
model_optim_conf["betas"] = config.learning_rate
if config.eps is not None:
model_optim_conf["eps"] = config.learning_rate
for param in model_params:
params.append({"params": param, **model_optim_conf})
params.append({"params": item.learnable_parameters, **model_optim_conf})
return params
@ -363,17 +218,12 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@property
def gradients(self) -> list[Tensor]:
"""Returns a list of detached gradients for all learnable parameters in all models"""
return [
param.grad.detach()
for model in self.models.values()
for param in model.parameters()
if param.grad is not None
]
return [param.grad.detach() for param in self.learnable_parameters if param.grad is not None]
@property
def total_gradient_norm(self) -> float:
"""Returns the total gradient norm for all learnable parameters in all models"""
return compute_grad_norm(parameters=self.parameters)
return compute_grad_norm(parameters=self.learnable_parameters)
@cached_property
def optimizer(self) -> Optimizer:
@ -441,38 +291,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
return lr_scheduler
@cached_property
def models(self) -> dict[str, fl.Module]:
return self.load_models()
def set_models_to_train_mode(self) -> None:
for model in self.models.values():
model.train()
def set_models_to_eval_mode(self) -> None:
for model in self.models.values():
model.eval()
def prepare_model(self, model_name: str) -> None:
model = self.models[model_name]
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
model.load_from_safetensors(tensors_path=checkpoint)
else:
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
model.requires_grad_(requires_grad=requires_grad)
model.to(self.device)
model.zero_grad()
def prepare_models(self) -> None:
assert self.models, "No models found."
for model_name in self.models:
self.prepare_model(model_name=model_name)
@abstractmethod
def load_models(self) -> dict[str, fl.Module]:
...
@abstractmethod
def get_item(self, index: int) -> Batch:
"""
@ -563,7 +381,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@scoped_seed(seed=get_training_seed)
def train(self) -> None:
"""Train the model."""
self.set_models_to_train_mode()
self.set_models_to_mode("train")
self._call_callbacks(event_name="on_train_begin")
assert self.learnable_parameters, "There are no learnable parameters in the models."
self.evaluate()
@ -581,12 +399,43 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@scoped_seed(seed=get_evaluation_seed)
def evaluate(self) -> None:
"""Evaluate the model."""
self.set_models_to_eval_mode()
self.set_models_to_mode(mode="eval")
self._call_callbacks(event_name="on_evaluate_begin")
self.compute_evaluation()
self._call_callbacks(event_name="on_evaluate_end")
self.set_models_to_train_mode()
self.set_models_to_mode(mode="train")
def set_models_to_mode(self, mode: Literal["train", "eval"]) -> None:
for item in self.models.values():
if mode == "train":
item.model.train()
elif mode == "eval":
item.model.eval()
def _call_callbacks(self, event_name: str) -> None:
for callback in self.callbacks:
for callback in self.callbacks.values():
getattr(callback, event_name)(self)
def _load_callbacks(self) -> None:
for name, config in self.config:
if not isinstance(config, CallbackConfig):
continue
try:
registered_callback = getattr(self, name)
except AttributeError:
raise ValueError(
f"Callback {name} is in the config but not registered in the Trainer. Create a method with the @register_callback decorator."
)
assert callable(registered_callback)
registered_callback(config)
def _load_models(self) -> None:
for name, config in self.config.models.items():
try:
registered_model = getattr(self, name)
except AttributeError:
raise ValueError(
f"Model {name} is in the config but not registered in the Trainer. Create a method with the @register_model decorator."
)
assert callable(registered_model)
registered_model(config)

View file

@ -1,15 +1,13 @@
from abc import ABC
from functools import cached_property
from pathlib import Path
from typing import Any, Literal
import wandb
from PIL import Image
from pydantic import BaseModel, ConfigDict
from refiners.training_utils.callback import Callback
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.trainer import Trainer, register_callback
number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
@ -64,7 +62,7 @@ class WandbLogger:
return self.wandb_run.name or "" # type: ignore
class WandbConfig(BaseModel):
class WandbConfig(CallbackConfig):
"""
Wandb configuration.
@ -87,18 +85,16 @@ class WandbConfig(BaseModel):
anonymous: Literal["never", "allow", "must"] | None = None
id: str | None = None
model_config = ConfigDict(extra="forbid")
AnyTrainer = Trainer[BaseConfig, Any]
class WandbCallback(Callback["TrainerWithWandb"]):
epoch_losses: list[float]
iteration_losses: list[float]
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
trainer.load_wandb()
def __init__(self, config: WandbConfig, /, trainer_config: dict[str, Any]) -> None:
self.config = config
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
self.logger = WandbLogger({**config.model_dump(), "config": trainer_config})
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
self.epoch_losses = []
@ -131,19 +127,13 @@ class WandbMixin(ABC):
config: Any
wandb_logger: WandbLogger
def load_wandb(self) -> None:
wandb_config = getattr(self.config, "wandb", None)
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
self.wandb_logger = WandbLogger(init_config=init_config)
@register_callback()
def wandb(self, config: WandbConfig) -> WandbCallback:
return WandbCallback(config, trainer_config=self.config.model_dump())
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
self.wandb_logger.log(data=data, step=self.clock.step)
@cached_property
def wandb_callback(self) -> WandbCallback:
return WandbCallback()
self.wandb.logger.log(data=data, step=self.clock.step)
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):

View file

@ -1,6 +1,11 @@
[models.mock_model]
requires_grad = true
[clock]
verbose = false
[gradient_clipping]
clip_grad_norm = 1.0
[training]
duration = "100:epoch"
@ -9,7 +14,6 @@ device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
@ -21,6 +25,3 @@ learning_rate = 1
scheduler_type = "ConstantLR"
update_interval = "1:step"
warmup = "20:step"
[dropout]
dropout_probability = 0.0

View file

@ -5,12 +5,17 @@ learning_rate = 1e-5
[models.mock_model2]
requires_grad = true
[clock]
verbose = false
[gradient_clipping]
clip_grad_norm = 1.0
[training]
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
@ -21,6 +26,3 @@ learning_rate = 1
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout_probability = 0.0

View file

@ -1,5 +1,4 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import cast
@ -10,13 +9,15 @@ from torch.optim import SGD
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm
from refiners.training_utils.config import BaseConfig, TimeUnit
from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number
from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.trainer import (
Trainer,
TrainingClock,
WarmupScheduler,
count_learnable_parameters,
human_readable_number,
register_model,
)
@ -55,13 +56,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
targets=torch.cat([b.targets for b in batch]),
)
@cached_property
def mock_model(self) -> MockModel:
@register_model()
def mock_model(self, config: ModelConfig) -> MockModel:
return MockModel()
def load_models(self) -> dict[str, fl.Module]:
return {"mock_model": self.mock_model}
def compute_loss(self, batch: MockBatch) -> Tensor:
self.step_counter += 1
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
@ -217,17 +215,14 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
class MockTrainerWith2Models(MockTrainer):
@cached_property
def mock_model1(self) -> MockModel:
@register_model()
def mock_model1(self, config: ModelConfig) -> MockModel:
return MockModel()
@cached_property
def mock_model2(self) -> MockModel:
@register_model()
def mock_model2(self, config: ModelConfig) -> MockModel:
return MockModel()
def load_models(self) -> dict[str, fl.Module]:
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
def compute_loss(self, batch: MockBatch) -> Tensor:
self.step_counter += 1
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
@ -246,7 +241,5 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
assert (
len(mock_trainer_2_models.optimizer.param_groups) == 12
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5