Skip to content

Latent Diffusion

FixedGroupNorm

FixedGroupNorm(target: GroupNorm)

Bases: Chain, Adapter[GroupNorm]

Adapter for GroupNorm layers to fix the running mean and variance.

This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers are consistent across tiles.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(self, target: fl.GroupNorm) -> None:
    self.mean = None
    self.var = None
    with self.setup_adapter(target):
        super().__init__(fl.Lambda(self.compute_group_norm))

LatentDiffusionAutoencoder

LatentDiffusionAutoencoder(
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Latent diffusion autoencoder model.

Attributes:

Name Type Description
encoder_scale

The encoder scale to use.

Parameters:

Name Type Description Default
device device | str | None

The PyTorch device to use.

None
dtype dtype | None

The PyTorch data type to use.

None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(
    self,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initializes the model.

    Args:
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    super().__init__(
        Encoder(device=device, dtype=dtype),
        Decoder(device=device, dtype=dtype),
    )
    self._tile_size = None
    self._blending = None

decode

decode(x: Tensor) -> Tensor

Decode a latent tensor.

Parameters:

Name Type Description Default
x Tensor

The latent to decode.

required

Returns:

Type Description
Tensor

The decoded image tensor.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def decode(self, x: Tensor) -> Tensor:
    """Decode a latent tensor.

    Args:
        x: The latent to decode.

    Returns:
        The decoded image tensor.
    """
    decoder = self[1]
    x = decoder(x / self.encoder_scale)
    return x

encode

encode(x: Tensor) -> Tensor

Encode an image.

Parameters:

Name Type Description Default
x Tensor

The image tensor to encode.

required

Returns:

Type Description
Tensor

The encoded tensor.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def encode(self, x: Tensor) -> Tensor:
    """Encode an image.

    Args:
        x: The image tensor to encode.

    Returns:
        The encoded tensor.
    """
    encoder = self[0]
    x = self.encoder_scale * encoder(x)
    return x

image_to_latents

image_to_latents(image: Image) -> Tensor

Encode an image to latents.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def image_to_latents(self, image: Image.Image) -> Tensor:
    """
    Encode an image to latents.
    """
    return self.images_to_latents([image])

images_to_latents

images_to_latents(images: list[Image]) -> Tensor

Convert a list of images to latents.

Parameters:

Name Type Description Default
images list[Image]

The list of images to convert.

required

Returns:

Type Description
Tensor

A tensor containing the latents associated with the images.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
    """Convert a list of images to latents.

    Args:
        images: The list of images to convert.

    Returns:
        A tensor containing the latents associated with the images.
    """
    x = images_to_tensor(images, device=self.device, dtype=self.dtype)
    x = 2 * x - 1
    return self.encode(x)

latents_to_image

latents_to_image(x: Tensor) -> Image

Decode latents to an image.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_image(self, x: Tensor) -> Image.Image:
    """
    Decode latents to an image.
    """
    if x.shape[0] != 1:
        raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

    return self.latents_to_images(x)[0]

latents_to_images

latents_to_images(x: Tensor) -> list[Image]

Convert a tensor of latents to images.

Parameters:

Name Type Description Default
x Tensor

The tensor of latents to convert.

required

Returns:

Type Description
list[Image]

A list of images associated with the latents.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
    """Convert a tensor of latents to images.

    Args:
        x: The tensor of latents to convert.

    Returns:
        A list of images associated with the latents.
    """
    x = self.decode(x)
    x = (x + 1) / 2
    return tensor_to_images(x)

tiled_image_to_latents

tiled_image_to_latents(image: Image) -> Tensor

Convert an image to latents with gradient blending to smooth tile edges.

You need to activate the tiled inference context manager with the tiled_inference method to use this method.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): latents = lda.tiled_image_to_latents(sample_image)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:
    """
    Convert an image to latents with gradient blending to smooth tile edges.

    You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.

    ```python
    with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
        latents = lda.tiled_image_to_latents(sample_image)
    """
    if self._tile_size is None:
        raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")

    assert self._tile_size is not None and self._blending is not None
    image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
    image_tensor = 2 * image_tensor - 1
    return self._tiled_encode(image_tensor, self._tile_size, self._blending)

tiled_inference

tiled_inference(
    image: Image,
    tile_size: tuple[int, int] = (512, 512),
    blending: int = 64,
) -> Generator[None, None, None]

Context manager for tiled inference operations to save VRAM for large images.

This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the image.

Be careful not to use the normal image_to_latents and latents_to_image methods while this context manager is active, as this will fail silently and run the operation without tiling.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32): latents = lda.tiled_image_to_latents(sample_image) decoded_image = lda.tiled_latents_to_image(latents)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
@contextmanager
def tiled_inference(
    self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64
) -> Generator[None, None, None]:
    """
    Context manager for tiled inference operations to save VRAM for large images.

    This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the
    autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is
    consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the
    image.

    Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is
    active, as this will fail silently and run the operation without tiling.

    ```python
    with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):
        latents = lda.tiled_image_to_latents(sample_image)
        decoded_image = lda.tiled_latents_to_image(latents)
    """
    try:
        self._blending = blending
        self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])
        self._add_fixed_group_norm(image, inference_size=self._tile_size)
        yield
    finally:
        self._remove_fixed_group_norm()
        self._tile_size = None
        self._blending = None

tiled_latents_to_image

tiled_latents_to_image(x: Tensor) -> Image

Convert latents to an image with gradient blending to smooth tile edges.

You need to activate the tiled inference context manager with the tiled_inference method to use this method.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): image = lda.tiled_latents_to_image(latents)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:
    """
    Convert latents to an image with gradient blending to smooth tile edges.

    You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.

    ```python
    with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
        image = lda.tiled_latents_to_image(latents)
    """
    if self._tile_size is None:
        raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")

    assert self._tile_size is not None and self._blending is not None
    result = self._tiled_decode(x, self._tile_size, self._blending)
    return tensor_to_image((result + 1) / 2)

LatentDiffusionModel

LatentDiffusionModel(
    unet: Chain,
    lda: LatentDiffusionAutoencoder,
    clip_text_encoder: Chain,
    solver: Solver,
    classifier_free_guidance: bool = True,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Module, ABC

Source code in src/refiners/foundationals/latent_diffusion/model.py
def __init__(
    self,
    unet: fl.Chain,
    lda: LatentDiffusionAutoencoder,
    clip_text_encoder: fl.Chain,
    solver: Solver,
    classifier_free_guidance: bool = True,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    super().__init__()
    self.device: Device = device if isinstance(device, Device) else Device(device=device)
    self.dtype = dtype
    self.unet = unet.to(device=self.device, dtype=self.dtype)
    self.lda = lda.to(device=self.device, dtype=self.dtype)
    self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
    self.solver = solver.to(device=self.device, dtype=self.dtype)
    self.classifier_free_guidance = classifier_free_guidance

init_latents

init_latents(
    size: tuple[int, int],
    init_image: Image | None = None,
    noise: Tensor | None = None,
) -> Tensor

Initialize the latents for the diffusion process.

Parameters:

Name Type Description Default
size tuple[int, int]

The size of the latent (in pixel space).

required
init_image Image | None

The image to use as initialization for the latents.

None
noise Tensor | None

The noise to add to the latents.

None
Source code in src/refiners/foundationals/latent_diffusion/model.py
def init_latents(
    self,
    size: tuple[int, int],
    init_image: Image.Image | None = None,
    noise: Tensor | None = None,
) -> Tensor:
    """Initialize the latents for the diffusion process.

    Args:
        size: The size of the latent (in pixel space).
        init_image: The image to use as initialization for the latents.
        noise: The noise to add to the latents.
    """
    height, width = size
    latent_height = height // 8
    latent_width = width // 8

    if noise is None:
        noise = LatentDiffusionModel.sample_noise(
            size=(1, 4, latent_height, latent_width),
            device=self.device,
            dtype=self.dtype,
        )

    assert list(noise.shape[2:]) == [
        latent_height,
        latent_width,
    ], f"noise shape is not compatible: {noise.shape}, with size: {size}"

    if init_image is None:
        latent = noise
    else:
        resized = init_image.resize(size=(width, height))  # type: ignore
        encoded_image = self.lda.image_to_latents(resized)
        latent = self.solver.add_noise(
            x=encoded_image,
            noise=noise,
            step=self.solver.first_inference_step,
        )

    return self.solver.scale_model_input(latent, step=-1)

sample_noise staticmethod

sample_noise(
    size: tuple[int, ...],
    device: device | None = None,
    dtype: dtype | None = None,
    offset_noise: float | None = None,
) -> Tensor

Sample noise from a normal distribution with an optional offset.

Parameters:

Name Type Description Default
size tuple[int, ...]

The size of the noise tensor.

required
device device | None

The device to put the noise tensor on.

None
dtype dtype | None

The data type of the noise tensor.

None
offset_noise float | None

The offset of the noise tensor. Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.

None
Source code in src/refiners/foundationals/latent_diffusion/model.py
@staticmethod
def sample_noise(
    size: tuple[int, ...],
    device: Device | None = None,
    dtype: DType | None = None,
    offset_noise: float | None = None,
) -> torch.Tensor:
    """Sample noise from a normal distribution with an optional offset.

    Args:
        size: The size of the noise tensor.
        device: The device to put the noise tensor on.
        dtype: The data type of the noise tensor.
        offset_noise: The offset of the noise tensor.
            Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.
    """
    noise = torch.randn(size=size, device=device, dtype=dtype)
    if offset_noise is not None:
        noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)
    return noise

set_inference_steps

set_inference_steps(
    num_steps: int, first_step: int = 0
) -> None

Set the steps of the diffusion process.

Parameters:

Name Type Description Default
num_steps int

The number of inference steps.

required
first_step int

The first inference step, used for image-to-image diffusion. You may be used to setting a float in [0, 1] called strength instead, which is an abstraction for this. The first step is round((1 - strength) * (num_steps - 1)).

0
Source code in src/refiners/foundationals/latent_diffusion/model.py
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
    """Set the steps of the diffusion process.

    Args:
        num_steps: The number of inference steps.
        first_step: The first inference step, used for image-to-image diffusion.
            You may be used to setting a float in `[0, 1]` called `strength` instead,
            which is an abstraction for this. The first step is
            `round((1 - strength) * (num_steps - 1))`.
    """
    self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)

ControlLora

ControlLora(
    name: str,
    unet: SDXLUNet,
    scale: float = 1.0,
    condition_channels: int = 3,
)

Bases: Passthrough

ControlLora is a Half-UNet clone of the target UNet, patched with various LoRA layers, ZeroConvolution layers, and a ConditionEncoder.

Like ControlNet, it injects residual tensors into the target UNet. See https://github.com/HighCWu/control-lora-v2 for more details.

Gets context:

Type Description
Float[Tensor, 'batch condition_channels width height']

The input image.

Sets context:

Type Description
list[Tensor]

The residuals to be added to the target UNet's residuals. (context="unet", key="residuals")

Parameters:

Name Type Description Default
name str

The name of the ControlLora.

required
unet SDXLUNet

The target UNet.

required
scale float

The scale to multiply the residuals by.

1.0
condition_channels int

The number of channels of the input condition tensor.

3
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(
    self,
    name: str,
    unet: SDXLUNet,
    scale: float = 1.0,
    condition_channels: int = 3,
) -> None:
    """Initialize the ControlLora.

    Args:
        name: The name of the ControlLora.
        unet: The target UNet.
        scale: The scale to multiply the residuals by.
        condition_channels: The number of channels of the input condition tensor.
    """
    self.name = name

    super().__init__(
        timestep_encoder := unet.layer("TimestepEncoder", Chain).structural_copy(),
        downblocks := unet.layer("DownBlocks", Chain).structural_copy(),
        middle_block := unet.layer("MiddleBlock", Chain).structural_copy(),
    )

    # modify the context_key of the copied TimestepEncoder to avoid conflicts
    timestep_encoder.context_key = f"timestep_embedding_control_lora_{name}"

    # modify the context_key of each RangeAdapter2d to avoid conflicts
    for range_adapter in self.layers(RangeAdapter2d):
        range_adapter.context_key = f"timestep_embedding_control_lora_{name}"

    # insert the ConditionEncoder in the first DownBlock
    first_downblock = downblocks.layer(0, Chain)
    out_channels = first_downblock.layer(0, Conv2d).out_channels
    first_downblock.append(
        Residual(
            UseContext(f"control_lora_{name}", f"condition"),
            ConditionEncoder(
                in_channels=condition_channels,
                out_channels=out_channels,
                device=unet.device,
                dtype=unet.dtype,
            ),
        )
    )

    # replace each ResidualAccumulator by a ZeroConvolution
    for residual_accumulator in self.layers(ResidualAccumulator):
        downblock = self.ensure_find_parent(residual_accumulator)

        first_layer = downblock[0]
        assert hasattr(first_layer, "out_channels"), f"{first_layer} has no out_channels attribute"

        block_channels = first_layer.out_channels
        assert isinstance(block_channels, int)

        downblock.replace(
            residual_accumulator,
            ZeroConvolution(
                scale=scale,
                residual_index=residual_accumulator.n,
                in_channels=block_channels,
                out_channels=block_channels,
                device=unet.device,
                dtype=unet.dtype,
            ),
        )

    # append a ZeroConvolution to middle_block
    middle_block_channels = middle_block.layer(0, ResidualBlock).out_channels
    middle_block.append(
        ZeroConvolution(
            scale=scale,
            residual_index=len(downblocks),
            in_channels=middle_block_channels,
            out_channels=middle_block_channels,
            device=unet.device,
            dtype=unet.dtype,
        )
    )

scale property writable

scale: float

The scale of the residuals stored in the context.

ControlLoraAdapter

ControlLoraAdapter(
    name: str,
    target: SDXLUNet,
    scale: float = 1.0,
    condition_channels: int = 3,
    weights: dict[str, Tensor] | None = None,
)

Bases: Chain, Adapter[SDXLUNet]

Adapter for ControlLora.

This adapter simply prepends a ControlLora model inside the target SDXLUNet.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(
    self,
    name: str,
    target: SDXLUNet,
    scale: float = 1.0,
    condition_channels: int = 3,
    weights: dict[str, Tensor] | None = None,
) -> None:
    with self.setup_adapter(target):
        self.name = name
        self._control_lora = [
            ControlLora(
                name=name,
                unet=target,
                scale=scale,
                condition_channels=condition_channels,
            ),
        ]

        super().__init__(target)

    if weights:
        self.load_weights(weights)

control_lora property

control_lora: ControlLora

The ControlLora model.

scale property writable

scale: float

The scale of the injected residuals.

load_condition_encoder staticmethod

load_condition_encoder(
    state_dict: dict[str, Tensor], control_lora: ControlLora
)

Load the ConditionEncoder's layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

The state_dict containing the ConditionEncoder layers to load.

required
control_lora ControlLora

The ControlLora to load the ConditionEncoder layers into.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod
def load_condition_encoder(
    state_dict: dict[str, Tensor],
    control_lora: ControlLora,
):
    """Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.

    Args:
        state_dict: The state_dict containing the ConditionEncoder layers to load.
        control_lora: The ControlLora to load the ConditionEncoder layers into.
    """
    condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)
    condition_encoder_state_dict = {
        key.removeprefix("ConditionEncoder."): value
        for key, value in state_dict.items()
        if "ConditionEncoder" in key
    }
    condition_encoder_layer.load_state_dict(condition_encoder_state_dict)

load_lora_layers staticmethod

load_lora_layers(
    name: str,
    state_dict: dict[str, Tensor],
    control_lora: ControlLora,
) -> None

Load the LoRA layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default
name str

The name of the ControlLora.

required
state_dict dict[str, Tensor]

The state_dict containing the LoRA layers to load.

required
control_lora ControlLora

The ControlLora to load the LoRA layers into.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod
def load_lora_layers(
    name: str,
    state_dict: dict[str, Tensor],
    control_lora: ControlLora,
) -> None:
    """Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.

    Args:
        name: The name of the ControlLora.
        state_dict: The state_dict containing the LoRA layers to load.
        control_lora: The ControlLora to load the LoRA layers into.
    """
    # filter the LoraAdapters from the state_dict
    lora_weights = {
        key.removeprefix("ControlLora."): value for key, value in state_dict.items() if "ControlLora" in key
    }
    lora_weights = {f"{key}.weight": value for key, value in lora_weights.items()}

    # move the tensors to the device and dtype of the ControlLora
    lora_weights = {
        key: value.to(
            dtype=control_lora.dtype,
            device=control_lora.device,
        )
        for key, value in lora_weights.items()
    }

    # load every LoRA layers from the filtered state_dict
    loras = Lora.from_dict(name, state_dict=lora_weights)

    # attach the LoRA layers to the ControlLora
    adapters: list[LoraAdapter] = []
    for key, lora in loras.items():
        target = control_lora.layer(key.split("."), WeightedModule)
        assert lora.is_compatible(target)
        adapter = LoraAdapter(target, lora)
        adapters.append(adapter)

    for adapter in adapters:
        adapter.inject(control_lora)

load_weights

load_weights(state_dict: dict[str, Tensor]) -> None

Load the weights from the state_dict into the ControlLora.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

The state_dict containing the weights to load.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def load_weights(
    self,
    state_dict: dict[str, Tensor],
) -> None:
    """Load the weights from the state_dict into the `ControlLora`.

    Args:
        state_dict: The state_dict containing the weights to load.
    """
    ControlLoraAdapter.load_lora_layers(self.name, state_dict, self.control_lora)
    ControlLoraAdapter.load_zero_convolution_layers(state_dict, self.control_lora)
    ControlLoraAdapter.load_condition_encoder(state_dict, self.control_lora)

load_zero_convolution_layers staticmethod

load_zero_convolution_layers(
    state_dict: dict[str, Tensor], control_lora: ControlLora
)

Load the ZeroConvolution layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

The state_dict containing the ZeroConvolution layers to load.

required
control_lora ControlLora

The ControlLora to load the ZeroConvolution layers into.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod
def load_zero_convolution_layers(
    state_dict: dict[str, Tensor],
    control_lora: ControlLora,
):
    """Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.

    Args:
        state_dict: The state_dict containing the ZeroConvolution layers to load.
        control_lora: The ControlLora to load the ZeroConvolution layers into.
    """
    zero_convolution_layers = list(control_lora.layers(ZeroConvolution))
    for i, zero_convolution_layer in enumerate(zero_convolution_layers):
        zero_convolution_state_dict = {
            key.removeprefix(f"ZeroConvolution_{i+1:02d}."): value
            for key, value in state_dict.items()
            if f"ZeroConvolution_{i+1:02d}" in key
        }
        zero_convolution_layer.load_state_dict(zero_convolution_state_dict)

SDXLAutoencoder

SDXLAutoencoder(
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: LatentDiffusionAutoencoder

Stable Diffusion XL autoencoder model.

Attributes:

Name Type Description
encoder_scale float

The encoder scale to use.

Parameters:

Name Type Description Default
device device | str | None

The PyTorch device to use.

None
dtype dtype | None

The PyTorch data type to use.

None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(
    self,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initializes the model.

    Args:
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    super().__init__(
        Encoder(device=device, dtype=dtype),
        Decoder(device=device, dtype=dtype),
    )
    self._tile_size = None
    self._blending = None

SDXLIPAdapter

SDXLIPAdapter(
    target: SDXLUNet,
    clip_image_encoder: CLIPImageEncoderH | None = None,
    image_proj: (
        ImageProjection | PerceiverResampler | None
    ) = None,
    scale: float = 1.0,
    fine_grained: bool = False,
    weights: dict[str, Tensor] | None = None,
)

Bases: IPAdapter[SDXLUNet]

Image Prompt adapter for the Stable Diffusion XL U-Net model.

Parameters:

Name Type Description Default
target SDXLUNet

The SDXLUNet model to adapt.

required
clip_image_encoder CLIPImageEncoderH | None

The CLIP image encoder to use.

None
image_proj ImageProjection | PerceiverResampler | None

The image projection to use.

None
scale float

The scale to use for the image prompt.

1.0
fine_grained bool

Whether to use fine-grained image prompt.

False
weights dict[str, Tensor] | None

The weights of the IPAdapter.

None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py
def __init__(
    self,
    target: SDXLUNet,
    clip_image_encoder: CLIPImageEncoderH | None = None,
    image_proj: ImageProjection | PerceiverResampler | None = None,
    scale: float = 1.0,
    fine_grained: bool = False,
    weights: dict[str, Tensor] | None = None,
) -> None:
    """Initialize the adapter.

    Args:
        target: The SDXLUNet model to adapt.
        clip_image_encoder: The CLIP image encoder to use.
        image_proj: The image projection to use.
        scale: The scale to use for the image prompt.
        fine_grained: Whether to use fine-grained image prompt.
        weights: The weights of the IPAdapter.
    """
    clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)

    if image_proj is None:
        cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
        image_proj = (
            ImageProjection(
                clip_image_embedding_dim=clip_image_encoder.output_dim,
                clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
                device=target.device,
                dtype=target.dtype,
            )
            if not fine_grained
            else PerceiverResampler(
                latents_dim=1280,  # not `cross_attn_2d.context_embedding_dim` in this case
                num_attention_layers=4,
                num_attention_heads=20,
                head_dim=64,
                num_tokens=16,
                input_dim=clip_image_encoder.embedding_dim,  # = dim before final projection
                output_dim=cross_attn_2d.context_embedding_dim,
                device=target.device,
                dtype=target.dtype,
            )
        )
    elif fine_grained:
        assert isinstance(image_proj, PerceiverResampler)

    super().__init__(
        target=target,
        clip_image_encoder=clip_image_encoder,
        image_proj=image_proj,
        scale=scale,
        fine_grained=fine_grained,
        weights=weights,
    )

SDXLLcmAdapter

SDXLLcmAdapter(
    target: SDXLUNet,
    condition_scale_embedding_dim: int = 256,
    condition_scale: float = 7.5,
)

Bases: Chain, Adapter[SDXLUNet]

Note that LCM must be used without CFG. You can disable CFG on SD by setting the classifier_free_guidance attribute to False.

Parameters:

Name Type Description Default
target SDXLUNet

A SDXL UNet.

required
condition_scale_embedding_dim int

LCM uses a condition scale embedding, this is its dimension.

256
condition_scale float

Because of the embedding, the condition scale must be passed to this adapter instead of SD. The condition scale passed to SD will be ignored.

7.5
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py
def __init__(
    self,
    target: SDXLUNet,
    condition_scale_embedding_dim: int = 256,
    condition_scale: float = 7.5,
) -> None:
    """Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]
    for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].

    Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the
    `classifier_free_guidance` attribute to `False`.

    Args:
        target: A SDXL UNet.
        condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.
        condition_scale: Because of the embedding, the condition scale must be passed to this adapter
            instead of SD. The condition scale passed to SD will be ignored.
    """
    assert condition_scale_embedding_dim % 2 == 0
    self.condition_scale_embedding_dim = condition_scale_embedding_dim
    self.condition_scale = condition_scale
    with self.setup_adapter(target):
        super().__init__(target)

SDXLUNet

SDXLUNet(
    in_channels: int,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Stable Diffusion XL U-Net.

See [arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis for more details.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
device device | str | None

Device to use for computation.

None
dtype dtype | None

Data type to use for computation.

None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def __init__(
    self,
    in_channels: int,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the U-Net.

    Args:
        in_channels: Number of input channels.
        device: Device to use for computation.
        dtype: Data type to use for computation.
    """
    self.in_channels = in_channels
    super().__init__(
        TimestepEncoder(device=device, dtype=dtype),
        DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
        MiddleBlock(device=device, dtype=dtype),
        fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
        UpBlocks(device=device, dtype=dtype),
        OutputBlock(device=device, dtype=dtype),
    )
    for residual_block in self.layers(ResidualBlock):
        chain = residual_block.layer("Chain", fl.Chain)
        RangeAdapter2d(
            target=chain.layer("Conv2d_1", fl.Conv2d),
            channels=residual_block.out_channels,
            embedding_dim=1280,
            context_key="timestep_embedding",
            device=device,
            dtype=dtype,
        ).inject(chain)
    for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
        block.append(module=ResidualAccumulator(n=n))
    for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
        block.insert(index=0, module=ResidualConcatenator(n=-n - 2))

set_clip_text_embedding

set_clip_text_embedding(
    clip_text_embedding: Tensor,
) -> None

Set the clip text embedding context.

Note

This context is required by the SDXLCrossAttention blocks.

Parameters:

Name Type Description Default
clip_text_embedding Tensor

The CLIP text embedding tensor.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
    """Set the clip text embedding context.

    Note:
        This context is required by the `SDXLCrossAttention` blocks.

    Args:
        clip_text_embedding: The CLIP text embedding tensor.
    """
    self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})

set_pooled_text_embedding

set_pooled_text_embedding(
    pooled_text_embedding: Tensor,
) -> None

Set the pooled text embedding context.

Note

This is required by TextTimeEmbedding.

Parameters:

Name Type Description Default
pooled_text_embedding Tensor

The pooled text embedding tensor.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
    """Set the pooled text embedding context.

    Note:
        This is required by `TextTimeEmbedding`.

    Args:
        pooled_text_embedding: The pooled text embedding tensor.
    """
    self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})

set_time_ids

set_time_ids(time_ids: Tensor) -> None

Set the time IDs context.

Note

This is required by TextTimeEmbedding.

Parameters:

Name Type Description Default
time_ids Tensor

The time IDs tensor.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_time_ids(self, time_ids: Tensor) -> None:
    """Set the time IDs context.

    Note:
        This is required by `TextTimeEmbedding`.

    Args:
        time_ids: The time IDs tensor.
    """
    self.set_context(context="diffusion", value={"time_ids": time_ids})

set_timestep

set_timestep(timestep: Tensor) -> None

Set the timestep context.

Note

This is required by TimestepEncoder.

Parameters:

Name Type Description Default
timestep Tensor

The timestep tensor.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_timestep(self, timestep: Tensor) -> None:
    """Set the timestep context.

    Note:
        This is required by `TimestepEncoder`.

    Args:
        timestep: The timestep tensor.
    """
    self.set_context(context="diffusion", value={"timestep": timestep})

StableDiffusion_XL

StableDiffusion_XL(
    unet: SDXLUNet | None = None,
    lda: SDXLAutoencoder | None = None,
    clip_text_encoder: DoubleTextEncoder | None = None,
    solver: Solver | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: LatentDiffusionModel

Stable Diffusion XL model.

Attributes:

Name Type Description
unet SDXLUNet

The U-Net model.

clip_text_encoder DoubleTextEncoder

The text encoder.

lda SDXLAutoencoder

The image autoencoder.

Parameters:

Name Type Description Default
unet SDXLUNet | None

The SDXLUNet U-Net model to use.

None
lda SDXLAutoencoder | None

The SDXLAutoencoder image autoencoder to use.

None
clip_text_encoder DoubleTextEncoder | None

The DoubleTextEncoder text encoder to use.

None
solver Solver | None

The solver to use.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def __init__(
    self,
    unet: SDXLUNet | None = None,
    lda: SDXLAutoencoder | None = None,
    clip_text_encoder: DoubleTextEncoder | None = None,
    solver: Solver | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initializes the model.

    Args:
        unet: The SDXLUNet U-Net model to use.
        lda: The SDXLAutoencoder image autoencoder to use.
        clip_text_encoder: The DoubleTextEncoder text encoder to use.
        solver: The solver to use.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    unet = unet or SDXLUNet(in_channels=4)
    lda = lda or SDXLAutoencoder()
    clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
    solver = solver or DDIM(num_inference_steps=30)

    super().__init__(
        unet=unet,
        lda=lda,
        clip_text_encoder=clip_text_encoder,
        solver=solver,
        device=device,
        dtype=dtype,
    )

default_time_ids property

default_time_ids: Tensor

The default time IDs to use.

compute_clip_text_embedding

compute_clip_text_embedding(
    text: str | list[str],
    negative_text: str | list[str] = "",
) -> tuple[Tensor, Tensor]

Compute the CLIP text embedding associated with the given prompt and negative prompt.

Parameters:

Name Type Description Default
text str | list[str]

The prompt to compute the CLIP text embedding of.

required
negative_text str | list[str]

The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., "").

''
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_clip_text_embedding(
    self, text: str | list[str], negative_text: str | list[str] = ""
) -> tuple[Tensor, Tensor]:
    """Compute the CLIP text embedding associated with the given prompt and negative prompt.

    Args:
        text: The prompt to compute the CLIP text embedding of.
        negative_text: The negative prompt to compute the CLIP text embedding of.
            If not provided, the negative prompt is assumed to be empty (i.e., `""`).
    """

    text = [text] if isinstance(text, str) else text

    if not self.classifier_free_guidance:
        return self.clip_text_encoder(text)

    negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
    assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

    conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
    negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)

    return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
        tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
    )

compute_self_attention_guidance

compute_self_attention_guidance(
    x: Tensor,
    noise: Tensor,
    step: int,
    *,
    clip_text_embedding: Tensor,
    pooled_text_embedding: Tensor,
    time_ids: Tensor,
    **kwargs: Tensor
) -> Tensor

Compute the self-attention guidance.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
noise Tensor

The noise tensor.

required
step int

The step to compute the self-attention guidance at.

required
clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required
pooled_text_embedding Tensor

The pooled CLIP text embedding to compute the self-attention guidance with.

required
time_ids Tensor

The time IDs to compute the self-attention guidance with.

required

Returns:

Type Description
Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_self_attention_guidance(
    self,
    x: Tensor,
    noise: Tensor,
    step: int,
    *,
    clip_text_embedding: Tensor,
    pooled_text_embedding: Tensor,
    time_ids: Tensor,
    **kwargs: Tensor,
) -> Tensor:
    """Compute the self-attention guidance.

    Args:
        x: The input tensor.
        noise: The noise tensor.
        step: The step to compute the self-attention guidance at.
        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
        pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with.
        time_ids: The time IDs to compute the self-attention guidance with.

    Returns:
        The computed self-attention guidance.
    """
    sag = self._find_sag_adapter()
    assert sag is not None

    degraded_latents = sag.compute_degraded_latents(
        solver=self.solver,
        latents=x,
        noise=noise,
        step=step,
        classifier_free_guidance=True,
    )

    negative_text_embedding, _ = clip_text_embedding.chunk(2)
    negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
    timestep = self.solver.timesteps[step].unsqueeze(dim=0)
    time_ids, _ = time_ids.chunk(2)

    self.set_unet_context(
        timestep=timestep,
        clip_text_embedding=negative_text_embedding,
        pooled_text_embedding=negative_pooled_embedding,
        time_ids=time_ids,
    )
    if "ip_adapter" in self.unet.provider.contexts:
        # this implementation is a bit hacky, it should be refactored in the future
        ip_adapter_context = self.unet.use_context("ip_adapter")
        image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
        ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
        degraded_noise = self.unet(degraded_latents)
        ip_adapter_context["clip_image_embedding"] = image_embedding_copy
    else:
        degraded_noise = self.unet(degraded_latents)

    return sag.scale * (noise - degraded_noise)

has_self_attention_guidance

has_self_attention_guidance() -> bool

Whether the model has self-attention guidance or not.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def has_self_attention_guidance(self) -> bool:
    """Whether the model has self-attention guidance or not."""
    return self._find_sag_adapter() is not None

set_self_attention_guidance

set_self_attention_guidance(
    enable: bool, scale: float = 1.0
) -> None

Sets the self-attention guidance.

See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.

Parameters:

Name Type Description Default
enable bool

Whether to enable self-attention guidance or not.

required
scale float

The scale to use.

1.0
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
    """Sets the self-attention guidance.

    See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
    for more details.

    Args:
        enable: Whether to enable self-attention guidance or not.
        scale: The scale to use.
    """
    if enable:
        if sag := self._find_sag_adapter():
            sag.scale = scale
        else:
            SDXLSAGAdapter(target=self.unet, scale=scale).inject()
    else:
        if sag := self._find_sag_adapter():
            sag.eject()

set_unet_context

set_unet_context(
    *,
    timestep: Tensor,
    clip_text_embedding: Tensor,
    pooled_text_embedding: Tensor,
    time_ids: Tensor,
    **_: Tensor
) -> None

Set the various context parameters required by the U-Net model.

Parameters:

Name Type Description Default
timestep Tensor

The timestep to set.

required
clip_text_embedding Tensor

The CLIP text embedding to set.

required
pooled_text_embedding Tensor

The pooled CLIP text embedding to set.

required
time_ids Tensor

The time IDs to set.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_unet_context(
    self,
    *,
    timestep: Tensor,
    clip_text_embedding: Tensor,
    pooled_text_embedding: Tensor,
    time_ids: Tensor,
    **_: Tensor,
) -> None:
    """Set the various context parameters required by the U-Net model.

    Args:
        timestep: The timestep to set.
        clip_text_embedding: The CLIP text embedding to set.
        pooled_text_embedding: The pooled CLIP text embedding to set.
        time_ids: The time IDs to set.
    """
    self.unet.set_timestep(timestep=timestep)
    self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
    self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
    self.unet.set_time_ids(time_ids=time_ids)

add_lcm_lora

add_lcm_lora(
    manager: SDLoraManager,
    tensors: dict[str, Tensor],
    name: str = "lcm",
    scale: float = 8.0 / 64.0,
    check_validity: bool = True,
) -> None

Add a LCM-LoRA or a LoRA with similar structure such as SDXL-Lightning to SDXLUNet.

This is a complex LoRA so SDLoraManager.add_loras() is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of auto_attach_loras.

LCM-LoRA can be used with or without CFG in SD. If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.

Parameters:

Name Type Description Default
manager SDLoraManager

A SDLoraManager for SDXL.

required
tensors dict[str, Tensor]

The state_dict of the LoRA.

required
name str

The name of the LoRA.

'lcm'
scale float

The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).

8.0 / 64.0
check_validity bool

Perform additional checks, raise an exception if they fail.

True
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py
def add_lcm_lora(
    manager: SDLoraManager,
    tensors: dict[str, torch.Tensor],
    name: str = "lcm",
    scale: float = 8.0 / 64.0,
    check_validity: bool = True,
) -> None:
    """Add a [LCM-LoRA](https://arxiv.org/abs/2311.05556) or a LoRA with similar structure
    such as [SDXL-Lightning](https://arxiv.org/abs/2402.13929) to SDXLUNet.

    This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]
    is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of
    [auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].

    LCM-LoRA can be used with or without CFG in SD.
    If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.

    Args:
        manager: A SDLoraManager for SDXL.
        tensors: The `state_dict` of the LoRA.
        name: The name of the LoRA.
        scale: The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).
        check_validity: Perform additional checks, raise an exception if they fail.
    """

    assert isinstance(manager.target, StableDiffusion_XL)
    unet = manager.target.unet

    loras = Lora.from_dict(name, {k: v.to(unet.device, unet.dtype) for k, v in tensors.items()})
    assert all(k.startswith("lora_unet_") for k in loras.keys())
    loras = {k: loras[k] for k in sorted(loras.keys(), key=SDLoraManager.sort_keys)}

    debug_map: list[tuple[str, str]] | None = [] if check_validity else None

    # Projections are in `SDXLCrossAttention` but not in `CrossAttentionBlock`.
    loras_projs = {k: v for k, v in loras.items() if k.endswith("proj_in") or k.endswith("proj_out")}
    auto_attach_loras(
        loras_projs,
        unet,
        exclude=["CrossAttentionBlock"],
        include=["SDXLCrossAttention"],
        debug_map=debug_map,
    )

    manager.add_loras_to_unet(
        {k: v for k, v in loras.items() if k not in loras_projs},
        debug_map=debug_map,
    )

    if debug_map is not None:
        _check_validity(debug_map)

    # LoRAs are finally injected, set the scale with the manager.
    manager.set_scale(name, scale)

ICLight

ICLight(
    patch_weights: dict[str, Tensor],
    unet: SD1UNet,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: StableDiffusion_1

IC-Light is a Stable Diffusion model that can be used to relight a reference image.

At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.

Example
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.clip import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
no_grad().__enter__()
manual_seed(42)

sd = ICLight(
    patch_weights=load_from_safetensors(
        path=hf_hub_download(
            repo_id="refiners/ic_light.sd1_5.fc",
            filename="model.safetensors",
        ),
        device=device,
    ),
    unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
        tensors_path=hf_hub_download(
            repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
            filename="model.safetensors",
        )
    ),
    clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
        tensors_path=hf_hub_download(
            repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
            filename="model.safetensors",
        )
    ),
    lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
        tensors_path=hf_hub_download(
            repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
            filename="model.safetensors",
        )
    ),
    device=device,
    dtype=dtype,
)

prompt = "soft lighting, high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)

image = Image.open("reference-image.png").resize((512, 512))
sd.set_ic_light_condition(image)

x = torch.randn(
    size=(1, 4, 64, 64),
    device=device,
    dtype=dtype,
)

for step in sd.steps:
    x = sd(
        x=x,
        step=step,
        clip_text_embedding=clip_text_embedding,
        condition_scale=1.5,
    )
predicted_image = sd.lda.latents_to_image(x)

predicted_image.save("ic-light-output.png")
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def __init__(
    self,
    patch_weights: dict[str, torch.Tensor],
    unet: SD1UNet,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: torch.device | str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> None:
    super().__init__(
        unet=unet,
        lda=lda,
        clip_text_encoder=clip_text_encoder,
        solver=solver,
        device=device,
        dtype=dtype,
    )
    self._extend_conv_in()
    self._apply_patch(weights=patch_weights)

compute_gray_composite staticmethod

compute_gray_composite(image: Image, mask: Image) -> Image

Compute a grayscale composite of an image and a mask.

IC-Light will recreate the image

Parameters:

Name Type Description Default
image Image

The image to composite.

required
mask Image

The mask to use for the composite.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
@staticmethod
def compute_gray_composite(
    image: Image.Image,
    mask: Image.Image,
) -> Image.Image:
    """Compute a grayscale composite of an image and a mask.

    IC-Light will recreate the image

    Args:
        image: The image to composite.
        mask: The mask to use for the composite.
    """
    assert mask.mode == "L", "Mask must be a grayscale image"
    assert image.size == mask.size, "Image and mask must have the same size"
    background = Image.new("RGB", image.size, (127, 127, 127))
    return Image.composite(image, background, mask)

set_ic_light_condition

set_ic_light_condition(
    image: Image, mask: Image | None = None
) -> None

Set the IC light condition.

Parameters:

Name Type Description Default
image Image

The reference image.

required
mask Image | None

The mask to use for the reference image.

None

If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def set_ic_light_condition(
    self,
    image: Image.Image,
    mask: Image.Image | None = None,
) -> None:
    """Set the IC light condition.

    Args:
        image: The reference image.
        mask: The mask to use for the reference image.

    If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
    the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
    """
    if mask is not None:
        image = self.compute_gray_composite(image=image, mask=mask)
    latents = self.lda.image_to_latents(image)
    self._ic_light_condition = latents

SD1Autoencoder

SD1Autoencoder(
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: LatentDiffusionAutoencoder

Stable Diffusion 1.5 autoencoder model.

Attributes:

Name Type Description
encoder_scale float

The encoder scale to use.

Parameters:

Name Type Description Default
device device | str | None

The PyTorch device to use.

None
dtype dtype | None

The PyTorch data type to use.

None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(
    self,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initializes the model.

    Args:
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    super().__init__(
        Encoder(device=device, dtype=dtype),
        Decoder(device=device, dtype=dtype),
    )
    self._tile_size = None
    self._blending = None

SD1ELLAAdapter

SD1ELLAAdapter(
    target: SD1UNet,
    weights: dict[str, Tensor] | None = None,
)

Bases: ELLAAdapter[SD1UNet]

ELLA adapter for Stable Diffusion 1.5.

Parameters:

Name Type Description Default
target SD1UNet

The target model to adapt.

required
weights dict[str, Tensor] | None

The weights of the ELLA adapter (see scripts/conversion/convert_ella_adapter.py).

None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py
def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None:
    """Initialize the adapter.

    Args:
        target: The target model to adapt.
        weights: The weights of the ELLA adapter (see `scripts/conversion/convert_ella_adapter.py`).
    """
    latents_encoder = ELLA(
        time_channel=320,
        timestep_embedding_dim=768,
        width=768,
        num_layers=6,
        num_heads=8,
        num_latents=64,
        input_dim=2048,
        device=target.device,
        dtype=target.dtype,
    )
    super().__init__(target=target, latents_encoder=latents_encoder, weights=weights)

SD1UNet

SD1UNet(
    in_channels: int,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Stable Diffusion 1.5 U-Net.

See [arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models for more details.

Parameters:

Name Type Description Default
in_channels int

The number of input channels.

required
device device | str | None

The PyTorch device to use for computation.

None
dtype dtype | None

The PyTorch dtype to use for computation.

None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def __init__(
    self,
    in_channels: int,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the U-Net.

    Args:
        in_channels: The number of input channels.
        device: The PyTorch device to use for computation.
        dtype: The PyTorch dtype to use for computation.
    """
    self.in_channels = in_channels
    super().__init__(
        TimestepEncoder(device=device, dtype=dtype),
        DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
        fl.Sum(
            fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
            MiddleBlock(device=device, dtype=dtype),
        ),
        UpBlocks(device=device, dtype=dtype),
        fl.Chain(
            fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
            fl.SiLU(),
            fl.Conv2d(
                in_channels=320,
                out_channels=4,
                kernel_size=3,
                stride=1,
                padding=1,
                device=device,
                dtype=dtype,
            ),
        ),
    )
    for residual_block in self.layers(ResidualBlock):
        chain = residual_block.layer("Chain", fl.Chain)
        RangeAdapter2d(
            target=chain.layer("Conv2d_1", fl.Conv2d),
            channels=residual_block.out_channels,
            embedding_dim=1280,
            context_key="timestep_embedding",
            device=device,
            dtype=dtype,
        ).inject(chain)
    for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
        block.append(ResidualAccumulator(n))
    for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
        block.insert(0, ResidualConcatenator(-n - 2))

set_clip_text_embedding

set_clip_text_embedding(
    clip_text_embedding: Tensor,
) -> None

Set the CLIP text embedding.

Note

This context is required by the CLIPLCrossAttention blocks.

Parameters:

Name Type Description Default
clip_text_embedding Tensor

The CLIP text embedding.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
    """Set the CLIP text embedding.

    Note:
        This context is required by the `CLIPLCrossAttention` blocks.

    Args:
        clip_text_embedding: The CLIP text embedding.
    """
    self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})

set_timestep

set_timestep(timestep: Tensor) -> None

Set the timestep.

Note

This context is required by TimestepEncoder.

Parameters:

Name Type Description Default
timestep Tensor

The timestep.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_timestep(self, timestep: Tensor) -> None:
    """Set the timestep.

    Note:
        This context is required by `TimestepEncoder`.

    Args:
        timestep: The timestep.
    """
    self.set_context("diffusion", {"timestep": timestep})

StableDiffusion_1

StableDiffusion_1(
    unet: SD1UNet | None = None,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: LatentDiffusionModel

Stable Diffusion 1.5 model.

Attributes:

Name Type Description
unet SD1UNet

The U-Net model.

clip_text_encoder CLIPTextEncoderL

The text encoder.

lda SD1Autoencoder

The image autoencoder.

Example:

import torch

from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1

# Load SD
sd15 = StableDiffusion_1(device="cuda", dtype=torch.float16)

sd15.clip_text_encoder.load_from_safetensors("sd1_5.text_encoder.safetensors")
sd15.unet.load_from_safetensors("sd1_5.unet.safetensors")
sd15.lda.load_from_safetensors("sd1_5.autoencoder.safetensors")

# Hyperparameters
prompt = "a cute cat, best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
seed = 42

sd15.set_inference_steps(50)

with no_grad():  # Disable gradient calculation for memory-efficient inference
    clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
    manual_seed(seed)

    x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)

    # Diffusion process
    for step in sd15.steps:
        x = sd15(x, step=step, clip_text_embedding=clip_text_embedding)

    predicted_image = sd15.lda.decode_latents(x)
    predicted_image.save("output.png")

Parameters:

Name Type Description Default
unet SD1UNet | None

The SD1UNet U-Net model to use.

None
lda SD1Autoencoder | None

The SD1Autoencoder image autoencoder to use.

None
clip_text_encoder CLIPTextEncoderL | None

The CLIPTextEncoderL text encoder to use.

None
solver Solver | None

The solver to use.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(
    self,
    unet: SD1UNet | None = None,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initializes the model.

    Args:
        unet: The SD1UNet U-Net model to use.
        lda: The SD1Autoencoder image autoencoder to use.
        clip_text_encoder: The CLIPTextEncoderL text encoder to use.
        solver: The solver to use.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    unet = unet or SD1UNet(in_channels=4)
    lda = lda or SD1Autoencoder()
    clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
    solver = solver or DPMSolver(num_inference_steps=30)

    super().__init__(
        unet=unet,
        lda=lda,
        clip_text_encoder=clip_text_encoder,
        solver=solver,
        device=device,
        dtype=dtype,
    )

compute_clip_text_embedding

compute_clip_text_embedding(
    text: str | list[str],
    negative_text: str | list[str] = "",
) -> Tensor

Compute the CLIP text embedding associated with the given prompt and negative prompt.

Parameters:

Name Type Description Default
text str | list[str]

The prompt to compute the CLIP text embedding of.

required
negative_text str | list[str]

The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., "").

''
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
    """Compute the CLIP text embedding associated with the given prompt and negative prompt.

    Args:
        text: The prompt to compute the CLIP text embedding of.
        negative_text: The negative prompt to compute the CLIP text embedding of.
            If not provided, the negative prompt is assumed to be empty (i.e., `""`).
    """
    text = [text] if isinstance(text, str) else text

    if not self.classifier_free_guidance:
        return self.clip_text_encoder(text)

    negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
    assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"

    conditional_embedding = self.clip_text_encoder(text)
    negative_embedding = self.clip_text_encoder(negative_text)

    return torch.cat((negative_embedding, conditional_embedding))

compute_self_attention_guidance

compute_self_attention_guidance(
    x: Tensor,
    noise: Tensor,
    step: int,
    *,
    clip_text_embedding: Tensor,
    **kwargs: Tensor
) -> Tensor

Compute the self-attention guidance.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
noise Tensor

The noise tensor.

required
step int

The step to compute the self-attention guidance at.

required
clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required

Returns:

Type Description
Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(
    self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
    """Compute the self-attention guidance.

    Args:
        x: The input tensor.
        noise: The noise tensor.
        step: The step to compute the self-attention guidance at.
        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.

    Returns:
        The computed self-attention guidance.
    """
    sag = self._find_sag_adapter()
    assert sag is not None

    degraded_latents = sag.compute_degraded_latents(
        solver=self.solver,
        latents=x,
        noise=noise,
        step=step,
        classifier_free_guidance=True,
    )

    timestep = self.solver.timesteps[step].unsqueeze(dim=0)
    negative_embedding, _ = clip_text_embedding.chunk(2)
    self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
    if "ip_adapter" in self.unet.provider.contexts:
        # this implementation is a bit hacky, it should be refactored in the future
        ip_adapter_context = self.unet.use_context("ip_adapter")
        image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
        ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
        degraded_noise = self.unet(degraded_latents)
        ip_adapter_context["clip_image_embedding"] = image_embedding_copy
    else:
        degraded_noise = self.unet(degraded_latents)

    return sag.scale * (noise - degraded_noise)

has_self_attention_guidance

has_self_attention_guidance() -> bool

Whether the model has self-attention guidance or not.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def has_self_attention_guidance(self) -> bool:
    """Whether the model has self-attention guidance or not."""
    return self._find_sag_adapter() is not None

set_self_attention_guidance

set_self_attention_guidance(
    enable: bool, scale: float = 1.0
) -> None

Set whether to enable self-attention guidance.

See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.

Parameters:

Name Type Description Default
enable bool

Whether to enable self-attention guidance.

required
scale float

The scale to use.

1.0
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
    """Set whether to enable self-attention guidance.

    See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
    for more details.

    Args:
        enable: Whether to enable self-attention guidance.
        scale: The scale to use.
    """
    if enable:
        if sag := self._find_sag_adapter():
            sag.scale = scale
        else:
            SD1SAGAdapter(target=self.unet, scale=scale).inject()
    else:
        if sag := self._find_sag_adapter():
            sag.eject()

set_unet_context

set_unet_context(
    *,
    timestep: Tensor,
    clip_text_embedding: Tensor,
    **_: Tensor
) -> None

Set the various context parameters required by the U-Net model.

Parameters:

Name Type Description Default
timestep Tensor

The timestep tensor to use.

required
clip_text_embedding Tensor

The CLIP text embedding tensor to use.

required
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
    """Set the various context parameters required by the U-Net model.

    Args:
        timestep: The timestep tensor to use.
        clip_text_embedding: The CLIP text embedding tensor to use.
    """
    self.unet.set_timestep(timestep=timestep)
    self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

StableDiffusion_1_Inpainting

StableDiffusion_1_Inpainting(
    unet: SD1UNet | None = None,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: StableDiffusion_1

Stable Diffusion 1.5 inpainting model.

Attributes:

Name Type Description
unet

The U-Net model.

clip_text_encoder

The text encoder.

lda

The image autoencoder.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(
    self,
    unet: SD1UNet | None = None,
    lda: SD1Autoencoder | None = None,
    clip_text_encoder: CLIPTextEncoderL | None = None,
    solver: Solver | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    self.mask_latents: Tensor | None = None
    self.target_image_latents: Tensor | None = None
    unet = unet or SD1UNet(in_channels=9)
    super().__init__(
        unet=unet,
        lda=lda,
        clip_text_encoder=clip_text_encoder,
        solver=solver,
        device=device,
        dtype=dtype,
    )

compute_self_attention_guidance

compute_self_attention_guidance(
    x: Tensor,
    noise: Tensor,
    step: int,
    *,
    clip_text_embedding: Tensor,
    **kwargs: Tensor
) -> Tensor

Compute the self-attention guidance.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
noise Tensor

The noise tensor.

required
step int

The step to compute the self-attention guidance at.

required
clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required

Returns:

Type Description
Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(
    self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
    """Compute the self-attention guidance.

    Args:
        x: The input tensor.
        noise: The noise tensor.
        step: The step to compute the self-attention guidance at.
        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.

    Returns:
        The computed self-attention guidance.
    """
    sag = self._find_sag_adapter()
    assert sag is not None
    assert self.mask_latents is not None
    assert self.target_image_latents is not None

    degraded_latents = sag.compute_degraded_latents(
        solver=self.solver,
        latents=x,
        noise=noise,
        step=step,
        classifier_free_guidance=True,
    )
    x = torch.cat(
        tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
        dim=1,
    )

    timestep = self.solver.timesteps[step].unsqueeze(dim=0)
    negative_embedding, _ = clip_text_embedding.chunk(2)
    self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)

    if "ip_adapter" in self.unet.provider.contexts:
        # this implementation is a bit hacky, it should be refactored in the future
        ip_adapter_context = self.unet.use_context("ip_adapter")
        image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
        ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
        degraded_noise = self.unet(x)
        ip_adapter_context["clip_image_embedding"] = image_embedding_copy
    else:
        degraded_noise = self.unet(x)

    return sag.scale * (noise - degraded_noise)

set_inpainting_conditions

set_inpainting_conditions(
    target_image: Image,
    mask: Image,
    latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]

Set the inpainting conditions.

Parameters:

Name Type Description Default
target_image Image

The target image to inpaint.

required
mask Image

The mask to use for inpainting.

required
latents_size tuple[int, int]

The size of the latents to use.

(64, 64)

Returns:

Type Description
tuple[Tensor, Tensor]

The mask latents and the target image latents.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_inpainting_conditions(
    self,
    target_image: Image.Image,
    mask: Image.Image,
    latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
    """Set the inpainting conditions.

    Args:
        target_image: The target image to inpaint.
        mask: The mask to use for inpainting.
        latents_size: The size of the latents to use.

    Returns:
        The mask latents and the target image latents.
    """
    target_image = target_image.convert(mode="RGB")
    mask = mask.convert(mode="L")

    mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
    mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
    self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))

    init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
    masked_init_image = init_image_tensor * (1 - mask_tensor)
    self.target_image_latents = self.lda.encode(x=masked_init_image)

    return self.mask_latents, self.target_image_latents

DDIM

DDIM(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Solver

Denoising Diffusion Implicit Model (DDIM) solver.

See [arXiv:2010.02502] Denoising Diffusion Implicit Models for more details.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/ddim.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: Device | str = "cpu",
    dtype: Dtype = torch.float32,
) -> None:
    """Initializes a new DDIM solver.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
        raise NotImplementedError
    if params and params.sde_variance != 0.0:
        raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")

    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        params=params,
        device=device,
        dtype=dtype,
    )

DDPM

DDPM(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: device | str = "cpu",
)

Bases: Solver

Denoising Diffusion Probabilistic Model (DDPM) solver.

Warning

Only used for training Latent Diffusion models. Cannot be called.

See [arXiv:2006.11239] Denoising Diffusion Probabilistic Models for more details.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
device device | str

The PyTorch device to use.

'cpu'
Source code in src/refiners/foundationals/latent_diffusion/solvers/ddpm.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: Device | str = "cpu",
) -> None:
    """Initializes a new DDPM solver.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        device: The PyTorch device to use.
    """

    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
        raise NotImplementedError

    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        params=params,
        device=device,
    )

DPMSolver

DPMSolver(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    last_step_first_order: bool = False,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Solver

Diffusion probabilistic models (DPMs) solver.

See [arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models for more details.

Note

Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts when used with SDXL and few steps. This parameter is a way to mitigate that effect by using a first-order (Euler) update instead of a second-order update for the last step of the diffusion.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
last_step_first_order bool

Use a first-order update for the last step.

False
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    last_step_first_order: bool = False,
    device: torch.device | str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> None:
    """Initializes a new DPM solver.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        last_step_first_order: Use a first-order update for the last step.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
        raise NotImplementedError
    if params and params.sde_variance not in (0.0, 1.0):
        raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")

    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        params=params,
        device=device,
        dtype=dtype,
    )
    self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
    self.last_step_first_order = last_step_first_order
    sigmas = self.noise_std / self.cumulative_scale_factors
    self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
    sigma_min = sigmas[0:1]  # corresponds to `final_sigmas_type="sigma_min" in diffusers`
    self.sigmas = torch.cat([self.sigmas, sigma_min])
    self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
        self.sigmas
    )
    self.timesteps = self._timesteps_from_sigmas(sigmas)

dpm_solver_first_order_update

dpm_solver_first_order_update(
    x: Tensor,
    noise: Tensor,
    step: int,
    sde_noise: Tensor | None = None,
) -> Tensor

Applies a first-order backward Euler update to the input data x.

Parameters:

Name Type Description Default
x Tensor

The input data.

required
noise Tensor

The predicted noise.

required
step int

The current step.

required

Returns:

Type Description
Tensor

The denoised version of the input data x.

Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def dpm_solver_first_order_update(
    self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
    """Applies a first-order backward Euler update to the input data `x`.

    Args:
        x: The input data.
        noise: The predicted noise.
        step: The current step.

    Returns:
        The denoised version of the input data `x`.
    """
    current_ratio = self.signal_to_noise_ratios[step]
    next_ratio = self.signal_to_noise_ratios[step + 1]

    next_scale_factor = self.cumulative_scale_factors[step + 1]

    next_noise_std = self.noise_std[step + 1]
    current_noise_std = self.noise_std[step]

    ratio_delta = current_ratio - next_ratio

    if sde_noise is None:
        return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise

    factor = 1.0 - torch.exp(2.0 * ratio_delta)
    return (
        (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
        + next_scale_factor * factor * noise
        + next_noise_std * safe_sqrt(factor) * sde_noise
    )

multistep_dpm_solver_second_order_update

multistep_dpm_solver_second_order_update(
    x: Tensor, step: int, sde_noise: Tensor | None = None
) -> Tensor

Applies a second-order backward Euler update to the input data x.

Parameters:

Name Type Description Default
x Tensor

The input data.

required
step int

The current step.

required

Returns:

Type Description
Tensor

The denoised version of the input data x.

Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def multistep_dpm_solver_second_order_update(
    self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
    """Applies a second-order backward Euler update to the input data `x`.

    Args:
        x: The input data.
        step: The current step.

    Returns:
        The denoised version of the input data `x`.
    """
    current_data_estimation = self.estimated_data[-1]
    previous_data_estimation = self.estimated_data[-2]

    next_ratio = self.signal_to_noise_ratios[step + 1]
    current_ratio = self.signal_to_noise_ratios[step]
    previous_ratio = self.signal_to_noise_ratios[step - 1]

    next_scale_factor = self.cumulative_scale_factors[step + 1]
    next_noise_std = self.noise_std[step + 1]
    current_noise_std = self.noise_std[step]

    estimation_delta = (current_data_estimation - previous_data_estimation) / (
        (current_ratio - previous_ratio) / (next_ratio - current_ratio)
    )
    ratio_delta = current_ratio - next_ratio

    if sde_noise is None:
        factor = 1.0 - torch.exp(ratio_delta)
        return (
            (next_noise_std / current_noise_std) * x
            + next_scale_factor * factor * current_data_estimation
            + 0.5 * next_scale_factor * factor * estimation_delta
        )

    factor = 1.0 - torch.exp(2.0 * ratio_delta)
    return (
        (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
        + next_scale_factor * factor * current_data_estimation
        + 0.5 * next_scale_factor * factor * estimation_delta
        + next_noise_std * safe_sqrt(factor) * sde_noise
    )

rebuild

rebuild(
    num_inference_steps: int | None,
    first_inference_step: int | None = None,
) -> DPMSolver

Rebuilds the solver with new parameters.

Parameters:

Name Type Description Default
num_inference_steps int | None

The number of inference steps.

required
first_inference_step int | None

The first inference step.

None
Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def rebuild(
    self: "DPMSolver",
    num_inference_steps: int | None,
    first_inference_step: int | None = None,
) -> "DPMSolver":
    """Rebuilds the solver with new parameters.

    Args:
        num_inference_steps: The number of inference steps.
        first_inference_step: The first inference step.
    """
    r = super().rebuild(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
    )
    r.last_step_first_order = self.last_step_first_order
    return r

Euler

Euler(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Solver

Euler solver.

See [arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models for more details.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/euler.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: Device | str = "cpu",
    dtype: Dtype = torch.float32,
):
    """Initializes a new Euler solver.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """
    if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
        raise NotImplementedError
    if params and params.sde_variance != 0.0:
        raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")

    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        params=params,
        device=device,
        dtype=dtype,
    )
    self.sigmas = self._generate_sigmas()

init_noise_sigma property

init_noise_sigma: Tensor

The initial noise sigma.

scale_model_input

scale_model_input(x: Tensor, step: int) -> Tensor

Scales the model input according to the current step.

Parameters:

Name Type Description Default
x Tensor

The model input.

required
step int

The current step. This method is called with step=-1 in init_latents.

required

Returns:

Type Description
Tensor

The scaled model input.

Source code in src/refiners/foundationals/latent_diffusion/solvers/euler.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
    """Scales the model input according to the current step.

    Args:
        x: The model input.
        step: The current step. This method is called with `step=-1` in `init_latents`.

    Returns:
        The scaled model input.
    """

    if step == -1:
        return x * self.init_noise_sigma

    sigma = self.sigmas[step]
    return x / ((sigma**2 + 1) ** 0.5)

FrankenSolver

FrankenSolver(
    get_diffusers_scheduler: Callable[[], SchedulerLike],
    num_inference_steps: int,
    first_inference_step: int = 0,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
    **kwargs: Any
)

Bases: Solver

Lets you use Diffusers Schedulers as Refiners Solvers.

For instance
from diffusers import EulerDiscreteScheduler
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver

scheduler = EulerDiscreteScheduler(...)
solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
Source code in src/refiners/foundationals/latent_diffusion/solvers/franken.py
def __init__(
    self,
    get_diffusers_scheduler: Callable[[], SchedulerLike],
    num_inference_steps: int,
    first_inference_step: int = 0,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
    **kwargs: Any,  # for typing, ignored
) -> None:
    self.get_diffusers_scheduler = get_diffusers_scheduler
    self.diffusers_scheduler = self.get_diffusers_scheduler()
    self.diffusers_scheduler.set_timesteps(num_inference_steps)
    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        device=device,
        dtype=dtype,
    )

LCMSolver

LCMSolver(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    num_orig_steps: int = 50,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Solver

Latent Consistency Model solver.

This solver is designed for use either with a specific base model or a specific LoRA.

See [arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference for details.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
num_orig_steps int

The number of inference steps of the emulated DPM solver.

50
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/lcm.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    num_orig_steps: int = 50,
    device: torch.device | str = "cpu",
    dtype: torch.dtype = torch.float32,
):
    """Initializes a new LCM solver.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        num_orig_steps: The number of inference steps of the emulated DPM solver.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.
    """

    assert (
        num_orig_steps >= num_inference_steps
    ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"

    params = self.resolve_params(params)
    if params.model_prediction_type != ModelPredictionType.NOISE:
        raise NotImplementedError

    self._dpm = [
        DPMSolver(
            num_inference_steps=num_orig_steps,
            params=SolverParams(
                num_train_timesteps=params.num_train_timesteps,
                timesteps_spacing=params.timesteps_spacing,
            ),
            device=device,
            dtype=dtype,
        )
    ]
    super().__init__(
        num_inference_steps=num_inference_steps,
        first_inference_step=first_inference_step,
        params=params,
        device=device,
        dtype=dtype,
    )

ModelPredictionType

Bases: str, Enum

An enumeration of possible outputs of the model.

Attributes:

Name Type Description
NOISE

The model predicts the noise (epsilon).

SAMPLE

The model predicts the denoised sample (x0).

NoiseSchedule

Bases: str, Enum

An enumeration of schedules used to sample the noise.

Attributes:

Name Type Description
UNIFORM

A uniform noise schedule.

QUADRATIC

A quadratic noise schedule. Corresponds to "Stable Diffusion" in [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 1.

KARRAS

Solver

Solver(
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Module, ABC

The base class for creating a diffusion model solver.

Solvers create a sequence of noise and scaling factors used in the diffusion process, which gradually transforms the original data distribution into a Gaussian one.

This process is described using several parameters such as initial and final diffusion rates, and is encapsulated into a __call__ method that applies a step of the diffusion process.

Attributes:

Name Type Description
params ResolvedSolverParams

The common parameters for solvers. See SolverParams.

num_inference_steps

The number of inference steps to perform.

first_inference_step

The step to start the inference process from.

scale_factors

The scale factors used to denoise the input. These are called "betas" in other implementations, and 1 - scale_factors is called "alphas".

cumulative_scale_factors

The cumulative scale factors used to denoise the input. These are called "alpha_t" in other implementations.

noise_std

The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other implementations.

signal_to_noise_ratios

The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other implementations.

Parameters:

Name Type Description Default
num_inference_steps int

The number of inference steps to perform.

required
first_inference_step int

The first inference step to perform.

0
params BaseSolverParams | None

The common parameters for solvers.

None
device device | str

The PyTorch device to use for the solver's tensors.

'cpu'
dtype dtype

The PyTorch data type to use for the solver's tensors.

float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def __init__(
    self,
    num_inference_steps: int,
    first_inference_step: int = 0,
    params: BaseSolverParams | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initializes a new `Solver` instance.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.
        params: The common parameters for solvers.
        device: The PyTorch device to use for the solver's tensors.
        dtype: The PyTorch data type to use for the solver's tensors.
    """
    super().__init__()

    self.num_inference_steps = num_inference_steps
    self.first_inference_step = first_inference_step
    self.params = self.resolve_params(params)

    self.scale_factors = self.sample_noise_schedule()
    self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))
    self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))
    self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)
    self.timesteps = self._generate_timesteps()

    self.to(device=device, dtype=dtype)

all_steps property

all_steps: list[int]

Return a list of all inference steps.

device property writable

device: device

The PyTorch device used for the solver's tensors.

dtype property writable

dtype: dtype

The PyTorch data type used for the solver's tensors.

inference_steps property

inference_steps: list[int]

Return a list of inference steps to perform.

add_noise

add_noise(
    x: Tensor, noise: Tensor, step: int | list[int]
) -> Tensor

Add noise to the input tensor using the solver's parameters.

Parameters:

Name Type Description Default
x Tensor

The input tensor to add noise to.

required
noise Tensor

The noise tensor to add to the input tensor.

required
step int | list[int]

The current step(s) of the diffusion process.

required

Returns:

Type Description
Tensor

The input tensor with added noise.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def add_noise(
    self,
    x: Tensor,
    noise: Tensor,
    step: int | list[int],
) -> Tensor:
    """Add noise to the input tensor using the solver's parameters.

    Args:
        x: The input tensor to add noise to.
        noise: The noise tensor to add to the input tensor.
        step: The current step(s) of the diffusion process.

    Returns:
        The input tensor with added noise.
    """
    if isinstance(step, list):
        assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
        return torch.stack(
            tensors=[
                self._add_noise(
                    x=x[i],
                    noise=noise[i],
                    step=step[i],
                )
                for i in range(x.shape[0])
            ],
            dim=0,
        )

    return self._add_noise(x=x, noise=noise, step=step)

generate_timesteps staticmethod

generate_timesteps(
    spacing: TimestepSpacing,
    num_inference_steps: int,
    num_train_timesteps: int = 1000,
    offset: int = 0,
) -> Tensor

Generate a tensor of timesteps according to a given spacing.

Parameters:

Name Type Description Default
spacing TimestepSpacing

The spacing to use for the timesteps.

required
num_inference_steps int

The number of inference steps to perform.

required
num_train_timesteps int

The number of timesteps used to train the diffusion process.

1000
offset int

The offset to use for the timesteps.

0
Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
@staticmethod
def generate_timesteps(
    spacing: TimestepSpacing,
    num_inference_steps: int,
    num_train_timesteps: int = 1000,
    offset: int = 0,
) -> Tensor:
    """Generate a tensor of timesteps according to a given spacing.

    Args:
        spacing: The spacing to use for the timesteps.
        num_inference_steps: The number of inference steps to perform.
        num_train_timesteps: The number of timesteps used to train the diffusion process.
        offset: The offset to use for the timesteps.
    """
    max_timestep = num_train_timesteps - 1 + offset
    match spacing:
        case TimestepSpacing.LINSPACE:
            return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)
        case TimestepSpacing.LINSPACE_ROUNDED:
            return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
        case TimestepSpacing.LEADING:
            step_ratio = num_train_timesteps // num_inference_steps
            return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
        case TimestepSpacing.TRAILING:
            step_ratio = num_train_timesteps // num_inference_steps
            max_timestep = num_train_timesteps - 1 + offset
            return torch.arange(max_timestep, offset, -step_ratio)
        case TimestepSpacing.CUSTOM:
            raise RuntimeError("generate_timesteps called with custom spacing")

rebuild

rebuild(
    num_inference_steps: int | None,
    first_inference_step: int | None = None,
) -> T

Rebuild the solver with new parameters.

Parameters:

Name Type Description Default
num_inference_steps int | None

The number of inference steps to perform.

required
first_inference_step int | None

The first inference step to perform.

None

Returns:

Type Description
T

A new solver instance with the specified parameters.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
    """Rebuild the solver with new parameters.

    Args:
        num_inference_steps: The number of inference steps to perform.
        first_inference_step: The first inference step to perform.

    Returns:
        A new solver instance with the specified parameters.
    """
    return self.__class__(
        num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
        first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
        params=dataclasses.replace(self.params),
        device=self.device,
        dtype=self.dtype,
    )

remove_noise

remove_noise(x: Tensor, noise: Tensor, step: int) -> Tensor

Remove noise from the input tensor using the current step of the diffusion process.

Note

See [arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15 and [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance.

Parameters:

Name Type Description Default
x Tensor

The input tensor to remove noise from.

required
noise Tensor

The noise tensor to remove from the input tensor.

required
step int

The current step of the diffusion process.

required

Returns:

Type Description
Tensor

The denoised input tensor.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
    """Remove noise from the input tensor using the current step of the diffusion process.

    Note:
        See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15](https://arxiv.org/abs/2006.11239)
        and [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939).

    Args:
        x: The input tensor to remove noise from.
        noise: The noise tensor to remove from the input tensor.
        step: The current step of the diffusion process.

    Returns:
        The denoised input tensor.
    """
    timestep = self.timesteps[step]
    cumulative_scale_factors = self.cumulative_scale_factors[timestep]
    noise_stds = self.noise_std[timestep]
    denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
    return denoised_x

sample_noise_schedule

sample_noise_schedule() -> Tensor

Sample the noise schedule.

Returns:

Type Description
Tensor

A tensor representing the noise schedule.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_noise_schedule(self) -> Tensor:
    """Sample the noise schedule.

    Returns:
        A tensor representing the noise schedule.
    """
    match self.params.noise_schedule:
        case NoiseSchedule.UNIFORM:
            return 1 - self.sample_power_distribution(1)
        case NoiseSchedule.QUADRATIC:
            return 1 - self.sample_power_distribution(2)
        case NoiseSchedule.KARRAS:
            return 1 - self.sample_power_distribution(7)

sample_power_distribution

sample_power_distribution(power: float = 2) -> Tensor

Sample a power distribution.

Parameters:

Name Type Description Default
power float

The power to use for the distribution.

2

Returns:

Type Description
Tensor

A tensor representing the power distribution between the initial and final diffusion rates of the solver.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
    """Sample a power distribution.

    Args:
        power: The power to use for the distribution.

    Returns:
        A tensor representing the power distribution between the initial and final diffusion rates of the solver.
    """
    return (
        torch.linspace(
            start=self.params.initial_diffusion_rate ** (1 / power),
            end=self.params.final_diffusion_rate ** (1 / power),
            steps=self.params.num_train_timesteps,
        )
        ** power
    )

scale_model_input

scale_model_input(x: Tensor, step: int) -> Tensor

Scale the model's input according to the current timestep.

Note

This method should only be overridden by solvers that need to scale the input according to the current timestep.

By default, this method does not scale the input. (scale=1)

Parameters:

Name Type Description Default
x Tensor

The input tensor to scale.

required
step int

The current step of the diffusion process.

required

Returns:

Type Description
Tensor

The scaled input tensor.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
    """Scale the model's input according to the current timestep.

    Note:
        This method should only be overridden by solvers that
        need to scale the input according to the current timestep.

        By default, this method does not scale the input.
        (scale=1)

    Args:
        x: The input tensor to scale.
        step: The current step of the diffusion process.

    Returns:
        The scaled input tensor.
    """
    return x

to

to(
    device: device | str | None = None,
    dtype: dtype | None = None,
) -> Solver

Move the solver to the specified device and data type.

Parameters:

Name Type Description Default
device device | str | None

The PyTorch device to move the solver to.

None
dtype dtype | None

The PyTorch data type to move the solver to.

None

Returns:

Type Description
Solver

The solver instance, moved to the specified device and data type.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
    """Move the solver to the specified device and data type.

    Args:
        device: The PyTorch device to move the solver to.
        dtype: The PyTorch data type to move the solver to.

    Returns:
        The solver instance, moved to the specified device and data type.
    """
    super().to(device=device, dtype=dtype)
    for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
        match name:
            case "timesteps":
                setattr(self, name, attr.to(device=device))
            case _:
                setattr(self, name, attr.to(device=device, dtype=dtype))
    return self

SolverParams dataclass

SolverParams(
    *,
    num_train_timesteps: int | None = None,
    timesteps_spacing: TimestepSpacing | None = None,
    timesteps_offset: int | None = None,
    initial_diffusion_rate: float | None = None,
    final_diffusion_rate: float | None = None,
    noise_schedule: NoiseSchedule | None = None,
    sigma_schedule: NoiseSchedule | None = None,
    model_prediction_type: (
        ModelPredictionType | None
    ) = None,
    sde_variance: float = 0.0
)

Bases: BaseSolverParams

Common parameters for solvers.

Parameters:

Name Type Description Default
num_train_timesteps int | None

The number of timesteps used to train the diffusion process.

None
timesteps_spacing TimestepSpacing | None

The spacing to use for the timesteps.

None
timesteps_offset int | None

The offset to use for the timesteps.

None
initial_diffusion_rate float | None

The initial diffusion rate used to sample the noise schedule.

None
final_diffusion_rate float | None

The final diffusion rate used to sample the noise schedule.

None
noise_schedule NoiseSchedule | None

The noise schedule used to sample the noise schedule.

None
model_prediction_type ModelPredictionType | None

Defines what the model predicts.

None

TimestepSpacing

Bases: str, Enum

An enumeration of methods to space the timesteps.

See [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 2.

Attributes:

Name Type Description
LINSPACE

Sample N steps with linear interpolation, return a floating-point tensor.

LINSPACE_ROUNDED

Same as LINSPACE but return an integer tensor with rounded timesteps.

LEADING

Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM, with a mitigation for that issue.

TRAILING

Sample N+1 steps, do not include the first timestep.

CUSTOM

Use custom timespacing in solver (override _generate_timesteps, see DPM).

SDLoraManager

SDLoraManager(target: LatentDiffusionModel)

Manage LoRAs for a Stable Diffusion model.

Note

In the context of SDLoraManager, a "LoRA" is a set of "LoRA layers" that can be attached to a target model.

Parameters:

Name Type Description Default
target LatentDiffusionModel

The target model to manage the LoRAs for.

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def __init__(
    self,
    target: LatentDiffusionModel,
) -> None:
    """Initialize the LoRA manager.

    Args:
        target: The target model to manage the LoRAs for.
    """
    self.target = target

clip_text_encoder property

clip_text_encoder: Chain

The Stable Diffusion's text encoder.

lora_adapters property

lora_adapters: list[LoraAdapter]

List of all the LoraAdapters managed by the SDLoraManager.

loras property

loras: list[Lora[Any]]

List of all the LoRA layers managed by the SDLoraManager.

names property

names: list[str]

List of all the LoRA names managed the SDLoraManager

scales property

scales: dict[str, float]

The scales of all the LoRAs managed by the SDLoraManager.

unet property

unet: Chain

The Stable Diffusion's U-Net model.

add_loras

add_loras(
    name: str,
    /,
    tensors: dict[str, Tensor],
    scale: float = 1.0,
    unet_inclusions: list[str] | None = None,
    unet_exclusions: list[str] | None = None,
    unet_preprocess: dict[str, str] | None = None,
    text_encoder_inclusions: list[str] | None = None,
    text_encoder_exclusions: list[str] | None = None,
) -> None

Load a single LoRA from a state_dict.

Warning

This method expects the keys of the state_dict to be in the commonly found formats on CivitAI's hub.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
tensors dict[str, Tensor]

The state_dict of the LoRA to load.

required
scale float

The scale to use for the LoRA.

1.0
unet_inclusions list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered when patching the UNet.

None
unet_exclusions list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered when patching the UNet. If this is None then it defaults to ["TimestepEncoder"].

None
unet_preprocess dict[str, str] | None

A map between parts of state dict keys and layer names. This is used to attach some keys to specific parts of the UNet. You should leave it set to None (it has a default value), otherwise read the source code to understand how it works.

None
text_encoder_inclusions list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered when patching the text encoder.

None
text_encoder_exclusions list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered when patching the text encoder.

None

Raises:

Type Description
AssertionError

If the Manager already has a LoRA with the same name.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras(
    self,
    name: str,
    /,
    tensors: dict[str, Tensor],
    scale: float = 1.0,
    unet_inclusions: list[str] | None = None,
    unet_exclusions: list[str] | None = None,
    unet_preprocess: dict[str, str] | None = None,
    text_encoder_inclusions: list[str] | None = None,
    text_encoder_exclusions: list[str] | None = None,
) -> None:
    """Load a single LoRA from a `state_dict`.

    Warning:
        This method expects the keys of the `state_dict` to be in the commonly found formats on CivitAI's hub.

    Args:
        name: The name of the LoRA.
        tensors: The `state_dict` of the LoRA to load.
        scale: The scale to use for the LoRA.
        unet_inclusions: A list of layer names, only layers with such a layer
            in their ancestors will be considered when patching the UNet.
        unet_exclusions: A list of layer names, layers with such a layer in
            their ancestors will not be considered when patching the UNet.
            If this is `None` then it defaults to `["TimestepEncoder"]`.
        unet_preprocess: A map between parts of state dict keys and layer names.
            This is used to attach some keys to specific parts of the UNet.
            You should leave it set to `None` (it has a default value),
            otherwise read the source code to understand how it works.
        text_encoder_inclusions: A list of layer names, only layers with such a layer
            in their ancestors will be considered when patching the text encoder.
        text_encoder_exclusions: A list of layer names, layers with such a layer in
            their ancestors will not be considered when patching the text encoder.

    Raises:
        AssertionError: If the Manager already has a LoRA with the same name.
    """
    assert name not in self.names, f"LoRA {name} already exists"

    # load LoRA the state_dict
    loras = Lora.from_dict(
        name,
        state_dict={
            key: value.to(
                device=self.target.device,
                dtype=self.target.dtype,
            )
            for key, value in tensors.items()
        },
    )
    # sort all the LoRA's keys using the `sort_keys` method
    loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}

    # if no key contains "unet" or "text", assume all keys are for the unet
    if all("unet" not in key and "text" not in key for key in loras.keys()):
        loras = {f"unet_{key}": value for key, value in loras.items()}

    # attach the LoRA to the target
    self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess)
    self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions)

    # set the scale of the LoRA
    self.set_scale(name, scale)

add_loras_to_text_encoder

add_loras_to_text_encoder(
    loras: dict[str, Lora[Any]],
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    debug_map: list[tuple[str, str]] | None = None,
) -> None

Add multiple LoRAs to the text encoder. See add_loras for details about arguments.

Parameters:

Name Type Description Default
loras dict[str, Lora[Any]]

The dictionary of LoRAs to add to the text encoder. (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_text_encoder(
    self,
    loras: dict[str, Lora[Any]],
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    debug_map: list[tuple[str, str]] | None = None,
) -> None:
    """Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.

    Args:
        loras: The dictionary of LoRAs to add to the text encoder.
            (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
    """
    text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
    auto_attach_loras(
        text_encoder_loras,
        self.clip_text_encoder,
        exclude=exclude,
        include=include,
        debug_map=debug_map,
    )

add_loras_to_unet

add_loras_to_unet(
    loras: dict[str, Lora[Any]],
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    preprocess: dict[str, str] | None = None,
    debug_map: list[tuple[str, str]] | None = None,
) -> None

Add multiple LoRAs to the U-Net. See add_loras for details about arguments.

Parameters:

Name Type Description Default
loras dict[str, Lora[Any]]

The dictionary of LoRAs to add to the U-Net. (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_unet(
    self,
    loras: dict[str, Lora[Any]],
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    preprocess: dict[str, str] | None = None,
    debug_map: list[tuple[str, str]] | None = None,
) -> None:
    """Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.

    Args:
        loras: The dictionary of LoRAs to add to the U-Net.
            (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
    """
    unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}

    if exclude is None:
        exclude = ["TimestepEncoder"]

    if preprocess is None:
        preprocess = {
            "res": "ResidualBlock",
            "downsample": "Downsample",
            "upsample": "Upsample",
        }

    if include is not None:
        preprocess = {k: v for k, v in preprocess.items() if v in include}

    preprocess = {k: v for k, v in preprocess.items() if v not in exclude}

    loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())}
    loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded}

    for exc_k, exc_v in preprocess.items():
        ls = {k: v for k, v in loras_excluded.items() if exc_k in k}
        auto_attach_loras(ls, self.unet, include=[exc_v], exclude=exclude, debug_map=debug_map)

    auto_attach_loras(
        loras_remaining,
        self.unet,
        exclude=[*exclude, *preprocess.values()],
        include=include,
        debug_map=debug_map,
    )

get_loras_by_name

get_loras_by_name(name: str) -> list[Lora[Any]]

Get the LoRA layers with the given name.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
    """Get the LoRA layers with the given name.

    Args:
        name: The name of the LoRA.
    """
    return [lora for lora in self.loras if lora.name == name]

get_scale

get_scale(name: str) -> float

Get the scale of the LoRA with the given name.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required

Returns:

Type Description
float

The scale of the LoRA layers with the given name.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def get_scale(self, name: str, /) -> float:
    """Get the scale of the LoRA with the given name.

    Args:
        name: The name of the LoRA.

    Returns:
        The scale of the LoRA layers with the given name.
    """
    loras = self.get_loras_by_name(name)
    assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
    return loras[0].scale

remove_all

remove_all() -> None

Remove all the LoRAs from the target.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def remove_all(self) -> None:
    """Remove all the LoRAs from the target."""
    for lora_adapter in self.lora_adapters:
        lora_adapter.eject()

remove_loras

remove_loras(*names: str) -> None

Remove multiple LoRAs from the target.

Parameters:

Name Type Description Default
names str

The names of the LoRAs to remove.

()
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def remove_loras(self, *names: str) -> None:
    """Remove multiple LoRAs from the target.

    Args:
        names: The names of the LoRAs to remove.
    """
    for lora_adapter in self.lora_adapters:
        for name in names:
            lora_adapter.remove_lora(name)

        if len(lora_adapter.loras) == 0:
            lora_adapter.eject()

set_scale

set_scale(name: str, scale: float) -> None

Set the scale of the LoRA with the given name.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
scale float

The new scale to set.

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def set_scale(self, name: str, scale: float, /) -> None:
    """Set the scale of the LoRA with the given name.

    Args:
        name: The name of the LoRA.
        scale: The new scale to set.
    """
    self.update_scales({name: scale})

sort_keys staticmethod

sort_keys(key: str) -> tuple[str, int]

Compute the score of a key, relatively to its suffix.

When used by sorted, the keys will only be sorted "at the suffix level". The idea is that sometimes closely related keys in the state dict are not in the same order as the one we expect, for instance q -> k -> v or in -> out. This attempts to fix that issue, not cases where distant layers are called in a different order.

Parameters:

Name Type Description Default
key str

The key to sort.

required

Returns:

Type Description
str

The padded prefix of the key.

int

A score depending on the key's suffix.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
    """Compute the score of a key, relatively to its suffix.

    When used by [`sorted`][sorted], the keys will only be sorted "at the suffix level".
    The idea is that sometimes closely related keys in the state dict are not in the
    same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This
    attempts to fix that issue, not cases where distant layers are called in a different
    order.

    Args:
        key: The key to sort.

    Returns:
        The padded prefix of the key.
        A score depending on the key's suffix.
    """

    # this dict might not be exhaustive
    suffix_scores = {"q": 1, "k": 2, "v": 3, "in": 3, "out": 4, "out0": 4, "out_0": 4}
    patterns = ["_{}", "_{}_lora"]

    # apply patterns to the keys of suffix_scores
    key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns}

    # get the suffix and score for `key` (default: no suffix, highest score = 5)
    (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), ("", 5))

    padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))
    return (padded_key_prefix, score)

update_scales

update_scales(scales: dict[str, float]) -> None

Update the scales of multiple LoRAs.

Parameters:

Name Type Description Default
scales dict[str, float]

The scales to update. (keys are the names of the LoRAs, values are the new scales to set)

required
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def update_scales(self, scales: dict[str, float], /) -> None:
    """Update the scales of multiple LoRAs.

    Args:
        scales: The scales to update.
            (keys are the names of the LoRAs, values are the new scales to set)
    """
    assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
    for name, scale in scales.items():
        for lora in self.get_loras_by_name(name):
            lora.scale = scale

IPAdapter

IPAdapter(
    target: T,
    clip_image_encoder: CLIPImageEncoderH,
    image_proj: Module,
    scale: float = 1.0,
    fine_grained: bool = False,
    weights: dict[str, Tensor] | None = None,
)

Bases: Generic[T], Chain, Adapter[T]

Image Prompt adapter for a Stable Diffusion U-Net model.

See [arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models for more details.

Parameters:

Name Type Description Default
target T

The target model to adapt.

required
clip_image_encoder CLIPImageEncoderH

The CLIP image encoder to use.

required
image_proj Module

The image projection to use.

required
scale float

The scale to use for the image prompt.

1.0
fine_grained bool

Whether to use fine-grained image prompt.

False
weights dict[str, Tensor] | None

The weights of the IPAdapter.

None
Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def __init__(
    self,
    target: T,
    clip_image_encoder: CLIPImageEncoderH,
    image_proj: fl.Module,
    scale: float = 1.0,
    fine_grained: bool = False,
    weights: dict[str, Tensor] | None = None,
) -> None:
    """Initialize the adapter.

    Args:
        target: The target model to adapt.
        clip_image_encoder: The CLIP image encoder to use.
        image_proj: The image projection to use.
        scale: The scale to use for the image prompt.
        fine_grained: Whether to use fine-grained image prompt.
        weights: The weights of the IPAdapter.
    """
    with self.setup_adapter(target):
        super().__init__(target)

    self.fine_grained = fine_grained
    self._clip_image_encoder = [clip_image_encoder]
    if fine_grained:
        self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
    self._image_proj = [image_proj]

    self.sub_adapters = [
        CrossAttentionAdapter(target=cross_attn, scale=scale)
        for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
    ]

    if weights is not None:
        image_proj_state_dict: dict[str, Tensor] = {
            k.removeprefix("image_proj."): v for k, v in weights.items() if k.startswith("image_proj.")
        }
        self.image_proj.load_state_dict(image_proj_state_dict)

        for i, cross_attn in enumerate(self.sub_adapters):
            cross_attention_weights: list[Tensor] = []
            for k, v in weights.items():
                prefix = f"ip_adapter.{i:03d}."
                if not k.startswith(prefix):
                    continue
                cross_attention_weights.append(v)

            assert len(cross_attention_weights) == 2
            cross_attn.load_weights(*cross_attention_weights)

clip_image_encoder property

clip_image_encoder: CLIPImageEncoderH

The CLIP image encoder of the adapter.

scale property writable

scale: float

The scale of the adapter.

compute_clip_image_embedding

compute_clip_image_embedding(
    image_prompt: Tensor | Image | list[Image],
    weights: list[float] | None = None,
    concat_batches: bool = True,
) -> Tensor

Compute the CLIP image embedding.

Parameters:

Name Type Description Default
image_prompt Tensor | Image | list[Image]

The image prompt to use.

required
weights list[float] | None

The scale to use for the image prompt.

None
concat_batches bool

Whether to concatenate the batches.

True

Returns:

Type Description
Tensor

The CLIP image embedding.

Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def compute_clip_image_embedding(
    self,
    image_prompt: Tensor | Image.Image | list[Image.Image],
    weights: list[float] | None = None,
    concat_batches: bool = True,
) -> Tensor:
    """Compute the CLIP image embedding.

    Args:
        image_prompt: The image prompt to use.
        weights: The scale to use for the image prompt.
        concat_batches: Whether to concatenate the batches.

    Returns:
        The CLIP image embedding.
    """
    if isinstance(image_prompt, Image.Image):
        image_prompt = self.preprocess_image(image_prompt)
    elif isinstance(image_prompt, list):
        assert all(isinstance(image, Image.Image) for image in image_prompt)
        image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])

    negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)

    batch_size = image_prompt.shape[0]
    if weights is not None:
        assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
        if any(weight != 1.0 for weight in weights):
            conditional_embedding *= (
                torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
                .unsqueeze(-1)
                .unsqueeze(-1)
            )

    if batch_size > 1 and concat_batches:
        # Create a longer image tokens sequence when a batch of images is given
        # See https://github.com/tencent-ailab/IP-Adapter/issues/99
        negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
        conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)

    return torch.cat((negative_embedding, conditional_embedding))

preprocess_image

preprocess_image(
    image: Image,
    size: tuple[int, int] = (224, 224),
    mean: list[float] | None = None,
    std: list[float] | None = None,
) -> Tensor

Preprocess the image.

Note

The default mean and std are parameters from https://github.com/openai/CLIP

Parameters:

Name Type Description Default
image Image

The image to preprocess.

required
size tuple[int, int]

The size to resize the image to.

(224, 224)
mean list[float] | None

The mean to use for normalization.

None
std list[float] | None

The standard deviation to use for normalization.

None
Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def preprocess_image(
    self,
    image: Image.Image,
    size: tuple[int, int] = (224, 224),
    mean: list[float] | None = None,
    std: list[float] | None = None,
) -> Tensor:
    """Preprocess the image.

    Note:
        The default mean and std are parameters from
        https://github.com/openai/CLIP

    Args:
        image: The image to preprocess.
        size: The size to resize the image to.
        mean: The mean to use for normalization.
        std: The standard deviation to use for normalization.
    """
    resized = image.resize(size)  # type: ignore
    return normalize(
        image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),
        mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
        std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
    )

set_clip_image_embedding

set_clip_image_embedding(image_embedding: Tensor) -> None

Set the CLIP image embedding context.

Note

This is required by ImageCrossAttention.

Parameters:

Name Type Description Default
image_embedding Tensor

The CLIP image embedding to set.

required
Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
    """Set the CLIP image embedding context.

    Note:
        This is required by `ImageCrossAttention`.

    Args:
        image_embedding: The CLIP image embedding to set.
    """
    self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

AdaIN

AdaIN(epsilon: float = 1e-08)

Bases: Module

Apply Adaptive Instance Normalization (AdaIN) to the target features.

See [arXiv:1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization for more details.

Receives:

Name Type Description
reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features.

targets Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The target features.

Returns:

Name Type Description
reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features (unchanged).

targets Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The target features, renormalized.

Parameters:

Name Type Description Default
epsilon float

A small value to avoid division by zero.

1e-08
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(self, epsilon: float = 1e-8) -> None:
    """Initialize the AdaIN module.

    Args:
        epsilon: A small value to avoid division by zero.
    """
    super().__init__()
    self.epsilon = epsilon

ExtractReferenceFeatures

ExtractReferenceFeatures(*args: Any, **kwargs: Any)

Bases: Module

Extract the reference features from the input features.

Note

This layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).

The reference features are the first features of the conditional and unconditional input features. They are extracted, and repeated to match the batch size of the input features.

Receives:

Name Type Description
features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The input features.

Returns:

Name Type Description
reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

ScaleReferenceFeatures

ScaleReferenceFeatures(scale: float = 1.0)

Bases: Module

Scale the reference features.

Note

This layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).

This layer scales the reference features which will later be used (in the attention dot product) with the target features.

Receives:

Name Type Description
features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The input reference features.

Returns:

Name Type Description
features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The rescaled reference features.

Parameters:

Name Type Description Default
scale float

The scaling factor.

1.0
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(
    self,
    scale: float = 1.0,
) -> None:
    """Initialize the ScaleReferenceFeatures module.

    Args:
        scale: The scaling factor.
    """
    super().__init__()
    self.scale = scale

SharedSelfAttentionAdapter

SharedSelfAttentionAdapter(
    target: SelfAttention, scale: float = 1.0
)

Bases: Chain, Adapter[SelfAttention]

Upgrades a SelfAttention layer into a SharedSelfAttention layer.

This adapter inserts 3 StyleAligned modules right after the original Q, K, V Linear-s (wrapped inside a fl.Distribute).

Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(
    self,
    target: fl.SelfAttention,
    scale: float = 1.0,
) -> None:
    with self.setup_adapter(target):
        super().__init__(target)

    self._style_aligned_layers = [
        StyleAligned(  # Query
            adain=True,
            concatenate=False,
            scale=scale,
        ),
        StyleAligned(  # Key
            adain=True,
            concatenate=True,
            scale=scale,
        ),
        StyleAligned(  # Value
            adain=False,
            concatenate=True,
            scale=scale,
        ),
    ]

StyleAligned

StyleAligned(
    adain: bool, concatenate: bool, scale: float = 1.0
)

Bases: Chain

StyleAligned module.

This layer encapsulates the logic of the StyleAligned method, as described in [arXiv:2312.02133] Style Aligned Image Generation via Shared Attention.

See also https://blog.finegrain.ai/posts/implementing-style-aligned/.

Receives:

Name Type Description
features Float[Tensor, 'cfg_batch_size sequence_length_in embedding_dim']

The input features.

Returns:

Name Type Description
shared_features Float[Tensor, 'cfg_batch_size sequence_length_out embedding_dim']

The transformed features.

Parameters:

Name Type Description Default
adain bool

Whether to apply Adaptive Instance Normalization to the target features.

required
scale float

The scaling factor for the reference features.

1.0
concatenate bool

Whether to concatenate the reference and target features.

required
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(
    self,
    adain: bool,
    concatenate: bool,
    scale: float = 1.0,
) -> None:
    """Initialize the StyleAligned module.

    Args:
        adain: Whether to apply Adaptive Instance Normalization to the target features.
        scale: The scaling factor for the reference features.
        concatenate: Whether to concatenate the reference and target features.
    """
    super().__init__(
        # (features): (cfg_batch_size sequence_length embedding_dim)
        fl.Parallel(
            fl.Identity(),
            ExtractReferenceFeatures(),
        ),
        # (targets, reference)
        AdaIN(),
        # (targets_renormalized, reference)
        fl.Distribute(
            fl.Identity(),
            ScaleReferenceFeatures(scale=scale),
        ),
        # (targets_renormalized, reference_scaled)
        fl.Concatenate(
            fl.GetArg(index=0),  # targets
            fl.GetArg(index=1),  # reference
            dim=-2,  # sequence_length
        ),
        # (features_with_shared_reference)
    )

    if not adain:
        adain_module = self.ensure_find(AdaIN)
        self.remove(adain_module)

    if not concatenate:
        concatenate_module = self.ensure_find(fl.Concatenate)
        self.replace(
            old_module=concatenate_module,
            new_module=fl.GetArg(index=0),  # targets
        )

scale property writable

scale: float

The scaling factor for the reference features.

StyleAlignedAdapter

StyleAlignedAdapter(target: T, scale: float = 1.0)

Bases: Generic[T], Chain, Adapter[T]

Upgrade each SelfAttention layer of a UNet into a SharedSelfAttention layer.

Parameters:

Name Type Description Default
target T

The target module.

required
scale float

The scaling factor for the reference features.

1.0
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(
    self,
    target: T,
    scale: float = 1.0,
) -> None:
    """Initialize the StyleAlignedAdapter.

    Args:
        target: The target module.
        scale: The scaling factor for the reference features.
    """
    with self.setup_adapter(target):
        super().__init__(target)

    # create a SharedSelfAttentionAdapter for each SelfAttention module
    self.shared_self_attention_adapters = tuple(
        SharedSelfAttentionAdapter(
            target=self_attention,
            scale=scale,
        )
        for self_attention in self.target.layers(fl.SelfAttention)
    )

scale property writable

scale: float

The scaling factor for the reference features.

DiffusionTarget dataclass

DiffusionTarget(
    *,
    tile: Tile,
    solver: Solver,
    init_latents: Tensor | None = None,
    opacity_mask: Tensor | None = None,
    weight: int = 1,
    start_step: int = 0,
    end_step: int = MAX_STEPS
)

Represents a target for the tiled diffusion process.

This class encapsulates the parameters and properties needed to define a specific area (target) within a larger diffusion process, allowing for fine-grained control over different regions of the generated image.

Attributes:

Name Type Description
tile Tile

The tile defining the area of the target within the latent image.

solver Solver

The solver to use for this target's diffusion process. This is useful because some solvers have an internal state that needs to be updated during the diffusion process. Using the same solver instance for multiple targets would interfere with this internal state.

init_latents Tensor | None

The initial latents for this target. If None, the target will be initialized with noise.

opacity_mask Tensor | None

Mask controlling the target's visibility in the final image. If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent which means the target has no influence.

weight int

The importance of this target in the final image. Higher values increase the target's influence.

start_step int

The diffusion step at which this target begins to influence the process.

end_step int

The diffusion step at which this target stops influencing the process.

size Size

The size of the target area.

offset tuple[int, int]

The top-left offset of the target area within the latent image.

The combination of opacity_mask and weight determines the target's overall contribution to the final generated image. The solver is responsible for the actual diffusion calculations for this target.

MultiDiffusion

Bases: ABC, Generic[T]

MultiDiffusion class for performing multi-target diffusion using tiled diffusion.

For more details, refer to the paper: MultiDiffusion

generate_latent_tiles staticmethod

generate_latent_tiles(
    size: Size, tile_size: Size, min_overlap: int = 8
) -> list[Tile]

Generate tiles for a latent image with the given size and tile size.

If one dimension of the tile_size is larger than the corresponding dimension of the image size, a single tile is used to cover the entire image - and therefore tile_size is ignored. This algorithm ensures that the tile size is respected as much as possible, while still covering the entire image and respecting the minimum overlap.

Source code in src/refiners/foundationals/latent_diffusion/multi_diffusion.py
@staticmethod
def generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:
    """
    Generate tiles for a latent image with the given size and tile size.

    If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is
    used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size
    is respected as much as possible, while still covering the entire image and respecting the minimum overlap.
    """
    assert (
        0 <= min_overlap < min(tile_size.height, tile_size.width)
    ), "Overlap must be non-negative and less than the tile size"

    if tile_size.width > size.width or tile_size.height > size.height:
        return [Tile(top=0, left=0, bottom=size.height, right=size.width)]

    tiles: list[Tile] = []

    def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:
        if tile_length >= length:
            return 1, 0
        num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1
        overlap = (num_tiles * tile_length - length) // (num_tiles - 1)
        return num_tiles, overlap

    num_tiles_x, overlap_x = _compute_tiles_and_overlap(
        length=size.width, tile_length=tile_size.width, min_overlap=min_overlap
    )
    num_tiles_y, overlap_y = _compute_tiles_and_overlap(
        length=size.height, tile_length=tile_size.height, min_overlap=min_overlap
    )

    for i in range(num_tiles_y):
        for j in range(num_tiles_x):
            x = j * (tile_size.width - overlap_x)
            y = i * (tile_size.height - overlap_y)

            # Adjust x and y coordinates to ensure full-sized tiles
            if x + tile_size.width > size.width:
                x = size.width - tile_size.width
            if y + tile_size.height > size.height:
                y = size.height - tile_size.height

            tile_right = x + tile_size.width
            tile_bottom = y + tile_size.height
            tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))

    return tiles

ELLA

ELLA(
    time_channel: int,
    timestep_embedding_dim: int,
    width: int,
    num_layers: int,
    num_heads: int,
    num_latents: int,
    input_dim: int | None = None,
    out_dim: int | None = None,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Passthrough

ELLA latents encoder.

See [arXiv:2403.05135] ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment for more details.

Source code in src/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(
    self,
    time_channel: int,
    timestep_embedding_dim: int,
    width: int,
    num_layers: int,
    num_heads: int,
    num_latents: int,
    input_dim: int | None = None,
    out_dim: int | None = None,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(
        TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype),
        fl.UseContext("adapted_cross_attention_block", "llm_text_embedding"),
        PerceiverResampler(
            timestep_embedding_dim,
            width,
            num_layers,
            num_heads,
            num_latents,
            out_dim,
            input_dim,
            device=device,
            dtype=dtype,
        ),
        fl.SetContext("ella", "latents"),
    )

ELLAAdapter

ELLAAdapter(
    target: T,
    latents_encoder: ELLA,
    weights: dict[str, Tensor] | None = None,
)

Bases: Generic[T], Chain, Adapter[T]

Adapter for ELLA.

Source code in src/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None:
    if weights is not None:
        latents_encoder.load_state_dict(weights)

    self._latents_encoder = [latents_encoder]
    with self.setup_adapter(target):
        super().__init__(target)
    self.sub_adapters = [
        ELLACrossAttentionAdapter(use_context)
        for cross_attn in target.layers(CrossAttentionBlock)
        for use_context in cross_attn.layers(fl.UseContext)
    ]