From bce391038359abe3e630aff45554a65acd1c0fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Mon, 18 Sep 2023 15:03:35 +0200 Subject: [PATCH] utils: simplify normalize a bit --- src/refiners/fluxion/utils.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index f231bca..f77a7bb 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -4,10 +4,11 @@ from numpy import array, float32 from pathlib import Path from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import as_tensor, norm as _norm, manual_seed as _manual_seed # type: ignore +from torch import norm as _norm, manual_seed as _manual_seed # type: ignore import torch from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore from torch import Tensor, device as Device, dtype as DType +from jaxtyping import Float T = TypeVar("T") @@ -35,7 +36,9 @@ def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> # Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py -def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor: +def normalize( + tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float], inplace: bool = False +) -> Float[Tensor, "*batch channels height width"]: assert tensor.is_floating_point() assert tensor.ndim >= 3 @@ -43,19 +46,11 @@ def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool tensor = tensor.clone() dtype = tensor.dtype - - mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device) - std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device) - + mean_tensor = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1) + std_tensor = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1) if (std_tensor == 0).any(): raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") - if mean_tensor.ndim == 1: - mean_tensor = mean_tensor.view(-1, 1, 1) - - if std_tensor.ndim == 1: - std_tensor = std_tensor.view(-1, 1, 1) - return tensor.sub_(mean_tensor).div_(std_tensor)