utils: simplify normalize a bit

This commit is contained in:
Cédric Deltheil 2023-09-18 15:03:35 +02:00 committed by Cédric Deltheil
parent d6046e1fbf
commit bce3910383

View file

@ -4,10 +4,11 @@ from numpy import array, float32
from pathlib import Path
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import as_tensor, norm as _norm, manual_seed as _manual_seed # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
import torch
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch import Tensor, device as Device, dtype as DType
from jaxtyping import Float
T = TypeVar("T")
@ -35,7 +36,9 @@ def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") ->
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
def normalize(
tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float], inplace: bool = False
) -> Float[Tensor, "*batch channels height width"]:
assert tensor.is_floating_point()
assert tensor.ndim >= 3
@ -43,19 +46,11 @@ def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool
tensor = tensor.clone()
dtype = tensor.dtype
mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device)
mean_tensor = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1)
std_tensor = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1)
if (std_tensor == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean_tensor.ndim == 1:
mean_tensor = mean_tensor.view(-1, 1, 1)
if std_tensor.ndim == 1:
std_tensor = std_tensor.view(-1, 1, 1)
return tensor.sub_(mean_tensor).div_(std_tensor)