mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
Switch gradient clipping to native torch torch.nn.utils.clip_grad_norm_
This commit is contained in:
parent
68fe725767
commit
38c86f59f4
|
@ -15,7 +15,6 @@ from refiners.training_utils.config import (
|
||||||
Optimizers,
|
Optimizers,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
|
||||||
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||||
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||||
|
|
||||||
|
@ -52,7 +51,6 @@ __all__ = [
|
||||||
"OptimizerConfig",
|
"OptimizerConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"ClockConfig",
|
"ClockConfig",
|
||||||
"GradientClippingConfig",
|
|
||||||
"Optimizers",
|
"Optimizers",
|
||||||
"LRSchedulerType",
|
"LRSchedulerType",
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,19 +7,18 @@ from typing import Any, Callable, Iterable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, cuda, nn
|
from torch import cuda, nn
|
||||||
|
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
|
||||||
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
|
def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
|
||||||
"""
|
"""
|
||||||
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
Computes the gradient norm of the parameters in the given iterable.
|
||||||
|
|
||||||
|
We use the `torch.nn.utils.clip_grad_norm_` function to process the gradients efficiently on the GPU or CPU.
|
||||||
"""
|
"""
|
||||||
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
return nn.utils.clip_grad.clip_grad_norm_(parameters, float("inf")).item()
|
||||||
assert gradients, "The model has no gradients to compute the norm."
|
|
||||||
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
|
||||||
return total_norm # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
|
def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
||||||
from refiners.training_utils.gradient_clipping import GradientClippingConfig
|
|
||||||
|
|
||||||
# PyTorch optimizer parameters type
|
# PyTorch optimizer parameters type
|
||||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||||
|
@ -28,6 +27,7 @@ class TrainingConfig(BaseModel):
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
|
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
|
||||||
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
|
gradient_clipping_max_norm: float | None = None
|
||||||
evaluation_seed: int = 0
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
@ -167,7 +167,6 @@ class BaseConfig(BaseModel):
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
lr_scheduler: LRSchedulerConfig
|
lr_scheduler: LRSchedulerConfig
|
||||||
clock: ClockConfig = ClockConfig()
|
clock: ClockConfig = ClockConfig()
|
||||||
gradient_clipping: GradientClippingConfig = GradientClippingConfig()
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
from typing import TYPE_CHECKING, Any, Iterable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from refiners.training_utils.config import BaseConfig
|
|
||||||
from refiners.training_utils.trainer import Trainer
|
|
||||||
|
|
||||||
|
|
||||||
def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
|
||||||
"""
|
|
||||||
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
|
||||||
"""
|
|
||||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
|
||||||
assert gradients, "The model has no gradients to clip."
|
|
||||||
clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
|
||||||
for gradient in gradients:
|
|
||||||
gradient.mul_(other=clip_coefficient) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None:
|
|
||||||
"""
|
|
||||||
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
|
||||||
"""
|
|
||||||
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
|
||||||
assert gradients, "The model has no gradients to clip."
|
|
||||||
for gradient in gradients:
|
|
||||||
gradient.clamp_(min=-clip_value, max=clip_value)
|
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingConfig(CallbackConfig):
|
|
||||||
clip_grad_norm: float | None = None
|
|
||||||
clip_grad_value: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class GradientClipping(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def __init__(self, config: GradientClippingConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
clip_norm = self.config.clip_grad_norm
|
|
||||||
if clip_norm is not None:
|
|
||||||
clip_gradient_norm(
|
|
||||||
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_value = self.config.clip_grad_value
|
|
||||||
if clip_value is not None:
|
|
||||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
|
|
@ -37,7 +37,6 @@ from refiners.training_utils.common import (
|
||||||
scoped_seed,
|
scoped_seed,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
|
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
|
||||||
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupScheduler(LRScheduler):
|
class WarmupScheduler(LRScheduler):
|
||||||
|
@ -161,10 +160,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
verbose=config.verbose,
|
verbose=config.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_callback()
|
|
||||||
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
|
|
||||||
return GradientClipping(config)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models(self) -> ModelRegistry:
|
def models(self) -> ModelRegistry:
|
||||||
return self._models
|
return self._models
|
||||||
|
@ -351,6 +346,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
self._call_callbacks(event_name="on_backward_end")
|
self._call_callbacks(event_name="on_backward_end")
|
||||||
if self.clock.is_optimizer_step:
|
if self.clock.is_optimizer_step:
|
||||||
self._call_callbacks(event_name="on_optimizer_step_begin")
|
self._call_callbacks(event_name="on_optimizer_step_begin")
|
||||||
|
max_norm = self.config.training.gradient_clipping_max_norm or float("inf")
|
||||||
|
self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self._call_callbacks(event_name="on_optimizer_step_end")
|
self._call_callbacks(event_name="on_optimizer_step_end")
|
||||||
|
|
|
@ -112,6 +112,7 @@ class WandbCallback(Callback["TrainerWithWandb"]):
|
||||||
trainer.wandb_log(data={"step_loss": loss_value})
|
trainer.wandb_log(data={"step_loss": loss_value})
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
|
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
trainer.wandb_log(data={"total_grad_norm": trainer.grad_norm})
|
||||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||||
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
|
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
|
||||||
self.iteration_losses = []
|
self.iteration_losses = []
|
||||||
|
@ -124,9 +125,6 @@ class WandbCallback(Callback["TrainerWithWandb"]):
|
||||||
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
|
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||||
|
|
||||||
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
|
|
||||||
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
|
|
||||||
|
|
||||||
|
|
||||||
class WandbMixin(ABC):
|
class WandbMixin(ABC):
|
||||||
config: Any
|
config: Any
|
||||||
|
|
|
@ -5,9 +5,6 @@ use_activation = true
|
||||||
[clock]
|
[clock]
|
||||||
verbose = false
|
verbose = false
|
||||||
|
|
||||||
[gradient_clipping]
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
|
@ -17,6 +14,7 @@ batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
evaluation_interval = "5:epoch"
|
evaluation_interval = "5:epoch"
|
||||||
evaluation_seed = 1
|
evaluation_seed = 1
|
||||||
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
|
|
|
@ -8,9 +8,6 @@ requires_grad = true
|
||||||
[clock]
|
[clock]
|
||||||
verbose = false
|
verbose = false
|
||||||
|
|
||||||
[gradient_clipping]
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
|
@ -18,6 +15,7 @@ batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
evaluation_interval = "5:epoch"
|
evaluation_interval = "5:epoch"
|
||||||
evaluation_seed = 1
|
evaluation_seed = 1
|
||||||
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
|
|
Loading…
Reference in a new issue