mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add karras sampling to the Scheduler abstract class, default is quadratic
This commit is contained in:
parent
f22f969d65
commit
ad8f02e555
|
@ -1,5 +1,5 @@
|
||||||
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
|
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
|
||||||
|
|
||||||
class DDIM(Scheduler):
|
class DDIM(Scheduler):
|
||||||
|
@ -9,6 +9,7 @@ class DDIM(Scheduler):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -17,6 +18,7 @@ class DDIM(Scheduler):
|
||||||
num_train_timesteps,
|
num_train_timesteps,
|
||||||
initial_diffusion_rate,
|
initial_diffusion_rate,
|
||||||
final_diffusion_rate,
|
final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
|
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -16,6 +16,7 @@ class DPMSolver(Scheduler):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
|
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,17 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
T = TypeVar("T", bound="Scheduler")
|
T = TypeVar("T", bound="Scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseSchedule(str, Enum):
|
||||||
|
UNIFORM = "uniform"
|
||||||
|
QUADRATIC = "quadratic"
|
||||||
|
KARRAS = "karras"
|
||||||
|
|
||||||
|
|
||||||
class Scheduler(ABC):
|
class Scheduler(ABC):
|
||||||
"""
|
"""
|
||||||
A base class for creating a diffusion model scheduler.
|
A base class for creating a diffusion model scheduler.
|
||||||
|
@ -24,6 +31,7 @@ class Scheduler(ABC):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = float32,
|
||||||
):
|
):
|
||||||
|
@ -33,17 +41,8 @@ class Scheduler(ABC):
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
self.initial_diffusion_rate = initial_diffusion_rate
|
self.initial_diffusion_rate = initial_diffusion_rate
|
||||||
self.final_diffusion_rate = final_diffusion_rate
|
self.final_diffusion_rate = final_diffusion_rate
|
||||||
self.scale_factors = (
|
self.noise_schedule = noise_schedule
|
||||||
1.0
|
self.scale_factors = self.sample_noise_schedule()
|
||||||
- linspace(
|
|
||||||
start=initial_diffusion_rate**0.5,
|
|
||||||
end=final_diffusion_rate**0.5,
|
|
||||||
steps=num_train_timesteps,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
** 2
|
|
||||||
)
|
|
||||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||||
|
@ -71,6 +70,29 @@ class Scheduler(ABC):
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
return list(range(self.num_inference_steps))
|
return list(range(self.num_inference_steps))
|
||||||
|
|
||||||
|
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||||
|
return (
|
||||||
|
linspace(
|
||||||
|
start=self.initial_diffusion_rate ** (1 / power),
|
||||||
|
end=self.final_diffusion_rate ** (1 / power),
|
||||||
|
steps=self.num_train_timesteps,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
** power
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_noise_schedule(self) -> Tensor:
|
||||||
|
match self.noise_schedule:
|
||||||
|
case "uniform":
|
||||||
|
return 1 - self.sample_power_distribution(1)
|
||||||
|
case "quadratic":
|
||||||
|
return 1 - self.sample_power_distribution(2)
|
||||||
|
case "karras":
|
||||||
|
return 1 - self.sample_power_distribution(7)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
|
Loading…
Reference in a new issue