mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
make Scheduler a fl.Module + Change name Scheduler -> Solver
This commit is contained in:
parent
07cb2ff21c
commit
73f6ccfc98
|
@ -19,7 +19,7 @@ ______________________________________________________________________
|
|||
|
||||
## Latest News 🔥
|
||||
|
||||
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr))
|
||||
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr))
|
||||
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
|
||||
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
|
||||
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))
|
||||
|
|
|
@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
|||
clip_text_embedding = torch.rand(1, 77, 768)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
||||
scheduler = DPMSolver(num_inference_steps=10)
|
||||
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
|
||||
solver = DPMSolver(num_inference_steps=10)
|
||||
timestep = solver.timesteps[0].unsqueeze(dim=0)
|
||||
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
||||
|
||||
x = torch.randn(1, 4, 64, 64)
|
||||
|
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
|
|||
LatentDiffusionAutoencoder,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||
SD1ControlnetAdapter,
|
||||
SD1IPAdapter,
|
||||
|
@ -33,7 +33,7 @@ __all__ = [
|
|||
"SDXLIPAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
"DPMSolver",
|
||||
"Scheduler",
|
||||
"Solver",
|
||||
"CLIPTextEncoderL",
|
||||
"LatentDiffusionAutoencoder",
|
||||
"SDFreeUAdapter",
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import Tensor, device as Device, dtype as DType
|
|||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
|
||||
T = TypeVar("T", bound="fl.Module")
|
||||
|
||||
|
@ -20,7 +20,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
unet: fl.Module,
|
||||
lda: LatentDiffusionAutoencoder,
|
||||
clip_text_encoder: fl.Module,
|
||||
scheduler: Scheduler,
|
||||
solver: Solver,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
|
@ -30,10 +30,10 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||
self.solver = solver.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
|
@ -51,15 +51,15 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||
return self.scheduler.add_noise(
|
||||
return self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
step=self.scheduler.first_inference_step,
|
||||
step=self.solver.first_inference_step,
|
||||
)
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
return self.scheduler.inference_steps
|
||||
return self.solver.inference_steps
|
||||
|
||||
@abstractmethod
|
||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||
|
@ -82,12 +82,12 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
def forward(
|
||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||
) -> Tensor:
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
# scale latents for schedulers that need it
|
||||
latents = self.scheduler.scale_model_input(latents, step=step)
|
||||
# scale latents for solvers that need it
|
||||
latents = self.solver.scale_model_input(latents, step=step)
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
|
@ -101,14 +101,14 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||
)
|
||||
|
||||
return self.scheduler(x, predicted_noise=predicted_noise, step=step)
|
||||
return self.solver(x, predicted_noise=predicted_noise, step=step)
|
||||
|
||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||
return self.__class__(
|
||||
unet=self.unet.structural_copy(),
|
||||
lda=self.lda.structural_copy(),
|
||||
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||
scheduler=self.scheduler,
|
||||
solver=self.solver,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
|
|
@ -51,7 +51,7 @@ class MultiDiffusion(Generic[T, D], ABC):
|
|||
match step:
|
||||
case step if step == target.start_step and target.init_latents is not None:
|
||||
noise_view = target.crop(noise)
|
||||
view = self.ldm.scheduler.add_noise(
|
||||
view = self.ldm.solver.add_noise(
|
||||
x=target.init_latents,
|
||||
noise=noise_view,
|
||||
step=step,
|
||||
|
|
|
@ -5,22 +5,22 @@ from typing import Generic, TypeVar
|
|||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
|
||||
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||
|
||||
|
||||
def add_noise_interval(
|
||||
scheduler: Scheduler,
|
||||
solver: Solver,
|
||||
/,
|
||||
x: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
initial_timestep: torch.Tensor,
|
||||
target_timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
|
||||
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
|
||||
initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep]
|
||||
target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep]
|
||||
|
||||
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
||||
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
||||
|
@ -33,7 +33,7 @@ class Restart(Generic[T]):
|
|||
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
||||
(https://arxiv.org/pdf/2306.14878.pdf)
|
||||
|
||||
Works only with the DDIM scheduler for now.
|
||||
Works only with the DDIM solver for now.
|
||||
"""
|
||||
|
||||
ldm: T
|
||||
|
@ -43,7 +43,7 @@ class Restart(Generic[T]):
|
|||
end_time: float = 2
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
|
||||
assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -53,15 +53,15 @@ class Restart(Generic[T]):
|
|||
condition_scale: float = 7.5,
|
||||
**kwargs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
original_scheduler = self.ldm.scheduler
|
||||
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||
new_scheduler.timesteps = self.timesteps
|
||||
self.ldm.scheduler = new_scheduler
|
||||
original_solver = self.ldm.solver
|
||||
new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||
new_solver.timesteps = self.timesteps
|
||||
self.ldm.solver = new_solver
|
||||
|
||||
for _ in range(self.num_iterations):
|
||||
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
||||
x = add_noise_interval(
|
||||
new_scheduler,
|
||||
new_solver,
|
||||
x=x,
|
||||
noise=noise,
|
||||
initial_timestep=self.timesteps[-1],
|
||||
|
@ -73,18 +73,18 @@ class Restart(Generic[T]):
|
|||
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
||||
)
|
||||
|
||||
self.ldm.scheduler = original_scheduler
|
||||
self.ldm.solver = original_solver
|
||||
|
||||
return x
|
||||
|
||||
@cached_property
|
||||
def start_step(self) -> int:
|
||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
|
||||
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
|
||||
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time)))
|
||||
|
||||
@cached_property
|
||||
def end_timestep(self) -> int:
|
||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
|
||||
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
||||
|
||||
@cached_property
|
||||
|
@ -92,7 +92,7 @@ class Restart(Generic[T]):
|
|||
return (
|
||||
torch.round(
|
||||
torch.linspace(
|
||||
start=int(self.ldm.scheduler.timesteps[self.start_step]),
|
||||
start=int(self.ldm.solver.timesteps[self.start_step]),
|
||||
end=self.end_timestep,
|
||||
steps=self.num_steps,
|
||||
)
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
|
@ -9,7 +9,7 @@ import refiners.fluxion.layers as fl
|
|||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.utils import gaussian_blur, interpolate
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
|
@ -89,13 +89,13 @@ class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
return interpolate(attn_mask, Size((h, w)))
|
||||
|
||||
def compute_degraded_latents(
|
||||
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
||||
self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
||||
) -> Tensor:
|
||||
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
|
||||
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
|
||||
original_latents = solver.remove_noise(x=latents, noise=noise, step=step)
|
||||
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
|
||||
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
|
||||
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
|
||||
return solver.add_noise(degraded_latents, noise=noise, step=step)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"]
|
|
@ -1,9 +1,9 @@
|
|||
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
|
||||
class DDIM(Scheduler):
|
||||
class DDIM(Solver):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
|
@ -25,15 +25,14 @@ class DDIM(Scheduler):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""
|
||||
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
||||
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
||||
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
|
||||
"""
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
|
@ -1,9 +1,9 @@
|
|||
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
|
||||
class DDPM(Scheduler):
|
||||
class DDPM(Solver):
|
||||
"""
|
||||
Denoising Diffusion Probabilistic Model
|
||||
|
|
@ -3,10 +3,10 @@ from collections import deque
|
|||
import numpy as np
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
|
||||
class DPMSolver(Scheduler):
|
||||
class DPMSolver(Solver):
|
||||
"""
|
||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
|
||||
|
@ -48,7 +48,6 @@ class DPMSolver(Scheduler):
|
|||
# ...and we want the same result as the original codebase.
|
||||
return tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||
device=self.device,
|
||||
).flip(0)
|
||||
|
||||
def rebuild(
|
|
@ -2,10 +2,10 @@ import numpy as np
|
|||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
|
||||
class EulerScheduler(Scheduler):
|
||||
class Euler(Solver):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
|
@ -40,9 +40,7 @@ class EulerScheduler(Scheduler):
|
|||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
timesteps = torch.tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
||||
).flip(0)
|
||||
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
|
@ -4,7 +4,9 @@ from typing import TypeVar
|
|||
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||
|
||||
T = TypeVar("T", bound="Scheduler")
|
||||
from refiners.fluxion import layers as fl
|
||||
|
||||
T = TypeVar("T", bound="Solver")
|
||||
|
||||
|
||||
class NoiseSchedule(str, Enum):
|
||||
|
@ -13,11 +15,11 @@ class NoiseSchedule(str, Enum):
|
|||
KARRAS = "karras"
|
||||
|
||||
|
||||
class Scheduler(ABC):
|
||||
class Solver(fl.Module, ABC):
|
||||
"""
|
||||
A base class for creating a diffusion model scheduler.
|
||||
A base class for creating a diffusion model solver.
|
||||
|
||||
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
|
||||
Solver creates a sequence of noise and scaling factors used in the diffusion process,
|
||||
which gradually transforms the original data distribution into a Gaussian one.
|
||||
|
||||
This process is described using several parameters such as initial and final diffusion rates,
|
||||
|
@ -36,9 +38,8 @@ class Scheduler(ABC):
|
|||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
self.device: Device = Device(device)
|
||||
self.dtype: DType = dtype
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.initial_diffusion_rate = initial_diffusion_rate
|
||||
|
@ -50,6 +51,7 @@ class Scheduler(ABC):
|
|||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||
self.timesteps = self._generate_timesteps()
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
|
@ -69,57 +71,6 @@ class Scheduler(ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def all_steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
@property
|
||||
def inference_steps(self) -> list[int]:
|
||||
return self.all_steps[self.first_inference_step :]
|
||||
|
||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||
return self.__class__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||
final_diffusion_rate=self.final_diffusion_rate,
|
||||
noise_schedule=self.noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||
"""
|
||||
return x
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
return (
|
||||
linspace(
|
||||
start=self.initial_diffusion_rate ** (1 / power),
|
||||
end=self.final_diffusion_rate ** (1 / power),
|
||||
steps=self.num_train_timesteps,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
** power
|
||||
)
|
||||
|
||||
def sample_noise_schedule(self) -> Tensor:
|
||||
match self.noise_schedule:
|
||||
case "uniform":
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
case "quadratic":
|
||||
return 1 - self.sample_power_distribution(2)
|
||||
case "karras":
|
||||
return 1 - self.sample_power_distribution(7)
|
||||
case _:
|
||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
x: Tensor,
|
||||
|
@ -141,14 +92,77 @@ class Scheduler(ABC):
|
|||
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||
return denoised_x
|
||||
|
||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||
if device is not None:
|
||||
self.device = Device(device)
|
||||
self.timesteps = self.timesteps.to(device)
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
|
||||
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
|
||||
self.noise_std = self.noise_std.to(device, dtype=dtype)
|
||||
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
|
||||
@property
|
||||
def all_steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
@property
|
||||
def inference_steps(self) -> list[int]:
|
||||
return self.all_steps[self.first_inference_step :]
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
return self.scale_factors.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return self.scale_factors.dtype
|
||||
|
||||
@device.setter
|
||||
def device(self, device: Device | str | None = None) -> None:
|
||||
self.to(device=device)
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, dtype: DType | None = None) -> None:
|
||||
self.to(dtype=dtype)
|
||||
|
||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||
return self.__class__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||
final_diffusion_rate=self.final_diffusion_rate,
|
||||
noise_schedule=self.noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with solvers that need to scale the input according to the current timestep.
|
||||
"""
|
||||
return x
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
return (
|
||||
linspace(
|
||||
start=self.initial_diffusion_rate ** (1 / power),
|
||||
end=self.final_diffusion_rate ** (1 / power),
|
||||
steps=self.num_train_timesteps,
|
||||
)
|
||||
** power
|
||||
)
|
||||
|
||||
def sample_noise_schedule(self) -> Tensor:
|
||||
match self.noise_schedule:
|
||||
case "uniform":
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
case "quadratic":
|
||||
return 1 - self.sample_power_distribution(2)
|
||||
case "karras":
|
||||
return 1 - self.sample_power_distribution(7)
|
||||
case _:
|
||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||
|
||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||
super().to(device=device, dtype=dtype)
|
||||
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
|
||||
match name:
|
||||
case "timesteps":
|
||||
setattr(self, name, attr.to(device=device))
|
||||
case _:
|
||||
setattr(self, name, attr.to(device=device, dtype=dtype))
|
||||
return self
|
|
@ -7,8 +7,8 @@ from refiners.fluxion.utils import image_to_tensor, interpolate
|
|||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
|
||||
|
@ -26,20 +26,20 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
unet: SD1UNet | None = None,
|
||||
lda: SD1Autoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
solver: Solver | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SD1UNet(in_channels=4)
|
||||
lda = lda or SD1Autoencoder()
|
||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||
solver = solver or DPMSolver(num_inference_steps=30)
|
||||
|
||||
super().__init__(
|
||||
unet=unet,
|
||||
lda=lda,
|
||||
clip_text_encoder=clip_text_encoder,
|
||||
scheduler=scheduler,
|
||||
solver=solver,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -82,14 +82,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
assert sag is not None
|
||||
|
||||
degraded_latents = sag.compute_degraded_latents(
|
||||
scheduler=self.scheduler,
|
||||
solver=self.solver,
|
||||
latents=x,
|
||||
noise=noise,
|
||||
step=step,
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
|
@ -111,14 +111,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
unet: SD1UNet | None = None,
|
||||
lda: SD1Autoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
solver: Solver | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
self.mask_latents: Tensor | None = None
|
||||
self.target_image_latents: Tensor | None = None
|
||||
super().__init__(
|
||||
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
|
||||
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -162,7 +162,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
assert self.target_image_latents is not None
|
||||
|
||||
degraded_latents = sag.compute_degraded_latents(
|
||||
scheduler=self.scheduler,
|
||||
solver=self.solver,
|
||||
latents=x,
|
||||
noise=noise,
|
||||
step=step,
|
||||
|
@ -173,7 +173,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
dim=1,
|
||||
)
|
||||
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ from torch import Tensor, device as Device, dtype as DType
|
|||
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
|
@ -23,20 +23,20 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
unet: SDXLUNet | None = None,
|
||||
lda: SDXLAutoencoder | None = None,
|
||||
clip_text_encoder: DoubleTextEncoder | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
solver: Solver | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SDXLUNet(in_channels=4)
|
||||
lda = lda or SDXLAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||
scheduler = scheduler or DDIM(num_inference_steps=30)
|
||||
solver = solver or DDIM(num_inference_steps=30)
|
||||
|
||||
super().__init__(
|
||||
unet=unet,
|
||||
lda=lda,
|
||||
clip_text_encoder=clip_text_encoder,
|
||||
scheduler=scheduler,
|
||||
solver=solver,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -131,7 +131,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
assert sag is not None
|
||||
|
||||
degraded_latents = sag.compute_degraded_latents(
|
||||
scheduler=self.scheduler,
|
||||
solver=self.solver,
|
||||
latents=x,
|
||||
noise=noise,
|
||||
step=step,
|
||||
|
@ -140,7 +140,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
|
||||
negative_text_embedding, _ = clip_text_embedding.chunk(2)
|
||||
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||
time_ids, _ = time_ids.chunk(2)
|
||||
|
||||
self.set_unet_context(
|
||||
|
|
|
@ -19,7 +19,8 @@ from refiners.foundationals.latent_diffusion import (
|
|||
SD1UNet,
|
||||
StableDiffusion_1,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDPM
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
|
@ -150,7 +151,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
return TextEmbeddingLatentsDataset(trainer=self)
|
||||
|
||||
@cached_property
|
||||
def ddpm_scheduler(self) -> DDPM:
|
||||
def ddpm_solver(self) -> Solver:
|
||||
return DDPM(
|
||||
num_inference_steps=1000,
|
||||
device=self.device,
|
||||
|
@ -159,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
def sample_timestep(self) -> Tensor:
|
||||
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
|
||||
self.current_step = random_step
|
||||
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
||||
return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0)
|
||||
|
||||
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
|
||||
return sample_noise(
|
||||
|
@ -170,7 +171,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
clip_text_embedding, latents = batch.text_embeddings, batch.latents
|
||||
timestep = self.sample_timestep()
|
||||
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
|
||||
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)
|
||||
noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step)
|
||||
self.unet.set_timestep(timestep=timestep)
|
||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
prediction = self.unet(noisy_latents)
|
||||
|
@ -182,7 +183,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
unet=self.unet,
|
||||
lda=self.lda,
|
||||
clip_text_encoder=self.text_encoder,
|
||||
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
|
||||
solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
|
||||
device=self.device,
|
||||
)
|
||||
prompts = self.config.test_diffusion.prompts
|
||||
|
|
|
@ -24,8 +24,8 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
|||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from tests.utils import ensure_similar_images
|
||||
|
@ -491,8 +491,8 @@ def sd15_ddim(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||
ddim_solver = DDIM(num_inference_steps=20)
|
||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
|
@ -509,8 +509,8 @@ def sd15_ddim_karras(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
|
@ -527,8 +527,8 @@ def sd15_euler(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
euler_scheduler = EulerScheduler(num_inference_steps=30)
|
||||
sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device)
|
||||
euler_solver = Euler(num_inference_steps=30)
|
||||
sd15 = StableDiffusion_1(solver=euler_solver, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
|
@ -545,8 +545,8 @@ def sd15_ddim_lda_ft_mse(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||
ddim_solver = DDIM(num_inference_steps=20)
|
||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
|
||||
|
@ -599,8 +599,8 @@ def sdxl_ddim(
|
|||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
scheduler = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
||||
solver = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
||||
|
@ -617,8 +617,8 @@ def sdxl_ddim_lda_fp16_fix(
|
|||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
scheduler = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
||||
solver = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(solver=solver, device=test_device)
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
||||
|
@ -659,8 +659,8 @@ def test_diffusion_std_random_init_euler(
|
|||
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_euler
|
||||
euler_scheduler = sd15_euler.scheduler
|
||||
assert isinstance(euler_scheduler, EulerScheduler)
|
||||
euler_solver = sd15_euler.solver
|
||||
assert isinstance(euler_solver, Euler)
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
@ -670,7 +670,7 @@ def test_diffusion_std_random_init_euler(
|
|||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
x = x * euler_scheduler.init_noise_sigma
|
||||
x = x * euler_solver.init_noise_sigma
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
|
@ -1202,7 +1202,7 @@ def test_diffusion_refonly(
|
|||
|
||||
for step in sd15.steps:
|
||||
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||
noised_guide = sd15.solver.add_noise(guide, noise, step)
|
||||
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||
x = sd15(
|
||||
x,
|
||||
|
@ -1244,7 +1244,7 @@ def test_diffusion_inpainting_refonly(
|
|||
|
||||
for step in sd15.steps:
|
||||
noise = torch.randn_like(guide)
|
||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||
noised_guide = sd15.solver.add_noise(guide, noise, step)
|
||||
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
||||
# inpaint variation models")
|
||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
|
||||
|
||||
|
||||
def test_ddpm_diffusers():
|
||||
|
@ -83,7 +83,7 @@ def test_euler_diffusers():
|
|||
use_karras_sigmas=False,
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
||||
refiners_scheduler = Euler(num_inference_steps=30)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
predicted_noise = randn(1, 4, 32, 32)
|
Loading…
Reference in a new issue