diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index b0d83c5..f231bca 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -4,7 +4,7 @@ from numpy import array, float32 from pathlib import Path from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import norm as _norm, manual_seed as _manual_seed # type: ignore +from torch import as_tensor, norm as _norm, manual_seed as _manual_seed # type: ignore import torch from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore from torch import Tensor, device as Device, dtype as DType @@ -34,6 +34,31 @@ def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> ) # type: ignore +# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py +def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor: + assert tensor.is_floating_point() + assert tensor.ndim >= 3 + + if not inplace: + tensor = tensor.clone() + + dtype = tensor.dtype + + mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device) + + if (std_tensor == 0).any(): + raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + + if mean_tensor.ndim == 1: + mean_tensor = mean_tensor.view(-1, 1, 1) + + if std_tensor.ndim == 1: + std_tensor = std_tensor.view(-1, 1, 1) + + return tensor.sub_(mean_tensor).div_(std_tensor) + + def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze( 0 diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 2250c2e..9dc358d 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -2,7 +2,7 @@ from enum import IntEnum from functools import partial from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING -from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType +from torch import Tensor, cat, zeros_like, device as Device, dtype as DType from PIL import Image from refiners.fluxion.adapters.adapter import Adapter @@ -10,7 +10,7 @@ from refiners.fluxion.adapters.lora import Lora from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.fluxion.layers.attentions import ScaledDotProductAttention -from refiners.fluxion.utils import image_to_tensor +from refiners.fluxion.utils import image_to_tensor, normalize import refiners.fluxion.layers as fl if TYPE_CHECKING: @@ -228,33 +228,8 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): std: list[float] | None = None, ) -> Tensor: # Default mean and std are parameters from https://github.com/openai/CLIP - return self._normalize( + return normalize( image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype), mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, std=[0.26862954, 0.26130258, 0.27577711] if std is None else std, ) - - # Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py - @staticmethod - def _normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor: - assert tensor.is_floating_point() - assert tensor.ndim >= 3 - - if not inplace: - tensor = tensor.clone() - - dtype = tensor.dtype - - mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device) - std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device) - - if (std_tensor == 0).any(): - raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") - - if mean_tensor.ndim == 1: - mean_tensor = mean_tensor.view(-1, 1, 1) - - if std_tensor.ndim == 1: - std_tensor = std_tensor.view(-1, 1, 1) - - return tensor.sub_(mean_tensor).div_(std_tensor)