mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/ld) add Solver
docstrings
This commit is contained in:
parent
289261f2fb
commit
0c5a7a8269
|
@ -10,16 +10,23 @@ T = TypeVar("T", bound="Solver")
|
||||||
|
|
||||||
|
|
||||||
class NoiseSchedule(str, Enum):
|
class NoiseSchedule(str, Enum):
|
||||||
|
"""An enumeration of noise schedules used to sample the noise schedule.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
UNIFORM: A uniform noise schedule.
|
||||||
|
QUADRATIC: A quadratic noise schedule.
|
||||||
|
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
|
||||||
|
"""
|
||||||
|
|
||||||
UNIFORM = "uniform"
|
UNIFORM = "uniform"
|
||||||
QUADRATIC = "quadratic"
|
QUADRATIC = "quadratic"
|
||||||
KARRAS = "karras"
|
KARRAS = "karras"
|
||||||
|
|
||||||
|
|
||||||
class Solver(fl.Module, ABC):
|
class Solver(fl.Module, ABC):
|
||||||
"""
|
"""The base class for creating a diffusion model solver.
|
||||||
A base class for creating a diffusion model solver.
|
|
||||||
|
|
||||||
Solver creates a sequence of noise and scaling factors used in the diffusion process,
|
Solvers create a sequence of noise and scaling factors used in the diffusion process,
|
||||||
which gradually transforms the original data distribution into a Gaussian one.
|
which gradually transforms the original data distribution into a Gaussian one.
|
||||||
|
|
||||||
This process is described using several parameters such as initial and final diffusion rates,
|
This process is described using several parameters such as initial and final diffusion rates,
|
||||||
|
@ -39,6 +46,18 @@ class Solver(fl.Module, ABC):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes a new `Solver` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps to perform.
|
||||||
|
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||||
|
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||||
|
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||||
|
first_inference_step: The first inference step to perform.
|
||||||
|
device: The PyTorch device to use for the solver's tensors.
|
||||||
|
dtype: The PyTorch data type to use for the solver's tensors.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
@ -55,19 +74,25 @@ class Solver(fl.Module, ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""
|
"""Apply a step of the diffusion process using the Solver.
|
||||||
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
|
|
||||||
|
|
||||||
This method should be overridden by subclasses to implement the specific diffusion process.
|
Note:
|
||||||
|
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to apply the diffusion process to.
|
||||||
|
predicted_noise: The predicted noise tensor for the current step.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
generator: The random number generator to use for sampling noise.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
"""
|
"""Generate a tensor of timesteps.
|
||||||
Generates a tensor of timesteps.
|
|
||||||
|
|
||||||
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
Note:
|
||||||
|
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -77,6 +102,16 @@ class Solver(fl.Module, ABC):
|
||||||
noise: Tensor,
|
noise: Tensor,
|
||||||
step: int,
|
step: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Add noise to the input tensor using the solver's parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to add noise to.
|
||||||
|
noise: The noise tensor to add to the input tensor.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The input tensor with added noise.
|
||||||
|
"""
|
||||||
timestep = self.timesteps[step]
|
timestep = self.timesteps[step]
|
||||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
noise_stds = self.noise_std[timestep]
|
noise_stds = self.noise_std[timestep]
|
||||||
|
@ -84,28 +119,43 @@ class Solver(fl.Module, ABC):
|
||||||
return noised_x
|
return noised_x
|
||||||
|
|
||||||
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
"""Remove noise from the input tensor using the current step of the diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to remove noise from.
|
||||||
|
noise: The noise tensor to remove from the input tensor.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised input tensor.
|
||||||
|
"""
|
||||||
timestep = self.timesteps[step]
|
timestep = self.timesteps[step]
|
||||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
noise_stds = self.noise_std[timestep]
|
noise_stds = self.noise_std[timestep]
|
||||||
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
|
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
|
||||||
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
# Useful to preview progress or for guidance
|
||||||
|
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
||||||
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_steps(self) -> list[int]:
|
def all_steps(self) -> list[int]:
|
||||||
|
"""Return a list of all inference steps."""
|
||||||
return list(range(self.num_inference_steps))
|
return list(range(self.num_inference_steps))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inference_steps(self) -> list[int]:
|
def inference_steps(self) -> list[int]:
|
||||||
|
"""Return a list of inference steps to perform."""
|
||||||
return self.all_steps[self.first_inference_step :]
|
return self.all_steps[self.first_inference_step :]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
|
"""The PyTorch device used for the solver's tensors."""
|
||||||
return self.scale_factors.device
|
return self.scale_factors.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType:
|
def dtype(self) -> DType:
|
||||||
|
"""The PyTorch data type used for the solver's tensors."""
|
||||||
return self.scale_factors.dtype
|
return self.scale_factors.dtype
|
||||||
|
|
||||||
@device.setter
|
@device.setter
|
||||||
|
@ -117,6 +167,15 @@ class Solver(fl.Module, ABC):
|
||||||
self.to(dtype=dtype)
|
self.to(dtype=dtype)
|
||||||
|
|
||||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||||
|
"""Rebuild the solver with new parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps to perform.
|
||||||
|
first_inference_step: The first inference step to perform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new solver instance with the specified parameters.
|
||||||
|
"""
|
||||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
|
@ -131,12 +190,30 @@ class Solver(fl.Module, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
"""
|
"""Scale the model's input according to the current timestep.
|
||||||
For compatibility with solvers that need to scale the input according to the current timestep.
|
|
||||||
|
Note:
|
||||||
|
This method should only be overridden by solvers that
|
||||||
|
need to scale the input according to the current timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to scale.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scaled input tensor.
|
||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||||
|
"""Sample a power distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
power: The power to use for the distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
|
||||||
|
"""
|
||||||
return (
|
return (
|
||||||
linspace(
|
linspace(
|
||||||
start=self.initial_diffusion_rate ** (1 / power),
|
start=self.initial_diffusion_rate ** (1 / power),
|
||||||
|
@ -147,6 +224,11 @@ class Solver(fl.Module, ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_noise_schedule(self) -> Tensor:
|
def sample_noise_schedule(self) -> Tensor:
|
||||||
|
"""Sample the noise schedule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor representing the noise schedule.
|
||||||
|
"""
|
||||||
match self.noise_schedule:
|
match self.noise_schedule:
|
||||||
case "uniform":
|
case "uniform":
|
||||||
return 1 - self.sample_power_distribution(1)
|
return 1 - self.sample_power_distribution(1)
|
||||||
|
@ -158,6 +240,15 @@ class Solver(fl.Module, ABC):
|
||||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||||
|
|
||||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||||
|
"""Move the solver to the specified device and data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The PyTorch device to move the solver to.
|
||||||
|
dtype: The PyTorch data type to move the solver to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The solver instance, moved to the specified device and data type.
|
||||||
|
"""
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
|
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
|
||||||
match name:
|
match name:
|
||||||
|
|
Loading…
Reference in a new issue