mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
remove a couple from torch import ...
from the code
This commit is contained in:
parent
45143e2851
commit
e423ba4291
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor, cat, device as Device, dtype as DType
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
from refiners.fluxion.context import ContextProvider, Contexts
|
||||
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
|
||||
|
@ -950,7 +950,7 @@ class Concatenate(Chain):
|
|||
|
||||
def forward(self, *args: Any) -> Tensor:
|
||||
outputs = [module(*args) for module in self]
|
||||
return cat(
|
||||
return torch.cat(
|
||||
[output for output in outputs if output is not None],
|
||||
dim=self.dim,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import (
|
||||
GroupNorm as _GroupNorm,
|
||||
InstanceNorm2d as _InstanceNorm2d,
|
||||
|
@ -111,8 +112,8 @@ class LayerNorm2d(WeightedModule):
|
|||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
|
||||
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
|
||||
self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
|
||||
self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
|
||||
self.eps = eps
|
||||
|
||||
def forward(
|
||||
|
@ -121,7 +122,7 @@ class LayerNorm2d(WeightedModule):
|
|||
) -> Float[Tensor, "batch channels height width"]:
|
||||
x_mean = x.mean(1, keepdim=True)
|
||||
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||
x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
|
||||
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
return x_out
|
||||
|
||||
|
|
|
@ -8,15 +8,7 @@ from numpy import array, float32
|
|||
from PIL import Image
|
||||
from safetensors import safe_open as _safe_open # type: ignore
|
||||
from safetensors.torch import save_file as _save_file # type: ignore
|
||||
from torch import (
|
||||
Tensor,
|
||||
cat,
|
||||
device as Device,
|
||||
dtype as DType,
|
||||
manual_seed as _manual_seed, # type: ignore
|
||||
no_grad as _no_grad, # type: ignore
|
||||
norm as _norm, # type: ignore
|
||||
)
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -24,14 +16,14 @@ E = TypeVar("E")
|
|||
|
||||
|
||||
def norm(x: Tensor) -> Tensor:
|
||||
return _norm(x) # type: ignore
|
||||
return torch.norm(x) # type: ignore
|
||||
|
||||
|
||||
def manual_seed(seed: int) -> None:
|
||||
_manual_seed(seed)
|
||||
torch.manual_seed(seed) # type: ignore
|
||||
|
||||
|
||||
class no_grad(_no_grad):
|
||||
class no_grad(torch.no_grad):
|
||||
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
||||
return object.__new__(cls)
|
||||
|
||||
|
@ -123,7 +115,7 @@ def gaussian_blur(
|
|||
def images_to_tensor(
|
||||
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> Tensor:
|
||||
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
|
||||
return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
|
||||
|
||||
|
||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from torch import Tensor, arange, device as Device, dtype as DType
|
||||
import torch
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
@ -25,7 +26,7 @@ class PositionalEncoder(fl.Chain):
|
|||
|
||||
@property
|
||||
def position_ids(self) -> Tensor:
|
||||
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||
return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||
|
||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||
return self.position_ids[:, : x.shape[1]]
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import re
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, cat, zeros
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
@ -22,7 +23,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
|||
with self.setup_adapter(target):
|
||||
super().__init__(fl.Lambda(func=self.lookup))
|
||||
p = Parameter(
|
||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||
torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||
) # requires_grad=True by default
|
||||
self.old_weight = cast(Parameter, target.weight)
|
||||
self.new_weight = p
|
||||
|
@ -30,11 +31,18 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
|||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||
def lookup(self, x: Tensor) -> Tensor:
|
||||
# Concatenate old and new weights for dynamic embedding updates during training
|
||||
return F.embedding(x, cat([self.old_weight, self.new_weight]))
|
||||
return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))
|
||||
|
||||
def add_embedding(self, embedding: Tensor) -> None:
|
||||
assert embedding.shape == (self.old_weight.shape[1],)
|
||||
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
||||
p = Parameter(
|
||||
torch.cat(
|
||||
[
|
||||
self.new_weight,
|
||||
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.new_weight = p
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import math
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
|
||||
from torch import Tensor, device as Device, dtype as DType, nn
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
@ -98,7 +99,7 @@ class PerceiverScaledDotProductAttention(fl.Module):
|
|||
v = self.reshape_tensor(value)
|
||||
|
||||
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
|
||||
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
||||
attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
||||
attention = attention @ v
|
||||
|
||||
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
|
||||
|
@ -159,7 +160,7 @@ class PerceiverAttention(fl.Chain):
|
|||
)
|
||||
|
||||
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
||||
return cat((x, latents), dim=-2)
|
||||
return torch.cat((x, latents), dim=-2)
|
||||
|
||||
|
||||
class LatentsToken(fl.Chain):
|
||||
|
@ -484,7 +485,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
image_prompt = self.preprocess_image(image_prompt)
|
||||
elif isinstance(image_prompt, list):
|
||||
assert all(isinstance(image, Image.Image) for image in image_prompt)
|
||||
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
|
||||
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
|
||||
|
||||
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
||||
|
||||
|
@ -493,7 +494,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
|
||||
if any(weight != 1.0 for weight in weights):
|
||||
conditional_embedding *= (
|
||||
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
||||
torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
@ -501,20 +502,20 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
if batch_size > 1 and concat_batches:
|
||||
# Create a longer image tokens sequence when a batch of images is given
|
||||
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
|
||||
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
|
||||
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
|
||||
negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
|
||||
conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)
|
||||
|
||||
return cat((negative_embedding, conditional_embedding))
|
||||
return torch.cat((negative_embedding, conditional_embedding))
|
||||
|
||||
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
|
||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||
clip_embedding = image_encoder(image_prompt)
|
||||
conditional_embedding = self.image_proj(clip_embedding)
|
||||
if not self.fine_grained:
|
||||
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
||||
negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
|
||||
else:
|
||||
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||
clip_embedding = image_encoder(zeros_like(image_prompt))
|
||||
clip_embedding = image_encoder(torch.zeros_like(image_prompt))
|
||||
negative_embedding = self.image_proj(clip_embedding)
|
||||
return negative_embedding, conditional_embedding
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float, Int
|
||||
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
|
|||
half_dim = embedding_dim // 2
|
||||
# Note: it is important that this computation is done in float32.
|
||||
# The result can be cast to lower precision later if necessary.
|
||||
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
|
||||
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
|
||||
exponent /= half_dim
|
||||
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
|
||||
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
|
||||
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
|
||||
embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import dataclasses
|
||||
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
|
@ -28,7 +29,7 @@ class DDIM(Solver):
|
|||
first_inference_step: int = 0,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
dtype: Dtype = torch.float32,
|
||||
) -> None:
|
||||
"""Initializes a new DDIM solver.
|
||||
|
||||
|
@ -71,7 +72,7 @@ class DDIM(Solver):
|
|||
(
|
||||
self.timesteps[step + 1]
|
||||
if step < self.num_inference_steps - 1
|
||||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||
else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||
),
|
||||
)
|
||||
current_scale_factor, previous_scale_factor = (
|
||||
|
@ -82,8 +83,8 @@ class DDIM(Solver):
|
|||
else self.cumulative_scale_factors[0]
|
||||
),
|
||||
)
|
||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
|
||||
noise_factor = sqrt(1 - previous_scale_factor**2)
|
||||
predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
|
||||
noise_factor = torch.sqrt(1 - previous_scale_factor**2)
|
||||
|
||||
# Do not add noise at the last step to avoid visual artifacts.
|
||||
if step == self.num_inference_steps - 1:
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import deque
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
|
@ -38,7 +38,7 @@ class DPMSolver(Solver):
|
|||
params: BaseSolverParams | None = None,
|
||||
last_step_first_order: bool = False,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
dtype: Dtype = torch.float32,
|
||||
):
|
||||
"""Initializes a new DPM solver.
|
||||
|
||||
|
@ -62,7 +62,7 @@ class DPMSolver(Solver):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||
self.last_step_first_order = last_step_first_order
|
||||
|
||||
def rebuild(
|
||||
|
@ -94,7 +94,7 @@ class DPMSolver(Solver):
|
|||
offset = self.params.timesteps_offset
|
||||
max_timestep = self.params.num_train_timesteps - 1 + offset
|
||||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||
return tensor(np_space).flip(0)
|
||||
return torch.tensor(np_space).flip(0)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
||||
|
@ -110,7 +110,7 @@ class DPMSolver(Solver):
|
|||
The denoised version of the input data `x`.
|
||||
"""
|
||||
current_timestep = self.timesteps[step]
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||
|
||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||
|
@ -144,7 +144,7 @@ class DPMSolver(Solver):
|
|||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||
current_timestep = self.timesteps[step]
|
||||
next_timestep = self.timesteps[step - 1]
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
|
@ -23,7 +23,7 @@ class Euler(Solver):
|
|||
first_inference_step: int = 0,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
dtype: Dtype = torch.float32,
|
||||
):
|
||||
"""Initializes a new Euler solver.
|
||||
|
||||
|
@ -57,7 +57,7 @@ class Euler(Solver):
|
|||
"""Generate the sigmas used by the solver."""
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
sigmas = torch.cat([sigmas, torch.tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import dataclasses
|
||||
from typing import Any, Callable, Protocol, TypeVar
|
||||
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||
|
||||
|
@ -60,7 +61,7 @@ class FrankenSolver(Solver):
|
|||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
dtype: DType = torch.float32,
|
||||
**kwargs: Any, # for typing, ignored
|
||||
) -> None:
|
||||
self.get_diffusers_scheduler = get_diffusers_scheduler
|
||||
|
|
|
@ -4,19 +4,8 @@ from enum import Enum
|
|||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from torch import (
|
||||
Generator,
|
||||
Tensor,
|
||||
arange,
|
||||
device as Device,
|
||||
dtype as DType,
|
||||
float32,
|
||||
linspace,
|
||||
log,
|
||||
sqrt,
|
||||
stack,
|
||||
tensor,
|
||||
)
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
|
||||
|
@ -161,7 +150,7 @@ class Solver(fl.Module, ABC):
|
|||
first_inference_step: int = 0,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
"""Initializes a new `Solver` instance.
|
||||
|
||||
|
@ -179,9 +168,9 @@ class Solver(fl.Module, ABC):
|
|||
self.params = self.resolve_params(params)
|
||||
|
||||
self.scale_factors = self.sample_noise_schedule()
|
||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||
self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))
|
||||
self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)
|
||||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
@ -227,16 +216,16 @@ class Solver(fl.Module, ABC):
|
|||
max_timestep = num_train_timesteps - 1 + offset
|
||||
match spacing:
|
||||
case TimestepSpacing.LINSPACE:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
||||
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)
|
||||
case TimestepSpacing.LINSPACE_ROUNDED:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||
case TimestepSpacing.LEADING:
|
||||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
||||
return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
||||
case TimestepSpacing.TRAILING:
|
||||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
max_timestep = num_train_timesteps - 1 + offset
|
||||
return arange(max_timestep, offset, -step_ratio)
|
||||
return torch.arange(max_timestep, offset, -step_ratio)
|
||||
case TimestepSpacing.CUSTOM:
|
||||
raise RuntimeError("generate_timesteps called with custom spacing")
|
||||
|
||||
|
@ -290,7 +279,7 @@ class Solver(fl.Module, ABC):
|
|||
"""
|
||||
if isinstance(step, list):
|
||||
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
|
||||
return stack(
|
||||
return torch.stack(
|
||||
tensors=[
|
||||
self._add_noise(
|
||||
x=x[i],
|
||||
|
@ -400,7 +389,7 @@ class Solver(fl.Module, ABC):
|
|||
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
|
||||
"""
|
||||
return (
|
||||
linspace(
|
||||
torch.linspace(
|
||||
start=self.params.initial_diffusion_rate ** (1 / power),
|
||||
end=self.params.final_diffusion_rate ** (1 / power),
|
||||
steps=self.params.num_train_timesteps,
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, split
|
||||
from torch import Tensor, cat, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
@ -48,7 +49,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
|||
return self.ensure_find(CLIPTokenizer)
|
||||
|
||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||
for str_tokens in split(tokens, 1):
|
||||
for str_tokens in torch.split(tokens, 1):
|
||||
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
|
||||
end_of_text_index.append(cast(int, position))
|
||||
|
||||
|
|
Loading…
Reference in a new issue