remove a couple from torch import ... from the code

This commit is contained in:
Laurent 2024-08-21 08:46:46 +00:00 committed by Laureηt
parent 45143e2851
commit 2cb0f06119
13 changed files with 76 additions and 80 deletions

View file

@ -6,7 +6,7 @@ from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
import torch import torch
from torch import Tensor, cat, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import ContextProvider, Contexts from refiners.fluxion.context import ContextProvider, Contexts
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
@ -950,7 +950,7 @@ class Concatenate(Chain):
def forward(self, *args: Any) -> Tensor: def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self] outputs = [module(*args) for module in self]
return cat( return torch.cat(
[output for output in outputs if output is not None], [output for output in outputs if output is not None],
dim=self.dim, dim=self.dim,
) )

View file

@ -1,5 +1,6 @@
import torch
from jaxtyping import Float from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros from torch import Tensor, device as Device, dtype as DType
from torch.nn import ( from torch.nn import (
GroupNorm as _GroupNorm, GroupNorm as _GroupNorm,
InstanceNorm2d as _InstanceNorm2d, InstanceNorm2d as _InstanceNorm2d,
@ -111,8 +112,8 @@ class LayerNorm2d(WeightedModule):
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype)) self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype)) self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
self.eps = eps self.eps = eps
def forward( def forward(
@ -121,7 +122,7 @@ class LayerNorm2d(WeightedModule):
) -> Float[Tensor, "batch channels height width"]: ) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True) x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True) x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps) x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1) x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out return x_out

View file

@ -8,15 +8,7 @@ from numpy import array, float32
from PIL import Image from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore from safetensors.torch import save_file as _save_file # type: ignore
from torch import ( from torch import Tensor, device as Device, dtype as DType
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
T = TypeVar("T") T = TypeVar("T")
@ -24,14 +16,14 @@ E = TypeVar("E")
def norm(x: Tensor) -> Tensor: def norm(x: Tensor) -> Tensor:
return _norm(x) # type: ignore return torch.norm(x) # type: ignore
def manual_seed(seed: int) -> None: def manual_seed(seed: int) -> None:
_manual_seed(seed) torch.manual_seed(seed) # type: ignore
class no_grad(_no_grad): class no_grad(torch.no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls) return object.__new__(cls)
@ -123,7 +115,7 @@ def gaussian_blur(
def images_to_tensor( def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor: ) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images]) return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:

View file

@ -1,4 +1,5 @@
from torch import Tensor, arange, device as Device, dtype as DType import torch
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -25,7 +26,7 @@ class PositionalEncoder(fl.Chain):
@property @property
def position_ids(self) -> Tensor: def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor: def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]] return self.position_ids[:, : x.shape[1]]

View file

@ -1,8 +1,9 @@
import re import re
from typing import cast from typing import cast
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, cat, zeros from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -22,7 +23,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup)) super().__init__(fl.Lambda(func=self.lookup))
p = Parameter( p = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default ) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight) self.old_weight = cast(Parameter, target.weight)
self.new_weight = p self.new_weight = p
@ -30,11 +31,18 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor: def lookup(self, x: Tensor) -> Tensor:
# Concatenate old and new weights for dynamic embedding updates during training # Concatenate old and new weights for dynamic embedding updates during training
return F.embedding(x, cat([self.old_weight, self.new_weight])) return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))
def add_embedding(self, embedding: Tensor) -> None: def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],) assert embedding.shape == (self.old_weight.shape[1],)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])) p = Parameter(
torch.cat(
[
self.new_weight,
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
]
)
)
self.new_weight = p self.new_weight = p
@property @property

View file

@ -1,9 +1,10 @@
import math import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
import torch
from jaxtyping import Float from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like from torch import Tensor, device as Device, dtype as DType, nn
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -98,7 +99,7 @@ class PerceiverScaledDotProductAttention(fl.Module):
v = self.reshape_tensor(value) v = self.reshape_tensor(value)
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1) attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype) attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v attention = attention @ v
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1) return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
@ -159,7 +160,7 @@ class PerceiverAttention(fl.Chain):
) )
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor: def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2) return torch.cat((x, latents), dim=-2)
class LatentsToken(fl.Chain): class LatentsToken(fl.Chain):
@ -484,7 +485,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
image_prompt = self.preprocess_image(image_prompt) image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list): elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt) assert all(isinstance(image, Image.Image) for image in image_prompt)
image_prompt = cat([self.preprocess_image(image) for image in image_prompt]) image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt) negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
@ -493,7 +494,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images" assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
if any(weight != 1.0 for weight in weights): if any(weight != 1.0 for weight in weights):
conditional_embedding *= ( conditional_embedding *= (
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype) torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
.unsqueeze(-1) .unsqueeze(-1)
.unsqueeze(-1) .unsqueeze(-1)
) )
@ -501,20 +502,20 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
if batch_size > 1 and concat_batches: if batch_size > 1 and concat_batches:
# Create a longer image tokens sequence when a batch of images is given # Create a longer image tokens sequence when a batch of images is given
# See https://github.com/tencent-ailab/IP-Adapter/issues/99 # See https://github.com/tencent-ailab/IP-Adapter/issues/99
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1) negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1) conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)
return cat((negative_embedding, conditional_embedding)) return torch.cat((negative_embedding, conditional_embedding))
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]: def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt) clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding) conditional_embedding = self.image_proj(clip_embedding)
if not self.fine_grained: if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding)) negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
else: else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352 # See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt)) clip_embedding = image_encoder(torch.zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding) negative_embedding = self.image_proj(clip_embedding)
return negative_embedding, conditional_embedding return negative_embedding, conditional_embedding

View file

@ -1,7 +1,8 @@
import math import math
import torch
from jaxtyping import Float, Int from jaxtyping import Float, Int
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32. # Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary. # The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device) exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
exponent /= half_dim exponent /= half_dim
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0) embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
embedding = cat([cos(embedding), sin(embedding)], dim=-1) embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
return embedding return embedding

View file

@ -1,6 +1,7 @@
import dataclasses import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams, BaseSolverParams,
@ -28,7 +29,7 @@ class DDIM(Solver):
first_inference_step: int = 0, first_inference_step: int = 0,
params: BaseSolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = torch.float32,
) -> None: ) -> None:
"""Initializes a new DDIM solver. """Initializes a new DDIM solver.
@ -71,7 +72,7 @@ class DDIM(Solver):
( (
self.timesteps[step + 1] self.timesteps[step + 1]
if step < self.num_inference_steps - 1 if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype) else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
), ),
) )
current_scale_factor, previous_scale_factor = ( current_scale_factor, previous_scale_factor = (
@ -82,8 +83,8 @@ class DDIM(Solver):
else self.cumulative_scale_factors[0] else self.cumulative_scale_factors[0]
), ),
) )
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = sqrt(1 - previous_scale_factor**2) noise_factor = torch.sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts. # Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1: if step == self.num_inference_steps - 1:

View file

@ -3,7 +3,7 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams, BaseSolverParams,
@ -38,7 +38,7 @@ class DPMSolver(Solver):
params: BaseSolverParams | None = None, params: BaseSolverParams | None = None,
last_step_first_order: bool = False, last_step_first_order: bool = False,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = torch.float32,
): ):
"""Initializes a new DPM solver. """Initializes a new DPM solver.
@ -62,7 +62,7 @@ class DPMSolver(Solver):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
def rebuild( def rebuild(
@ -94,7 +94,7 @@ class DPMSolver(Solver):
offset = self.params.timesteps_offset offset = self.params.timesteps_offset
max_timestep = self.params.num_train_timesteps - 1 + offset max_timestep = self.params.num_train_timesteps - 1 + offset
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0) return torch.tensor(np_space).flip(0)
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
@ -110,7 +110,7 @@ class DPMSolver(Solver):
The denoised version of the input data `x`. The denoised version of the input data `x`.
""" """
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
previous_ratio = self.signal_to_noise_ratios[previous_timestep] previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep] current_ratio = self.signal_to_noise_ratios[current_timestep]
@ -144,7 +144,7 @@ class DPMSolver(Solver):
Returns: Returns:
The denoised version of the input data `x`. The denoised version of the input data `x`.
""" """
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1] next_timestep = self.timesteps[step - 1]

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams, BaseSolverParams,
@ -23,7 +23,7 @@ class Euler(Solver):
first_inference_step: int = 0, first_inference_step: int = 0,
params: BaseSolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = torch.float32,
): ):
"""Initializes a new Euler solver. """Initializes a new Euler solver.
@ -57,7 +57,7 @@ class Euler(Solver):
"""Generate the sigmas used by the solver.""" """Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu())) sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
sigmas = torch.cat([sigmas, tensor([0.0])]) sigmas = torch.cat([sigmas, torch.tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype) return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:

View file

@ -1,7 +1,8 @@
import dataclasses import dataclasses
from typing import Any, Callable, Protocol, TypeVar from typing import Any, Callable, Protocol, TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32 import torch
from torch import Generator, Tensor, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
@ -60,7 +61,7 @@ class FrankenSolver(Solver):
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = torch.float32,
**kwargs: Any, # for typing, ignored **kwargs: Any, # for typing, ignored
) -> None: ) -> None:
self.get_diffusers_scheduler = get_diffusers_scheduler self.get_diffusers_scheduler = get_diffusers_scheduler

View file

@ -4,19 +4,8 @@ from enum import Enum
from typing import TypeVar from typing import TypeVar
import numpy as np import numpy as np
from torch import ( import torch
Generator, from torch import Generator, Tensor, device as Device, dtype as DType
Tensor,
arange,
device as Device,
dtype as DType,
float32,
linspace,
log,
sqrt,
stack,
tensor,
)
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
@ -161,7 +150,7 @@ class Solver(fl.Module, ABC):
first_inference_step: int = 0, first_inference_step: int = 0,
params: BaseSolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = torch.float32,
) -> None: ) -> None:
"""Initializes a new `Solver` instance. """Initializes a new `Solver` instance.
@ -179,9 +168,9 @@ class Solver(fl.Module, ABC):
self.params = self.resolve_params(params) self.params = self.resolve_params(params)
self.scale_factors = self.sample_noise_schedule() self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)
self.timesteps = self._generate_timesteps() self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype) self.to(device=device, dtype=dtype)
@ -227,16 +216,16 @@ class Solver(fl.Module, ABC):
max_timestep = num_train_timesteps - 1 + offset max_timestep = num_train_timesteps - 1 + offset
match spacing: match spacing:
case TimestepSpacing.LINSPACE: case TimestepSpacing.LINSPACE:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0) return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)
case TimestepSpacing.LINSPACE_ROUNDED: case TimestepSpacing.LINSPACE_ROUNDED:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING: case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps step_ratio = num_train_timesteps // num_inference_steps
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0) return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
case TimestepSpacing.TRAILING: case TimestepSpacing.TRAILING:
step_ratio = num_train_timesteps // num_inference_steps step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio) return torch.arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.CUSTOM: case TimestepSpacing.CUSTOM:
raise RuntimeError("generate_timesteps called with custom spacing") raise RuntimeError("generate_timesteps called with custom spacing")
@ -290,7 +279,7 @@ class Solver(fl.Module, ABC):
""" """
if isinstance(step, list): if isinstance(step, list):
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length" assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
return stack( return torch.stack(
tensors=[ tensors=[
self._add_noise( self._add_noise(
x=x[i], x=x[i],
@ -400,7 +389,7 @@ class Solver(fl.Module, ABC):
A tensor representing the power distribution between the initial and final diffusion rates of the solver. A tensor representing the power distribution between the initial and final diffusion rates of the solver.
""" """
return ( return (
linspace( torch.linspace(
start=self.params.initial_diffusion_rate ** (1 / power), start=self.params.initial_diffusion_rate ** (1 / power),
end=self.params.final_diffusion_rate ** (1 / power), end=self.params.final_diffusion_rate ** (1 / power),
steps=self.params.num_train_timesteps, steps=self.params.num_train_timesteps,

View file

@ -1,7 +1,8 @@
from typing import cast from typing import cast
import torch
from jaxtyping import Float from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType, split from torch import Tensor, cat, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -48,7 +49,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return self.ensure_find(CLIPTokenizer) return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
for str_tokens in split(tokens, 1): for str_tokens in torch.split(tokens, 1):
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
end_of_text_index.append(cast(int, position)) end_of_text_index.append(cast(int, position))