mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove a couple from torch import ...
from the code
This commit is contained in:
parent
45143e2851
commit
2cb0f06119
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
|
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.fluxion.context import ContextProvider, Contexts
|
from refiners.fluxion.context import ContextProvider, Contexts
|
||||||
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
|
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
|
||||||
|
@ -950,7 +950,7 @@ class Concatenate(Chain):
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Tensor:
|
def forward(self, *args: Any) -> Tensor:
|
||||||
outputs = [module(*args) for module in self]
|
outputs = [module(*args) for module in self]
|
||||||
return cat(
|
return torch.cat(
|
||||||
[output for output in outputs if output is not None],
|
[output for output in outputs if output is not None],
|
||||||
dim=self.dim,
|
dim=self.dim,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import (
|
from torch.nn import (
|
||||||
GroupNorm as _GroupNorm,
|
GroupNorm as _GroupNorm,
|
||||||
InstanceNorm2d as _InstanceNorm2d,
|
InstanceNorm2d as _InstanceNorm2d,
|
||||||
|
@ -111,8 +112,8 @@ class LayerNorm2d(WeightedModule):
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
|
self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
|
||||||
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
|
self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -121,7 +122,7 @@ class LayerNorm2d(WeightedModule):
|
||||||
) -> Float[Tensor, "batch channels height width"]:
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
x_mean = x.mean(1, keepdim=True)
|
x_mean = x.mean(1, keepdim=True)
|
||||||
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
|
||||||
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
|
@ -8,15 +8,7 @@ from numpy import array, float32
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors import safe_open as _safe_open # type: ignore
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import (
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
Tensor,
|
|
||||||
cat,
|
|
||||||
device as Device,
|
|
||||||
dtype as DType,
|
|
||||||
manual_seed as _manual_seed, # type: ignore
|
|
||||||
no_grad as _no_grad, # type: ignore
|
|
||||||
norm as _norm, # type: ignore
|
|
||||||
)
|
|
||||||
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -24,14 +16,14 @@ E = TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
def norm(x: Tensor) -> Tensor:
|
def norm(x: Tensor) -> Tensor:
|
||||||
return _norm(x) # type: ignore
|
return torch.norm(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def manual_seed(seed: int) -> None:
|
def manual_seed(seed: int) -> None:
|
||||||
_manual_seed(seed)
|
torch.manual_seed(seed) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class no_grad(_no_grad):
|
class no_grad(torch.no_grad):
|
||||||
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
||||||
return object.__new__(cls)
|
return object.__new__(cls)
|
||||||
|
|
||||||
|
@ -123,7 +115,7 @@ def gaussian_blur(
|
||||||
def images_to_tensor(
|
def images_to_tensor(
|
||||||
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
|
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
|
return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as DType
|
import torch
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ class PositionalEncoder(fl.Chain):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position_ids(self) -> Tensor:
|
def position_ids(self) -> Tensor:
|
||||||
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
return self.position_ids[:, : x.shape[1]]
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import re
|
import re
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, cat, zeros
|
from torch import Tensor
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
@ -22,7 +23,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(fl.Lambda(func=self.lookup))
|
super().__init__(fl.Lambda(func=self.lookup))
|
||||||
p = Parameter(
|
p = Parameter(
|
||||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
) # requires_grad=True by default
|
) # requires_grad=True by default
|
||||||
self.old_weight = cast(Parameter, target.weight)
|
self.old_weight = cast(Parameter, target.weight)
|
||||||
self.new_weight = p
|
self.new_weight = p
|
||||||
|
@ -30,11 +31,18 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
def lookup(self, x: Tensor) -> Tensor:
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
# Concatenate old and new weights for dynamic embedding updates during training
|
# Concatenate old and new weights for dynamic embedding updates during training
|
||||||
return F.embedding(x, cat([self.old_weight, self.new_weight]))
|
return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))
|
||||||
|
|
||||||
def add_embedding(self, embedding: Tensor) -> None:
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
assert embedding.shape == (self.old_weight.shape[1],)
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
p = Parameter(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
self.new_weight,
|
||||||
|
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
self.new_weight = p
|
self.new_weight = p
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
||||||
|
|
||||||
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -98,7 +99,7 @@ class PerceiverScaledDotProductAttention(fl.Module):
|
||||||
v = self.reshape_tensor(value)
|
v = self.reshape_tensor(value)
|
||||||
|
|
||||||
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
|
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
|
||||||
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
||||||
attention = attention @ v
|
attention = attention @ v
|
||||||
|
|
||||||
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
|
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
|
||||||
|
@ -159,7 +160,7 @@ class PerceiverAttention(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
||||||
return cat((x, latents), dim=-2)
|
return torch.cat((x, latents), dim=-2)
|
||||||
|
|
||||||
|
|
||||||
class LatentsToken(fl.Chain):
|
class LatentsToken(fl.Chain):
|
||||||
|
@ -484,7 +485,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
image_prompt = self.preprocess_image(image_prompt)
|
image_prompt = self.preprocess_image(image_prompt)
|
||||||
elif isinstance(image_prompt, list):
|
elif isinstance(image_prompt, list):
|
||||||
assert all(isinstance(image, Image.Image) for image in image_prompt)
|
assert all(isinstance(image, Image.Image) for image in image_prompt)
|
||||||
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
|
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
|
||||||
|
|
||||||
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
||||||
|
|
||||||
|
@ -493,7 +494,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
|
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
|
||||||
if any(weight != 1.0 for weight in weights):
|
if any(weight != 1.0 for weight in weights):
|
||||||
conditional_embedding *= (
|
conditional_embedding *= (
|
||||||
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
||||||
.unsqueeze(-1)
|
.unsqueeze(-1)
|
||||||
.unsqueeze(-1)
|
.unsqueeze(-1)
|
||||||
)
|
)
|
||||||
|
@ -501,20 +502,20 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
if batch_size > 1 and concat_batches:
|
if batch_size > 1 and concat_batches:
|
||||||
# Create a longer image tokens sequence when a batch of images is given
|
# Create a longer image tokens sequence when a batch of images is given
|
||||||
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
|
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
|
||||||
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
|
negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
|
||||||
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
|
conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)
|
||||||
|
|
||||||
return cat((negative_embedding, conditional_embedding))
|
return torch.cat((negative_embedding, conditional_embedding))
|
||||||
|
|
||||||
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
|
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||||
clip_embedding = image_encoder(image_prompt)
|
clip_embedding = image_encoder(image_prompt)
|
||||||
conditional_embedding = self.image_proj(clip_embedding)
|
conditional_embedding = self.image_proj(clip_embedding)
|
||||||
if not self.fine_grained:
|
if not self.fine_grained:
|
||||||
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
|
||||||
else:
|
else:
|
||||||
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||||
clip_embedding = image_encoder(zeros_like(image_prompt))
|
clip_embedding = image_encoder(torch.zeros_like(image_prompt))
|
||||||
negative_embedding = self.image_proj(clip_embedding)
|
negative_embedding = self.image_proj(clip_embedding)
|
||||||
return negative_embedding, conditional_embedding
|
return negative_embedding, conditional_embedding
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
from jaxtyping import Float, Int
|
from jaxtyping import Float, Int
|
||||||
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
# Note: it is important that this computation is done in float32.
|
# Note: it is important that this computation is done in float32.
|
||||||
# The result can be cast to lower precision later if necessary.
|
# The result can be cast to lower precision later if necessary.
|
||||||
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
|
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
|
||||||
exponent /= half_dim
|
exponent /= half_dim
|
||||||
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
|
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
|
||||||
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
|
embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
import torch
|
||||||
|
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
BaseSolverParams,
|
BaseSolverParams,
|
||||||
|
@ -28,7 +29,7 @@ class DDIM(Solver):
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: BaseSolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new DDIM solver.
|
"""Initializes a new DDIM solver.
|
||||||
|
|
||||||
|
@ -71,7 +72,7 @@ class DDIM(Solver):
|
||||||
(
|
(
|
||||||
self.timesteps[step + 1]
|
self.timesteps[step + 1]
|
||||||
if step < self.num_inference_steps - 1
|
if step < self.num_inference_steps - 1
|
||||||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
current_scale_factor, previous_scale_factor = (
|
current_scale_factor, previous_scale_factor = (
|
||||||
|
@ -82,8 +83,8 @@ class DDIM(Solver):
|
||||||
else self.cumulative_scale_factors[0]
|
else self.cumulative_scale_factors[0]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
|
predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
|
||||||
noise_factor = sqrt(1 - previous_scale_factor**2)
|
noise_factor = torch.sqrt(1 - previous_scale_factor**2)
|
||||||
|
|
||||||
# Do not add noise at the last step to avoid visual artifacts.
|
# Do not add noise at the last step to avoid visual artifacts.
|
||||||
if step == self.num_inference_steps - 1:
|
if step == self.num_inference_steps - 1:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
BaseSolverParams,
|
BaseSolverParams,
|
||||||
|
@ -38,7 +38,7 @@ class DPMSolver(Solver):
|
||||||
params: BaseSolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
last_step_first_order: bool = False,
|
last_step_first_order: bool = False,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""Initializes a new DPM solver.
|
"""Initializes a new DPM solver.
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class DPMSolver(Solver):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||||
self.last_step_first_order = last_step_first_order
|
self.last_step_first_order = last_step_first_order
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
||||||
|
@ -94,7 +94,7 @@ class DPMSolver(Solver):
|
||||||
offset = self.params.timesteps_offset
|
offset = self.params.timesteps_offset
|
||||||
max_timestep = self.params.num_train_timesteps - 1 + offset
|
max_timestep = self.params.num_train_timesteps - 1 + offset
|
||||||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||||
return tensor(np_space).flip(0)
|
return torch.tensor(np_space).flip(0)
|
||||||
|
|
||||||
def dpm_solver_first_order_update(
|
def dpm_solver_first_order_update(
|
||||||
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
||||||
|
@ -110,7 +110,7 @@ class DPMSolver(Solver):
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||||
|
|
||||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||||
|
@ -144,7 +144,7 @@ class DPMSolver(Solver):
|
||||||
Returns:
|
Returns:
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
next_timestep = self.timesteps[step - 1]
|
next_timestep = self.timesteps[step - 1]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
BaseSolverParams,
|
BaseSolverParams,
|
||||||
|
@ -23,7 +23,7 @@ class Euler(Solver):
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: BaseSolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""Initializes a new Euler solver.
|
"""Initializes a new Euler solver.
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class Euler(Solver):
|
||||||
"""Generate the sigmas used by the solver."""
|
"""Generate the sigmas used by the solver."""
|
||||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
|
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
|
||||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
sigmas = torch.cat([sigmas, torch.tensor([0.0])])
|
||||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Callable, Protocol, TypeVar
|
from typing import Any, Callable, Protocol, TypeVar
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32
|
import torch
|
||||||
|
from torch import Generator, Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||||
|
|
||||||
|
@ -60,7 +61,7 @@ class FrankenSolver(Solver):
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = torch.float32,
|
||||||
**kwargs: Any, # for typing, ignored
|
**kwargs: Any, # for typing, ignored
|
||||||
) -> None:
|
) -> None:
|
||||||
self.get_diffusers_scheduler = get_diffusers_scheduler
|
self.get_diffusers_scheduler = get_diffusers_scheduler
|
||||||
|
|
|
@ -4,19 +4,8 @@ from enum import Enum
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import (
|
import torch
|
||||||
Generator,
|
from torch import Generator, Tensor, device as Device, dtype as DType
|
||||||
Tensor,
|
|
||||||
arange,
|
|
||||||
device as Device,
|
|
||||||
dtype as DType,
|
|
||||||
float32,
|
|
||||||
linspace,
|
|
||||||
log,
|
|
||||||
sqrt,
|
|
||||||
stack,
|
|
||||||
tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
|
|
||||||
|
@ -161,7 +150,7 @@ class Solver(fl.Module, ABC):
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: BaseSolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new `Solver` instance.
|
"""Initializes a new `Solver` instance.
|
||||||
|
|
||||||
|
@ -179,9 +168,9 @@ class Solver(fl.Module, ABC):
|
||||||
self.params = self.resolve_params(params)
|
self.params = self.resolve_params(params)
|
||||||
|
|
||||||
self.scale_factors = self.sample_noise_schedule()
|
self.scale_factors = self.sample_noise_schedule()
|
||||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))
|
||||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
self.to(device=device, dtype=dtype)
|
self.to(device=device, dtype=dtype)
|
||||||
|
@ -227,16 +216,16 @@ class Solver(fl.Module, ABC):
|
||||||
max_timestep = num_train_timesteps - 1 + offset
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
match spacing:
|
match spacing:
|
||||||
case TimestepSpacing.LINSPACE:
|
case TimestepSpacing.LINSPACE:
|
||||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)
|
||||||
case TimestepSpacing.LINSPACE_ROUNDED:
|
case TimestepSpacing.LINSPACE_ROUNDED:
|
||||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||||
case TimestepSpacing.LEADING:
|
case TimestepSpacing.LEADING:
|
||||||
step_ratio = num_train_timesteps // num_inference_steps
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
||||||
case TimestepSpacing.TRAILING:
|
case TimestepSpacing.TRAILING:
|
||||||
step_ratio = num_train_timesteps // num_inference_steps
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
max_timestep = num_train_timesteps - 1 + offset
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
return arange(max_timestep, offset, -step_ratio)
|
return torch.arange(max_timestep, offset, -step_ratio)
|
||||||
case TimestepSpacing.CUSTOM:
|
case TimestepSpacing.CUSTOM:
|
||||||
raise RuntimeError("generate_timesteps called with custom spacing")
|
raise RuntimeError("generate_timesteps called with custom spacing")
|
||||||
|
|
||||||
|
@ -290,7 +279,7 @@ class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
if isinstance(step, list):
|
if isinstance(step, list):
|
||||||
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
|
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
|
||||||
return stack(
|
return torch.stack(
|
||||||
tensors=[
|
tensors=[
|
||||||
self._add_noise(
|
self._add_noise(
|
||||||
x=x[i],
|
x=x[i],
|
||||||
|
@ -400,7 +389,7 @@ class Solver(fl.Module, ABC):
|
||||||
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
|
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
linspace(
|
torch.linspace(
|
||||||
start=self.params.initial_diffusion_rate ** (1 / power),
|
start=self.params.initial_diffusion_rate ** (1 / power),
|
||||||
end=self.params.final_diffusion_rate ** (1 / power),
|
end=self.params.final_diffusion_rate ** (1 / power),
|
||||||
steps=self.params.num_train_timesteps,
|
steps=self.params.num_train_timesteps,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType, split
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -48,7 +49,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
return self.ensure_find(CLIPTokenizer)
|
return self.ensure_find(CLIPTokenizer)
|
||||||
|
|
||||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||||
for str_tokens in split(tokens, 1):
|
for str_tokens in torch.split(tokens, 1):
|
||||||
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
|
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
|
||||||
end_of_text_index.append(cast(int, position))
|
end_of_text_index.append(cast(int, position))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue