Switch gradient clipping to native torch torch.nn.utils.clip_grad_norm_

This commit is contained in:
limiteinductive 2024-03-19 16:34:34 +00:00 committed by Benjamin Trom
parent 68fe725767
commit 38c86f59f4
8 changed files with 11 additions and 76 deletions

View file

@ -15,7 +15,6 @@ from refiners.training_utils.config import (
Optimizers, Optimizers,
TrainingConfig, TrainingConfig,
) )
from refiners.training_utils.gradient_clipping import GradientClippingConfig
from refiners.training_utils.trainer import Trainer, register_callback, register_model from refiners.training_utils.trainer import Trainer, register_callback, register_model
from refiners.training_utils.wandb import WandbConfig, WandbMixin from refiners.training_utils.wandb import WandbConfig, WandbMixin
@ -52,7 +51,6 @@ __all__ = [
"OptimizerConfig", "OptimizerConfig",
"TrainingConfig", "TrainingConfig",
"ClockConfig", "ClockConfig",
"GradientClippingConfig",
"Optimizers", "Optimizers",
"LRSchedulerType", "LRSchedulerType",
] ]

View file

@ -7,19 +7,18 @@ from typing import Any, Callable, Iterable
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from torch import Tensor, cuda, nn from torch import cuda, nn
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float: def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
""" """
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value. Computes the gradient norm of the parameters in the given iterable.
We use the `torch.nn.utils.clip_grad_norm_` function to process the gradients efficiently on the GPU or CPU.
""" """
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None] return nn.utils.clip_grad.clip_grad_norm_(parameters, float("inf")).item()
assert gradients, "The model has no gradients to compute the norm."
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
return total_norm # type: ignore
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int: def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:

View file

@ -12,7 +12,6 @@ from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
from refiners.training_utils.gradient_clipping import GradientClippingConfig
# PyTorch optimizer parameters type # PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -28,6 +27,7 @@ class TrainingConfig(BaseModel):
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP) gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
gradient_clipping_max_norm: float | None = None
evaluation_seed: int = 0 evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@ -167,7 +167,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig() clock: ClockConfig = ClockConfig()
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")

View file

@ -1,52 +0,0 @@
from typing import TYPE_CHECKING, Any, Iterable
import torch
from torch import nn
from refiners.training_utils.callback import Callback, CallbackConfig
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
"""
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
for gradient in gradients:
gradient.mul_(other=clip_coefficient) # type: ignore
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
"""
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
for gradient in gradients:
gradient.clamp_(min=-clip_value, max=clip_value)
class GradientClippingConfig(CallbackConfig):
clip_grad_norm: float | None = None
clip_grad_value: float | None = None
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
def __init__(self, config: GradientClippingConfig) -> None:
self.config = config
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = self.config.clip_grad_norm
if clip_norm is not None:
clip_gradient_norm(
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
)
clip_value = self.config.clip_grad_value
if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)

View file

@ -37,7 +37,6 @@ from refiners.training_utils.common import (
scoped_seed, scoped_seed,
) )
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
class WarmupScheduler(LRScheduler): class WarmupScheduler(LRScheduler):
@ -161,10 +160,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
verbose=config.verbose, verbose=config.verbose,
) )
@register_callback()
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
return GradientClipping(config)
@property @property
def models(self) -> ModelRegistry: def models(self) -> ModelRegistry:
return self._models return self._models
@ -351,6 +346,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
self._call_callbacks(event_name="on_backward_end") self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step: if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin") self._call_callbacks(event_name="on_optimizer_step_begin")
max_norm = self.config.training.gradient_clipping_max_norm or float("inf")
self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end") self._call_callbacks(event_name="on_optimizer_step_end")

View file

@ -112,6 +112,7 @@ class WandbCallback(Callback["TrainerWithWandb"]):
trainer.wandb_log(data={"step_loss": loss_value}) trainer.wandb_log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None: def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.grad_norm})
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses) avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss}) trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = [] self.iteration_losses = []
@ -124,9 +125,6 @@ class WandbCallback(Callback["TrainerWithWandb"]):
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None: def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]}) trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
class WandbMixin(ABC): class WandbMixin(ABC):
config: Any config: Any

View file

@ -5,9 +5,6 @@ use_activation = true
[clock] [clock]
verbose = false verbose = false
[gradient_clipping]
clip_grad_norm = 1.0
[training] [training]
duration = "100:epoch" duration = "100:epoch"
seed = 0 seed = 0
@ -17,6 +14,7 @@ batch_size = 4
gradient_accumulation = "4:step" gradient_accumulation = "4:step"
evaluation_interval = "5:epoch" evaluation_interval = "5:epoch"
evaluation_seed = 1 evaluation_seed = 1
gradient_clipping_max_norm = 1.0
[optimizer] [optimizer]
optimizer = "SGD" optimizer = "SGD"

View file

@ -8,9 +8,6 @@ requires_grad = true
[clock] [clock]
verbose = false verbose = false
[gradient_clipping]
clip_grad_norm = 1.0
[training] [training]
duration = "100:epoch" duration = "100:epoch"
seed = 0 seed = 0
@ -18,6 +15,7 @@ batch_size = 4
gradient_accumulation = "4:step" gradient_accumulation = "4:step"
evaluation_interval = "5:epoch" evaluation_interval = "5:epoch"
evaluation_seed = 1 evaluation_seed = 1
gradient_clipping_max_norm = 1.0
[optimizer] [optimizer]
optimizer = "SGD" optimizer = "SGD"