mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
utils: remove inplace opt-in from normalize
This commit is contained in:
parent
bce3910383
commit
e319f13d05
|
@ -37,21 +37,18 @@ def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") ->
|
||||||
|
|
||||||
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
def normalize(
|
def normalize(
|
||||||
tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float], inplace: bool = False
|
tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float]
|
||||||
) -> Float[Tensor, "*batch channels height width"]:
|
) -> Float[Tensor, "*batch channels height width"]:
|
||||||
assert tensor.is_floating_point()
|
assert tensor.is_floating_point()
|
||||||
assert tensor.ndim >= 3
|
assert tensor.ndim >= 3
|
||||||
|
|
||||||
if not inplace:
|
|
||||||
tensor = tensor.clone()
|
|
||||||
|
|
||||||
dtype = tensor.dtype
|
dtype = tensor.dtype
|
||||||
mean_tensor = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
pixel_mean = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
||||||
std_tensor = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
pixel_std = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
||||||
if (std_tensor == 0).any():
|
if (pixel_std == 0).any():
|
||||||
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||||
|
|
||||||
return tensor.sub_(mean_tensor).div_(std_tensor)
|
return (tensor - pixel_mean) / pixel_std
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
|
|
Loading…
Reference in a new issue