remove a couple from torch import ... from the code

This commit is contained in:
Laurent 2024-08-21 08:46:46 +00:00 committed by Laureηt
parent 45143e2851
commit 2cb0f06119
13 changed files with 76 additions and 80 deletions

View file

@ -6,7 +6,7 @@ from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import ContextProvider, Contexts
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
@ -950,7 +950,7 @@ class Concatenate(Chain):
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat(
return torch.cat(
[output for output in outputs if output is not None],
dim=self.dim,
)

View file

@ -1,5 +1,6 @@
import torch
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
from torch import Tensor, device as Device, dtype as DType
from torch.nn import (
GroupNorm as _GroupNorm,
InstanceNorm2d as _InstanceNorm2d,
@ -111,8 +112,8 @@ class LayerNorm2d(WeightedModule):
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
self.eps = eps
def forward(
@ -121,7 +122,7 @@ class LayerNorm2d(WeightedModule):
) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out

View file

@ -8,15 +8,7 @@ from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
T = TypeVar("T")
@ -24,14 +16,14 @@ E = TypeVar("E")
def norm(x: Tensor) -> Tensor:
return _norm(x) # type: ignore
return torch.norm(x) # type: ignore
def manual_seed(seed: int) -> None:
_manual_seed(seed)
torch.manual_seed(seed) # type: ignore
class no_grad(_no_grad):
class no_grad(torch.no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)
@ -123,7 +115,7 @@ def gaussian_blur(
def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:

View file

@ -1,4 +1,5 @@
from torch import Tensor, arange, device as Device, dtype as DType
import torch
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
@ -25,7 +26,7 @@ class PositionalEncoder(fl.Chain):
@property
def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]

View file

@ -1,8 +1,9 @@
import re
from typing import cast
import torch
import torch.nn.functional as F
from torch import Tensor, cat, zeros
from torch import Tensor
from torch.nn import Parameter
import refiners.fluxion.layers as fl
@ -22,7 +23,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup))
p = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight)
self.new_weight = p
@ -30,11 +31,18 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
# Concatenate old and new weights for dynamic embedding updates during training
return F.embedding(x, cat([self.old_weight, self.new_weight]))
return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))
def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
p = Parameter(
torch.cat(
[
self.new_weight,
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
]
)
)
self.new_weight = p
@property

View file

@ -1,9 +1,10 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
from torch import Tensor, device as Device, dtype as DType, nn
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
@ -98,7 +99,7 @@ class PerceiverScaledDotProductAttention(fl.Module):
v = self.reshape_tensor(value)
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
@ -159,7 +160,7 @@ class PerceiverAttention(fl.Chain):
)
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2)
return torch.cat((x, latents), dim=-2)
class LatentsToken(fl.Chain):
@ -484,7 +485,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
@ -493,7 +494,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
if any(weight != 1.0 for weight in weights):
conditional_embedding *= (
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
.unsqueeze(-1)
.unsqueeze(-1)
)
@ -501,20 +502,20 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
if batch_size > 1 and concat_batches:
# Create a longer image tokens sequence when a batch of images is given
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)
return cat((negative_embedding, conditional_embedding))
return torch.cat((negative_embedding, conditional_embedding))
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding))
negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
clip_embedding = image_encoder(torch.zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return negative_embedding, conditional_embedding

View file

@ -1,7 +1,8 @@
import math
import torch
from jaxtyping import Float, Int
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
exponent /= half_dim
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
return embedding

View file

@ -1,6 +1,7 @@
import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
@ -28,7 +29,7 @@ class DDIM(Solver):
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
) -> None:
"""Initializes a new DDIM solver.
@ -71,7 +72,7 @@ class DDIM(Solver):
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
@ -82,8 +83,8 @@ class DDIM(Solver):
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = sqrt(1 - previous_scale_factor**2)
predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = torch.sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:

View file

@ -3,7 +3,7 @@ from collections import deque
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
@ -38,7 +38,7 @@ class DPMSolver(Solver):
params: BaseSolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
):
"""Initializes a new DPM solver.
@ -62,7 +62,7 @@ class DPMSolver(Solver):
device=device,
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
def rebuild(
@ -94,7 +94,7 @@ class DPMSolver(Solver):
offset = self.params.timesteps_offset
max_timestep = self.params.num_train_timesteps - 1 + offset
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
return torch.tensor(np_space).flip(0)
def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
@ -110,7 +110,7 @@ class DPMSolver(Solver):
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
@ -144,7 +144,7 @@ class DPMSolver(Solver):
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]

View file

@ -1,6 +1,6 @@
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
@ -23,7 +23,7 @@ class Euler(Solver):
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
):
"""Initializes a new Euler solver.
@ -57,7 +57,7 @@ class Euler(Solver):
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
sigmas = torch.cat([sigmas, tensor([0.0])])
sigmas = torch.cat([sigmas, torch.tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:

View file

@ -1,7 +1,8 @@
import dataclasses
from typing import Any, Callable, Protocol, TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32
import torch
from torch import Generator, Tensor, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
@ -60,7 +61,7 @@ class FrankenSolver(Solver):
num_inference_steps: int,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType = float32,
dtype: DType = torch.float32,
**kwargs: Any, # for typing, ignored
) -> None:
self.get_diffusers_scheduler = get_diffusers_scheduler

View file

@ -4,19 +4,8 @@ from enum import Enum
from typing import TypeVar
import numpy as np
from torch import (
Generator,
Tensor,
arange,
device as Device,
dtype as DType,
float32,
linspace,
log,
sqrt,
stack,
tensor,
)
import torch
from torch import Generator, Tensor, device as Device, dtype as DType
from refiners.fluxion import layers as fl
@ -161,7 +150,7 @@ class Solver(fl.Module, ABC):
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
dtype: DType = torch.float32,
) -> None:
"""Initializes a new `Solver` instance.
@ -179,9 +168,9 @@ class Solver(fl.Module, ABC):
self.params = self.resolve_params(params)
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)
self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype)
@ -227,16 +216,16 @@ class Solver(fl.Module, ABC):
max_timestep = num_train_timesteps - 1 + offset
match spacing:
case TimestepSpacing.LINSPACE:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)
case TimestepSpacing.LINSPACE_ROUNDED:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
case TimestepSpacing.TRAILING:
step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio)
return torch.arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.CUSTOM:
raise RuntimeError("generate_timesteps called with custom spacing")
@ -290,7 +279,7 @@ class Solver(fl.Module, ABC):
"""
if isinstance(step, list):
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
return stack(
return torch.stack(
tensors=[
self._add_noise(
x=x[i],
@ -400,7 +389,7 @@ class Solver(fl.Module, ABC):
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
"""
return (
linspace(
torch.linspace(
start=self.params.initial_diffusion_rate ** (1 / power),
end=self.params.final_diffusion_rate ** (1 / power),
steps=self.params.num_train_timesteps,

View file

@ -1,7 +1,8 @@
from typing import cast
import torch
from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType, split
from torch import Tensor, cat, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
@ -48,7 +49,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
for str_tokens in split(tokens, 1):
for str_tokens in torch.split(tokens, 1):
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
end_of_text_index.append(cast(int, position))