diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 0c5c2f9..f0331a7 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -6,7 +6,7 @@ from collections import defaultdict from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload import torch -from torch import Tensor, cat, device as Device, dtype as DType +from torch import Tensor, device as Device, dtype as DType from refiners.fluxion.context import ContextProvider, Contexts from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule @@ -950,7 +950,7 @@ class Concatenate(Chain): def forward(self, *args: Any) -> Tensor: outputs = [module(*args) for module in self] - return cat( + return torch.cat( [output for output in outputs if output is not None], dim=self.dim, ) diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index bc7f0dd..56546cd 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -1,5 +1,6 @@ +import torch from jaxtyping import Float -from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros +from torch import Tensor, device as Device, dtype as DType from torch.nn import ( GroupNorm as _GroupNorm, InstanceNorm2d as _InstanceNorm2d, @@ -111,8 +112,8 @@ class LayerNorm2d(WeightedModule): dtype: DType | None = None, ) -> None: super().__init__() - self.weight = TorchParameter(ones(channels, device=device, dtype=dtype)) - self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype)) + self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype)) + self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype)) self.eps = eps def forward( @@ -121,7 +122,7 @@ class LayerNorm2d(WeightedModule): ) -> Float[Tensor, "batch channels height width"]: x_mean = x.mean(1, keepdim=True) x_var = (x - x_mean).pow(2).mean(1, keepdim=True) - x_norm = (x - x_mean) / sqrt(x_var + self.eps) + x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps) x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1) return x_out diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 31bc41f..824cae8 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -8,15 +8,7 @@ from numpy import array, float32 from PIL import Image from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import ( - Tensor, - cat, - device as Device, - dtype as DType, - manual_seed as _manual_seed, # type: ignore - no_grad as _no_grad, # type: ignore - norm as _norm, # type: ignore -) +from torch import Tensor, device as Device, dtype as DType from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore T = TypeVar("T") @@ -24,14 +16,14 @@ E = TypeVar("E") def norm(x: Tensor) -> Tensor: - return _norm(x) # type: ignore + return torch.norm(x) # type: ignore def manual_seed(seed: int) -> None: - _manual_seed(seed) + torch.manual_seed(seed) # type: ignore -class no_grad(_no_grad): +class no_grad(torch.no_grad): def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore return object.__new__(cls) @@ -123,7 +115,7 @@ def gaussian_blur( def images_to_tensor( images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None ) -> Tensor: - return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images]) + return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images]) def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: diff --git a/src/refiners/foundationals/clip/common.py b/src/refiners/foundationals/clip/common.py index 7a10949..347d2be 100644 --- a/src/refiners/foundationals/clip/common.py +++ b/src/refiners/foundationals/clip/common.py @@ -1,4 +1,5 @@ -from torch import Tensor, arange, device as Device, dtype as DType +import torch +from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl @@ -25,7 +26,7 @@ class PositionalEncoder(fl.Chain): @property def position_ids(self) -> Tensor: - return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) + return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) def get_position_ids(self, x: Tensor) -> Tensor: return self.position_ids[:, : x.shape[1]] diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index bda9ce2..94a171e 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -1,8 +1,9 @@ import re from typing import cast +import torch import torch.nn.functional as F -from torch import Tensor, cat, zeros +from torch import Tensor from torch.nn import Parameter import refiners.fluxion.layers as fl @@ -22,7 +23,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): with self.setup_adapter(target): super().__init__(fl.Lambda(func=self.lookup)) p = Parameter( - zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) + torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) ) # requires_grad=True by default self.old_weight = cast(Parameter, target.weight) self.new_weight = p @@ -30,11 +31,18 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings def lookup(self, x: Tensor) -> Tensor: # Concatenate old and new weights for dynamic embedding updates during training - return F.embedding(x, cat([self.old_weight, self.new_weight])) + return F.embedding(x, torch.cat([self.old_weight, self.new_weight])) def add_embedding(self, embedding: Tensor) -> None: assert embedding.shape == (self.old_weight.shape[1],) - p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])) + p = Parameter( + torch.cat( + [ + self.new_weight, + embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype), + ] + ) + ) self.new_weight = p @property diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 72b61f1..6cf95ee 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,9 +1,10 @@ import math from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +import torch from jaxtyping import Float from PIL import Image -from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like +from torch import Tensor, device as Device, dtype as DType, nn import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -98,7 +99,7 @@ class PerceiverScaledDotProductAttention(fl.Module): v = self.reshape_tensor(value) attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1) - attention = softmax(input=attention.float(), dim=-1).type(attention.dtype) + attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype) attention = attention @ v return attention.permute(0, 2, 1, 3).reshape(bs, length, -1) @@ -159,7 +160,7 @@ class PerceiverAttention(fl.Chain): ) def to_kv(self, x: Tensor, latents: Tensor) -> Tensor: - return cat((x, latents), dim=-2) + return torch.cat((x, latents), dim=-2) class LatentsToken(fl.Chain): @@ -484,7 +485,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): image_prompt = self.preprocess_image(image_prompt) elif isinstance(image_prompt, list): assert all(isinstance(image, Image.Image) for image in image_prompt) - image_prompt = cat([self.preprocess_image(image) for image in image_prompt]) + image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt]) negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt) @@ -493,7 +494,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images" if any(weight != 1.0 for weight in weights): conditional_embedding *= ( - tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype) + torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype) .unsqueeze(-1) .unsqueeze(-1) ) @@ -501,20 +502,20 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): if batch_size > 1 and concat_batches: # Create a longer image tokens sequence when a batch of images is given # See https://github.com/tencent-ailab/IP-Adapter/issues/99 - negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1) - conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1) + negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1) + conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1) - return cat((negative_embedding, conditional_embedding)) + return torch.cat((negative_embedding, conditional_embedding)) def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]: image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder clip_embedding = image_encoder(image_prompt) conditional_embedding = self.image_proj(clip_embedding) if not self.fine_grained: - negative_embedding = self.image_proj(zeros_like(clip_embedding)) + negative_embedding = self.image_proj(torch.zeros_like(clip_embedding)) else: # See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352 - clip_embedding = image_encoder(zeros_like(image_prompt)) + clip_embedding = image_encoder(torch.zeros_like(image_prompt)) negative_embedding = self.image_proj(clip_embedding) return negative_embedding, conditional_embedding diff --git a/src/refiners/foundationals/latent_diffusion/range_adapter.py b/src/refiners/foundationals/latent_diffusion/range_adapter.py index 67d054e..3977f02 100644 --- a/src/refiners/foundationals/latent_diffusion/range_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/range_adapter.py @@ -1,7 +1,8 @@ import math +import torch from jaxtyping import Float, Int -from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin +from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -14,10 +15,10 @@ def compute_sinusoidal_embedding( half_dim = embedding_dim // 2 # Note: it is important that this computation is done in float32. # The result can be cast to lower precision later if necessary. - exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device) + exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device) exponent /= half_dim - embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0) - embedding = cat([cos(embedding), sin(embedding)], dim=-1) + embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0) + embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1) return embedding diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 31e64b9..77423de 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -1,6 +1,7 @@ import dataclasses -from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor +import torch +from torch import Generator, Tensor, device as Device, dtype as Dtype from refiners.foundationals.latent_diffusion.solvers.solver import ( BaseSolverParams, @@ -28,7 +29,7 @@ class DDIM(Solver): first_inference_step: int = 0, params: BaseSolverParams | None = None, device: Device | str = "cpu", - dtype: Dtype = float32, + dtype: Dtype = torch.float32, ) -> None: """Initializes a new DDIM solver. @@ -71,7 +72,7 @@ class DDIM(Solver): ( self.timesteps[step + 1] if step < self.num_inference_steps - 1 - else tensor(data=[0], device=self.device, dtype=self.dtype) + else torch.tensor(data=[0], device=self.device, dtype=self.dtype) ), ) current_scale_factor, previous_scale_factor = ( @@ -82,8 +83,8 @@ class DDIM(Solver): else self.cumulative_scale_factors[0] ), ) - predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor - noise_factor = sqrt(1 - previous_scale_factor**2) + predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor + noise_factor = torch.sqrt(1 - previous_scale_factor**2) # Do not add noise at the last step to avoid visual artifacts. if step == self.num_inference_steps - 1: diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index c7b808b..b729829 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -3,7 +3,7 @@ from collections import deque import numpy as np import torch -from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor +from torch import Generator, Tensor, device as Device, dtype as Dtype from refiners.foundationals.latent_diffusion.solvers.solver import ( BaseSolverParams, @@ -38,7 +38,7 @@ class DPMSolver(Solver): params: BaseSolverParams | None = None, last_step_first_order: bool = False, device: Device | str = "cpu", - dtype: Dtype = float32, + dtype: Dtype = torch.float32, ): """Initializes a new DPM solver. @@ -62,7 +62,7 @@ class DPMSolver(Solver): device=device, dtype=dtype, ) - self.estimated_data = deque([tensor([])] * 2, maxlen=2) + self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2) self.last_step_first_order = last_step_first_order def rebuild( @@ -94,7 +94,7 @@ class DPMSolver(Solver): offset = self.params.timesteps_offset max_timestep = self.params.num_train_timesteps - 1 + offset np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] - return tensor(np_space).flip(0) + return torch.tensor(np_space).flip(0) def dpm_solver_first_order_update( self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None @@ -110,7 +110,7 @@ class DPMSolver(Solver): The denoised version of the input data `x`. """ current_timestep = self.timesteps[step] - previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) + previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0]) previous_ratio = self.signal_to_noise_ratios[previous_timestep] current_ratio = self.signal_to_noise_ratios[current_timestep] @@ -144,7 +144,7 @@ class DPMSolver(Solver): Returns: The denoised version of the input data `x`. """ - previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) + previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0]) current_timestep = self.timesteps[step] next_timestep = self.timesteps[step - 1] diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 422a242..1f2aa7c 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -1,6 +1,6 @@ import numpy as np import torch -from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor +from torch import Generator, Tensor, device as Device, dtype as Dtype from refiners.foundationals.latent_diffusion.solvers.solver import ( BaseSolverParams, @@ -23,7 +23,7 @@ class Euler(Solver): first_inference_step: int = 0, params: BaseSolverParams | None = None, device: Device | str = "cpu", - dtype: Dtype = float32, + dtype: Dtype = torch.float32, ): """Initializes a new Euler solver. @@ -57,7 +57,7 @@ class Euler(Solver): """Generate the sigmas used by the solver.""" sigmas = self.noise_std / self.cumulative_scale_factors sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu())) - sigmas = torch.cat([sigmas, tensor([0.0])]) + sigmas = torch.cat([sigmas, torch.tensor([0.0])]) return sigmas.to(device=self.device, dtype=self.dtype) def scale_model_input(self, x: Tensor, step: int) -> Tensor: diff --git a/src/refiners/foundationals/latent_diffusion/solvers/franken.py b/src/refiners/foundationals/latent_diffusion/solvers/franken.py index b661c6c..255afe9 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/franken.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/franken.py @@ -1,7 +1,8 @@ import dataclasses from typing import Any, Callable, Protocol, TypeVar -from torch import Generator, Tensor, device as Device, dtype as DType, float32 +import torch +from torch import Generator, Tensor, device as Device, dtype as DType from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing @@ -60,7 +61,7 @@ class FrankenSolver(Solver): num_inference_steps: int, first_inference_step: int = 0, device: Device | str = "cpu", - dtype: DType = float32, + dtype: DType = torch.float32, **kwargs: Any, # for typing, ignored ) -> None: self.get_diffusers_scheduler = get_diffusers_scheduler diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 6ff2be9..088a149 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -4,19 +4,8 @@ from enum import Enum from typing import TypeVar import numpy as np -from torch import ( - Generator, - Tensor, - arange, - device as Device, - dtype as DType, - float32, - linspace, - log, - sqrt, - stack, - tensor, -) +import torch +from torch import Generator, Tensor, device as Device, dtype as DType from refiners.fluxion import layers as fl @@ -161,7 +150,7 @@ class Solver(fl.Module, ABC): first_inference_step: int = 0, params: BaseSolverParams | None = None, device: Device | str = "cpu", - dtype: DType = float32, + dtype: DType = torch.float32, ) -> None: """Initializes a new `Solver` instance. @@ -179,9 +168,9 @@ class Solver(fl.Module, ABC): self.params = self.resolve_params(params) self.scale_factors = self.sample_noise_schedule() - self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) - self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) - self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) + self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0)) + self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0)) + self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std) self.timesteps = self._generate_timesteps() self.to(device=device, dtype=dtype) @@ -227,16 +216,16 @@ class Solver(fl.Module, ABC): max_timestep = num_train_timesteps - 1 + offset match spacing: case TimestepSpacing.LINSPACE: - return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0) + return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0) case TimestepSpacing.LINSPACE_ROUNDED: - return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) + return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) case TimestepSpacing.LEADING: step_ratio = num_train_timesteps // num_inference_steps - return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0) + return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0) case TimestepSpacing.TRAILING: step_ratio = num_train_timesteps // num_inference_steps max_timestep = num_train_timesteps - 1 + offset - return arange(max_timestep, offset, -step_ratio) + return torch.arange(max_timestep, offset, -step_ratio) case TimestepSpacing.CUSTOM: raise RuntimeError("generate_timesteps called with custom spacing") @@ -290,7 +279,7 @@ class Solver(fl.Module, ABC): """ if isinstance(step, list): assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length" - return stack( + return torch.stack( tensors=[ self._add_noise( x=x[i], @@ -400,7 +389,7 @@ class Solver(fl.Module, ABC): A tensor representing the power distribution between the initial and final diffusion rates of the solver. """ return ( - linspace( + torch.linspace( start=self.params.initial_diffusion_rate ** (1 / power), end=self.params.final_diffusion_rate ** (1 / power), steps=self.params.num_train_timesteps, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 39051a8..6050d8e 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -1,7 +1,8 @@ from typing import cast +import torch from jaxtyping import Float -from torch import Tensor, cat, device as Device, dtype as DType, split +from torch import Tensor, cat, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -48,7 +49,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): return self.ensure_find(CLIPTokenizer) def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: - for str_tokens in split(tokens, 1): + for str_tokens in torch.split(tokens, 1): position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore end_of_text_index.append(cast(int, position))