mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
Switch gradient clipping to native torch torch.nn.utils.clip_grad_norm_
This commit is contained in:
parent
68fe725767
commit
38c86f59f4
|
@ -15,7 +15,6 @@ from refiners.training_utils.config import (
|
|||
Optimizers,
|
||||
TrainingConfig,
|
||||
)
|
||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||
|
||||
|
@ -52,7 +51,6 @@ __all__ = [
|
|||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"ClockConfig",
|
||||
"GradientClippingConfig",
|
||||
"Optimizers",
|
||||
"LRSchedulerType",
|
||||
]
|
||||
|
|
|
@ -7,19 +7,18 @@ from typing import Any, Callable, Iterable
|
|||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import Tensor, cuda, nn
|
||||
from torch import cuda, nn
|
||||
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
|
||||
|
||||
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
|
||||
"""
|
||||
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
||||
Computes the gradient norm of the parameters in the given iterable.
|
||||
|
||||
We use the `torch.nn.utils.clip_grad_norm_` function to process the gradients efficiently on the GPU or CPU.
|
||||
"""
|
||||
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to compute the norm."
|
||||
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
||||
return total_norm # type: ignore
|
||||
return nn.utils.clip_grad.clip_grad_norm_(parameters, float("inf")).item()
|
||||
|
||||
|
||||
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
|
||||
|
|
|
@ -12,7 +12,6 @@ from torch.optim import SGD, Adam, AdamW, Optimizer
|
|||
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
|
@ -28,6 +27,7 @@ class TrainingConfig(BaseModel):
|
|||
batch_size: int = 1
|
||||
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
|
||||
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||
gradient_clipping_max_norm: float | None = None
|
||||
evaluation_seed: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
@ -167,7 +167,6 @@ class BaseConfig(BaseModel):
|
|||
optimizer: OptimizerConfig
|
||||
lr_scheduler: LRSchedulerConfig
|
||||
clock: ClockConfig = ClockConfig()
|
||||
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
|
||||
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
||||
"""
|
||||
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
||||
for gradient in gradients:
|
||||
gradient.mul_(other=clip_coefficient) # type: ignore
|
||||
|
||||
|
||||
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
|
||||
"""
|
||||
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
||||
"""
|
||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||
assert gradients, "The model has no gradients to clip."
|
||||
for gradient in gradients:
|
||||
gradient.clamp_(min=-clip_value, max=clip_value)
|
||||
|
||||
|
||||
class GradientClippingConfig(CallbackConfig):
|
||||
clip_grad_norm: float | None = None
|
||||
clip_grad_value: float | None = None
|
||||
|
||||
|
||||
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def __init__(self, config: GradientClippingConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
clip_norm = self.config.clip_grad_norm
|
||||
if clip_norm is not None:
|
||||
clip_gradient_norm(
|
||||
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
||||
)
|
||||
|
||||
clip_value = self.config.clip_grad_value
|
||||
if clip_value is not None:
|
||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
|
@ -37,7 +37,6 @@ from refiners.training_utils.common import (
|
|||
scoped_seed,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
|
||||
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
||||
|
||||
|
||||
class WarmupScheduler(LRScheduler):
|
||||
|
@ -161,10 +160,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
verbose=config.verbose,
|
||||
)
|
||||
|
||||
@register_callback()
|
||||
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
|
||||
return GradientClipping(config)
|
||||
|
||||
@property
|
||||
def models(self) -> ModelRegistry:
|
||||
return self._models
|
||||
|
@ -351,6 +346,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
self._call_callbacks(event_name="on_backward_end")
|
||||
if self.clock.is_optimizer_step:
|
||||
self._call_callbacks(event_name="on_optimizer_step_begin")
|
||||
max_norm = self.config.training.gradient_clipping_max_norm or float("inf")
|
||||
self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self._call_callbacks(event_name="on_optimizer_step_end")
|
||||
|
|
|
@ -112,6 +112,7 @@ class WandbCallback(Callback["TrainerWithWandb"]):
|
|||
trainer.wandb_log(data={"step_loss": loss_value})
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.wandb_log(data={"total_grad_norm": trainer.grad_norm})
|
||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
|
||||
self.iteration_losses = []
|
||||
|
@ -124,9 +125,6 @@ class WandbCallback(Callback["TrainerWithWandb"]):
|
|||
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||
|
||||
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
|
||||
|
||||
|
||||
class WandbMixin(ABC):
|
||||
config: Any
|
||||
|
|
|
@ -5,9 +5,6 @@ use_activation = true
|
|||
[clock]
|
||||
verbose = false
|
||||
|
||||
[gradient_clipping]
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
|
@ -17,6 +14,7 @@ batch_size = 4
|
|||
gradient_accumulation = "4:step"
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
|
|
|
@ -8,9 +8,6 @@ requires_grad = true
|
|||
[clock]
|
||||
verbose = false
|
||||
|
||||
[gradient_clipping]
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
|
@ -18,6 +15,7 @@ batch_size = 4
|
|||
gradient_accumulation = "4:step"
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
|
|
Loading…
Reference in a new issue