mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
move image tensor normalize under fluxion's utils
This commit is contained in:
parent
91ac2353e7
commit
d6046e1fbf
|
@ -4,7 +4,7 @@ from numpy import array, float32
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open as _safe_open # type: ignore
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
from torch import as_tensor, norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
@ -34,6 +34,31 @@ def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") ->
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
|
||||||
|
assert tensor.is_floating_point()
|
||||||
|
assert tensor.ndim >= 3
|
||||||
|
|
||||||
|
if not inplace:
|
||||||
|
tensor = tensor.clone()
|
||||||
|
|
||||||
|
dtype = tensor.dtype
|
||||||
|
|
||||||
|
mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
if (std_tensor == 0).any():
|
||||||
|
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||||
|
|
||||||
|
if mean_tensor.ndim == 1:
|
||||||
|
mean_tensor = mean_tensor.view(-1, 1, 1)
|
||||||
|
|
||||||
|
if std_tensor.ndim == 1:
|
||||||
|
std_tensor = std_tensor.view(-1, 1, 1)
|
||||||
|
|
||||||
|
return tensor.sub_(mean_tensor).div_(std_tensor)
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
||||||
0
|
0
|
||||||
|
|
|
@ -2,7 +2,7 @@ from enum import IntEnum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
|
from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
|
from torch import Tensor, cat, zeros_like, device as Device, dtype as DType
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -10,7 +10,7 @@ from refiners.fluxion.adapters.lora import Lora
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor
|
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -228,33 +228,8 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
std: list[float] | None = None,
|
std: list[float] | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# Default mean and std are parameters from https://github.com/openai/CLIP
|
# Default mean and std are parameters from https://github.com/openai/CLIP
|
||||||
return self._normalize(
|
return normalize(
|
||||||
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||||
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
|
||||||
@staticmethod
|
|
||||||
def _normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
|
|
||||||
assert tensor.is_floating_point()
|
|
||||||
assert tensor.ndim >= 3
|
|
||||||
|
|
||||||
if not inplace:
|
|
||||||
tensor = tensor.clone()
|
|
||||||
|
|
||||||
dtype = tensor.dtype
|
|
||||||
|
|
||||||
mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
|
||||||
std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
||||||
|
|
||||||
if (std_tensor == 0).any():
|
|
||||||
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
|
||||||
|
|
||||||
if mean_tensor.ndim == 1:
|
|
||||||
mean_tensor = mean_tensor.view(-1, 1, 1)
|
|
||||||
|
|
||||||
if std_tensor.ndim == 1:
|
|
||||||
std_tensor = std_tensor.view(-1, 1, 1)
|
|
||||||
|
|
||||||
return tensor.sub_(mean_tensor).div_(std_tensor)
|
|
||||||
|
|
Loading…
Reference in a new issue