mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/ld) add Solver
docstrings
This commit is contained in:
parent
289261f2fb
commit
0c5a7a8269
|
@ -10,16 +10,23 @@ T = TypeVar("T", bound="Solver")
|
|||
|
||||
|
||||
class NoiseSchedule(str, Enum):
|
||||
"""An enumeration of noise schedules used to sample the noise schedule.
|
||||
|
||||
Attributes:
|
||||
UNIFORM: A uniform noise schedule.
|
||||
QUADRATIC: A quadratic noise schedule.
|
||||
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
|
||||
"""
|
||||
|
||||
UNIFORM = "uniform"
|
||||
QUADRATIC = "quadratic"
|
||||
KARRAS = "karras"
|
||||
|
||||
|
||||
class Solver(fl.Module, ABC):
|
||||
"""
|
||||
A base class for creating a diffusion model solver.
|
||||
"""The base class for creating a diffusion model solver.
|
||||
|
||||
Solver creates a sequence of noise and scaling factors used in the diffusion process,
|
||||
Solvers create a sequence of noise and scaling factors used in the diffusion process,
|
||||
which gradually transforms the original data distribution into a Gaussian one.
|
||||
|
||||
This process is described using several parameters such as initial and final diffusion rates,
|
||||
|
@ -39,6 +46,18 @@ class Solver(fl.Module, ABC):
|
|||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
) -> None:
|
||||
"""Initializes a new `Solver` instance.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||
first_inference_step: The first inference step to perform.
|
||||
device: The PyTorch device to use for the solver's tensors.
|
||||
dtype: The PyTorch data type to use for the solver's tensors.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
|
@ -55,19 +74,25 @@ class Solver(fl.Module, ABC):
|
|||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
|
||||
"""Apply a step of the diffusion process using the Solver.
|
||||
|
||||
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||
Note:
|
||||
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||
|
||||
Args:
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""
|
||||
Generates a tensor of timesteps.
|
||||
"""Generate a tensor of timesteps.
|
||||
|
||||
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||
Note:
|
||||
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -77,6 +102,16 @@ class Solver(fl.Module, ABC):
|
|||
noise: Tensor,
|
||||
step: int,
|
||||
) -> Tensor:
|
||||
"""Add noise to the input tensor using the solver's parameters.
|
||||
|
||||
Args:
|
||||
x: The input tensor to add noise to.
|
||||
noise: The noise tensor to add to the input tensor.
|
||||
step: The current step of the diffusion process.
|
||||
|
||||
Returns:
|
||||
The input tensor with added noise.
|
||||
"""
|
||||
timestep = self.timesteps[step]
|
||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||
noise_stds = self.noise_std[timestep]
|
||||
|
@ -84,28 +119,43 @@ class Solver(fl.Module, ABC):
|
|||
return noised_x
|
||||
|
||||
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
"""Remove noise from the input tensor using the current step of the diffusion process.
|
||||
|
||||
Args:
|
||||
x: The input tensor to remove noise from.
|
||||
noise: The noise tensor to remove from the input tensor.
|
||||
step: The current step of the diffusion process.
|
||||
|
||||
Returns:
|
||||
The denoised input tensor.
|
||||
"""
|
||||
timestep = self.timesteps[step]
|
||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||
noise_stds = self.noise_std[timestep]
|
||||
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
|
||||
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
||||
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
|
||||
# Useful to preview progress or for guidance
|
||||
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
||||
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||
return denoised_x
|
||||
|
||||
@property
|
||||
def all_steps(self) -> list[int]:
|
||||
"""Return a list of all inference steps."""
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
@property
|
||||
def inference_steps(self) -> list[int]:
|
||||
"""Return a list of inference steps to perform."""
|
||||
return self.all_steps[self.first_inference_step :]
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
"""The PyTorch device used for the solver's tensors."""
|
||||
return self.scale_factors.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""The PyTorch data type used for the solver's tensors."""
|
||||
return self.scale_factors.dtype
|
||||
|
||||
@device.setter
|
||||
|
@ -117,6 +167,15 @@ class Solver(fl.Module, ABC):
|
|||
self.to(dtype=dtype)
|
||||
|
||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||
"""Rebuild the solver with new parameters.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
|
||||
Returns:
|
||||
A new solver instance with the specified parameters.
|
||||
"""
|
||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||
return self.__class__(
|
||||
|
@ -131,12 +190,30 @@ class Solver(fl.Module, ABC):
|
|||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with solvers that need to scale the input according to the current timestep.
|
||||
"""Scale the model's input according to the current timestep.
|
||||
|
||||
Note:
|
||||
This method should only be overridden by solvers that
|
||||
need to scale the input according to the current timestep.
|
||||
|
||||
Args:
|
||||
x: The input tensor to scale.
|
||||
step: The current step of the diffusion process.
|
||||
|
||||
Returns:
|
||||
The scaled input tensor.
|
||||
"""
|
||||
return x
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
"""Sample a power distribution.
|
||||
|
||||
Args:
|
||||
power: The power to use for the distribution.
|
||||
|
||||
Returns:
|
||||
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
|
||||
"""
|
||||
return (
|
||||
linspace(
|
||||
start=self.initial_diffusion_rate ** (1 / power),
|
||||
|
@ -147,6 +224,11 @@ class Solver(fl.Module, ABC):
|
|||
)
|
||||
|
||||
def sample_noise_schedule(self) -> Tensor:
|
||||
"""Sample the noise schedule.
|
||||
|
||||
Returns:
|
||||
A tensor representing the noise schedule.
|
||||
"""
|
||||
match self.noise_schedule:
|
||||
case "uniform":
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
|
@ -158,6 +240,15 @@ class Solver(fl.Module, ABC):
|
|||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||
|
||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||
"""Move the solver to the specified device and data type.
|
||||
|
||||
Args:
|
||||
device: The PyTorch device to move the solver to.
|
||||
dtype: The PyTorch data type to move the solver to.
|
||||
|
||||
Returns:
|
||||
The solver instance, moved to the specified device and data type.
|
||||
"""
|
||||
super().to(device=device, dtype=dtype)
|
||||
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
|
||||
match name:
|
||||
|
|
Loading…
Reference in a new issue