{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"A PyTorch microframework for foundation model adaptation","text":"

The simplest way to train and run adapters on top of foundation models.

At the era of foundation models, adaptation is quickly rising at the method of choice for bridging the last mile quality gap. We couldn't find a framework with first class citizen APIs for foundation model adaptation, so we created one. It's called Refiners, and we're building it on top of PyTorch, in the open, under the MIT License. Read our manifesto.

"},{"location":"concepts/chain/","title":"Chain","text":"

When we say models are implemented in a declarative way in Refiners, what this means in practice is they are implemented as Chains. Chain is a Python class to implement trees of modules. It is a subclass of Refiners' Module, which is in turn a subclass of PyTorch's Module. All inner nodes of a Chain are subclasses of Chain, and leaf nodes are subclasses of Refiners' Module.

"},{"location":"concepts/chain/#a-first-example","title":"A first example","text":"

To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables.

img_res = 28\nchannels = 128\nkernel_size = 3\nhidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels\nhidden_layer_out = 200\noutput_size = 10\n

Now, here is the model in PyTorch:

class BasicModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = nn.Conv2d(1, channels, kernel_size)\n        self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out)\n        self.maxpool = nn.MaxPool2d(2)\n        self.linear_2 = nn.Linear(hidden_layer_out, output_size)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = nn.functional.relu(x)\n        x = self.maxpool(x)\n        x = x.flatten(start_dim=1)\n        x = self.linear_1(x)\n        x = nn.functional.relu(x)\n        x = self.linear_2(x)\n        return nn.functional.softmax(x, dim=0)\n

And here is how we could implement the same model in Refiners:

class BasicModel(fl.Chain):\n    def __init__(self):\n        super().__init__(\n            fl.Conv2d(1, channels, kernel_size),\n            fl.ReLU(),\n            fl.MaxPool2d(2),\n            fl.Flatten(start_dim=1),\n            fl.Linear(hidden_layer_in, hidden_layer_out),\n            fl.ReLU(),\n            fl.Linear(hidden_layer_out, output_size),\n            fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)),\n        )\n

Note

We often use the namespace fl which means fluxion, which is the name of the part of Refiners that implements basic layers.

As of writing, Refiners does not include a Softmax layer by default, but as you can see you can easily call arbitrary code using fl.Lambda. Alternatively, if you just wanted to write Softmax(), you could implement it like this:

class Softmax(fl.Module):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.softmax(x, dim=0)\n

Note

Notice the type hints here. All of Refiners' codebase is typed, which makes it a pleasure to use if your downstream code is typed too.

"},{"location":"concepts/chain/#inspecting-and-manipulating","title":"Inspecting and manipulating","text":"

Let us instantiate the BasicModel we just defined and inspect its representation in a Python REPL:

>>> m = BasicModel()\n>>> m\n(CHAIN) BasicModel()\n    \u251c\u2500\u2500 Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n    \u251c\u2500\u2500 ReLU() #1\n    \u251c\u2500\u2500 MaxPool2d(kernel_size=2, stride=2)\n    \u251c\u2500\u2500 Flatten(start_dim=1)\n    \u251c\u2500\u2500 Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1\n    \u251c\u2500\u2500 ReLU() #2\n    \u251c\u2500\u2500 Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2\n    \u2514\u2500\u2500 Softmax()\n

The children of a Chain are stored in a dictionary and can be accessed by name or index. When layers of the same type appear in the Chain, distinct suffixed keys are automatically generated.

>>> m[0]\nConv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n>>> m.Conv2d\nConv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n>>> m[6]\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n>>> m.Linear_2\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n

The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it:

class ConvLayer(fl.Chain):\n    pass\n\nclass HiddenLayer(fl.Chain):\n    pass\n\nclass OutputLayer(fl.Chain):\n    pass\n\nm.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0)))\nm.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1)))\nm.append(OutputLayer(m.pop(2), m.pop(2)))\n

Did it work? Let's see:

>>> m\n(CHAIN) BasicModel()\n    \u251c\u2500\u2500 (CHAIN) ConvLayer()\n    \u2502   \u251c\u2500\u2500 Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n    \u2502   \u251c\u2500\u2500 ReLU()\n    \u2502   \u2514\u2500\u2500 MaxPool2d(kernel_size=2, stride=2)\n    \u251c\u2500\u2500 (CHAIN) HiddenLayer()\n    \u2502   \u251c\u2500\u2500 Flatten(start_dim=1)\n    \u2502   \u251c\u2500\u2500 Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)\n    \u2502   \u2514\u2500\u2500 ReLU()\n    \u2514\u2500\u2500 (CHAIN) OutputLayer()\n        \u251c\u2500\u2500 Linear(in_features=200, out_features=10, device=cpu, dtype=float32)\n        \u2514\u2500\u2500 Softmax()\n

Note

Organizing models like this is actually a good idea, it makes them easier to understand and adapt.

"},{"location":"concepts/chain/#accessing-and-iterating","title":"Accessing and iterating","text":"

There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named walk. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the Conv2d and the Linears), we can just do:

for x in m.layers(fl.WeightedModule):\n    print(x)\n

It prints:

Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\nLinear(in_features=21632, out_features=200, device=cpu, dtype=float32)\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n
"},{"location":"concepts/context/","title":"Context","text":""},{"location":"concepts/context/#motivation-avoiding-props-drilling","title":"Motivation: avoiding \"props drilling\"","text":"

Chains are a powerful tool to represent computational graphs, but they are not always convenient.

Many adapters add extra input to the model. For instance, ControlNet and T2i-Adapter require a guide (condition image), inpainting adapters require a mask, Latent Consistency Models use a condition scale embedding, other adapters may leverage time or context embeddings... Those inputs are often passed by the user in a high-level format (numbers, text...) and converted to embeddings by the model before being consumed in downstream layers.

Managing this extra input is inconvenient. Typically, you would add them to the inputs and outputs of each layer somehow. But if you add them as channels or concatenate them you get composability issues, and if you try to pass them as extra arguments you end up needing to deal with them in layers that should not be concerned with their presence.

The same kind of having to pass extra contextual information up and down a tree exists in other fields, and in particular in JavaScript frameworks that deal with a Virtual DOM such as React, where it is called \"props drilling\". To make it easier to manage, the Context API was introduced, and we went with a similar idea in Refiners.

"},{"location":"concepts/context/#a-simple-example","title":"A simple example","text":"

Here is an example of how contexts work:

from refiners.fluxion.context import Contexts\n\nclass MyProvider(fl.Chain):\n    def init_context(self) -> Contexts:\n        return {\"my context\": {\"my key\": None}}\n\nm = MyProvider(\n    fl.Chain(\n        fl.Sum(\n            fl.UseContext(\"my context\", \"my key\"),\n            fl.Lambda(lambda: 2),\n        ),\n        fl.SetContext(\"my context\", \"my key\"),\n    ),\n    fl.Chain(\n        fl.UseContext(\"my context\", \"my key\"),\n        fl.Lambda(print),\n    ),\n)\n\nm.set_context(\"my context\", {\"my key\": 4})\nm()  # prints 6\n

As you can see, to use the context, you define it by subclassing any Chain and defining init_context. You can set the context with the set_context method or the SetContext layer, and you can access it anywhere down the provider's tree with UseContext.

"},{"location":"concepts/context/#simplifying-complex-models-with-context","title":"Simplifying complex models with Context","text":"

Another use of the context is simplifying complex models, in particular those with long-range nested skip connections.

To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:

square = lambda: fl.Lambda(lambda x: x ** 2)\nsqrt = lambda: fl.Lambda(lambda x: x ** 0.5)\n\nm1 = fl.Chain(\n    fl.Residual(\n        square(),\n        fl.Residual(\n            square(),\n            fl.Residual(\n                square(),\n            ),\n            sqrt(),\n        ),\n        sqrt(),\n    ),\n    sqrt(),\n)\n

You can see two problems here:

Let us solve those two issues using the context:

from refiners.fluxion.context import Contexts\n\nclass MyModel(fl.Chain):\n    def init_context(self) -> Contexts:\n        return {\"mymodel\": {\"residuals\": []}}\n\ndef push_residual():\n    return fl.SetContext(\n        \"mymodel\",\n        \"residuals\",\n        callback=lambda l, x: l.append(x),\n    )\n\nclass ApplyResidual(fl.Sum):\n    def __init__(self):\n        super().__init__(\n            fl.Identity(),\n            fl.UseContext(\"mymodel\", \"residuals\").compose(lambda x: x.pop()),\n        )\n\nsquares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))\nsqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))\nm2 = MyModel(squares, sqrts)\n

As you can see, despite squares and sqrts being completely independent chains, they can access the same context due to being nested under the same provider.

Does it work?

>>> m1(2.0)\n2.5547711633552384\n>>> m2(2.0)\n2.5547711633552384\n

Yes!\u2728

"},{"location":"concepts/adapter/","title":"Adapter","text":"

Adapters are the final and most high-level abstraction in Refiners. They are the concept of adaptation turned into code.

An Adapter is generally a Chain that replaces a Module (the target) in another Chain (the parent). Typically the target will become a child of the adapter.

In code terms, Adapter is a generic mixin. Adapters subclass type(parent) and Adapter[type(target)]. For instance, if you adapt a Conv2d in a Sum, the definition of the Adapter could look like:

class MyAdapter(fl.Sum, fl.Adapter[fl.Conv2d]):\n    ...\n
"},{"location":"concepts/adapter/#a-simple-example-adapting-a-linear","title":"A simple example: adapting a Linear","text":"

Let us take a simple example to see how this works. Consider this model:

In code, it could look like this:

my_model = MyModel(fl.Chain(fl.Linear(), fl.Chain(...)))\n

Suppose we want to adapt the Linear to sum its output with the result of another chain. We can define and initialize an adapter like this:

class MyAdapter(fl.Sum, fl.Adapter[fl.Linear]):\n    def __init__(self, target: fl.Linear) -> None:\n        with self.setup_adapter(target):\n            super().__init__(fl.Chain(...), target)\n\n# Find the target and its parent in the chain.\n# For simplicity let us assume it is the only Linear.\nfor target, parent in my_model.walk(fl.Linear):\n    break\n\nadapter = MyAdapter(target)\n

The result is now this:

Note that the original chain is unmodified. You can still run inference on it as if the adapter did not exist. To use the adapter, you must inject it into the chain:

adapter.inject(parent)\n

The result will be:

Now if you run inference it will go through the Adapter. You can go back to the previous situation by calling adapter.eject().

"},{"location":"concepts/adapter/#a-more-complicated-example-adapting-a-chain","title":"A more complicated example: adapting a Chain","text":"

We are not limited to adapting base modules, we can also adapt Chains.

Starting from the same model as earlier, let us assume we want to:

This Adapter that will perform a structural_copy of part of its target, which means it will duplicate all Chain nodes but keep pointers to the same WeightedModules, and hence not use extra GPU memory.

class MyAdapter(fl.Chain, fl.Adapter[fl.Chain]):\n    def __init__(self, target: fl.Linear) -> None:\n        with self.setup_adapter(target):\n            new_b = fl.Chain(target, target.Chain.Chain_2.structural_copy())\n            super().__init__(new_b, target.Linear)\n\nadapter = MyAdapter(my_model.Chain_1)  # Chain A in the diagram\n

We end up with this:

We can now inject it into the original graph. It is not even needed to pass the parent this time, since Chains know their parents.

adapter.inject()\n

We obtain this:

Note that the Linear is in the Chain twice now, but that does not matter as long as you really want it to be the same Linear layer with the same weights.

As before, we can call eject the adapter to go back to the original model.

"},{"location":"concepts/adapter/#a-real-world-example-loraadapter","title":"A real-world example: LoraAdapter","text":"

A popular example of adaptation is LoRA. You can check out how we implement it in Refiners.

"},{"location":"concepts/adapter/#higher-level-adapters","title":"Higher-level adapters","text":"

If you use Refiners, you will find Adapters that go beyond the simple definition given at the top of this page. Some adapters inject multiple smaller adapters in models, others implement helper methods to be used by their caller...

From a bird's eye view, you can just consider Adapters as things you inject into models to adapt them, and that can be ejected to return the model to its original state. You will get a better feel for what is an adapter and how to leverage them by actually using the framework.

"},{"location":"getting-started/advanced/","title":"Advanced usage","text":""},{"location":"getting-started/advanced/#using-other-package-managers-pip-poetry","title":"Using other package managers (pip, Poetry...)","text":"

We use Rye to maintain and release Refiners but it conforms to the standard Python packaging guidelines and can be used with other package managers. Please refer to their respective documentation to figure out how to install a package from Git if you intend to use the development branch, as well as how to specify features.

"},{"location":"getting-started/advanced/#using-stable-releases-from-pypi","title":"Using stable releases from PyPI","text":"

Although we recommend using our development branch, we do publish more stable releases to PyPI and you are welcome to use them in your project. However, note that the format of weights can be different from the current state of the development branch, so you will need the conversion scripts from the corresponding tag in GitHub, for instance here for v0.2.0.

"},{"location":"getting-started/recommended/","title":"Recommended usage","text":"

Refiners is still a young project and development is active, so to use the latest and greatest version of the framework we recommend you use the main branch from our development repository.

Moreover, we recommend using Rye which simplifies several things related to Python package management, so start by following the instructions to install it on your system.

"},{"location":"getting-started/recommended/#installing","title":"Installing","text":"

To try Refiners, clone the GitHub repository and install it with all optional features:

git clone \"git@github.com:finegrain-ai/refiners.git\"\ncd refiners\nrye sync --all-features\n
"},{"location":"getting-started/recommended/#converting-weights","title":"Converting weights","text":"

The format of state dicts used by Refiners is custom and we do not redistribute model weights, but we provide conversion tools and working scripts for popular models. For instance, let us convert the autoencoder from Stable Diffusion 1.5:

python \"scripts/conversion/convert_diffusers_autoencoder_kl.py\" --to \"lda.safetensors\"\n

If you need to convert weights for all models, check out script/prepare_test_weights.py.

Warning

Using script/prepare_test_weights.py requires a GPU with significant VRAM and a lot of disk space.

Now to check that it works copy your favorite 512x512 picture in the current directory as input.png and create ldatest.py with this content:

from PIL import Image\nfrom refiners.fluxion.utils import no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder\n\nwith no_grad():\n    lda = SD1Autoencoder()\n    lda.load_from_safetensors(\"lda.safetensors\")\n\n    image = Image.open(\"input.png\")\n    latents = lda.image_to_latents(image)\n    decoded = lda.latents_to_image(latents)\n    decoded.save(\"output.png\")\n

Run it:

python ldatest.py\n

Inspect output.png: it should be similar to input.png but have a few differences. Latent Autoencoders are good compressors!

"},{"location":"getting-started/recommended/#using-refiners-in-your-own-project","title":"Using Refiners in your own project","text":"

So far you used Refiners as a standalone package, but if you want to create your own project using it as a dependency here is how you can proceed:

rye init --py \"3.11\" myproject\ncd myproject\nrye add --git \"git@github.com:finegrain-ai/refiners.git\" --features training refiners\nrye sync\n

If you only intend to do inference and no training, you can drop --features training.

To convert weights, you can either use a copy of the refiners repository as described above or add the conversion feature as a development dependency:

rye add --dev --git \"git@github.com:finegrain-ai/refiners.git\" --features conversion refiners\n

Note

You will still need to download the conversion scripts independently if you go that route.

"},{"location":"getting-started/recommended/#whats-next","title":"What's next?","text":"

We suggest you check out the guides section to dive into the usage of Refiners, of the Key Concepts section for a better understanding of how the framework works.

"},{"location":"guides/adapting_sdxl/","title":"Adapting Stable Diffusion XL","text":"

Stable Diffusion XL (SDXL) is a very popular text-to-image open source foundation model. This guide will show you how to boost its capabilities with Refiners, using iconic adapters the framework supports out-of-the-box, i.e. without the need for tedious prompt engineering. We'll follow a step by step approach, progressively increasing the number of adapters involved to showcase how simple adapter composition is using Refiners. Our use case will be the generation of an image with \"a futuristic castle surrounded by a forest, mountains in the background\".

"},{"location":"guides/adapting_sdxl/#prerequisites","title":"Prerequisites","text":"

Make sure Refiners is installed in your local environment - see Getting started - and you have access to a decent GPU.

Warning

As the examples in this guide's code snippets use CUDA, a minimum of 24GB VRAM is needed.

Before diving into the adapters themselves, let's establish a baseline by simply prompting SDXL with Refiners.

Reminder

A StableDiffusion model is composed of three modules:

As Refiners comes with a new model representation - see Chain - , you need to download and convert the weights of each module by calling our conversion scripts directly from your terminal (make sure you're in your local refiners directory, with your local environment active):

python scripts/conversion/convert_transformers_clip_text_model.py --from \"stabilityai/stable-diffusion-xl-base-1.0\" --subfolder2 text_encoder_2 --to DoubleCLIPTextEncoder.safetensors --half\npython scripts/conversion/convert_diffusers_unet.py --from \"stabilityai/stable-diffusion-xl-base-1.0\" --to sdxl-unet.safetensors --half\npython scripts/conversion/convert_diffusers_autoencoder_kl.py --from \"madebyollin/sdxl-vae-fp16-fix\" --subfolder \"\" --to sdxl-lda.safetensors --half\n

Note

This will download the original weights from https://huggingface.co/ which takes some time. If you already have this repo cloned locally, use the --from /path/to/stabilityai/stable-diffusion-xl-base-1.0 option instead.

Now, we can write the Python script responsible for inference. Just create a simple inference.py file, and open it in your favorite editor.

Start by instantiating a StableDiffusion_XL model and load it with the converted weights:

import torch\n\nfrom refiners.fluxion.utils import manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)  # Using half-precision for memory efficiency\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n

Then, define the inference parameters by setting the appropriate prompt / seed / inference steps:

# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\n\n# Enable self-attention guidance to enhance the quality of the generated images\nsdxl.set_self_attention_guidance(enable=True, scale=0.75)\n\n# ... Inference process\n

You can now define and run the proper inference process:

with no_grad():  # Disable gradient calculation for memory-efficient inference\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    manual_seed(seed)\n\n    # SDXL typically generates 1024x1024, here we use a higher resolution.\n    x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"vanilla_sdxl.png\")\n
Expand to see the entire end-to-end code
import torch\n\nfrom refiners.fluxion.utils import manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n\n# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\nsdxl.set_self_attention_guidance(\n    enable=True, scale=0.75\n)  # Enable self-attention guidance to enhance the quality of the generated images\n\n\nwith no_grad():  # Disable gradient calculation for memory-efficient inference\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    manual_seed(seed=seed)\n\n    # SDXL typically generates 1024x1024, here we use a higher resolution.\n    x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"vanilla_sdxl.png\")\n

It's time to execute your code. The resulting image should look like this:

Generated image of a castle using default SDXL weights.

It is not really what we prompted the model for, unfortunately. To get a more futuristic-looking castle, you can either go for tedious prompt engineering, or use a pretrainered LoRA tailored to our use case, like the Sci-fi Environments LoRA available on Civitai. We'll now show you how the LoRA option works with Refiners.

"},{"location":"guides/adapting_sdxl/#single-lora","title":"Single LoRA","text":"

To use the Sci-fi Environments LoRA, all you have to do is download its weights to disk as a .safetensors, and inject them into SDXL using SDLoraManager right after instantiating StableDiffusion_XL:

from refiners.fluxion.utils import load_from_safetensors\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\n\n# Load LoRA weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\")\nmanager.add_loras(\"scifi-lora\", tensors=scifi_lora_weights)\n
Expand to see the entire end-to-end code
import torch\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n\n# Load LoRA weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\")\nmanager.add_loras(\"scifi-lora\", tensors=scifi_lora_weights)\n\n# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\nsdxl.set_self_attention_guidance(\n    enable=True, scale=0.75\n)  # Enable self-attention guidance to enhance the quality of the generated images\n\nwith no_grad():\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    manual_seed(seed=seed)\n\n    # SDXL typically generates 1024x1024, here we use a higher resolution.\n    x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"scifi_sdxl.png\")\n

You should get something like this - pretty neat, isn't it?

Generated image of a castle in sci-fi style."},{"location":"guides/adapting_sdxl/#multiple-loras","title":"Multiple LoRAs","text":"

Continuing with our futuristic castle example, we might want to turn it, for instance, into a pixel art.

Again, we could either try some tedious prompt engineering, or instead use another LoRA found on the web, such as Pixel Art LoRA, found on Civitai. This is dead simple as SDLoraManager allows loading multiple LoRAs:

# Load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\"))\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixel-art-xl-v1.1.safetensors\"))\n

Adapters such as LoRAs also have a scale (roughly) quantifying the effect of this Adapter. Refiners allows setting different scales for each Adapter, allowing the user to balance the effect of each Adapter:

# Load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\"), scale=1.0)\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixel-art-xl-v1.1.safetensors\"), scale=1.4)\n
Expand to see the entire end-to-end code
import torch\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n\n# Load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\")\npixel_art_lora_weights = load_from_safetensors(\"pixel-art-xl-v1.1.safetensors\")\nmanager.add_loras(\"scifi-lora\", scifi_lora_weights, scale=1.0)\nmanager.add_loras(\"pixel-art-lora\", pixel_art_lora_weights, scale=1.4)\n\n# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\nsdxl.set_self_attention_guidance(\n    enable=True, scale=0.75\n)  # Enable self-attention guidance to enhance the quality of the generated images\n\nwith no_grad():\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    manual_seed(seed=seed)\n\n    # SDXL typically generates 1024x1024, here we use a higher resolution.\n    x = sdxl.init_latents((2048, 2048)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"scifi_pixel_sdxl.png\")\n

The results are looking great:

Generated image of a castle in sci-fi, pixel art style."},{"location":"guides/adapting_sdxl/#multiple-loras-ip-adapter","title":"Multiple LoRAs + IP-Adapter","text":"

Refiners really shines when it comes to composing different Adapters to fully exploit the possibilities of foundation models.

For instance, IP-Adapter (covered in a previous blog post) is a common choice for practictioners wanting to guide the diffusion process towards a specific prompt image.

In our example, consider this image of the Neuschwanstein Castle:

Credits: Bayerische Schl\u00f6sserverwaltung, Anton Brandl

We would like to guide the diffusion process to align with this image, using IP-Adapter. First, download the image as well as the weights of IP-Adapter by calling the following commands from your terminal (again, make sure in you're in your local refiners directory):

curl -O https://refine.rs/guides/adapting_sdxl/german-castle.jpg\npython scripts/conversion/convert_transformers_clip_image_model.py --from \"stabilityai/stable-diffusion-2-1-unclip\" --to CLIPImageEncoderH.safetensors --half\ncurl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin\npython scripts/conversion/convert_diffusers_ip_adapter.py --from ip-adapter-plus_sdxl_vit-h.bin --half\n

This will download and convert both IP-Adapter and CLIP Image Encoder pretrained weights.

Then, in your Python code, simply instantiate a SDXLIPAdapter targeting our sdxl.unet, and inject it using a simple .inject() call:

# IP-Adapter\nip_adapter = SDXLIPAdapter(\n    target=sdxl.unet, \n    weights=load_from_safetensors(\"ip-adapter-plus_sdxl_vit-h.safetensors\"),\n    scale=1.0,\n    fine_grained=True  # Use fine-grained IP-Adapter (i.e IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\"CLIPImageEncoderH.safetensors\")\nip_adapter.inject()\n

Then, at runtime, we simply compute the embedding of the image prompt through the ip_adapter object, and set its embedding calling .set_clip_image_embedding():

from PIL import Image\nimage_prompt = Image.open(\"german-castle.jpg\")\n\nwith torch.no_grad():\n    clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n    ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n# And start the diffusion process\n

Note

Be wary that composing Adapters (especially ones of different natures, such as LoRAs and IP-Adapter) can be tricky, as their respective effects can be adversarial. This is visible in our example below. In the code below, we tuned the LoRAs scales respectively to 1.5 and 1.55. We invite you to try and test different seeds and scales to find the perfect combination!

Expand to see the entire end-to-end code
import torch\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n\n# Load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\")\npixel_art_lora_weights = load_from_safetensors(\"pixel-art-xl-v1.1.safetensors\")\nmanager.add_loras(\"scifi-lora\", scifi_lora_weights, scale=1.5)\nmanager.add_loras(\"pixel-art-lora\", pixel_art_lora_weights, scale=1.55)\n\n# Load IP-Adapter\nip_adapter = SDXLIPAdapter(\n    target=sdxl.unet,\n    weights=load_from_safetensors(\"ip-adapter-plus_sdxl_vit-h.safetensors\"),\n    scale=1.0,\n    fine_grained=True,  # Use fine-grained IP-Adapter (IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\"CLIPImageEncoderH.safetensors\")\nip_adapter.inject()\n\n# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nimage_prompt = Image.open(\"german-castle.jpg\")\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\nsdxl.set_self_attention_guidance(\n    enable=True, scale=0.75\n)  # Enable self-attention guidance to enhance the quality of the generated images\n\nwith no_grad():\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n    ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n    manual_seed(seed=seed)\n    x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"scifi_pixel_IP_sdxl.png\")\n

The result looks convincing: we do get a pixel-art, futuristic-looking Neuschwanstein castle!

Generated image in sci-fi, pixel art style, using IP-Adapter."},{"location":"guides/adapting_sdxl/#everything-else-t2i-adapter","title":"Everything else + T2I-Adapter","text":"

T2I-Adapters1 are a powerful class of Adapters aiming at controlling the Text-to-Image (T2I) diffusion process with external control signals, such as canny edges or pose estimations inputs. In this section, we will compose our previous example with the Depth-Zoe Adapter, providing a depth condition to the diffusion process using the following depth map as input signal:

Input depth map of the initial castle image.

First, download the image as well as the weights of T2I-Depth-Zoe-Adapter by calling the following commands:

curl -O https://refine.rs/guides/adapting_sdxl/zoe-depth-map-german-castle.png\npython scripts/conversion/convert_diffusers_t2i_adapter.py --from \"TencentARC/t2i-adapter-depth-zoe-sdxl-1.0\" --to t2i_depth_zoe_xl.safetensors --half\n

Then, just inject it as usual:

# Load T2I-Adapter\nt2i_adapter = SDXLT2IAdapter(\n    target=sdxl.unet, \n    name=\"zoe-depth\", \n    weights=load_from_safetensors(\"t2i_depth_zoe_xl.safetensors\"),\n    scale=0.72,\n).inject()\n

Finally, at runtime, compute the embedding of the input condition through the t2i_adapter object, and set its embedding calling .set_condition_features():

from refiners.fluxion.utils import image_to_tensor, interpolate\n\nimage_depth_condition = Image.open(\"zoe-depth-map-german-castle.png\")\n\nwith torch.no_grad():\n    condition = image_to_tensor(image_depth_condition.convert(\"RGB\"), device=sdxl.device, dtype=sdxl.dtype)\n    # Spatial dimensions should be divisible by default downscale factor (=16 for T2IAdapter ConditionEncoder)\n    condition = interpolate(condition, torch.Size((1024, 1024)))\n    t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))\n
Expand to see the entire end-to-end code
import torch\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad, image_to_tensor\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL, SDXLT2IAdapter\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter\n\n# Load SDXL\nsdxl = StableDiffusion_XL(device=\"cuda\", dtype=torch.float16)\nsdxl.clip_text_encoder.load_from_safetensors(\"DoubleCLIPTextEncoder.safetensors\")\nsdxl.unet.load_from_safetensors(\"sdxl-unet.safetensors\")\nsdxl.lda.load_from_safetensors(\"sdxl-lda.safetensors\")\n\n# Load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"Sci-fi_Environments_sdxl.safetensors\")\npixel_art_lora_weights = load_from_safetensors(\"pixel-art-xl-v1.1.safetensors\")\nmanager.add_loras(\"scifi-lora\", scifi_lora_weights, scale=1.5)\nmanager.add_loras(\"pixel-art-lora\", pixel_art_lora_weights, scale=1.55)\n\n# Load IP-Adapter\nip_adapter = SDXLIPAdapter(\n    target=sdxl.unet,\n    weights=load_from_safetensors(\"ip-adapter-plus_sdxl_vit-h.safetensors\"),\n    scale=1.0,\n    fine_grained=True,  # Use fine-grained IP-Adapter (IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\"CLIPImageEncoderH.safetensors\")\nip_adapter.inject()\n\n# Load T2I-Adapter\nt2i_adapter = SDXLT2IAdapter(\n    target=sdxl.unet, \n    name=\"zoe-depth\", \n    weights=load_from_safetensors(\"t2i_depth_zoe_xl.safetensors\"),\n    scale=0.72,\n).inject()\n\n# Hyperparameters\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nimage_prompt = Image.open(\"german-castle.jpg\")\nimage_depth_condition = Image.open(\"zoe-depth-map-german-castle.png\")\nseed = 42\nsdxl.set_inference_steps(50, first_step=0)\nsdxl.set_self_attention_guidance(\n    enable=True, scale=0.75\n)  # Enable self-attention guidance to enhance the quality of the generated images\n\nwith no_grad():\n    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n        text=prompt + \", best quality, high quality\",\n        negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    )\n    time_ids = sdxl.default_time_ids\n\n    clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n    ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n    # Spatial dimensions should be divisible by default downscale factor (=16 for T2IAdapter ConditionEncoder)\n    condition = image_to_tensor(image_depth_condition.convert(\"RGB\").resize((1024, 1024)), device=sdxl.device, dtype=sdxl.dtype)\n    t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))\n\n    manual_seed(seed=seed)\n    x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)\n\n    # Diffusion process\n    for step in sdxl.steps:\n        if step % 10 == 0:\n            print(f\"Step {step}\")\n        x = sdxl(\n            x,\n            step=step,\n            clip_text_embedding=clip_text_embedding,\n            pooled_text_embedding=pooled_text_embedding,\n            time_ids=time_ids,\n        )\n    predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"scifi_pixel_IP_T2I_sdxl.png\")\n

The results look convincing: the depth and proportions of the initial castle are more faithful, while preserving our futuristic, pixel-art style!

Generated image in sci-fi, pixel art style, using IP and T2I Adapters."},{"location":"guides/adapting_sdxl/#wrap-up","title":"Wrap up","text":"

As you can see in this guide, composing Adapters on top of foundation models is pretty seamless in Refiners, allowing practitioners to quickly test out different combinations of Adapters for their needs. We encourage you to try out different ones, and even train some yourselves!

  1. Mou, C., Wang, X., Xie, L., Zhang, J., Qi, Z., Shan, Y., & Qie, X. (2023). T2i-adapter: Learning adapters to dig out more controllable ability for text-to-image diffusion models.\u00a0\u21a9

"},{"location":"guides/training_101/","title":"Training 101","text":"

This guide will walk you through training a model using Refiners. We built the training_utils module to provide a simple, flexible, statically type-safe interface.

We will use a simple model and a toy dataset for demonstration purposes. The model will be a simple autoencoder, and the dataset will be a synthetic dataset of rectangles of different sizes.

"},{"location":"guides/training_101/#pre-requisites","title":"Pre-requisites","text":"

We recommend installing Refiners targeting a specific commit hash to avoid unexpected changes in the API. You also get the benefit of having a perfectly reproducible environment.

"},{"location":"guides/training_101/#model","title":"Model","text":"

Let's start by building our autoencoder using Refiners.

Expand to see the autoencoder model.
from refiners.fluxion import layers as fl\n\n\n class ConvBlock(fl.Chain):\n    def __init__(self, in_channels: int, out_channels: int) -> None:\n        super().__init__(\n            fl.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=3,\n                padding=1,\n                groups=min(in_channels, out_channels)\n            ),\n            fl.LayerNorm2d(out_channels),\n            fl.SiLU(),\n            fl.Conv2d(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                kernel_size=1,\n                padding=0,\n            ),\n            fl.LayerNorm2d(out_channels),\n            fl.SiLU(),\n        )\n\n\nclass ResidualBlock(fl.Sum):\n    def __init__(self, in_channels: int, out_channels: int) -> None:\n        super().__init__(\n            ConvBlock(in_channels=in_channels, out_channels=out_channels),\n            fl.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=3,\n                padding=1,\n            ),\n        )\n\n\nclass Encoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            ResidualBlock(in_channels=1, out_channels=8),\n            fl.Downsample(channels=8, scale_factor=2, register_shape=False),\n            ResidualBlock(in_channels=8, out_channels=16),\n            fl.Downsample(channels=16, scale_factor=2, register_shape=False),\n            ResidualBlock(in_channels=16, out_channels=32),\n            fl.Downsample(channels=32, scale_factor=2, register_shape=False),\n            fl.Reshape(2048),\n            fl.Linear(in_features=2048, out_features=256),\n            fl.SiLU(),\n            fl.Linear(in_features=256, out_features=256),\n        )\n\n\nclass Decoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            fl.Linear(in_features=256, out_features=256),\n            fl.SiLU(),\n            fl.Linear(in_features=256, out_features=2048),\n            fl.Reshape(32, 8, 8),\n            ResidualBlock(in_channels=32, out_channels=32),\n            ResidualBlock(in_channels=32, out_channels=32),\n            fl.Upsample(channels=32, upsample_factor=2),\n            ResidualBlock(in_channels=32, out_channels=16),\n            ResidualBlock(in_channels=16, out_channels=16),\n            fl.Upsample(channels=16, upsample_factor=2),\n            ResidualBlock(in_channels=16, out_channels=8),\n            ResidualBlock(in_channels=8, out_channels=8),\n            fl.Upsample(channels=8, upsample_factor=2),\n            ResidualBlock(in_channels=8, out_channels=8),\n            ResidualBlock(in_channels=8, out_channels=1),\n            fl.Sigmoid(),\n        )\n\n\nclass Autoencoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            Encoder(),\n            Decoder(),\n        )\n\n    @property\n    def encoder(self) -> Encoder:\n        return self.ensure_find(Encoder)\n\n    @property\n    def decoder(self) -> Decoder:\n        return self.ensure_find(Decoder)\n

We now have a fully functional autoencoder that takes an image with one channel of size 64x64 and compresses it to a vector of size 256 (x16 compression). The decoder then takes this vector and reconstructs the original image.

import torch\n\nautoencoder = Autoencoder()\n\nx = torch.randn(2, 1, 64, 64) # batch of 2 images\n\nz = autoencoder.encoder(x) # [2, 256]\n\nx_reconstructed = autoencoder.decoder(z) # [2, 1, 64, 64]\n
"},{"location":"guides/training_101/#dataset","title":"Dataset","text":"

We will use a synthetic dataset of rectangles of different sizes. The dataset will be generated on the fly using this simple function:

import random\nfrom typing import Generator\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import image_to_tensor\n\ndef generate_mask(size: int, seed: int | None = None) -> Generator[torch.Tensor, None, None]:\n    \"\"\"Generate a tensor of a binary mask of size `size` using random rectangles.\"\"\"\n    if seed is None:\n        seed = random.randint(0, 2**32 - 1)\n    random.seed(seed)\n\n    while True:\n        rectangle = Image.new(\n            \"L\", (random.randint(1, size), random.randint(1, size)), color=255\n        )\n        mask = Image.new(\"L\", (size, size))\n        mask.paste(\n            rectangle,\n            (\n                random.randint(0, size - rectangle.width),\n                random.randint(0, size - rectangle.height),\n            ),\n        )\n        tensor = image_to_tensor(mask)\n\n        if random.random() > 0.5:\n            tensor = 1 - tensor\n\n        yield tensor\n

To generate a mask, do:

from refiners.fluxion.utils import tensor_to_image\n\nmask = next(generate_mask(64, seed=42))\ntensor_to_image(mask).save(\"mask.png\")\n

Here are a two examples of generated masks:

"},{"location":"guides/training_101/#trainer","title":"Trainer","text":"

We will now create a Trainer class to handle the training loop. This class will manage the model, the optimizer, the loss function, and the dataset. It will also orchestrate the training loop and the evaluation loop.

But first, we need to define the batch type that will be used to represent a batch for the forward and backward pass and the configuration associated with the trainer.

"},{"location":"guides/training_101/#batch","title":"Batch","text":"

Our batches are composed of a single tensor representing the images. We will define a simple Batch type to implement this.

from dataclasses import dataclass\n\n@dataclass\nclass Batch:\n    image: torch.Tensor\n
"},{"location":"guides/training_101/#config","title":"Config","text":"

We will now define the configuration for the autoencoder. It holds the configuration for the training loop, the optimizer, and the learning rate scheduler. It should inherit refiners.training_utils.BaseConfig and has the following mandatory attributes:

Example:

from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch\n\nclass AutoencoderConfig(BaseConfig):\n    ...\n\ntraining = TrainingConfig(\n    duration=Epoch(1000),\n    device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n    dtype=\"float32\"\n)\n\noptimizer = OptimizerConfig(\n    optimizer=Optimizers.AdamW,\n    learning_rate=1e-4,\n)\n\nlr_scheduler = LRSchedulerConfig(\n    type=LRSchedulerType.ConstantLR\n)\n\nconfig = AutoencoderConfig(\n    training=training,\n    optimizer=optimizer,\n    lr_scheduler=lr_scheduler,\n)\n
"},{"location":"guides/training_101/#subclass","title":"Subclass","text":"

We can now define the Trainer subclass. It should inherit from refiners.training_utils.Trainer and implement the following methods:

Here is a simple implementation of the create_data_iterable method. For this toy example, we will generate a simple list of Batch objects containing random masks. Later you can replace this with torch.utils.data.DataLoader or any other data loader with more complex features that support shuffling, parallel loading, etc.

from functools import cached_property\nfrom refiners.training_utils import Trainer\n\n\nclass AutoencoderConfig(BaseConfig):\n    num_images: int = 2048\n    batch_size: int = 32\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    def create_data_iterable(self) -> list[Batch]:\n        dataset: list[Batch] = []\n        generator = generate_mask(size=64)\n\n        for _ in range(self.config.num_images // self.config.batch_size):\n            masks = [next(generator) for _ in range(self.config.batch_size)]\n            dataset.append(Batch(image=torch.cat(masks, dim=0)))\n\n        return dataset\n\n    def compute_loss(self, batch: Batch) -> torch.Tensor:\n        raise NotImplementedError(\"We'll implement this later\")\n\n\ntrainer = AutoencoderTrainer(config)\n
"},{"location":"guides/training_101/#model-registration","title":"Model registration","text":"

For the Trainer to be able to handle the model, we need to register it.

We need two things to do so:

After registering the model, the self.autoencoder attribute will be available in the Trainer.

from refiners.training_utils import ModelConfig, register_model\n\n\nclass AutoencoderModelConfig(ModelConfig):\n    pass\n\n\nclass AutoencoderConfig(BaseConfig):\n    num_images: int = 2048\n    batch_size: int = 32\n    autoencoder: AutoencoderModelConfig\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    # ... other methods\n\n    @register_model()\n    def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:\n        return Autoencoder()\n\n    def compute_loss(self, batch: Batch) -> torch.Tensor:\n        batch.image = batch.image.to(self.device, self.dtype)\n        x_reconstructed = self.autoencoder.decoder(\n            self.autoencoder.encoder(batch.image)\n        )\n        return F.binary_cross_entropy(x_reconstructed, batch.image)\n

We now have a fully functional Trainer that can train our autoencoder. We can now call the train method to start the training loop.

trainer.train()\n

"},{"location":"guides/training_101/#logging","title":"Logging","text":"

Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from refiners.training_utils.Callback and implement any of the following methods:

We will implement the on_epoch_end method to log the loss and the reconstructed images and the on_compute_loss_end method to store the loss in a list.

from refiners.training_utils import Callback\nfrom loguru import logger\nfrom typing import Any\n\n\nclass LoggingCallback(Callback[Any]):\n    losses: list[float] = []\n\n    def on_compute_loss_end(self, loss: torch.Tensor) -> None:\n        self.losses.append(loss.item())\n\n    def on_epoch_end(self, epoch: int) -> None:\n        mean_loss = sum(self.losses) / len(self.losses)\n        logger.info(f\"Mean loss: {mean_loss}, epoch: {epoch}\")\n        self.losses = []\n

Exactly like models, we need to register the callback to the Trainer. We can do so by adding a CallbackConfig attribute to the config named logging and adding a method to the Trainer class that returns the callback decorated with @register_callback decorator.

from refiners.training_utils import CallbackConfig, register_callback\n\nclass AutoencoderConfig(BaseConfig):\n    # ... other properties\n    logging: CallbackConfig = CallbackConfig()\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    # ... other methods\n\n    @register_callback()\n    def logging(self, config: CallbackConfig) -> LoggingCallback:\n        return LoggingCallback()\n

"},{"location":"guides/training_101/#evaluation","title":"Evaluation","text":"

Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a compute_evaluation method, then we register a callback to call this method at regular intervals.

class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    # ... other methods\n\n    def compute_evaluation(self) -> None:\n        generator = generate_mask(size=64, seed=0)\n\n        grid: list[tuple[Image.Image, Image.Image]] = []\n        for _ in range(4):\n            mask = next(generator).to(self.device, self.dtype)\n            x_reconstructed = self.autoencoder.decoder(\n                self.autoencoder.encoder(mask)\n            )\n            loss = F.mse_loss(x_reconstructed, mask)\n            logger.info(f\"Validation loss: {loss.detach().cpu().item()}\")\n            grid.append(\n                (tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))\n            )\n\n        import matplotlib.pyplot as plt\n\n        _, axes = plt.subplots(4, 2, figsize=(8, 16))\n\n        for i, (mask, reconstructed) in enumerate(grid):\n            axes[i, 0].imshow(mask, cmap='gray')\n            axes[i, 0].axis('off')\n            axes[i, 0].set_title('Mask')\n\n            axes[i, 1].imshow(reconstructed, cmap='gray')\n            axes[i, 1].axis('off')\n            axes[i, 1].set_title('Reconstructed')\n\n        plt.tight_layout()\n        plt.savefig(f\"result_{trainer.clock.epoch}.png\")\n        plt.close()\n

We starting by implementing an EvaluationConfig that controls the evaluation interval and the seed for the random generator.

from refiners.training_utils.config import TimeValueField\n\nclass EvaluationConfig(CallbackConfig):\n    interval: TimeValueField\n    seed: int\n

The TimeValueField is a custom field that allow Pydantic to parse a string representing a time value (e.g., \"50:epochs\") into a TimeValue object. This is useful to specify the evaluation interval in the configuration file.

from refiners.training_utils import scoped_seed, Callback\n\nclass EvaluationCallback(Callback[Any]):\n    def __init__(self, config: EvaluationConfig) -> None:\n        self.config = config\n\n    def on_epoch_end(self, trainer: Trainer) -> None:\n        # The `is_due` method checks if the current epoch is a multiple of the interval.\n        if not trainer.clock.is_due(self.config.interval):\n            return\n\n        # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the \n        # evaluation.\n        with scoped_seed(self.config.seed):\n            trainer.compute_evaluation()\n

We can now register the callback to the Trainer.

class AutoencoderConfig(BaseConfig):\n    # ... other properties\n    evaluation: EvaluationConfig\n
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    # ... other methods\n\n    @register_callback()\n    def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:\n        return EvaluationCallback(config)   \n

We can now train the model and see the results in the result_{epoch}.png files.

"},{"location":"guides/training_101/#wrap-up","title":"Wrap up","text":"

You can train this toy model using the code below:

Expand to see the full code.
import random\nfrom dataclasses import dataclass\nfrom typing import Any, Generator\n\nimport torch\nfrom loguru import logger\nfrom PIL import Image\nfrom torch.nn import functional as F\n\nfrom refiners.fluxion import layers as fl\nfrom refiners.fluxion.utils import image_to_tensor, tensor_to_image\nfrom refiners.training_utils import (\n    BaseConfig,\n    Callback,\n    CallbackConfig,\n    ClockConfig,\n    Epoch,\n    LRSchedulerConfig,\n    LRSchedulerType,\n    ModelConfig,\n    OptimizerConfig,\n    Optimizers,\n    Trainer,\n    TrainingConfig,\n    register_callback,\n    register_model,\n)\nfrom refiners.training_utils.common import scoped_seed\nfrom refiners.training_utils.config import TimeValueField\n\n\nclass ConvBlock(fl.Chain):\n    def __init__(self, in_channels: int, out_channels: int) -> None:\n        super().__init__(\n            fl.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=3,\n                padding=1,\n                groups=min(in_channels, out_channels),\n            ),\n            fl.LayerNorm2d(out_channels),\n            fl.SiLU(),\n            fl.Conv2d(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                kernel_size=1,\n                padding=0,\n            ),\n            fl.LayerNorm2d(out_channels),\n            fl.SiLU(),\n        )\n\n\nclass ResidualBlock(fl.Sum):\n    def __init__(self, in_channels: int, out_channels: int) -> None:\n        super().__init__(\n            ConvBlock(in_channels=in_channels, out_channels=out_channels),\n            fl.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=3,\n                padding=1,\n            ),\n        )\n\n\nclass Encoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            ResidualBlock(in_channels=1, out_channels=8),\n            fl.Downsample(channels=8, scale_factor=2, register_shape=False),\n            ResidualBlock(in_channels=8, out_channels=16),\n            fl.Downsample(channels=16, scale_factor=2, register_shape=False),\n            ResidualBlock(in_channels=16, out_channels=32),\n            fl.Downsample(channels=32, scale_factor=2, register_shape=False),\n            fl.Reshape(2048),\n            fl.Linear(in_features=2048, out_features=256),\n            fl.SiLU(),\n            fl.Linear(in_features=256, out_features=256),\n        )\n\n\nclass Decoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            fl.Linear(in_features=256, out_features=256),\n            fl.SiLU(),\n            fl.Linear(in_features=256, out_features=2048),\n            fl.Reshape(32, 8, 8),\n            ResidualBlock(in_channels=32, out_channels=32),\n            ResidualBlock(in_channels=32, out_channels=32),\n            fl.Upsample(channels=32, upsample_factor=2),\n            ResidualBlock(in_channels=32, out_channels=16),\n            ResidualBlock(in_channels=16, out_channels=16),\n            fl.Upsample(channels=16, upsample_factor=2),\n            ResidualBlock(in_channels=16, out_channels=8),\n            ResidualBlock(in_channels=8, out_channels=8),\n            fl.Upsample(channels=8, upsample_factor=2),\n            ResidualBlock(in_channels=8, out_channels=8),\n            ResidualBlock(in_channels=8, out_channels=1),\n            fl.Sigmoid(),\n        )\n\n\nclass Autoencoder(fl.Chain):\n    def __init__(self) -> None:\n        super().__init__(\n            Encoder(),\n            Decoder(),\n        )\n\n    @property\n    def encoder(self) -> Encoder:\n        return self.ensure_find(Encoder)\n\n    @property\n    def decoder(self) -> Decoder:\n        return self.ensure_find(Decoder)\n\n\ndef generate_mask(size: int, seed: int | None = None) -> Generator[torch.Tensor, None, None]:\n    \"\"\"Generate a tensor of a binary mask of size `size` using random rectangles.\"\"\"\n    if seed is None:\n        seed = random.randint(0, 2**32 - 1)\n    random.seed(seed)\n\n    while True:\n        rectangle = Image.new(\"L\", (random.randint(1, size), random.randint(1, size)), color=255)\n        mask = Image.new(\"L\", (size, size))\n        mask.paste(\n            rectangle,\n            (\n                random.randint(0, size - rectangle.width),\n                random.randint(0, size - rectangle.height),\n            ),\n        )\n        tensor = image_to_tensor(mask)\n\n        if random.random() > 0.5:\n            tensor = 1 - tensor\n\n        yield tensor\n\n\n@dataclass\nclass Batch:\n    image: torch.Tensor\n\n\nclass AutoencoderModelConfig(ModelConfig):\n    pass\n\n\nclass LoggingCallback(Callback[Trainer[Any, Any]]):\n    losses: list[float] = []\n\n    def on_compute_loss_end(self, trainer: Trainer[Any, Any]) -> None:\n        self.losses.append(trainer.loss.detach().cpu().item())\n\n    def on_epoch_end(self, trainer: Trainer[Any, Any]) -> None:\n        mean_loss = sum(self.losses) / len(self.losses)\n        logger.info(f\"Mean loss: {mean_loss}, epoch: {trainer.clock.epoch}\")\n        self.losses = []\n\n\nclass EvaluationConfig(CallbackConfig):\n    interval: TimeValueField\n    seed: int\n\n\nclass EvaluationCallback(Callback[\"AutoencoderTrainer\"]):\n    def __init__(self, config: EvaluationConfig) -> None:\n        self.config = config\n\n    def on_epoch_end(self, trainer: \"AutoencoderTrainer\") -> None:\n        # The `is_due` method checks if the current epoch is a multiple of the interval.\n        if not trainer.clock.is_due(self.config.interval):\n            return\n\n        # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the\n        # evaluation.\n        with scoped_seed(self.config.seed):\n            trainer.compute_evaluation()\n\n\nclass AutoencoderConfig(BaseConfig):\n    num_images: int = 2048\n    batch_size: int = 32\n    autoencoder: AutoencoderModelConfig\n    evaluation: EvaluationConfig\n    logging: CallbackConfig = CallbackConfig()\n\n\nautoencoder_config = AutoencoderModelConfig(\n    requires_grad=True,  # set during registration to set the requires_grad attribute of the model.\n)\n\ntraining = TrainingConfig(\n    duration=Epoch(200),\n    device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n    dtype=\"float32\",\n)\n\noptimizer = OptimizerConfig(\n    optimizer=Optimizers.AdamW,\n    learning_rate=1e-4,\n)\n\nlr_scheduler = LRSchedulerConfig(type=LRSchedulerType.CONSTANT_LR)\n\nconfig = AutoencoderConfig(\n    training=training,\n    optimizer=optimizer,\n    lr_scheduler=lr_scheduler,\n    autoencoder=autoencoder_config,\n    evaluation=EvaluationConfig(interval=Epoch(50), seed=0),\n    clock=ClockConfig(verbose=False),  # to disable the default clock logging\n)\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n    def create_data_iterable(self) -> list[Batch]:\n        dataset: list[Batch] = []\n        generator = generate_mask(size=64)\n\n        for _ in range(self.config.num_images // self.config.batch_size):\n            masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)]\n            dataset.append(Batch(image=torch.cat(masks, dim=0)))\n\n        return dataset\n\n    @register_model()\n    def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:\n        return Autoencoder()\n\n    def compute_loss(self, batch: Batch) -> torch.Tensor:\n        batch.image = batch.image.to(self.device, self.dtype)\n        x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))\n        return F.binary_cross_entropy(x_reconstructed, batch.image)\n\n    def compute_evaluation(self) -> None:\n        generator = generate_mask(size=64, seed=0)\n\n        grid: list[tuple[Image.Image, Image.Image]] = []\n        validation_losses: list[float] = []\n        for _ in range(4):\n            mask = next(generator).to(self.device, self.dtype)\n            x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(mask))\n            loss = F.mse_loss(x_reconstructed, mask)\n            validation_losses.append(loss.detach().cpu().item())\n            grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed > 0.5).float())))\n\n        mean_loss = sum(validation_losses) / len(validation_losses)\n        logger.info(f\"Mean validation loss: {mean_loss}, epoch: {self.clock.epoch}\")\n\n        import matplotlib.pyplot as plt\n\n        _, axes = plt.subplots(4, 2, figsize=(8, 16))  # type: ignore\n\n        for i, (mask, reconstructed) in enumerate(grid):\n            axes[i, 0].imshow(mask, cmap=\"gray\")\n            axes[i, 0].axis(\"off\")\n            axes[i, 0].set_title(\"Mask\")\n\n            axes[i, 1].imshow(reconstructed, cmap=\"gray\")\n            axes[i, 1].axis(\"off\")\n            axes[i, 1].set_title(\"Reconstructed\")\n\n        plt.tight_layout()  # type: ignore\n        plt.savefig(f\"result_{trainer.clock.epoch}.png\")  # type: ignore\n        plt.close()  # type: ignore\n\n    @register_callback()\n    def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:\n        return EvaluationCallback(config)\n\n    @register_callback()\n    def logging(self, config: CallbackConfig) -> LoggingCallback:\n        return LoggingCallback()\n\n\ntrainer = AutoencoderTrainer(config)\n\ntrainer.train()\n
"},{"location":"home/why/","title":"Why Refiners?","text":""},{"location":"home/why/#pytorch-an-imperative-framework","title":"PyTorch: an imperative framework","text":"

PyTorch is a great framework to implement deep learning models, widely adopted in academia and industry around the globe. A core design principle of PyTorch is that users write imperative Python code that manipulates Tensors1. This code can be organized in Modules, which are just Python classes whose constructors typically initialize parameters and load weights, and which implement a forward method that computes the forward pass. Dealing with reconstructing an inference graph, backpropagation and so on are left to the framework.

This approach works very well in general, as demonstrated by the popularity of PyTorch. However, the growing importance of the Adaptation pattern is challenging it.

"},{"location":"home/why/#adaptation-patching-foundation-models","title":"Adaptation: patching foundation models","text":"

Adaptation is the idea of patching existing powerful models to implement new capabilities. Those models are called foundation models; they are typically trained from scratch on amounts of data inaccessible to most individuals, small companies or research labs, and exhibit emergent properties. Examples of such models are LLMs (GPT, LLaMa, Mistral), image generation models (Stable Diffusion, Muse), vision models (BLIP-2, LLaVA 1.5, Fuyu-8B) but also models trained on more specific tasks such as embedding extraction (CLIP, DINOv2) or image segmentation (SAM).

Adaptation of foundation models can take many forms. One of the simplest but most powerful derives from fine-tuning: re-training a subset of the weights of the model on a specific task, then distributing only those weights. Add to this a trick to significantly reduce the size of the fine-tuned weights and you get LoRA2, which is probably the most well-known adaptation method. However, adaptation can go beyond that and change the shape of the model or its inputs.

"},{"location":"home/why/#imperative-code-is-hard-to-patch-cleanly","title":"Imperative code is hard to patch cleanly","text":"

There are several approaches to patch the code of a foundation model implemented in typical PyTorch imperative style to support adaptation, including:

As believers in adaptation, none of those approaches was appealing to us, so we designed Refiners as a better option. Refiners is a micro-framework built on top of PyTorch which does away with its imperative style. In Refiners, models are implemented in a declarative way instead, which makes them by nature easier to manipulate and patch.

"},{"location":"home/why/#whats-next","title":"What's next?","text":"

Now you know why we wrote a declarative framework, you can check out how. It's not that complicated, we promise!

  1. Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.\u00a0\u21a9

  2. Hu et al., 2022. LoRA: Low-Rank Adaptation of Large Language Models.\u00a0\u21a9

"},{"location":"reference/SUMMARY/","title":"SUMMARY","text":""},{"location":"reference/fluxion/adapters/","title":" Adapters","text":""},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter","title":"Adapter","text":"

Bases: Generic[T]

Base class for adapters.

An Adapter modifies the structure of a Module (typically by adding, removing or replacing layers), to adapt it to a new task.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.target","title":"target property","text":"
target: T\n

The target of the adapter.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.eject","title":"eject","text":"
eject() -> None\n

Eject the adapter.

This method is the inverse of inject, and should leave the target in the same state as before the injection.

Source code in src/refiners/fluxion/adapters/adapter.py
def eject(self) -> None:\n    \"\"\"Eject the adapter.\n\n    This method is the inverse of [`inject`][refiners.fluxion.adapters.Adapter.inject],\n    and should leave the target in the same state as before the injection.\n    \"\"\"\n    assert isinstance(self, fl.Chain)\n\n    # In general, the \"actual target\" is the target.\n    # Here we deal with the edge case where the target\n    # is part of the replacement block and has been adapted by\n    # another adapter after this one. For instance, this is the\n    # case when stacking Controlnets.\n    actual_target = lookup_top_adapter(self, self.target)\n\n    if (parent := self.parent) is None:\n        if isinstance(actual_target, fl.ContextModule):\n            actual_target._set_parent(None)  # type: ignore[reportPrivateUsage]\n    else:\n        parent.replace(old_module=self, new_module=actual_target)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.inject","title":"inject","text":"
inject(parent: Chain | None = None) -> TAdapter\n

Inject the adapter.

This method replaces the target of the adapter by the adapter inside the parent of the target.

Parameters:

Name Type Description Default parent Chain | None

The parent to inject the adapter into, if the target doesn't have a parent.

None Source code in src/refiners/fluxion/adapters/adapter.py
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:\n    \"\"\"Inject the adapter.\n\n    This method replaces the target of the adapter by the adapter inside the parent of the target.\n\n    Args:\n        parent: The parent to inject the adapter into, if the target doesn't have a parent.\n    \"\"\"\n    assert isinstance(self, fl.Chain)\n\n    if (parent is None) and isinstance(self.target, fl.ContextModule):\n        parent = self.target.parent\n        if parent is not None:\n            assert isinstance(parent, fl.Chain), f\"{self.target} has invalid parent {parent}\"\n\n    target_parent = self.find_parent(self.target)\n\n    if parent is None:\n        if isinstance(self.target, fl.ContextModule):\n            self.target._set_parent(target_parent)  # type: ignore[reportPrivateUsage]\n        return self\n\n    # In general, `true_parent` is `parent`. We do this to support multiple adaptation,\n    # i.e. initializing two adapters before injecting them.\n    true_parent = parent.ensure_find_parent(self.target)\n    true_parent.replace(\n        old_module=self.target,\n        new_module=self,\n        old_module_parent=target_parent,\n    )\n    return self\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.setup_adapter","title":"setup_adapter","text":"
setup_adapter(target: T) -> Iterator[None]\n

Setup the adapter.

This method should be called by the constructor of the adapter. It sets the target of the adapter and ensures that the adapter is not a submodule of the target.

Parameters:

Name Type Description Default target T

The target of the adapter.

required Source code in src/refiners/fluxion/adapters/adapter.py
@contextlib.contextmanager\ndef setup_adapter(self, target: T) -> Iterator[None]:\n    \"\"\"Setup the adapter.\n\n    This method should be called by the constructor of the adapter.\n    It sets the target of the adapter and ensures that the adapter\n    is not a submodule of the target.\n\n    Args:\n        target: The target of the adapter.\n    \"\"\"\n    assert isinstance(self, fl.Chain)\n    assert (not hasattr(self, \"_modules\")) or (\n        len(self) == 0\n    ), \"Call the Chain constructor in the setup_adapter context.\"\n    self._target = [target]\n\n    if isinstance(target, fl.ContextModule):\n        assert isinstance(target, fl.ContextModule)\n        with target.no_parent_refresh():\n            yield\n    else:\n        yield\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Conv2dLora","title":"Conv2dLora","text":"
Conv2dLora(\n    name: str,\n    /,\n    in_channels: int,\n    out_channels: int,\n    rank: int = 16,\n    scale: float = 1.0,\n    kernel_size: tuple[int, int] = (1, 3),\n    stride: tuple[int, int] = (1, 1),\n    padding: tuple[int, int] = (0, 1),\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Lora[Conv2d]

Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.

This layer uses two Conv2d layers as its down and up layers.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required in_channels int

The number of input channels.

required out_channels int

The number of output channels.

required rank int

The rank of the LoRA.

16 scale float

The scale of the LoRA.

1.0 kernel_size tuple[int, int]

The kernel size of the LoRA.

(1, 3) stride tuple[int, int]

The stride of the LoRA.

(1, 1) padding tuple[int, int]

The padding of the LoRA.

(0, 1) device device | str | None

The device of the LoRA weights.

None dtype dtype | None

The dtype of the LoRA weights.

None Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n    self,\n    name: str,\n    /,\n    in_channels: int,\n    out_channels: int,\n    rank: int = 16,\n    scale: float = 1.0,\n    kernel_size: tuple[int, int] = (1, 3),\n    stride: tuple[int, int] = (1, 1),\n    padding: tuple[int, int] = (0, 1),\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the LoRA layer.\n\n    Args:\n        name: The name of the LoRA.\n        in_channels: The number of input channels.\n        out_channels: The number of output channels.\n        rank: The rank of the LoRA.\n        scale: The scale of the LoRA.\n        kernel_size: The kernel size of the LoRA.\n        stride: The stride of the LoRA.\n        padding: The padding of the LoRA.\n        device: The device of the LoRA weights.\n        dtype: The dtype of the LoRA weights.\n    \"\"\"\n    self.in_channels = in_channels\n    self.out_channels = out_channels\n    self.kernel_size = kernel_size\n    self.stride = stride\n    self.padding = padding\n\n    super().__init__(\n        name,\n        rank=rank,\n        scale=scale,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LinearLora","title":"LinearLora","text":"
LinearLora(\n    name: str,\n    /,\n    in_features: int,\n    out_features: int,\n    rank: int = 16,\n    scale: float = 1.0,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Lora[Linear]

Low-Rank Adaptation (LoRA) layer for linear layers.

This layer uses two Linear layers as its down and up layers.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required in_features int

The number of input features.

required out_features int

The number of output features.

required rank int

The rank of the LoRA.

16 scale float

The scale of the LoRA.

1.0 device device | str | None

The device of the LoRA weights.

None dtype dtype | None

The dtype of the LoRA weights.

None Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n    self,\n    name: str,\n    /,\n    in_features: int,\n    out_features: int,\n    rank: int = 16,\n    scale: float = 1.0,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the LoRA layer.\n\n    Args:\n        name: The name of the LoRA.\n        in_features: The number of input features.\n        out_features: The number of output features.\n        rank: The rank of the LoRA.\n        scale: The scale of the LoRA.\n        device: The device of the LoRA weights.\n        dtype: The dtype of the LoRA weights.\n    \"\"\"\n    self.in_features = in_features\n    self.out_features = out_features\n\n    super().__init__(\n        name,\n        rank=rank,\n        scale=scale,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora","title":"Lora","text":"
Lora(\n    name: str,\n    /,\n    rank: int = 16,\n    scale: float = 1.0,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Generic[T], Chain, ABC

Low-Rank Adaptation (LoRA) layer.

This layer's purpose is to approximate a given layer by two smaller layers: the down layer (aka A) and the up layer (aka B). See [ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models for more details.

Note

This layer is not meant to be used directly. Instead, use one of its subclasses:

Parameters:

Name Type Description Default name str

The name of the LoRA.

required rank int

The rank of the LoRA.

16 scale float

The scale of the LoRA.

1.0 device device | str | None

The device of the LoRA weights.

None dtype dtype | None

The dtype of the LoRA weights.

None Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n    self,\n    name: str,\n    /,\n    rank: int = 16,\n    scale: float = 1.0,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the LoRA layer.\n\n    Args:\n        name: The name of the LoRA.\n        rank: The rank of the LoRA.\n        scale: The scale of the LoRA.\n        device: The device of the LoRA weights.\n        dtype: The dtype of the LoRA weights.\n    \"\"\"\n    self.name = name\n    self._rank = rank\n    self._scale = scale\n\n    super().__init__(\n        *self.lora_layers(device=device, dtype=dtype),\n        fl.Multiply(scale),\n    )\n    self.reset_parameters()\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.down","title":"down property","text":"
down: T\n

The down layer.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.rank","title":"rank property","text":"
rank: int\n

The rank of the low-rank approximation.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.scale","title":"scale property writable","text":"
scale: float\n

The scale of the low-rank approximation.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.up","title":"up property","text":"
up: T\n

The up layer.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.from_dict","title":"from_dict classmethod","text":"
from_dict(\n    name: str, /, state_dict: dict[str, Tensor]\n) -> dict[str, Lora[Any]]\n

Create a dictionary of LoRA layers from a state dict.

Expects the state dict to be a succession of down and up weights.

Source code in src/refiners/fluxion/adapters/lora.py
@classmethod\ndef from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, \"Lora[Any]\"]:\n    \"\"\"\n    Create a dictionary of LoRA layers from a state dict.\n\n    Expects the state dict to be a succession of down and up weights.\n    \"\"\"\n    state_dict = {k: v for k, v in state_dict.items() if \".weight\" in k}\n    loras: dict[str, Lora[Any]] = {}\n    for down_key, down_tensor, up_tensor in zip(\n        list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]\n    ):\n        key = \".\".join(down_key.split(\".\")[:-2])\n        loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)\n    return loras\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.load_weights","title":"load_weights","text":"
load_weights(\n    down_weight: Tensor, up_weight: Tensor\n) -> None\n

Load the (pre-trained) weights of the LoRA.

Parameters:

Name Type Description Default down_weight Tensor

The down weight.

required up_weight Tensor

The up weight.

required Source code in src/refiners/fluxion/adapters/lora.py
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:\n    \"\"\"Load the (pre-trained) weights of the LoRA.\n\n    Args:\n        down_weight: The down weight.\n        up_weight: The up weight.\n    \"\"\"\n    assert down_weight.shape == self.down.weight.shape\n    assert up_weight.shape == self.up.weight.shape\n    self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))\n    self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.lora_layers","title":"lora_layers abstractmethod","text":"
lora_layers(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n) -> tuple[T, T]\n

Create the down and up layers of the LoRA.

Parameters:

Name Type Description Default device device | str | None

The device of the LoRA weights.

None dtype dtype | None

The dtype of the LoRA weights.

None Source code in src/refiners/fluxion/adapters/lora.py
@abstractmethod\ndef lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:\n    \"\"\"Create the down and up layers of the LoRA.\n\n    Args:\n        device: The device of the LoRA weights.\n        dtype: The dtype of the LoRA weights.\n    \"\"\"\n    ...\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.reset_parameters","title":"reset_parameters","text":"
reset_parameters() -> None\n

Reset the parameters of up and down layers.

Source code in src/refiners/fluxion/adapters/lora.py
def reset_parameters(self) -> None:\n    \"\"\"Reset the parameters of up and down layers.\"\"\"\n    normal_(tensor=self.down.weight, std=1 / self.rank)\n    zeros_(tensor=self.up.weight)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter","title":"LoraAdapter","text":"
LoraAdapter(target: WeightedModule, /, *loras: Lora[Any])\n

Bases: Sum, Adapter[WeightedModule]

Adapter for LoRA layers.

This adapter simply sums the target layer with the given LoRA layers.

Parameters:

Name Type Description Default target WeightedModule

The target layer.

required loras Lora[Any]

The LoRA layers.

() Source code in src/refiners/fluxion/adapters/lora.py
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:\n    \"\"\"Initialize the adapter.\n\n    Args:\n        target: The target layer.\n        loras: The LoRA layers.\n    \"\"\"\n    with self.setup_adapter(target):\n        super().__init__(target, *loras)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.lora_layers","title":"lora_layers property","text":"
lora_layers: Iterator[Lora[Any]]\n

The LoRA layers.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.loras","title":"loras property","text":"
loras: dict[str, Lora[Any]]\n

The LoRA layers indexed by name.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.names","title":"names property","text":"
names: list[str]\n

The names of the LoRA layers.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.scales","title":"scales property","text":"
scales: dict[str, float]\n

The scales of the LoRA layers indexed by names.

"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.add_lora","title":"add_lora","text":"
add_lora(lora: Lora[Any]) -> None\n

Add a LoRA layer to the adapter.

Raises:

Type Description AssertionError

If the adapter already contains a LoRA layer with the same name.

Parameters:

Name Type Description Default lora Lora[Any]

The LoRA layer to add.

required Source code in src/refiners/fluxion/adapters/lora.py
def add_lora(self, lora: Lora[Any], /) -> None:\n    \"\"\"Add a LoRA layer to the adapter.\n\n    Raises:\n        AssertionError: If the adapter already contains a LoRA layer with the same name.\n\n    Args:\n        lora: The LoRA layer to add.\n    \"\"\"\n    assert lora.name not in self.names, f\"LoRA layer with name {lora.name} already exists\"\n    self.append(lora)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.remove_lora","title":"remove_lora","text":"
remove_lora(name: str) -> Lora[Any] | None\n

Remove a LoRA layer from the adapter.

Note

If the adapter doesn't contain a LoRA layer with the given name, nothing happens and None is returned.

Parameters:

Name Type Description Default name str

The name of the LoRA layer to remove.

required Source code in src/refiners/fluxion/adapters/lora.py
def remove_lora(self, name: str, /) -> Lora[Any] | None:\n    \"\"\"Remove a LoRA layer from the adapter.\n\n    Note:\n        If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.\n\n    Args:\n        name: The name of the LoRA layer to remove.\n    \"\"\"\n    if name in self.names:\n        lora = self.loras[name]\n        self.remove(lora)\n        return lora\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.auto_attach_loras","title":"auto_attach_loras","text":"
auto_attach_loras(\n    loras: dict[str, Lora[Any]],\n    target: Chain,\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    sanity_check: bool = True,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> list[str]\n

Auto-attach several LoRA layers to a Chain.

Parameters:

Name Type Description Default loras dict[str, Lora[Any]]

A dictionary of LoRA layers associated to their respective key. The keys are typically derived from the state dict and only used for debug_map and the return value.

required target Chain

The target Chain.

required include list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered.

None exclude list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered.

None sanity_check bool

Check that LoRAs passed are correctly attached.

True debug_map list[tuple[str, str]] | None

Pass a list to get a debug mapping of key - path pairs of attached points.

None

Returns: A list of keys of LoRA layers which failed to attach.

Source code in src/refiners/fluxion/adapters/lora.py
def auto_attach_loras(\n    loras: dict[str, Lora[Any]],\n    target: fl.Chain,\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    sanity_check: bool = True,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> list[str]:\n    \"\"\"Auto-attach several LoRA layers to a Chain.\n\n    Args:\n        loras: A dictionary of LoRA layers associated to their respective key. The keys are typically\n            derived from the state dict and only used for `debug_map` and the return value.\n        target: The target Chain.\n        include: A list of layer names, only layers with such a layer in their ancestors will be considered.\n        exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.\n        sanity_check: Check that LoRAs passed are correctly attached.\n        debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.\n    Returns:\n        A list of keys of LoRA layers which failed to attach.\n    \"\"\"\n\n    if not sanity_check:\n        return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)\n\n    loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}\n    debug_map_1: list[tuple[str, str]] = []\n    failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)\n    if debug_map is not None:\n        debug_map += debug_map_1\n    if len(debug_map_1) != len(loras) or failed_keys_1:\n        raise ValueError(\n            f\"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed\"\n        )\n\n    # Extra sanity check: if we re-run the attach, all layers should fail.\n    debug_map_2: list[tuple[str, str]] = []\n    failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)\n    if debug_map_2 or len(failed_keys_2) != len(loras):\n        raise ValueError(\n            f\"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped\"\n        )\n\n    return failed_keys_1\n
"},{"location":"reference/fluxion/context/","title":" Context","text":""},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider","title":"ContextProvider","text":"
ContextProvider()\n

A class that provides a context store.

Source code in src/refiners/fluxion/context.py
def __init__(self) -> None:\n    \"\"\"Initializes the ContextProvider.\"\"\"\n    self.contexts: Contexts = {}\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.create","title":"create staticmethod","text":"
create(contexts: Contexts) -> ContextProvider\n

Create a ContextProvider from a dict of contexts.

Parameters:

Name Type Description Default contexts Contexts

The contexts.

required

Returns:

Type Description ContextProvider

A ContextProvider with the contexts.

Source code in src/refiners/fluxion/context.py
@staticmethod\ndef create(contexts: Contexts) -> \"ContextProvider\":\n    \"\"\"Create a ContextProvider from a dict of contexts.\n\n    Args:\n        contexts: The contexts.\n\n    Returns:\n        A ContextProvider with the contexts.\n    \"\"\"\n    provider = ContextProvider()\n    provider.update_contexts(contexts)\n    return provider\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.get_context","title":"get_context","text":"
get_context(key: str) -> Any\n

Retrieve a value from the context.

Parameters:

Name Type Description Default key str

The key of the context.

required

Returns:

Type Description Any

The context value.

Source code in src/refiners/fluxion/context.py
def get_context(self, key: str) -> Any:\n    \"\"\"Retrieve a value from the context.\n\n    Args:\n        key: The key of the context.\n\n    Returns:\n        The context value.\n    \"\"\"\n    return self.contexts.get(key)\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.set_context","title":"set_context","text":"
set_context(key: str, value: Context) -> None\n

Store a value in the context.

Parameters:

Name Type Description Default key str

The key of the context.

required value Context

The context.

required Source code in src/refiners/fluxion/context.py
def set_context(self, key: str, value: Context) -> None:\n    \"\"\"Store a value in the context.\n\n    Args:\n        key: The key of the context.\n        value: The context.\n    \"\"\"\n    self.contexts[key] = value\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.update_contexts","title":"update_contexts","text":"
update_contexts(new_contexts: Contexts) -> None\n

Update or set the contexts with new contexts.

Parameters:

Name Type Description Default new_contexts Contexts

The new contexts.

required Source code in src/refiners/fluxion/context.py
def update_contexts(self, new_contexts: Contexts) -> None:\n    \"\"\"Update or set the contexts with new contexts.\n\n    Args:\n        new_contexts: The new contexts.\n    \"\"\"\n    for key, value in new_contexts.items():\n        if key not in self.contexts:\n            self.contexts[key] = value\n        else:\n            self.contexts[key].update(value)\n
"},{"location":"reference/fluxion/layers/","title":" Layers","text":""},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Activation","title":"Activation","text":"
Activation()\n

Bases: Module, ABC

Base class for activation layers.

Activation layers are layers that apply a (non-linear) function to their input.

Receives:

Name Type Description x Tensor

Returns:

Type Description Tensor Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n    super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Attention","title":"Attention","text":"
Attention(\n    embedding_dim: int,\n    num_heads: int = 1,\n    key_embedding_dim: int | None = None,\n    value_embedding_dim: int | None = None,\n    inner_dim: int | None = None,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Multi-Head Attention layer.

See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details

This layer simply chains

Receives:

Name Type Description Query Float[Tensor, 'batch sequence_length embedding_dim'] Key Float[Tensor, 'batch sequence_length embedding_dim'] Value Float[Tensor, 'batch sequence_length embedding_dim']

Returns:

Type Description Float[Tensor, 'batch sequence_length embedding_dim'] Example
attention = fl.Attention(num_heads=8, embedding_dim=128)\n\ntensor = torch.randn(2, 10, 128)\noutput = attention(tensor, tensor, tensor)\n\nassert output.shape == (2, 10, 128)\n

Parameters:

Name Type Description Default embedding_dim int

The embedding dimension of the input and output tensors.

required num_heads int

The number of heads to use.

1 key_embedding_dim int | None

The embedding dimension of the key tensor.

None value_embedding_dim int | None

The embedding dimension of the value tensor.

None inner_dim int | None

The inner dimension of the linear layers.

None use_bias bool

Whether to use bias in the linear layers.

True is_causal bool

Whether to use causal attention.

False is_optimized bool

Whether to use optimized attention.

True device device | str | None

The device to use.

None dtype dtype | None

The dtype to use.

None Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n    self,\n    embedding_dim: int,\n    num_heads: int = 1,\n    key_embedding_dim: int | None = None,\n    value_embedding_dim: int | None = None,\n    inner_dim: int | None = None,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the Attention layer.\n\n    Args:\n        embedding_dim: The embedding dimension of the input and output tensors.\n        num_heads: The number of heads to use.\n        key_embedding_dim: The embedding dimension of the key tensor.\n        value_embedding_dim: The embedding dimension of the value tensor.\n        inner_dim: The inner dimension of the linear layers.\n        use_bias: Whether to use bias in the linear layers.\n        is_causal: Whether to use causal attention.\n        is_optimized: Whether to use optimized attention.\n        device: The device to use.\n        dtype: The dtype to use.\n    \"\"\"\n    assert (\n        embedding_dim % num_heads == 0\n    ), f\"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}\"\n    self.embedding_dim = embedding_dim\n    self.num_heads = num_heads\n    self.heads_dim = embedding_dim // num_heads\n    self.key_embedding_dim = key_embedding_dim or embedding_dim\n    self.value_embedding_dim = value_embedding_dim or embedding_dim\n    self.inner_dim = inner_dim or embedding_dim\n    self.use_bias = use_bias\n    self.is_causal = is_causal\n    self.is_optimized = is_optimized\n\n    super().__init__(\n        Distribute(\n            Linear(  # Query projection\n                in_features=self.embedding_dim,\n                out_features=self.inner_dim,\n                bias=self.use_bias,\n                device=device,\n                dtype=dtype,\n            ),\n            Linear(  # Key projection\n                in_features=self.key_embedding_dim,\n                out_features=self.inner_dim,\n                bias=self.use_bias,\n                device=device,\n                dtype=dtype,\n            ),\n            Linear(  # Value projection\n                in_features=self.value_embedding_dim,\n                out_features=self.inner_dim,\n                bias=self.use_bias,\n                device=device,\n                dtype=dtype,\n            ),\n        ),\n        ScaledDotProductAttention(\n            num_heads=num_heads,\n            is_causal=is_causal,\n            is_optimized=is_optimized,\n        ),\n        Linear(  # Output projection\n            in_features=self.inner_dim,\n            out_features=self.embedding_dim,\n            bias=True,\n            device=device,\n            dtype=dtype,\n        ),\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Breakpoint","title":"Breakpoint","text":"
Breakpoint(vscode: bool = True)\n

Bases: ContextModule

Breakpoint layer.

This layer pauses the execution when encountered, and opens a debugger.

Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, vscode: bool = True):\n    super().__init__()\n    self.vscode = vscode\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain","title":"Chain","text":"
Chain(*args: Module | Iterable[Module])\n

Bases: ContextModule

Chain layer.

This layer is the main building block of Fluxion. It is used to compose other layers in a sequential manner. Similarly to torch.nn.Sequential, it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer. However, it also provides additional methods to manipulate its sub-layers and their context.

Example
chain = fl.Chain(\n    fl.Linear(32, 64),\n    fl.ReLU(),\n    fl.Linear(64, 128),\n)\n\ntensor = torch.randn(2, 32)\noutput = chain(tensor)\n\nassert output.shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.device","title":"device property","text":"
device: device | None\n

The PyTorch device of the Chain's parameters.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.dtype","title":"dtype property","text":"
dtype: dtype | None\n

The PyTorch dtype of the Chain's parameters.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.provider","title":"provider property","text":"
provider: ContextProvider\n

The ContextProvider of the Chain.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.append","title":"append","text":"
append(module: Module) -> None\n

Append a new module to the chain.

Parameters:

Name Type Description Default module Module

The module to append.

required Source code in src/refiners/fluxion/layers/chain.py
def append(self, module: Module) -> None:\n    \"\"\"Append a new module to the chain.\n\n    Args:\n        module: The module to append.\n    \"\"\"\n    self.insert(-1, module)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.ensure_find","title":"ensure_find","text":"
ensure_find(layer_type: type[T]) -> T\n

Walk the Chain's sub-module tree and return the first layer of the given type.

Parameters:

Name Type Description Default layer_type type[T]

The type of layer to find.

required

Returns:

Type Description T

The first module of the given layer_type.

Raises:

Type Description AssertionError

If the module doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def ensure_find(self, layer_type: type[T]) -> T:\n    \"\"\"Walk the Chain's sub-module tree and return the first layer of the given type.\n\n    Args:\n        layer_type: The type of layer to find.\n\n    Returns:\n        The first module of the given layer_type.\n\n    Raises:\n        AssertionError: If the module doesn't exist.\n    \"\"\"\n    r = self.find(layer_type)\n    assert r is not None, f\"could not find {layer_type} in {self}\"\n    return r\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.ensure_find_parent","title":"ensure_find_parent","text":"
ensure_find_parent(module: Module) -> Chain\n

Walk the Chain's sub-module tree and return the parent of the given module.

Parameters:

Name Type Description Default module Module

The module whose parent to find.

required

Returns:

Type Description Chain

The parent of the given module.

Raises:

Type Description AssertionError

If the module doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def ensure_find_parent(self, module: Module) -> \"Chain\":\n    \"\"\"Walk the Chain's sub-module tree and return the parent of the given module.\n\n    Args:\n        module: The module whose parent to find.\n\n    Returns:\n        The parent of the given module.\n\n    Raises:\n        AssertionError: If the module doesn't exist.\n    \"\"\"\n    r = self.find_parent(module)\n    assert r is not None, f\"could not find {module} in {self}\"\n    return r\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.find","title":"find","text":"
find(layer_type: type[T]) -> T | None\n

Walk the Chain's sub-module tree and return the first layer of the given type.

Parameters:

Name Type Description Default layer_type type[T]

The type of layer to find.

required

Returns:

Type Description T | None

The first module of the given layer_type, or None if it doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def find(self, layer_type: type[T]) -> T | None:\n    \"\"\"Walk the Chain's sub-module tree and return the first layer of the given type.\n\n    Args:\n        layer_type: The type of layer to find.\n\n    Returns:\n        The first module of the given layer_type, or None if it doesn't exist.\n    \"\"\"\n    return next(self.layers(layer_type=layer_type), None)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.find_parent","title":"find_parent","text":"
find_parent(module: Module) -> Chain | None\n

Walk the Chain's sub-module tree and return the parent of the given module.

Parameters:

Name Type Description Default module Module

The module whose parent to find.

required

Returns:

Type Description Chain | None

The parent of the given module, or None if it doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def find_parent(self, module: Module) -> \"Chain | None\":\n    \"\"\"Walk the Chain's sub-module tree and return the parent of the given module.\n\n    Args:\n        module: The module whose parent to find.\n\n    Returns:\n        The parent of the given module, or None if it doesn't exist.\n    \"\"\"\n    if module in self:  # avoid DFS-crawling the whole tree\n        return self\n    for _, parent in self.walk(lambda m, _: m == module):\n        return parent\n    return None\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.init_context","title":"init_context","text":"
init_context() -> Contexts\n

Initialize the context provider with some default values.

This method is called when the Chain is created, and when it is reset. This method may be overridden by subclasses to provide default values for the context provider.

Source code in src/refiners/fluxion/layers/chain.py
def init_context(self) -> Contexts:\n    \"\"\"Initialize the context provider with some default values.\n\n    This method is called when the Chain is created, and when it is reset.\n    This method may be overridden by subclasses to provide default values for the context provider.\n    \"\"\"\n    return {}\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert","title":"insert","text":"
insert(index: int, module: Module) -> None\n

Insert a new module in the chain.

Parameters:

Name Type Description Default index int

The index at which to insert the module.

required module Module

The module to insert.

required

Raises:

Type Description IndexError

If the index is out of range.

Source code in src/refiners/fluxion/layers/chain.py
def insert(self, index: int, module: Module) -> None:\n    \"\"\"Insert a new module in the chain.\n\n    Args:\n        index: The index at which to insert the module.\n        module: The module to insert.\n\n    Raises:\n        IndexError: If the index is out of range.\n    \"\"\"\n    if index < 0:\n        index = max(0, len(self._modules) + index + 1)\n    modules = list(self)\n    modules.insert(index, module)\n    self._regenerate_keys(modules)\n    if isinstance(module, ContextModule):\n        module._set_parent(self)\n    self._register_provider()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert_after_type","title":"insert_after_type","text":"
insert_after_type(\n    module_type: type[Module], new_module: Module\n) -> None\n

Insert a new module in the chain, right after the first module of the given type.

Parameters:

Name Type Description Default module_type type[Module]

The type of module to insert after.

required new_module Module

The module to insert.

required

Raises:

Type Description ValueError

If no module of the given type exists in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:\n    \"\"\"Insert a new module in the chain, right after the first module of the given type.\n\n    Args:\n        module_type: The type of module to insert after.\n        new_module: The module to insert.\n\n    Raises:\n        ValueError: If no module of the given type exists in the chain.\n    \"\"\"\n    for i, module in enumerate(self):\n        if isinstance(module, module_type):\n            self.insert(i + 1, new_module)\n            return\n    raise ValueError(f\"No module of type {module_type.__name__} found in the chain.\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert_before_type","title":"insert_before_type","text":"
insert_before_type(\n    module_type: type[Module], new_module: Module\n) -> None\n

Insert a new module in the chain, right before the first module of the given type.

Parameters:

Name Type Description Default module_type type[Module]

The type of module to insert before.

required new_module Module

The module to insert.

required

Raises:

Type Description ValueError

If no module of the given type exists in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:\n    \"\"\"Insert a new module in the chain, right before the first module of the given type.\n\n    Args:\n        module_type: The type of module to insert before.\n        new_module: The module to insert.\n\n    Raises:\n        ValueError: If no module of the given type exists in the chain.\n    \"\"\"\n    for i, module in enumerate(self):\n        if isinstance(module, module_type):\n            self.insert(i, new_module)\n            return\n    raise ValueError(f\"No module of type {module_type.__name__} found in the chain.\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.layer","title":"layer","text":"
layer(\n    key: str | int | Sequence[str | int],\n    layer_type: type[T] = Module,\n) -> T\n

Access a layer of the Chain given its type.

Example
# same as my_chain[\"Linear_2\"], asserts it is a Linear\nmy_chain.layer(\"Linear_2\", fl.Linear)\n\n\n# same as my_chain[3], asserts it is a Linear\nmy_chain.layer(3, fl.Linear)\n\n# probably won't work\nmy_chain.layer(\"Conv2d\", fl.Linear)\n\n\n# same as my_chain[\"foo\"][42][\"bar\"],\n# assuming bar is a MyType and all parents are Chains\nmy_chain.layer((\"foo\", 42, \"bar\"), fl.MyType)\n

Parameters:

Name Type Description Default key str | int | Sequence[str | int]

The key or path of the layer.

required layer_type type[T]

The type of the layer.

Module

Yields:

Type Description T

The layer.

Raises:

Type Description AssertionError

If the layer doesn't exist or the type is invalid.

Source code in src/refiners/fluxion/layers/chain.py
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:\n    \"\"\"Access a layer of the Chain given its type.\n\n    Example:\n        ```py\n        # same as my_chain[\"Linear_2\"], asserts it is a Linear\n        my_chain.layer(\"Linear_2\", fl.Linear)\n\n\n        # same as my_chain[3], asserts it is a Linear\n        my_chain.layer(3, fl.Linear)\n\n        # probably won't work\n        my_chain.layer(\"Conv2d\", fl.Linear)\n\n\n        # same as my_chain[\"foo\"][42][\"bar\"],\n        # assuming bar is a MyType and all parents are Chains\n        my_chain.layer((\"foo\", 42, \"bar\"), fl.MyType)\n        ```\n\n    Args:\n        key: The key or path of the layer.\n        layer_type: The type of the layer.\n\n    Yields:\n        The layer.\n\n    Raises:\n        AssertionError: If the layer doesn't exist or the type is invalid.\n    \"\"\"\n    if isinstance(key, (str, int)):\n        r = self[key]\n        assert isinstance(r, layer_type), f\"layer {key} is {type(r)}, not {layer_type}\"\n        return r\n    if len(key) == 0:\n        assert isinstance(self, layer_type), f\"layer is {type(self)}, not {layer_type}\"\n        return self\n    if len(key) == 1:\n        return self.layer(key[0], layer_type)\n    return self.layer(key[0], Chain).layer(key[1:], layer_type)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.layers","title":"layers","text":"
layers(\n    layer_type: type[T], recurse: bool = False\n) -> Iterator[T]\n

Walk the Chain's sub-module tree and yield each layer of the given type.

Parameters:

Name Type Description Default layer_type type[T]

The type of layer to yield.

required recurse bool

Whether to recurse into sub-Chains.

False

Yields:

Type Description T

Each module of the given layer_type.

Source code in src/refiners/fluxion/layers/chain.py
def layers(\n    self,\n    layer_type: type[T],\n    recurse: bool = False,\n) -> Iterator[T]:\n    \"\"\"Walk the Chain's sub-module tree and yield each layer of the given type.\n\n    Args:\n        layer_type: The type of layer to yield.\n        recurse: Whether to recurse into sub-Chains.\n\n    Yields:\n        Each module of the given layer_type.\n    \"\"\"\n    for module, _ in self.walk(layer_type, recurse):\n        yield module\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.pop","title":"pop","text":"
pop(index: int = -1) -> Module\n

Pop a module from the chain at the given index.

Parameters:

Name Type Description Default index int

The index of the module to pop.

-1

Returns:

Type Description Module

The popped module.

Raises:

Type Description IndexError

If the index is out of range.

Source code in src/refiners/fluxion/layers/chain.py
def pop(self, index: int = -1) -> Module:\n    \"\"\"Pop a module from the chain at the given index.\n\n    Args:\n        index: The index of the module to pop.\n\n    Returns:\n        The popped module.\n\n    Raises:\n        IndexError: If the index is out of range.\n    \"\"\"\n    modules = list(self)\n    if index < 0:\n        index = len(modules) + index\n    if index < 0 or index >= len(modules):\n        raise IndexError(\"Index out of range.\")\n    removed_module = modules.pop(index)\n    if isinstance(removed_module, ContextModule):\n        removed_module._set_parent(None)\n    self._regenerate_keys(modules)\n    return removed_module\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.remove","title":"remove","text":"
remove(module: Module) -> None\n

Remove a module from the chain.

Parameters:

Name Type Description Default module Module

The module to remove.

required

Raises:

Type Description ValueError

If the module is not in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def remove(self, module: Module) -> None:\n    \"\"\"Remove a module from the chain.\n\n    Args:\n        module: The module to remove.\n\n    Raises:\n        ValueError: If the module is not in the chain.\n    \"\"\"\n    modules = list(self)\n    try:\n        modules.remove(module)\n    except ValueError:\n        raise ValueError(f\"{module} is not in {self}\")\n    self._regenerate_keys(modules)\n    if isinstance(module, ContextModule):\n        module._set_parent(None)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.replace","title":"replace","text":"
replace(\n    old_module: Module,\n    new_module: Module,\n    old_module_parent: Chain | None = None,\n) -> None\n

Replace a module in the chain with a new module.

Parameters:

Name Type Description Default old_module Module

The module to replace.

required new_module Module

The module to replace with.

required old_module_parent Chain | None

The parent of the old module. If None, the old module is orphanized.

None

Raises:

Type Description ValueError

If the module is not in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def replace(\n    self,\n    old_module: Module,\n    new_module: Module,\n    old_module_parent: \"Chain | None\" = None,\n) -> None:\n    \"\"\"Replace a module in the chain with a new module.\n\n    Args:\n        old_module: The module to replace.\n        new_module: The module to replace with.\n        old_module_parent: The parent of the old module.\n            If None, the old module is orphanized.\n\n    Raises:\n        ValueError: If the module is not in the chain.\n    \"\"\"\n    modules = list(self)\n    try:\n        modules[modules.index(old_module)] = new_module\n    except ValueError:\n        raise ValueError(f\"{old_module} is not in {self}\")\n    self._regenerate_keys(modules)\n    if isinstance(new_module, ContextModule):\n        new_module._set_parent(self)\n    if isinstance(old_module, ContextModule):\n        old_module._set_parent(old_module_parent)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.set_context","title":"set_context","text":"
set_context(context: str, value: Any) -> None\n

Set a value in the context provider.

Parameters:

Name Type Description Default context str

The context to update.

required value Any

The value to set.

required Source code in src/refiners/fluxion/layers/chain.py
def set_context(self, context: str, value: Any) -> None:\n    \"\"\"Set a value in the context provider.\n\n    Args:\n        context: The context to update.\n        value: The value to set.\n    \"\"\"\n    self._provider.set_context(context, value)\n    self._register_provider()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.structural_copy","title":"structural_copy","text":"
structural_copy() -> TChain\n

Copy the structure of the Chain tree.

This method returns a recursive copy of the Chain tree where all inner nodes (instances of Chain and its subclasses) are duplicated and all leaves (regular Modules) are not.

Such copies can be adapted without disrupting the base model, but do not require extra GPU memory since the weights are in the leaves and hence not copied.

Source code in src/refiners/fluxion/layers/chain.py
def structural_copy(self: TChain) -> TChain:\n    \"\"\"Copy the structure of the Chain tree.\n\n    This method returns a recursive copy of the Chain tree where all inner nodes\n    (instances of Chain and its subclasses) are duplicated and all leaves\n    (regular Modules) are not.\n\n    Such copies can be adapted without disrupting the base model, but do not\n    require extra GPU memory since the weights are in the leaves and hence not copied.\n    \"\"\"\n    if hasattr(self, \"_pre_structural_copy\"):\n        assert callable(self._pre_structural_copy)\n        self._pre_structural_copy()\n\n    modules = [structural_copy(m) for m in self]\n    clone = super().structural_copy()\n    clone._provider = ContextProvider.create(clone.init_context())\n\n    for module in modules:\n        clone.append(module=module)\n\n    if hasattr(clone, \"_post_structural_copy\"):\n        assert callable(clone._post_structural_copy)\n        clone._post_structural_copy(self)\n\n    return clone\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.walk","title":"walk","text":"
walk(\n    predicate: (\n        type[T] | Callable[[Module, Chain], bool] | None\n    ) = None,\n    recurse: bool = False,\n) -> (\n    Iterator[tuple[T, Chain]]\n    | Iterator[tuple[Module, Chain]]\n)\n

Walk the Chain's sub-module tree and yield each module that matches the predicate.

Parameters:

Name Type Description Default predicate type[T] | Callable[[Module, Chain], bool] | None

The predicate to match.

None recurse bool

Whether to recurse into sub-Chains.

False

Yields:

Type Description Iterator[tuple[T, Chain]] | Iterator[tuple[Module, Chain]]

Each module that matches the predicate.

Source code in src/refiners/fluxion/layers/chain.py
def walk(\n    self,\n    predicate: type[T] | Callable[[Module, \"Chain\"], bool] | None = None,\n    recurse: bool = False,\n) -> Iterator[tuple[T, \"Chain\"]] | Iterator[tuple[Module, \"Chain\"]]:\n    \"\"\"Walk the Chain's sub-module tree and yield each module that matches the predicate.\n\n    Args:\n        predicate: The predicate to match.\n        recurse: Whether to recurse into sub-Chains.\n\n    Yields:\n        Each module that matches the predicate.\n    \"\"\"\n\n    if get_origin(predicate) is not None:\n        raise ValueError(f\"subscripted generics cannot be used as predicates\")\n\n    if isinstance(predicate, type):\n        # if the predicate is a Module type\n        # build a predicate function that matches the type\n        return self._walk(\n            predicate=lambda m, _: isinstance(m, predicate),\n            recurse=recurse,\n        )\n    else:\n        return self._walk(\n            predicate=predicate,\n            recurse=recurse,\n        )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Concatenate","title":"Concatenate","text":"
Concatenate(*modules: Module, dim: int = 0)\n

Bases: Chain

Concatenation layer.

This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.

Example
concatenate = fl.Concatenate(\n    fl.Linear(32, 128),\n    fl.Linear(32, 128),\n    dim=1,\n)\n\ntensor = torch.randn(2, 32)\noutput = concatenate(tensor)\n\nassert output.shape == (2, 256)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *modules: Module, dim: int = 0) -> None:\n    super().__init__(*modules)\n    self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule","title":"ContextModule","text":"
ContextModule(*args: Any, **kwargs: Any)\n

Bases: Module

A module containing a ContextProvider.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)\n    self._parent = []\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.ensure_parent","title":"ensure_parent property","text":"
ensure_parent: Chain\n

Return the module's parent, or raise an error if module is an orphan.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.parent","title":"parent property","text":"
parent: Chain | None\n

Return the module's parent, or None if module is an orphan.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.provider","title":"provider property","text":"
provider: ContextProvider\n

Return the module's context provider.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.get_parents","title":"get_parents","text":"
get_parents() -> list[Chain]\n

Recursively retrieve the module's parents.

Source code in src/refiners/fluxion/layers/module.py
def get_parents(self) -> \"list[Chain]\":\n    \"\"\"Recursively retrieve the module's parents.\"\"\"\n    return self._parent + self._parent[0].get_parents() if self._parent else []\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.get_path","title":"get_path","text":"
get_path(\n    parent: Chain | None = None, top: Module | None = None\n) -> str\n

Get the path of the module in the chain.

Parameters:

Name Type Description Default parent Chain | None

The parent of the module in the chain.

None top Module | None

The top module of the chain. If None, the path will be relative to the root of the chain.

None Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: \"Chain | None\" = None, top: \"Module | None\" = None) -> str:\n    \"\"\"Get the path of the module in the chain.\n\n    Args:\n        parent: The parent of the module in the chain.\n        top: The top module of the chain.\n            If None, the path will be relative to the root of the chain.\n    \"\"\"\n\n    return super().get_path(parent=parent or self.parent, top=top)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.use_context","title":"use_context","text":"
use_context(context_name: str) -> Context\n

Retrieve the context object from the module's context provider.

Source code in src/refiners/fluxion/layers/module.py
def use_context(self, context_name: str) -> Context:\n    \"\"\"Retrieve the context object from the module's context provider.\"\"\"\n    context = self.provider.get_context(context_name)\n    assert context is not None, f\"Context {context_name} not found.\"\n    return context\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Conv2d","title":"Conv2d","text":"
Conv2d(\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] = (1, 1),\n    padding: int | tuple[int, int] | str = (0, 0),\n    groups: int = 1,\n    use_bias: bool = True,\n    dilation: int | tuple[int, int] = (1, 1),\n    padding_mode: str = \"zeros\",\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Conv2d, WeightedModule

2D Convolutional layer.

This layer wraps torch.nn.Conv2d.

Receives:

Type Description Real[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description Real[Tensor, 'batch out_channels out_height out_width'] Example
conv2d = fl.Conv2d(\n    in_channels=3,\n    out_channels=32,\n    kernel_size=3,\n    stride=1,\n    padding=1,\n)\n\ntensor = torch.randn(2, 3, 128, 128)\noutput = conv2d(tensor)\n\nassert output.shape == (2, 32, 128, 128)\n
Source code in src/refiners/fluxion/layers/conv.py
def __init__(\n    self,\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] = (1, 1),\n    padding: int | tuple[int, int] | str = (0, 0),\n    groups: int = 1,\n    use_bias: bool = True,\n    dilation: int | tuple[int, int] = (1, 1),\n    padding_mode: str = \"zeros\",\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(  # type: ignore\n        in_channels=in_channels,\n        out_channels=out_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        groups=groups,\n        bias=use_bias,\n        padding_mode=padding_mode,\n        device=device,\n        dtype=dtype,\n    )\n    self.use_bias = use_bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ConvTranspose2d","title":"ConvTranspose2d","text":"
ConvTranspose2d(\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] = 1,\n    padding: int | tuple[int, int] = 0,\n    output_padding: int | tuple[int, int] = 0,\n    groups: int = 1,\n    use_bias: bool = True,\n    dilation: int | tuple[int, int] = 1,\n    padding_mode: str = \"zeros\",\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ConvTranspose2d, WeightedModule

2D Transposed Convolutional layer.

This layer wraps torch.nn.ConvTranspose2d.

Receives:

Type Description Real[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description Real[Tensor, 'batch out_channels out_height out_width'] Example
conv2d = fl.ConvTranspose2d(\n    in_channels=3,\n    out_channels=32,\n    kernel_size=3,\n    stride=1,\n    padding=1,\n)\n\ntensor = torch.randn(2, 3, 128, 128)\noutput = conv2d(tensor)\n\nassert output.shape == (2, 32, 128, 128)\n
Source code in src/refiners/fluxion/layers/conv.py
def __init__(\n    self,\n    in_channels: int,\n    out_channels: int,\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] = 1,\n    padding: int | tuple[int, int] = 0,\n    output_padding: int | tuple[int, int] = 0,\n    groups: int = 1,\n    use_bias: bool = True,\n    dilation: int | tuple[int, int] = 1,\n    padding_mode: str = \"zeros\",\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(  # type: ignore\n        in_channels=in_channels,\n        out_channels=out_channels,\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        output_padding=output_padding,\n        dilation=dilation,\n        groups=groups,\n        bias=use_bias,\n        padding_mode=padding_mode,\n        device=device,\n        dtype=dtype,\n    )\n    self.use_bias = use_bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Converter","title":"Converter","text":"
Converter(set_device: bool = True, set_dtype: bool = True)\n

Bases: ContextModule

A Converter class that adjusts tensor properties based on a parent module's settings.

This class inherits from ContextModule and provides functionality to adjust the device and dtype of input tensor(s) to match the parent module's attributes.

Note

Ensure the parent module has device and dtype attributes if set_device or set_dtype are set to True.

Parameters:

Name Type Description Default set_device bool

If True, matches the device of the input tensor(s) to the parent's device.

True set_dtype bool

If True, matches the dtype of the input tensor(s) to the parent's dtype.

True Source code in src/refiners/fluxion/layers/converter.py
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:\n    \"\"\"Initializes the Converter layer.\n\n    Args:\n        set_device: If True, matches the device of the input tensor(s) to the parent's device.\n        set_dtype: If True, matches the dtype of the input tensor(s) to the parent's dtype.\n    \"\"\"\n    super().__init__()\n    self.set_device = set_device\n    self.set_dtype = set_dtype\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Cos","title":"Cos","text":"
Cos(*args: Any, **kwargs: Any)\n

Bases: Module

Cosine operator layer.

This layer applies the cosine function to the input tensor. See also torch.cos.

Example
cos = fl.Cos()\n\ntensor = torch.tensor([0, torch.pi])\noutput = cos(tensor)\n\nexpected_output = torch.tensor([1.0, -1.0])\nassert torch.allclose(output, expected_output, atol=1e-6)\n
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Distribute","title":"Distribute","text":"
Distribute(*args: Module | Iterable[Module])\n

Bases: Chain

Distribute layer.

This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.

Example
distribute = fl.Distribute(\n    fl.Linear(32, 128),\n    fl.Linear(64, 256),\n)\n\ntensor1 = torch.randn(2, 32)\ntensor2 = torch.randn(4, 64)\noutputs = distribute(tensor1, tensor2)\n\nassert len(outputs) == 2\nassert outputs[0].shape == (2, 128)\nassert outputs[1].shape == (4, 256)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Downsample","title":"Downsample","text":"
Downsample(\n    channels: int,\n    scale_factor: int,\n    padding: int = 0,\n    register_shape: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Downsample layer.

This layer downsamples the input by the given scale factor.

Raises:

Type Description RuntimeError

If the context sampling is not set or if the context does not contain a list.

Parameters:

Name Type Description Default channels int

The number of input and output channels.

required scale_factor int

The factor by which to downsample the input.

required padding int

The amount of zero-padding added to both sides of the input.

0 register_shape bool

If True, registers the input shape in the context.

True device device | str | None

The device to use for the convolutional layer.

None dtype dtype | None

The dtype to use for the convolutional layer.

None Source code in src/refiners/fluxion/layers/sampling.py
def __init__(\n    self,\n    channels: int,\n    scale_factor: int,\n    padding: int = 0,\n    register_shape: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n):\n    \"\"\"Initializes the Downsample layer.\n\n    Args:\n        channels: The number of input and output channels.\n        scale_factor: The factor by which to downsample the input.\n        padding: The amount of zero-padding added to both sides of the input.\n        register_shape: If True, registers the input shape in the context.\n        device: The device to use for the convolutional layer.\n        dtype: The dtype to use for the convolutional layer.\n    \"\"\"\n    self.channels = channels\n    self.in_channels = channels\n    self.out_channels = channels\n    self.scale_factor = scale_factor\n    self.padding = padding\n\n    super().__init__(\n        Conv2d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=3,\n            stride=scale_factor,\n            padding=padding,\n            device=device,\n            dtype=dtype,\n        ),\n    )\n\n    if padding == 0:\n        zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))\n        self.insert(\n            index=0,\n            module=Lambda(func=zero_pad),\n        )\n\n    if register_shape:\n        self.insert(\n            index=0,\n            module=SetContext(\n                context=\"sampling\",\n                key=\"shapes\",\n                callback=self.register_shape,\n            ),\n        )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Embedding","title":"Embedding","text":"
Embedding(\n    num_embeddings: int,\n    embedding_dim: int,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Embedding, WeightedModule

Embedding layer.

This layer wraps torch.nn.Embedding.

Receives:

Type Description Int[Tensor, 'batch length']

Returns:

Type Description Float[Tensor, 'batch length embedding_dim'] Example
embedding = fl.Embedding(\n    num_embeddings=10,\n    embedding_dim=128\n)\n\ntensor = torch.randint(0, 10, (2, 10))\noutput = embedding(tensor)\n\nassert output.shape == (2, 10, 128)\n

Parameters:

Name Type Description Default num_embeddings int

The number of embeddings.

required embedding_dim int

The dimension of the embeddings.

required device device | str | None

The device to use for the embedding layer.

None dtype dtype | None

The dtype to use for the embedding layer.

None Source code in src/refiners/fluxion/layers/embedding.py
def __init__(\n    self,\n    num_embeddings: int,\n    embedding_dim: int,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n):\n    \"\"\"Initializes the Embedding layer.\n\n    Args:\n        num_embeddings: The number of embeddings.\n        embedding_dim: The dimension of the embeddings.\n        device: The device to use for the embedding layer.\n        dtype: The dtype to use for the embedding layer.\n    \"\"\"\n    _Embedding.__init__(  # type: ignore\n        self,\n        num_embeddings=num_embeddings,\n        embedding_dim=embedding_dim,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Flatten","title":"Flatten","text":"
Flatten(start_dim: int = 0, end_dim: int = -1)\n

Bases: Module

Flatten operation layer.

This layer flattens the input tensor between the given dimensions. See also torch.flatten.

Example
flatten = fl.Flatten(start_dim=1)\n\ntensor = torch.randn(10, 10, 10)\noutput = flatten(tensor)\n\nassert output.shape == (10, 100)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n    self,\n    start_dim: int = 0,\n    end_dim: int = -1,\n) -> None:\n    super().__init__()\n    self.start_dim = start_dim\n    self.end_dim = end_dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GLU","title":"GLU","text":"
GLU(activation: Activation)\n

Bases: Activation

Gated Linear Unit activation function.

See [arXiv:2002.05202] GLU Variants Improve Transformer for more details.

Example
glu = fl.GLU(fl.ReLU())\ntensor = torch.tensor([[1.0, 0.0, -1.0, 1.0]])\noutput = glu(tensor)\nassert torch.allclose(output, torch.tensor([0.0, 0.0]))\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self, activation: Activation) -> None:\n    super().__init__()\n    self.activation = activation\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GeLU","title":"GeLU","text":"
GeLU(\n    approximation: GeLUApproximation = GeLUApproximation.NONE,\n)\n

Bases: Activation

Gaussian Error Linear Unit activation function.

This activation can be quite expensive to compute, a few approximations are available, see GeLUApproximation.

See [arXiv:1606.08415] Gaussian Error Linear Units for more details.

Example
gelu = fl.GeLU()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = gelu(tensor)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(\n    self,\n    approximation: GeLUApproximation = GeLUApproximation.NONE,\n) -> None:\n    super().__init__()\n    self.approximation = approximation\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GeLUApproximation","title":"GeLUApproximation","text":"

Bases: Enum

Approximation methods for the Gaussian Error Linear Unit activation function.

Attributes:

Name Type Description NONE

No approximation, use the original formula.

TANH

Use the tanh approximation.

SIGMOID

Use the sigmoid approximation.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GetArg","title":"GetArg","text":"
GetArg(index: int)\n

Bases: Module

GetArg operation layer.

This layer returns the nth tensor of the input arguments.

Example
get_arg = fl.GetArg(1)\n\ninputs = (\n    torch.randn(10, 10),\n    torch.randn(20, 20),\n    torch.randn(30, 30),\n)\noutput = get_arg(*inputs)\n\nassert id(inputs[1]) == id(output)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, index: int) -> None:\n    super().__init__()\n    self.index = index\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GroupNorm","title":"GroupNorm","text":"
GroupNorm(\n    channels: int,\n    num_groups: int,\n    eps: float = 1e-05,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: GroupNorm, WeightedModule

Group Normalization layer.

This layer wraps torch.nn.GroupNorm.

Receives:

Type Description Float[Tensor, 'batch channels *normalized_shape']

Returns:

Type Description Float[Tensor, 'batch channels *normalized_shape'] Example
groupnorm = fl.GroupNorm(channels=128, num_groups=8)\n\ntensor = torch.randn(2, 128, 8)\noutput = groupnorm(tensor)\n\nassert output.shape == (2, 128, 8)\n
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n    self,\n    channels: int,\n    num_groups: int,\n    eps: float = 1e-5,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(  # type: ignore\n        num_groups=num_groups,\n        num_channels=channels,\n        eps=eps,\n        affine=True,  # otherwise not a WeightedModule\n        device=device,\n        dtype=dtype,\n    )\n    self.channels = channels\n    self.num_groups = num_groups\n    self.eps = eps\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Identity","title":"Identity","text":"
Identity()\n

Bases: Module

Identity operator layer.

This layer simply returns the input tensor.

Example
identity = fl.Identity()\n\ntensor = torch.randn(10, 10)\noutput = identity(tensor)\n\nassert torch.equal(tensor, output)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self) -> None:\n    super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.InstanceNorm2d","title":"InstanceNorm2d","text":"
InstanceNorm2d(\n    num_features: int,\n    eps: float = 1e-05,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: InstanceNorm2d, Module

Instance Normalization layer.

This layer wraps torch.nn.InstanceNorm2d.

Receives:

Type Description Float[Tensor, 'batch channels height width']

Returns:

Type Description Float[Tensor, 'batch channels height width'] Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n    self,\n    num_features: int,\n    eps: float = 1e-05,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(  # type: ignore\n        num_features=num_features,\n        eps=eps,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Interpolate","title":"Interpolate","text":"
Interpolate(mode: str = 'nearest', antialias: bool = False)\n

Bases: Module

Interpolate layer.

This layer wraps torch.nn.functional.interpolate.

Source code in src/refiners/fluxion/layers/sampling.py
def __init__(\n    self,\n    mode: str = \"nearest\",\n    antialias: bool = False,\n) -> None:\n    super().__init__()\n    self.mode = mode\n    self.antialias = antialias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Lambda","title":"Lambda","text":"
Lambda(func: Callable[..., Any])\n

Bases: Module

Lambda layer.

This layer wraps a Callable.

When called, it will Example
lambda_layer = fl.Lambda(lambda x: x + 1)\n\ntensor = torch.tensor([1, 2, 3])\noutput = lambda_layer(tensor)\n\nexpected_output = torch.tensor([2, 3, 4])\nassert torch.allclose(output, expected_output)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, func: Callable[..., Any]) -> None:\n    super().__init__()\n    self.func = func\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.LayerNorm","title":"LayerNorm","text":"
LayerNorm(\n    normalized_shape: int | list[int],\n    eps: float = 1e-05,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: LayerNorm, WeightedModule

Layer Normalization layer.

This layer wraps torch.nn.LayerNorm.

Receives:

Type Description Float[Tensor, batch * normalized_shape]

Returns:

Type Description Float[Tensor, batch * normalized_shape] Example
layernorm = fl.LayerNorm(normalized_shape=128)\n\ntensor = torch.randn(2, 128)\noutput = layernorm(tensor)\n\nassert output.shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n    self,\n    normalized_shape: int | list[int],\n    eps: float = 0.00001,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(  # type: ignore\n        normalized_shape=normalized_shape,\n        eps=eps,\n        elementwise_affine=True,  # otherwise not a WeightedModule\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.LayerNorm2d","title":"LayerNorm2d","text":"
LayerNorm2d(\n    channels: int,\n    eps: float = 1e-06,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: WeightedModule

2D Layer Normalization layer.

This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.

Receives:

Type Description Float[Tensor, 'batch channels height width']

Returns:

Type Description Float[Tensor, 'batch channels height width'] Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n    self,\n    channels: int,\n    eps: float = 1e-6,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__()\n    self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))\n    self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))\n    self.eps = eps\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Linear","title":"Linear","text":"
Linear(\n    in_features: int,\n    out_features: int,\n    bias: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Linear, WeightedModule

Linear layer.

This layer wraps torch.nn.Linear.

Receives:

Name Type Description Input Float[Tensor, 'batch in_features']

Returns:

Name Type Description Output Float[Tensor, 'batch out_features'] Example
linear = fl.Linear(in_features=32, out_features=128)\n\ntensor = torch.randn(2, 32)\noutput = linear(tensor)\n\nassert output.shape == (2, 128)\n

Parameters:

Name Type Description Default in_features int

The number of input features.

required out_features int

The number of output features.

required bias bool

If True, adds a learnable bias to the output.

True device device | str | None

The device to use for the linear layer.

None dtype dtype | None

The dtype to use for the linear layer.

None Source code in src/refiners/fluxion/layers/linear.py
def __init__(\n    self,\n    in_features: int,\n    out_features: int,\n    bias: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initializes the Linear layer.\n\n    Args:\n        in_features: The number of input features.\n        out_features: The number of output features.\n        bias: If True, adds a learnable bias to the output.\n        device: The device to use for the linear layer.\n        dtype: The dtype to use for the linear layer.\n    \"\"\"\n    self.in_features = in_features\n    self.out_features = out_features\n    super().__init__(  # type: ignore\n        in_features=in_features,\n        out_features=out_features,\n        bias=bias,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Matmul","title":"Matmul","text":"
Matmul(input: Module, other: Module)\n

Bases: Chain

Matrix multiplication layer.

This layer returns the matrix multiplication of the outputs of its two sub-modules.

Example
matmul = fl.Matmul(\n    fl.Identity(),\n    fl.Multiply(scale=2),\n)\n\ntensor = torch.randn(10, 10)\noutput = matmul(tensor)\n\nexpected_output = tensor @ (2 * tensor)\nassert torch.allclose(output, expected_output)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, input: Module, other: Module) -> None:\n    super().__init__(\n        input,\n        other,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MaxPool1d","title":"MaxPool1d","text":"
MaxPool1d(\n    kernel_size: int,\n    stride: int | None = None,\n    padding: int = 0,\n    dilation: int = 1,\n    return_indices: bool = False,\n    ceil_mode: bool = False,\n)\n

Bases: MaxPool1d, Module

MaxPool1d layer.

This layer wraps torch.nn.MaxPool1d.

Receives:

Type Description Float[Tensor, 'batch channels in_length']

Returns:

Type Description Float[Tensor, 'batch channels out_length']

Parameters:

Name Type Description Default kernel_size int

The size of the sliding window.

required stride int | None

The stride of the sliding window.

None padding int

The amount of zero-padding added to both sides of the input.

0 dilation int

The spacing between kernel elements.

1 return_indices bool

If True, returns the max indices along with the outputs.

False ceil_mode bool

If True, uses ceil instead of floor to compute the output shape.

False Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(\n    self,\n    kernel_size: int,\n    stride: int | None = None,\n    padding: int = 0,\n    dilation: int = 1,\n    return_indices: bool = False,\n    ceil_mode: bool = False,\n) -> None:\n    \"\"\"Initializes the MaxPool1d layer.\n\n    Args:\n        kernel_size: The size of the sliding window.\n        stride: The stride of the sliding window.\n        padding: The amount of zero-padding added to both sides of the input.\n        dilation: The spacing between kernel elements.\n        return_indices: If True, returns the max indices along with the outputs.\n        ceil_mode: If True, uses ceil instead of floor to compute the output shape.\n    \"\"\"\n    super().__init__(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        return_indices=return_indices,\n        ceil_mode=ceil_mode,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MaxPool2d","title":"MaxPool2d","text":"
MaxPool2d(\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] | None = None,\n    padding: int | tuple[int, int] = (0, 0),\n    dilation: int | tuple[int, int] = (1, 1),\n    return_indices: bool = False,\n    ceil_mode: bool = False,\n)\n

Bases: MaxPool2d, Module

MaxPool2d layer.

This layer wraps torch.nn.MaxPool2d.

Receives:

Type Description Float[Tensor, 'batch channels in_height in_width']

Returns:

Type Description Float[Tensor, 'batch channels out_height out_width']

Parameters:

Name Type Description Default kernel_size int | tuple[int, int]

The size of the sliding window.

required stride int | tuple[int, int] | None

The stride of the sliding window.

None padding int | tuple[int, int]

The amount of zero-padding added to both sides of the input.

(0, 0) dilation int | tuple[int, int]

The spacing between kernel elements.

(1, 1) return_indices bool

If True, returns the max indices along with the outputs.

False ceil_mode bool

If True, uses ceil instead of floor to compute the output shape.

False Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(\n    self,\n    kernel_size: int | tuple[int, int],\n    stride: int | tuple[int, int] | None = None,\n    padding: int | tuple[int, int] = (0, 0),\n    dilation: int | tuple[int, int] = (1, 1),\n    return_indices: bool = False,\n    ceil_mode: bool = False,\n) -> None:\n    \"\"\"Initializes the MaxPool2d layer.\n\n    Args:\n        kernel_size: The size of the sliding window.\n        stride: The stride of the sliding window.\n        padding: The amount of zero-padding added to both sides of the input.\n        dilation: The spacing between kernel elements.\n        return_indices: If True, returns the max indices along with the outputs.\n        ceil_mode: If True, uses ceil instead of floor to compute the output shape.\n    \"\"\"\n    super().__init__(\n        kernel_size=kernel_size,\n        stride=stride,\n        padding=padding,\n        dilation=dilation,\n        return_indices=return_indices,\n        ceil_mode=ceil_mode,\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module","title":"Module","text":"
Module(*args: Any, **kwargs: Any)\n

Bases: Module

A wrapper around torch.nn.Module.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.basic_attributes","title":"basic_attributes","text":"
basic_attributes(\n    init_attrs_only: bool = False,\n) -> dict[str, BasicType | Sequence[BasicType]]\n

Return a dictionary of basic attributes of the module.

Parameters:

Name Type Description Default init_attrs_only bool

Whether to only return attributes that are passed to the module's constructor.

False Source code in src/refiners/fluxion/layers/module.py
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:\n    \"\"\"Return a dictionary of basic attributes of the module.\n\n    Args:\n        init_attrs_only: Whether to only return attributes that are passed to the module's constructor.\n    \"\"\"\n    sig = signature(obj=self.__init__)\n    init_params = set(sig.parameters.keys()) - {\"self\"}\n    default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}\n\n    def is_basic_attribute(key: str, value: Any) -> bool:\n        if key.startswith(\"_\"):\n            return False\n\n        if isinstance(value, BasicType):\n            return True\n\n        if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):\n            return True\n\n        return False\n\n    return {\n        key: value\n        for key, value in self.__dict__.items()\n        if is_basic_attribute(key=key, value=value)\n        and (not init_attrs_only or (key in init_params and value != default_values.get(key)))\n    }\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.get_path","title":"get_path","text":"
get_path(\n    parent: Chain | None = None, top: Module | None = None\n) -> str\n

Get the path of the module in the chain.

Parameters:

Name Type Description Default parent Chain | None

The parent of the module in the chain.

None top Module | None

The top module of the chain. If None, the path will be relative to the root of the chain.

None Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: \"Chain | None\" = None, top: \"Module | None\" = None) -> str:\n    \"\"\"Get the path of the module in the chain.\n\n    Args:\n        parent: The parent of the module in the chain.\n        top: The top module of the chain.\n            If None, the path will be relative to the root of the chain.\n    \"\"\"\n    if (parent is None) or (self == top):\n        return self.__class__.__name__\n    for k, m in parent._modules.items():  # type: ignore\n        if m is self:\n            return parent.get_path(parent=parent.parent, top=top) + \".\" + k\n    raise ValueError(f\"{self} not found in {parent}\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.load_from_safetensors","title":"load_from_safetensors","text":"
load_from_safetensors(\n    tensors_path: str | Path, strict: bool = True\n) -> T\n

Load the module's state from a SafeTensors file.

Parameters:

Name Type Description Default tensors_path str | Path

The path to the SafeTensors file.

required strict bool

Whether to raise an error if the SafeTensors's content doesn't map perfectly to the module's state.

True

Returns:

Type Description T

The module, with its state loaded from the SafeTensors file.

Source code in src/refiners/fluxion/layers/module.py
def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T:\n    \"\"\"Load the module's state from a SafeTensors file.\n\n    Args:\n        tensors_path: The path to the SafeTensors file.\n        strict: Whether to raise an error if the SafeTensors's\n            content doesn't map perfectly to the module's state.\n\n    Returns:\n        The module, with its state loaded from the SafeTensors file.\n    \"\"\"\n    state_dict = load_from_safetensors(tensors_path)\n    self.load_state_dict(state_dict, strict=strict)\n    return self\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.named_modules","title":"named_modules","text":"
named_modules(\n    *args: Any, **kwargs: Any\n) -> Generator[tuple[str, Module], None, None]\n

Get all the sub-modules of the module.

Returns:

Type Description None

An iterator over all the sub-modules of the module.

Source code in src/refiners/fluxion/layers/module.py
def named_modules(self, *args: Any, **kwargs: Any) -> \"Generator[tuple[str, Module], None, None]\":  # type: ignore\n    \"\"\"Get all the sub-modules of the module.\n\n    Returns:\n        An iterator over all the sub-modules of the module.\n    \"\"\"\n    return super().named_modules(*args)  # type: ignore\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.pretty_print","title":"pretty_print","text":"
pretty_print(depth: int = -1) -> None\n

Print the module in a tree-like format.

Parameters:

Name Type Description Default depth int

The maximum depth of the tree to print. If negative, the whole tree is printed.

-1 Source code in src/refiners/fluxion/layers/module.py
def pretty_print(self, depth: int = -1) -> None:\n    \"\"\"Print the module in a tree-like format.\n\n    Args:\n        depth: The maximum depth of the tree to print.\n            If negative, the whole tree is printed.\n    \"\"\"\n    tree = ModuleTree(module=self)\n    print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth))  # type: ignore[reportPrivateUsage]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.to","title":"to","text":"
to(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n) -> T\n

Move the module to the given device and cast its parameters to the given dtype.

Parameters:

Name Type Description Default device device | str | None

The device to move the module to.

None dtype dtype | None

The dtype to cast the module's parameters to.

None

Returns:

Type Description T

The module, moved to the given device and cast to the given dtype.

Source code in src/refiners/fluxion/layers/module.py
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T:  # type: ignore\n    \"\"\"Move the module to the given device and cast its parameters to the given dtype.\n\n    Args:\n        device: The device to move the module to.\n        dtype: The dtype to cast the module's parameters to.\n\n    Returns:\n        The module, moved to the given device and cast to the given dtype.\n    \"\"\"\n    return super().to(device=device, dtype=dtype)  # type: ignore\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MultiLinear","title":"MultiLinear","text":"
MultiLinear(\n    input_dim: int,\n    output_dim: int,\n    inner_dim: int,\n    num_layers: int,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Multi-layer linear network.

This layer wraps multiple torch.nn.Linear layers, with an Activation layer in between.

Receives:

Name Type Description Input Float[Tensor, 'batch input_dim']

Returns:

Name Type Description Output Float[Tensor, 'batch output_dim'] Example
linear = fl.MultiLinear(\n    input_dim=32,\n    output_dim=128,\n    inner_dim=64,\n    num_layers=3,\n)\n\ntensor = torch.randn(2, 32)\noutput = linear(tensor)\n\nassert output.shape == (2, 128)\n

Parameters:

Name Type Description Default input_dim int

The input dimension of the first linear layer.

required output_dim int

The output dimension of the last linear layer.

required inner_dim int

The output dimension of the inner linear layers.

required num_layers int

The number of linear layers.

required device device | str | None

The device to use for the linear layers.

None dtype dtype | None

The dtype to use for the linear layers.

None Source code in src/refiners/fluxion/layers/linear.py
def __init__(\n    self,\n    input_dim: int,\n    output_dim: int,\n    inner_dim: int,\n    num_layers: int,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initializes the MultiLinear layer.\n\n    Args:\n        input_dim: The input dimension of the first linear layer.\n        output_dim: The output dimension of the last linear layer.\n        inner_dim: The output dimension of the inner linear layers.\n        num_layers: The number of linear layers.\n        device: The device to use for the linear layers.\n        dtype: The dtype to use for the linear layers.\n    \"\"\"\n    layers: list[Module] = []\n    for i in range(num_layers - 1):\n        layers.append(\n            Linear(\n                in_features=input_dim if i == 0 else inner_dim,\n                out_features=inner_dim,\n                device=device,\n                dtype=dtype,\n            )\n        )\n        layers.append(\n            ReLU(),\n        )\n    layers.append(\n        Linear(\n            in_features=inner_dim,\n            out_features=output_dim,\n            device=device,\n            dtype=dtype,\n        )\n    )\n\n    super().__init__(layers)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Multiply","title":"Multiply","text":"
Multiply(scale: float = 1.0, bias: float = 0.0)\n

Bases: Module

Multiply operator layer.

This layer scales and shifts the input tensor by the given scale and bias.

Example
multiply = fl.Multiply(scale=2, bias=1)\n\ntensor = torch.ones(1)\noutput = multiply(tensor)\n\nassert torch.allclose(output, torch.tensor([3.0]))\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n    self,\n    scale: float = 1.0,\n    bias: float = 0.0,\n) -> None:\n    super().__init__()\n    self.scale = scale\n    self.bias = bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Parallel","title":"Parallel","text":"
Parallel(*args: Module | Iterable[Module])\n

Bases: Chain

Parallel layer.

This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.

Example
parallel = fl.Parallel(\n    fl.Linear(32, 64),\n    fl.Identity(),\n    fl.Linear(32, 128),\n)\n\ntensor = torch.randn(2, 32)\noutputs = parallel(tensor)\n\nassert len(outputs) == 3\nassert outputs[0].shape == (2, 64)\nassert torch.allclose(outputs[1], tensor)\nassert outputs[2].shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Parameter","title":"Parameter","text":"
Parameter(\n    *dims: int,\n    requires_grad: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None\n)\n

Bases: WeightedModule

Parameter layer.

This layer simple wraps a PyTorch Parameter. When called, it simply returns the Parameter Tensor.

Attributes:

Name Type Description weight Parameter

The parameter Tensor.

Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n    self,\n    *dims: int,\n    requires_grad: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__()\n    self.dims = dims\n    self.weight = TorchParameter(\n        requires_grad=requires_grad,\n        data=torch.randn(\n            *dims,\n            device=device,\n            dtype=dtype,\n        ),\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Passthrough","title":"Passthrough","text":"
Passthrough(*args: Module | Iterable[Module])\n

Bases: Chain

Passthrough layer.

This layer call its sub-modules sequentially, and returns its original inputs, like an Identity layer.

Example
passthrough = fl.Passthrough(\n    fl.Linear(32, 128),\n    fl.ReLU(),\n    fl.Linear(128, 128),\n)\n\ntensor = torch.randn(2, 32)\noutput = passthrough(tensor)\n\nassert torch.allclose(output, tensor)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Permute","title":"Permute","text":"
Permute(*dims: int)\n

Bases: Module

Permute operation layer.

This layer permutes the input tensor according to the given dimensions. See also torch.permute.

Example
permute = fl.Permute(2, 0, 1)\n\ntensor = torch.randn(10, 20, 30)\noutput = permute(tensor)\n\nassert output.shape == (30, 10, 20)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *dims: int) -> None:\n    super().__init__()\n    self.dims = dims\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.PixelUnshuffle","title":"PixelUnshuffle","text":"
PixelUnshuffle(downscale_factor: int)\n

Bases: PixelUnshuffle, Module

Pixel Unshuffle layer.

This layer wraps torch.nn.PixelUnshuffle.

Receives:

Type Description Float[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description Float[Tensor, 'batch out_channels out_height out_width'] Source code in src/refiners/fluxion/layers/pixelshuffle.py
def __init__(self, downscale_factor: int):\n    _PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ReLU","title":"ReLU","text":"
ReLU()\n

Bases: Activation

Rectified Linear Unit activation function.

See Rectified Linear Units Improve Restricted Boltzmann Machines and Cognitron: A self-organizing multilayered neural network

Example
relu = fl.ReLU()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = relu(tensor)\n\nexpected_output = torch.tensor([[0.0, 0.0, 1.0]])\nassert torch.equal(output, expected_output)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n    super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ReflectionPad2d","title":"ReflectionPad2d","text":"
ReflectionPad2d(padding: int)\n

Bases: ReflectionPad2d, Module

Reflection padding layer.

This layer wraps torch.nn.ReflectionPad2d.

Receives:

Type Description Float[Tensor, 'batch channels in_height in_width']

Returns:

Type Description Float[Tensor, 'batch channels out_height out_width'] Source code in src/refiners/fluxion/layers/padding.py
def __init__(self, padding: int) -> None:\n    super().__init__(padding=padding)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Reshape","title":"Reshape","text":"
Reshape(*shape: int)\n

Bases: Module

Reshape operation layer.

This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape). See also torch.reshape.

Warning

The first dimension (batch dimension) is forcefully preserved.

Example
reshape = fl.Reshape(5, 2)\n\ntensor = torch.randn(2, 10, 1)\noutput = reshape(tensor)\n\nassert output.shape == (2, 5, 2)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *shape: int) -> None:\n    super().__init__()\n    self.shape = shape\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Residual","title":"Residual","text":"
Residual(*args: Module | Iterable[Module])\n

Bases: Chain

Residual layer.

This layer calls its sub-modules sequentially, and adds the original input to the output.

Example
residual = fl.Residual(\n    fl.Multiply(scale=10),\n)\n\ntensor = torch.ones(2, 32)\noutput = residual(tensor)\n\nassert output.shape == (2, 32)\nassert torch.allclose(output, 10 * tensor + tensor)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Return","title":"Return","text":"
Return(*args: Any, **kwargs: Any)\n

Bases: Module

Return layer.

This layer stops the execution of a Chain when encountered.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ScaledDotProductAttention","title":"ScaledDotProductAttention","text":"
ScaledDotProductAttention(\n    num_heads: int = 1,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    slice_size: int | None = None,\n)\n

Bases: Module

Scaled Dot Product Attention.

See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details

Note

This layer simply wraps scaled_dot_product_attention inside an fl.Module.

Receives:

Name Type Description Query Float[Tensor, 'batch num_queries embedding_dim'] Key Float[Tensor, 'batch num_keys embedding_dim'] Value Float[Tensor, 'batch num_values embedding_dim']

Returns:

Type Description Float[Tensor, 'batch num_queries embedding_dim'] Example
attention = fl.ScaledDotProductAttention(num_heads=8)\n\nquery = torch.randn(2, 10, 128)\nkey = torch.randn(2, 10, 128)\nvalue = torch.randn(2, 10, 128)\noutput = attention(query, key, value)\n\nassert output.shape == (2, 10, 128)\n

Parameters:

Name Type Description Default num_heads int

The number of heads to use.

1 is_causal bool

Whether to use causal attention.

False is_optimized bool

Whether to use optimized attention.

True slice_size int | None

The slice size to use for the optimized attention.

None Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n    self,\n    num_heads: int = 1,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    slice_size: int | None = None,\n) -> None:\n    \"\"\"Initialize the Scaled Dot Product Attention layer.\n\n    Args:\n        num_heads: The number of heads to use.\n        is_causal: Whether to use causal attention.\n        is_optimized: Whether to use optimized attention.\n        slice_size: The slice size to use for the optimized attention.\n    \"\"\"\n    super().__init__()\n    self.num_heads = num_heads\n    self.is_causal = is_causal\n    self.is_optimized = is_optimized\n    self.slice_size = slice_size\n    self.dot_product = (\n        scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SelfAttention","title":"SelfAttention","text":"
SelfAttention(\n    embedding_dim: int,\n    inner_dim: int | None = None,\n    num_heads: int = 1,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Attention

Multi-Head Self-Attention layer.

This layer simply chains

Receives:

Type Description Float[Tensor, 'batch sequence_length embedding_dim']

Returns:

Type Description Float[Tensor, 'batch sequence_length embedding_dim'] Example
self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)\n\ntensor = torch.randn(2, 10, 128)\noutput = self_attention(tensor)\n\nassert output.shape == (2, 10, 128)\n

Parameters:

Name Type Description Default embedding_dim int

The embedding dimension of the input and output tensors.

required inner_dim int | None

The inner dimension of the linear layers.

None num_heads int

The number of heads to use.

1 use_bias bool

Whether to use bias in the linear layers.

True is_causal bool

Whether to use causal attention.

False is_optimized bool

Whether to use optimized attention.

True device device | str | None

The device to use.

None dtype dtype | None

The dtype to use.

None Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n    self,\n    embedding_dim: int,\n    inner_dim: int | None = None,\n    num_heads: int = 1,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the Self-Attention layer.\n\n    Args:\n        embedding_dim: The embedding dimension of the input and output tensors.\n        inner_dim: The inner dimension of the linear layers.\n        num_heads: The number of heads to use.\n        use_bias: Whether to use bias in the linear layers.\n        is_causal: Whether to use causal attention.\n        is_optimized: Whether to use optimized attention.\n        device: The device to use.\n        dtype: The dtype to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=embedding_dim,\n        inner_dim=inner_dim,\n        num_heads=num_heads,\n        use_bias=use_bias,\n        is_causal=is_causal,\n        is_optimized=is_optimized,\n        device=device,\n        dtype=dtype,\n    )\n    self.insert(\n        index=0,\n        module=Parallel(\n            Identity(),  # Query projection's input\n            Identity(),  # Key projection's input\n            Identity(),  # Value projection's input\n        ),\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SelfAttention2d","title":"SelfAttention2d","text":"
SelfAttention2d(\n    channels: int,\n    num_heads: int = 1,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: SelfAttention

Multi-Head 2D Self-Attention layer.

This Module simply chains

Receives:

Type Description Float[Tensor, 'batch channels height width']

Returns:

Type Description Float[Tensor, 'batch channels height width'] Example
self_attention = fl.SelfAttention2d(channels=128, num_heads=8)\n\ntensor = torch.randn(2, 128, 64, 64)\noutput = self_attention(tensor)\n\nassert output.shape == (2, 128, 64, 64)\n

Parameters:

Name Type Description Default channels int

The number of channels of the input and output tensors.

required num_heads int

The number of heads to use.

1 use_bias bool

Whether to use bias in the linear layers.

True is_causal bool

Whether to use causal attention.

False is_optimized bool

Whether to use optimized attention.

True device device | str | None

The device to use.

None dtype dtype | None

The dtype to use.

None Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n    self,\n    channels: int,\n    num_heads: int = 1,\n    use_bias: bool = True,\n    is_causal: bool = False,\n    is_optimized: bool = True,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the 2D Self-Attention layer.\n\n    Args:\n        channels: The number of channels of the input and output tensors.\n        num_heads: The number of heads to use.\n        use_bias: Whether to use bias in the linear layers.\n        is_causal: Whether to use causal attention.\n        is_optimized: Whether to use optimized attention.\n        device: The device to use.\n        dtype: The dtype to use.\n    \"\"\"\n    assert channels % num_heads == 0, f\"channels {channels} must be divisible by num_heads {num_heads}\"\n    self.channels = channels\n\n    super().__init__(\n        embedding_dim=channels,\n        num_heads=num_heads,\n        use_bias=use_bias,\n        is_causal=is_causal,\n        is_optimized=is_optimized,\n        device=device,\n        dtype=dtype,\n    )\n\n    self.insert(0, Lambda(self._tensor_2d_to_sequence))\n    self.append(Lambda(self._sequence_to_tensor_2d))\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SetContext","title":"SetContext","text":"
SetContext(\n    context: str,\n    key: str,\n    callback: Callable[[Any, Any], Any] | None = None,\n)\n

Bases: ContextModule

SetContext layer.

This layer writes to the ContextProvider of its parent Chain.

When called (without a callback), it will When called (with a callback), it will Warning

The context needs to already exist in the ContextProvider

Source code in src/refiners/fluxion/layers/chain.py
def __init__(\n    self,\n    context: str,\n    key: str,\n    callback: Callable[[Any, Any], Any] | None = None,\n) -> None:\n    super().__init__()\n    self.context = context\n    self.key = key\n    self.callback = callback\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SiLU","title":"SiLU","text":"
SiLU()\n

Bases: Activation

Sigmoid Linear Unit activation function.

See [arXiv:1702.03118] Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning for more details.

Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n    super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sigmoid","title":"Sigmoid","text":"
Sigmoid()\n

Bases: Activation

Sigmoid activation function.

Example
sigmoid = fl.Sigmoid()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = sigmoid(tensor)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n    super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sin","title":"Sin","text":"
Sin(*args: Any, **kwargs: Any)\n

Bases: Module

Sine operator layer.

This layer applies the sine function to the input tensor. See also torch.sin.

Example
sin = fl.Sin()\n\ntensor = torch.tensor([0, torch.pi])\noutput = sin(tensor)\n\nexpected_output = torch.tensor([0.0, 0.0])\nassert torch.allclose(output, expected_output, atol=1e-6)\n
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Slicing","title":"Slicing","text":"
Slicing(\n    dim: int = 0,\n    start: int = 0,\n    end: int | None = None,\n    step: int = 1,\n)\n

Bases: Module

Slicing operation layer.

This layer slices the input tensor at the given dimension between the given start and end indices. See also torch.index_select.

Example
slicing = fl.Slicing(dim=1, start=50)\n\ntensor = torch.randn(10, 100)\noutput = slicing(tensor)\n\nassert output.shape == (10, 50)\nassert torch.allclose(output, tensor[:, 50:])\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n    self,\n    dim: int = 0,\n    start: int = 0,\n    end: int | None = None,\n    step: int = 1,\n) -> None:\n    super().__init__()\n    self.dim = dim\n    self.start = start\n    self.end = end\n    self.step = step\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Squeeze","title":"Squeeze","text":"
Squeeze(dim: int)\n

Bases: Module

Squeeze operation layer.

This layer squeezes the input tensor at the given dimension. See also torch.squeeze.

Example
squeeze = fl.Squeeze(dim=1)\n\ntensor = torch.randn(10, 1, 10)\noutput = squeeze(tensor)\n\nassert output.shape == (10, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n    super().__init__()\n    self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sum","title":"Sum","text":"
Sum(*args: Module | Iterable[Module])\n

Bases: Chain

Summation layer.

This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.

Example
summation = fl.Sum(\n    fl.Multiply(scale=2, bias=1),\n    fl.Multiply(scale=3, bias=0),\n)\n\ntensor = torch.ones(1)\noutput = summation(tensor)\n\nassert torch.allclose(output, torch.tensor([6.0]))\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n    super().__init__()\n    self._provider = ContextProvider()\n    modules = cast(\n        tuple[Module],\n        (\n            tuple(args[0])\n            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n            else tuple(args)\n        ),\n    )\n\n    for module in modules:\n        # Violating this would mean a ContextModule ends up in two chains,\n        # with a single one correctly set as its parent.\n        assert (\n            (not isinstance(module, ContextModule))\n            or (not module._can_refresh_parent)\n            or (module.parent is None)\n            or (module.parent == self)\n        ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n    self._regenerate_keys(modules)\n    self._reset_context()\n\n    for module in self:\n        if isinstance(module, ContextModule) and module.parent != self:\n            module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Transpose","title":"Transpose","text":"
Transpose(dim0: int, dim1: int)\n

Bases: Module

Transpose operation layer.

This layer transposes the input tensor between the two given dimensions. See also torch.transpose.

Example
transpose = fl.Transpose(dim0=1, dim1=2)\n\ntensor = torch.randn(10, 20, 30)\noutput = transpose(tensor)\n\nassert output.shape == (10, 30, 20)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim0: int, dim1: int) -> None:\n    super().__init__()\n    self.dim0 = dim0\n    self.dim1 = dim1\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Unflatten","title":"Unflatten","text":"
Unflatten(dim: int)\n

Bases: Module

Unflatten operation layer.

This layer unflattens the input tensor at the given dimension with the given sizes. See also torch.unflatten.

Example
unflatten = fl.Unflatten(dim=1)\n\ntensor = torch.randn(10, 100)\noutput = unflatten(tensor, sizes=(10, 10))\n\nassert output_unflatten.shape == (10, 10, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n    super().__init__()\n    self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Unsqueeze","title":"Unsqueeze","text":"
Unsqueeze(dim: int)\n

Bases: Module

Unsqueeze operation layer.

This layer unsqueezes the input tensor at the given dimension. See also torch.unsqueeze.

Example
unsqueeze = fl.Unsqueeze(dim=1)\n\ntensor = torch.randn(10, 10)\noutput = unsqueeze(tensor)\n\nassert output.shape == (10, 1, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n    super().__init__()\n    self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Upsample","title":"Upsample","text":"
Upsample(\n    channels: int,\n    upsample_factor: int | None = None,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Upsample layer.

This layer upsamples the input by the given scale factor.

Raises:

Type Description RuntimeError

If the context sampling is not set or if the context is empty.

Parameters:

Name Type Description Default channels int

The number of input and output channels.

required upsample_factor int | None

The factor by which to upsample the input. If None, the input shape is taken from the context.

None device device | str | None

The device to use for the convolutional layer.

None dtype dtype | None

The dtype to use for the convolutional layer.

None Source code in src/refiners/fluxion/layers/sampling.py
def __init__(\n    self,\n    channels: int,\n    upsample_factor: int | None = None,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n):\n    \"\"\"Initializes the Upsample layer.\n\n    Args:\n        channels: The number of input and output channels.\n        upsample_factor: The factor by which to upsample the input.\n            If None, the input shape is taken from the context.\n        device: The device to use for the convolutional layer.\n        dtype: The dtype to use for the convolutional layer.\n    \"\"\"\n    self.channels = channels\n    self.upsample_factor = upsample_factor\n    super().__init__(\n        Parallel(\n            Identity(),\n            (\n                Lambda(self._get_static_shape)\n                if upsample_factor is not None\n                else UseContext(context=\"sampling\", key=\"shapes\").compose(lambda x: x.pop())\n            ),\n        ),\n        Interpolate(),\n        Conv2d(\n            in_channels=channels,\n            out_channels=channels,\n            kernel_size=3,\n            padding=1,\n            device=device,\n            dtype=dtype,\n        ),\n    )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.UseContext","title":"UseContext","text":"
UseContext(context: str, key: str)\n

Bases: ContextModule

UseContext layer.

This layer reads from the ContextProvider of its parent Chain.

When called, it will Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, context: str, key: str) -> None:\n    super().__init__()\n    self.context = context\n    self.key = key\n    self.func: Callable[[Any], Any] = lambda x: x\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule","title":"WeightedModule","text":"
WeightedModule(*args: Any, **kwargs: Any)\n

Bases: Module

A module with a weight (Tensor) attribute.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule.device","title":"device property","text":"
device: device\n

Return the device of the module's weight.

"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule.dtype","title":"dtype property","text":"
dtype: dtype\n

Return the dtype of the module's weight.

"},{"location":"reference/fluxion/model_converter/","title":" Model Converter","text":""},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ConversionStage","title":"ConversionStage","text":"

Bases: Enum

Represents the current stage of the conversion process.

Attributes:

Name Type Description INIT

The conversion process has not started.

BASIC_LAYERS_MATCH

The source and target models have the same number of basic layers.

SHAPE_AND_LAYERS_MATCH

The shape of both models agree.

MODELS_OUTPUT_AGREE

The source and target models agree.

"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter","title":"ModelConverter","text":"
ModelConverter(\n    source_model: Module,\n    target_model: Module,\n    source_keys_to_skip: list[str] | None = None,\n    target_keys_to_skip: list[str] | None = None,\n    custom_layer_mapping: (\n        dict[type[Module], type[Module]] | None\n    ) = None,\n    threshold: float = 1e-05,\n    skip_output_check: bool = False,\n    skip_init_check: bool = False,\n    verbose: bool = True,\n)\n

Converts a model's state_dict to match another model's state_dict.

The conversion process consists of three stages
  1. Verify that the source and target models have the same number of basic layers.
  2. Find matching shapes and layers between the source and target models.
  3. Convert the source model's state_dict to match the target model's state_dict.
  4. Compare the outputs of the source and target models.

The conversion process can be run multiple times, and will resume from the last stage.

Example
source = ...\ntarget = ...\n\nconverter = ModelConverter(\n    source_model=source,\n    target_model=target,\n    threshold=0.1,\n    verbose=False\n)\n\nis_converted = converter(args)\nif is_converted:\n    converter.save_to_safetensors(path=\"converted_model.pt\")\n

Parameters:

Name Type Description Default source_model Module

The model to convert from.

required target_model Module

The model to convert to.

required source_keys_to_skip list[str] | None

A list of keys to skip when tracing the source model.

None target_keys_to_skip list[str] | None

A list of keys to skip when tracing the target model.

None custom_layer_mapping dict[type[Module], type[Module]] | None

A dictionary mapping custom layer types between the source and target models.

None threshold float

The threshold for comparing outputs between the source and target models.

1e-05 skip_output_check bool

Whether to skip comparing the outputs of the source and target models.

False skip_init_check bool

Whether to skip checking that the source and target models have the same number of basic layers.

False verbose bool

Whether to print messages during the conversion process.

True Source code in src/refiners/fluxion/model_converter.py
def __init__(\n    self,\n    source_model: nn.Module,\n    target_model: nn.Module,\n    source_keys_to_skip: list[str] | None = None,\n    target_keys_to_skip: list[str] | None = None,\n    custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,\n    threshold: float = 1e-5,\n    skip_output_check: bool = False,\n    skip_init_check: bool = False,\n    verbose: bool = True,\n) -> None:\n    \"\"\"Initializes the ModelConverter.\n\n    Args:\n        source_model: The model to convert from.\n        target_model: The model to convert to.\n        source_keys_to_skip: A list of keys to skip when tracing the source model.\n        target_keys_to_skip: A list of keys to skip when tracing the target model.\n        custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.\n        threshold: The threshold for comparing outputs between the source and target models.\n        skip_output_check: Whether to skip comparing the outputs of the source and target models.\n        skip_init_check: Whether to skip checking that the source and target models have the same number of basic\n            layers.\n        verbose: Whether to print messages during the conversion process.\n\n    \"\"\"\n    self.source_model = source_model\n    self.target_model = target_model\n    self.source_keys_to_skip = source_keys_to_skip or []\n    self.target_keys_to_skip = target_keys_to_skip or []\n    self.custom_layer_mapping = custom_layer_mapping or {}\n    self.threshold = threshold\n    self.skip_output_check = skip_output_check\n    self.skip_init_check = skip_init_check\n    self.verbose = verbose\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.compare_models","title":"compare_models","text":"
compare_models(\n    source_args: ModuleArgs,\n    target_args: ModuleArgs | None = None,\n    threshold: float = 1e-05,\n) -> bool\n

Compare the outputs of the source and target models.

Parameters:

Name Type Description Default source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None threshold float

The threshold for comparing outputs between the source and target models.

1e-05

Returns:

Type Description bool

True if the outputs of the source and target models agree.

Source code in src/refiners/fluxion/model_converter.py
def compare_models(\n    self,\n    source_args: ModuleArgs,\n    target_args: ModuleArgs | None = None,\n    threshold: float = 1e-5,\n) -> bool:\n    \"\"\"Compare the outputs of the source and target models.\n\n    Args:\n        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`\n            is not provided, these arguments will also be passed to the target model.\n        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.\n        threshold: The threshold for comparing outputs between the source and target models.\n\n    Returns:\n        True if the outputs of the source and target models agree.\n    \"\"\"\n    if target_args is None:\n        target_args = source_args\n\n    source_outputs = self._collect_layers_outputs(\n        module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip\n    )\n    target_outputs = self._collect_layers_outputs(\n        module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip\n    )\n\n    diff, prev_source_key, prev_target_key = None, None, None\n    for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):\n        diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()\n        if diff > threshold:\n            self._log(\n                f\"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and\"\n                f\" {target_key}, difference in norm: {diff}\"\n            )\n            return False\n        prev_source_key, prev_target_key = source_key, target_key\n\n    self._log(message=f\"Models agree. Difference in norm: {diff}\")\n\n    return True\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.get_mapping","title":"get_mapping","text":"
get_mapping() -> dict[str, str]\n

Get the mapping between the source and target models' state_dicts.

Source code in src/refiners/fluxion/model_converter.py
def get_mapping(self) -> dict[str, str]:\n    \"\"\"Get the mapping between the source and target models' state_dicts.\"\"\"\n    if not self:\n        raise ValueError(\"The conversion process is not done yet. Run `converter(args)` first.\")\n    assert self._stored_mapping is not None, \"Mapping is not stored\"\n    return self._stored_mapping\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.get_module_signature","title":"get_module_signature","text":"
get_module_signature(module: Module) -> ModelTypeShape\n

Get the signature of a module.

Source code in src/refiners/fluxion/model_converter.py
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:\n    \"\"\"Get the signature of a module.\"\"\"\n    layer_type = self._infer_basic_layer_type(module=module)\n    assert layer_type is not None, f\"Module {module} is not a basic layer\"\n    param_shapes = [p.shape for p in module.parameters()]\n    return (layer_type, tuple(param_shapes))\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.get_state_dict","title":"get_state_dict","text":"
get_state_dict() -> dict[str, Tensor]\n

Get the converted state_dict.

Source code in src/refiners/fluxion/model_converter.py
def get_state_dict(self) -> dict[str, Tensor]:\n    \"\"\"Get the converted state_dict.\"\"\"\n    if not self:\n        raise ValueError(\"The conversion process is not done yet. Run `converter(args)` first.\")\n    return self.target_model.state_dict()\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.map_state_dicts","title":"map_state_dicts","text":"
map_state_dicts(\n    source_args: ModuleArgs,\n    target_args: ModuleArgs | None = None,\n) -> dict[str, str] | None\n

Find a mapping between the source and target models' state_dicts.

Parameters:

Name Type Description Default source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None

Returns:

Type Description dict[str, str] | None

A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.

Source code in src/refiners/fluxion/model_converter.py
def map_state_dicts(\n    self,\n    source_args: ModuleArgs,\n    target_args: ModuleArgs | None = None,\n) -> dict[str, str] | None:\n    \"\"\"Find a mapping between the source and target models' state_dicts.\n\n    Args:\n        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`\n            is not provided, these arguments will also be passed to the target model.\n        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.\n\n    Returns:\n        A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.\n    \"\"\"\n    if target_args is None:\n        target_args = source_args\n\n    source_order = self._trace_module_execution_order(\n        module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip\n    )\n    target_order = self._trace_module_execution_order(\n        module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip\n    )\n\n    if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):\n        return None\n\n    mapping: dict[str, str] = {}\n    for source_type_shape in source_order:\n        source_keys = source_order[source_type_shape]\n        target_type_shape = source_type_shape\n        if not self._is_torch_basic_layer(module_type=source_type_shape[0]):\n            for source_custom_type, target_custom_type in self.custom_layer_mapping.items():\n                if source_custom_type == source_type_shape[0]:\n                    target_type_shape = (target_custom_type, source_type_shape[1])\n                    break\n\n        target_keys = target_order[target_type_shape]\n        mapping.update(zip(target_keys, source_keys))\n\n    return mapping\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.run","title":"run","text":"
run(\n    source_args: ModuleArgs,\n    target_args: ModuleArgs | None = None,\n) -> bool\n

Run the conversion process.

Parameters:

Name Type Description Default source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None

Returns:

Type Description bool

True if the conversion process is done and the models agree.

Source code in src/refiners/fluxion/model_converter.py
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:\n    \"\"\"Run the conversion process.\n\n    Args:\n        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`\n            is not provided, these arguments will also be passed to the target model.\n        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,\n            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.\n\n    Returns:\n        True if the conversion process is done and the models agree.\n    \"\"\"\n    if target_args is None:\n        target_args = source_args\n\n    match self.stage:\n        case ConversionStage.MODELS_OUTPUT_AGREE:\n            self._increment_stage()\n            return True\n\n        case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(\n            source_args=source_args, target_args=target_args\n        ):\n            self._increment_stage()\n            return True\n\n        case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(\n            source_args=source_args, target_args=target_args\n        ):\n            self._increment_stage()\n            return self.run(source_args=source_args, target_args=target_args)\n\n        case ConversionStage.INIT if self._run_init_stage():\n            self._increment_stage()\n            return self.run(source_args=source_args, target_args=target_args)\n\n        case _:\n            self._log(message=f\"Conversion failed at stage {self.stage.value}\")\n            return False\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModelConverter.save_to_safetensors","title":"save_to_safetensors","text":"
save_to_safetensors(\n    path: Path | str,\n    metadata: dict[str, str] | None = None,\n    half: bool = False,\n) -> None\n

Save the converted model to a SafeTensors file.

Warning

This method can only be called after the conversion process is done.

Parameters:

Name Type Description Default path Path | str

The path to save the converted model to.

required metadata dict[str, str] | None

Metadata to save with the converted model.

None half bool

Whether to save the converted model as half precision.

False

Raises:

Type Description ValueError

If the conversion process is not done yet. Run converter first.

Source code in src/refiners/fluxion/model_converter.py
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:\n    \"\"\"Save the converted model to a SafeTensors file.\n\n    Warning:\n        This method can only be called after the conversion process is done.\n\n    Args:\n        path: The path to save the converted model to.\n        metadata: Metadata to save with the converted model.\n        half: Whether to save the converted model as half precision.\n\n    Raises:\n        ValueError: If the conversion process is not done yet. Run `converter` first.\n    \"\"\"\n    if not self:\n        raise ValueError(\"The conversion process is not done yet. Run `converter(args)` first.\")\n    state_dict = self.get_state_dict()\n    if half:\n        state_dict = {key: value.half() for key, value in state_dict.items()}\n    save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)\n
"},{"location":"reference/fluxion/model_converter/#refiners.fluxion.model_converter.ModuleArgsDict","title":"ModuleArgsDict","text":"

Bases: TypedDict

Represents positional and keyword arguments passed to a module.

"},{"location":"reference/fluxion/utils/","title":" Utils","text":""},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.image_to_tensor","title":"image_to_tensor","text":"
image_to_tensor(\n    image: Image,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n) -> Tensor\n

Convert a PIL Image to a Tensor.

Parameters:

Name Type Description Default image Image

The image to convert.

required device device | str | None

The device to use for the tensor.

None dtype dtype | None

The dtype to use for the tensor.

None

Returns:

Type Description Tensor

The converted tensor.

Note

If the image is in mode RGB the tensor will have shape [3, H, W], otherwise [1, H, W] for mode L (grayscale) or [4, H, W] for mode RGBA.

Values are normalized to the range [0, 1].

Source code in src/refiners/fluxion/utils.py
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:\n    \"\"\"Convert a PIL Image to a Tensor.\n\n    Args:\n        image: The image to convert.\n        device: The device to use for the tensor.\n        dtype: The dtype to use for the tensor.\n\n    Returns:\n        The converted tensor.\n\n    Note:\n        If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,\n        otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.\n\n        Values are normalized to the range `[0, 1]`.\n    \"\"\"\n    image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)\n\n    assert isinstance(image.mode, str)  # type: ignore\n    match image.mode:\n        case \"L\":\n            image_tensor = image_tensor.unsqueeze(0)\n        case \"RGBA\" | \"RGB\":\n            image_tensor = image_tensor.permute(2, 0, 1)\n        case _:\n            raise ValueError(f\"Unsupported image mode: {image.mode}\")\n\n    return image_tensor.unsqueeze(0)\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.load_from_safetensors","title":"load_from_safetensors","text":"
load_from_safetensors(\n    path: Path | str, device: device | str = \"cpu\"\n) -> dict[str, Tensor]\n

Load tensors from a SafeTensor file from disk.

Parameters:

Name Type Description Default path Path | str

The path to the file.

required device device | str

The device to use for the tensors.

'cpu'

Returns:

Type Description dict[str, Tensor]

The loaded tensors.

Source code in src/refiners/fluxion/utils.py
def load_from_safetensors(path: Path | str, device: Device | str = \"cpu\") -> dict[str, Tensor]:\n    \"\"\"Load tensors from a SafeTensor file from disk.\n\n    Args:\n        path: The path to the file.\n        device: The device to use for the tensors.\n\n    Returns:\n        The loaded tensors.\n    \"\"\"\n    return _load_file(path, str(device))\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.load_tensors","title":"load_tensors","text":"
load_tensors(\n    path: Path | str, /, device: device | str = \"cpu\"\n) -> dict[str, Tensor]\n

Load tensors from a file saved with torch.save from disk.

Note

This function uses the weights_only mode of torch.load for additional safety.

Warning

Still, only load data you trust and favor using load_from_safetensors instead.

Source code in src/refiners/fluxion/utils.py
def load_tensors(path: Path | str, /, device: Device | str = \"cpu\") -> dict[str, Tensor]:\n    \"\"\"Load tensors from a file saved with `torch.save` from disk.\n\n    Note:\n        This function uses the `weights_only` mode of `torch.load` for additional safety.\n\n    Warning:\n        Still, **only load data you trust** and favor using\n        [`load_from_safetensors`][refiners.fluxion.utils.load_from_safetensors] instead.\n    \"\"\"\n    # see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560\n    with warnings.catch_warnings():\n        warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"TypedStorage is deprecated\")\n        tensors = torch.load(path, map_location=device, weights_only=True)  # type: ignore\n\n    assert isinstance(tensors, dict) and all(\n        isinstance(key, str) and isinstance(value, Tensor)\n        for key, value in tensors.items()  # type: ignore\n    ), \"Invalid tensor file, expected a dict[str, Tensor]\"\n\n    return cast(dict[str, Tensor], tensors)\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.save_to_safetensors","title":"save_to_safetensors","text":"
save_to_safetensors(\n    path: Path | str,\n    tensors: dict[str, Tensor],\n    metadata: dict[str, str] | None = None,\n) -> None\n

Save tensors to a SafeTensor file on disk.

Parameters:

Name Type Description Default path Path | str

The path to the file.

required tensors dict[str, Tensor]

The tensors to save.

required metadata dict[str, str] | None

The metadata to save.

None Source code in src/refiners/fluxion/utils.py
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:\n    \"\"\"Save tensors to a SafeTensor file on disk.\n\n    Args:\n        path: The path to the file.\n        tensors: The tensors to save.\n        metadata: The metadata to save.\n    \"\"\"\n    _save_file(tensors, path, metadata)  # type: ignore\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.summarize_tensor","title":"summarize_tensor","text":"
summarize_tensor(tensor: Tensor) -> str\n

Summarize a tensor.

This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.

Parameters:

Name Type Description Default tensor Tensor

The tensor to summarize.

required

Returns:

Type Description str

The summary string.

Source code in src/refiners/fluxion/utils.py
def summarize_tensor(tensor: torch.Tensor, /) -> str:\n    \"\"\"Summarize a tensor.\n\n    This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.\n\n    Args:\n        tensor: The tensor to summarize.\n\n    Returns:\n        The summary string.\n    \"\"\"\n    info_list = [\n        f\"shape=({', '.join(map(str, tensor.shape))})\",\n        f\"dtype={str(object=tensor.dtype).removeprefix('torch.')}\",\n        f\"device={tensor.device}\",\n    ]\n    if tensor.is_complex():\n        tensor_f = tensor.real.float()\n    else:\n        if tensor.numel() > 0:\n            info_list.extend(\n                [\n                    f\"min={tensor.min():.2f}\",  # type: ignore\n                    f\"max={tensor.max():.2f}\",  # type: ignore\n                ]\n            )\n        tensor_f = tensor.float()\n\n    info_list.extend(\n        [\n            f\"mean={tensor_f.mean():.2f}\",\n            f\"std={tensor_f.std():.2f}\",\n            f\"norm={norm(x=tensor_f):.2f}\",\n            f\"grad={tensor.requires_grad}\",\n        ]\n    )\n\n    return \"Tensor(\" + \", \".join(info_list) + \")\"\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.tensor_to_image","title":"tensor_to_image","text":"
tensor_to_image(tensor: Tensor) -> Image\n

Convert a Tensor to a PIL Image.

Parameters:

Name Type Description Default tensor Tensor

The tensor to convert.

required

Returns:

Type Description Image

The converted image.

Note

The tensor must have shape [1, channels, height, width] where the number of channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).

Expected values are in the range [0, 1] and are clamped to this range.

Source code in src/refiners/fluxion/utils.py
def tensor_to_image(tensor: Tensor) -> Image.Image:\n    \"\"\"Convert a Tensor to a PIL Image.\n\n    Args:\n        tensor: The tensor to convert.\n\n    Returns:\n        The converted image.\n\n    Note:\n        The tensor must have shape `[1, channels, height, width]` where the number of\n        channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).\n\n        Expected values are in the range `[0, 1]` and are clamped to this range.\n    \"\"\"\n    assert tensor.ndim == 4 and tensor.shape[0] == 1, f\"Unsupported tensor shape: {tensor.shape}\"\n    num_channels = tensor.shape[1]\n    tensor = tensor.clamp(0, 1).squeeze(0)\n    tensor = tensor.to(torch.float32)  # to avoid numpy error with bfloat16\n\n    match num_channels:\n        case 1:\n            tensor = tensor.squeeze(0)\n        case 3 | 4:\n            tensor = tensor.permute(1, 2, 0)\n        case _:\n            raise ValueError(f\"Unsupported number of channels: {num_channels}\")\n\n    return Image.fromarray((tensor.cpu().numpy() * 255).astype(\"uint8\"))  # type: ignore[reportUnknownType]\n
"},{"location":"reference/foundationals/clip/","title":" CLIP","text":""},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoder","title":"CLIPImageEncoder","text":"
CLIPImageEncoder(\n    image_size: int = 224,\n    embedding_dim: int = 768,\n    output_dim: int = 512,\n    patch_size: int = 32,\n    num_layers: int = 12,\n    num_attention_heads: int = 12,\n    feedforward_dim: int = 3072,\n    layer_norm_eps: float = 1e-05,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Contrastive Language-Image Pretraining (CLIP) image encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Parameters:

Name Type Description Default image_size int

The size of the input image.

224 embedding_dim int

The dimension of the embedding.

768 output_dim int

The dimension of the output.

512 patch_size int

The size of the patches.

32 num_layers int

The number of layers.

12 num_attention_heads int

The number of attention heads.

12 feedforward_dim int

The dimension of the feedforward layer.

3072 layer_norm_eps float

The epsilon value for normalization.

1e-05 device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(\n    self,\n    image_size: int = 224,\n    embedding_dim: int = 768,\n    output_dim: int = 512,\n    patch_size: int = 32,\n    num_layers: int = 12,\n    num_attention_heads: int = 12,\n    feedforward_dim: int = 3072,\n    layer_norm_eps: float = 1e-5,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize a CLIP image encoder.\n\n    Args:\n        image_size: The size of the input image.\n        embedding_dim: The dimension of the embedding.\n        output_dim: The dimension of the output.\n        patch_size: The size of the patches.\n        num_layers: The number of layers.\n        num_attention_heads: The number of attention heads.\n        feedforward_dim: The dimension of the feedforward layer.\n        layer_norm_eps: The epsilon value for normalization.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    self.image_size = image_size\n    self.embedding_dim = embedding_dim\n    self.output_dim = output_dim\n    self.patch_size = patch_size\n    self.num_layers = num_layers\n    self.num_attention_heads = num_attention_heads\n    self.feedforward_dim = feedforward_dim\n    cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]\n    super().__init__(\n        ViTEmbeddings(\n            image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype\n        ),\n        fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n        fl.Chain(\n            TransformerLayer(\n                embedding_dim=embedding_dim,\n                feedforward_dim=feedforward_dim,\n                num_attention_heads=num_attention_heads,\n                layer_norm_eps=layer_norm_eps,\n                device=device,\n                dtype=dtype,\n            )\n            for _ in range(num_layers)\n        ),\n        fl.Lambda(func=cls_token_pooling),\n        fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n        fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),\n    )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoderG","title":"CLIPImageEncoderG","text":"
CLIPImageEncoderG(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: CLIPImageEncoder

CLIP giant image encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1664

output_dim int

1280

patch_size int

14

num_layers int

48

num_attention_heads int

16

feedforward_dim int

8192

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n    \"\"\"Initialize CLIP giant image encoder.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1664,\n        output_dim=1280,\n        patch_size=14,\n        num_layers=48,\n        num_attention_heads=16,\n        feedforward_dim=8192,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoderH","title":"CLIPImageEncoderH","text":"
CLIPImageEncoderH(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: CLIPImageEncoder

CLIP huge image encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1280

output_dim int

1024

patch_size int

14

num_layers int

32

num_attention_heads int

16

feedforward_dim int

5120

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n    \"\"\"Initialize CLIP huge image encoder.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1280,\n        output_dim=1024,\n        patch_size=14,\n        num_layers=32,\n        num_attention_heads=16,\n        feedforward_dim=5120,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoder","title":"CLIPTextEncoder","text":"
CLIPTextEncoder(\n    embedding_dim: int = 768,\n    max_sequence_length: int = 77,\n    vocabulary_size: int = 49408,\n    num_layers: int = 12,\n    num_attention_heads: int = 12,\n    feedforward_dim: int = 3072,\n    layer_norm_eps: float = 1e-05,\n    use_quick_gelu: bool = False,\n    tokenizer: CLIPTokenizer | None = None,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Contrastive Language-Image Pretraining (CLIP) text encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Parameters:

Name Type Description Default embedding_dim int

The embedding dimension.

768 max_sequence_length int

The maximum sequence length.

77 vocabulary_size int

The vocabulary size.

49408 num_layers int

The number of layers.

12 num_attention_heads int

The number of attention heads.

12 feedforward_dim int

The feedforward dimension.

3072 layer_norm_eps float

The epsilon value for layer normalization.

1e-05 use_quick_gelu bool

Whether to use the quick GeLU activation function.

False tokenizer CLIPTokenizer | None

The tokenizer.

None device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(\n    self,\n    embedding_dim: int = 768,\n    max_sequence_length: int = 77,\n    vocabulary_size: int = 49408,\n    num_layers: int = 12,\n    num_attention_heads: int = 12,\n    feedforward_dim: int = 3072,\n    layer_norm_eps: float = 1e-5,\n    use_quick_gelu: bool = False,\n    tokenizer: CLIPTokenizer | None = None,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize CLIP text encoder.\n\n    Args:\n        embedding_dim: The embedding dimension.\n        max_sequence_length: The maximum sequence length.\n        vocabulary_size: The vocabulary size.\n        num_layers: The number of layers.\n        num_attention_heads: The number of attention heads.\n        feedforward_dim: The feedforward dimension.\n        layer_norm_eps: The epsilon value for layer normalization.\n        use_quick_gelu: Whether to use the quick GeLU activation function.\n        tokenizer: The tokenizer.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    self.embedding_dim = embedding_dim\n    self.max_sequence_length = max_sequence_length\n    self.vocabulary_size = vocabulary_size\n    self.num_layers = num_layers\n    self.num_attention_heads = num_attention_heads\n    self.feedforward_dim = feedforward_dim\n    self.layer_norm_eps = layer_norm_eps\n    self.use_quick_gelu = use_quick_gelu\n    super().__init__(\n        tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),\n        fl.Converter(set_dtype=False),\n        fl.Sum(\n            TokenEncoder(\n                vocabulary_size=vocabulary_size,\n                embedding_dim=embedding_dim,\n                device=device,\n                dtype=dtype,\n            ),\n            PositionalEncoder(\n                max_sequence_length=max_sequence_length,\n                embedding_dim=embedding_dim,\n                device=device,\n                dtype=dtype,\n            ),\n        ),\n        *(\n            TransformerLayer(\n                embedding_dim=embedding_dim,\n                num_attention_heads=num_attention_heads,\n                feedforward_dim=feedforward_dim,\n                layer_norm_eps=layer_norm_eps,\n                device=device,\n                dtype=dtype,\n            )\n            for _ in range(num_layers)\n        ),\n        fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n    )\n    if use_quick_gelu:\n        for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):\n            parent.replace(\n                old_module=gelu,\n                new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),\n            )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderG","title":"CLIPTextEncoderG","text":"
CLIPTextEncoderG(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: CLIPTextEncoder

CLIP giant text encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1280

num_layers int

32

num_attention_heads int

20

feedforward_dim int

5120

tokenizer CLIPTokenizer

CLIPTokenizer(pad_token_id=0)

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n    \"\"\"Initialize CLIP giant text encoder.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    tokenizer = CLIPTokenizer(pad_token_id=0)\n    super().__init__(\n        embedding_dim=1280,\n        num_layers=32,\n        num_attention_heads=20,\n        feedforward_dim=5120,\n        tokenizer=tokenizer,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderH","title":"CLIPTextEncoderH","text":"
CLIPTextEncoderH(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: CLIPTextEncoder

CLIP huge text encoder.

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1024

num_layers int

23

num_attention_heads int

16

feedforward_dim int

4096

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n    \"\"\"Initialize CLIP huge text encoder.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1024,\n        num_layers=23,\n        num_attention_heads=16,\n        feedforward_dim=4096,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderL","title":"CLIPTextEncoderL","text":"
CLIPTextEncoderL(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: CLIPTextEncoder

CLIP large text encoder.

Note

We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation of OpenAI (https://github.com/openai/CLIP/blob/a1d0717/clip/model.py#L166)

See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.

Attributes:

Name Type Description embedding_dim int

768

num_layers int

12

num_attention_heads int

12

feedforward_dim int

3072

use_quick_gelu bool

True

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n    \"\"\"Initialize CLIP large text encoder.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=768,\n        num_layers=12,\n        num_attention_heads=12,\n        feedforward_dim=3072,\n        use_quick_gelu=True,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/","title":" DINOv2","text":""},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_base","title":"DINOv2_base","text":"
DINOv2_base(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 base model.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.

Attributes:

Name Type Description embedding_dim int

768

patch_size int

14

image_size int

518

num_layers int

12

num_heads int

12

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 base model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=768,\n        patch_size=14,\n        image_size=518,\n        num_layers=12,\n        num_heads=12,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_base_reg","title":"DINOv2_base_reg","text":"
DINOv2_base_reg(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 base model with register.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.

Attributes:

Name Type Description embedding_dim int

768

patch_size int

14

image_size int

518

num_layers int

12

num_heads int

12

num_registers int

4

interpolate_antialias bool

True

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 base model with register.\n\n    Args:\n        device (torch.device | str | None): The PyTorch device to use.\n        dtype (torch.dtype | None): The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=768,\n        patch_size=14,\n        image_size=518,\n        num_layers=12,\n        num_heads=12,\n        num_registers=4,\n        interpolate_antialias=True,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_giant","title":"DINOv2_giant","text":"
DINOv2_giant(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 giant model.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1536

feedforward_dim int

4096

patch_size int

14

image_size int

518

num_layers int

40

num_heads int

24

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 giant model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1536,\n        feedforward_dim=4096,\n        patch_size=14,\n        image_size=518,\n        num_layers=40,\n        num_heads=24,\n        activation=GLU(SiLU()),\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_giant_reg","title":"DINOv2_giant_reg","text":"
DINOv2_giant_reg(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 giant model with register.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers

Attributes:

Name Type Description embedding_dim int

1536

feedforward_dim int

4096

patch_size int

14

image_size int

518

num_layers int

40

num_heads int

24

num_registers int

4

interpolate_antialias bool

True

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 giant model with register.\n\n    Args:\n        device (torch.device | str | None): The PyTorch device to use.\n        dtype (torch.dtype | None): The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1536,\n        feedforward_dim=4096,\n        patch_size=14,\n        image_size=518,\n        num_layers=40,\n        num_heads=24,\n        num_registers=4,\n        interpolate_antialias=True,\n        activation=GLU(SiLU()),\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_large","title":"DINOv2_large","text":"
DINOv2_large(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 large model.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.

Attributes:

Name Type Description embedding_dim int

1024

patch_size int

14

image_size int

518

num_layers int

24

num_heads int

16

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 large model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1024,\n        patch_size=14,\n        image_size=518,\n        num_layers=24,\n        num_heads=16,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_large_reg","title":"DINOv2_large_reg","text":"
DINOv2_large_reg(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 large model with register.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.

Attributes:

Name Type Description embedding_dim int

1024

patch_size int

14

image_size int

518

num_layers int

24

num_heads int

16

num_registers int

4

interpolate_antialias bool

True

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 large model with register.\n\n    Args:\n        device (torch.device | str | None): The PyTorch device to use.\n        dtype (torch.dtype | None): The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=1024,\n        patch_size=14,\n        image_size=518,\n        num_layers=24,\n        num_heads=16,\n        num_registers=4,\n        interpolate_antialias=True,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_small","title":"DINOv2_small","text":"
DINOv2_small(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 small model.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.

Attributes:

Name Type Description embedding_dim int

384

patch_size int

14

image_size int

518

num_layers int

12

num_heads int

6

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 small model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=384,\n        patch_size=14,\n        image_size=518,\n        num_layers=12,\n        num_heads=6,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_small_reg","title":"DINOv2_small_reg","text":"
DINOv2_small_reg(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: ViT

DINOv2 small model with register.

See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.

Attributes:

Name Type Description embedding_dim int

384

patch_size int

14

image_size int

518

num_layers int

12

num_heads int

6

num_registers int

4

interpolate_antialias bool

True

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n    self,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize DINOv2 small model with register.\n\n    Args:\n        device (torch.device | str | None): The PyTorch device to use.\n        dtype (torch.dtype | None): The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        embedding_dim=384,\n        patch_size=14,\n        image_size=518,\n        num_layers=12,\n        num_heads=6,\n        num_registers=4,\n        interpolate_antialias=True,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.ViT","title":"ViT","text":"
ViT(\n    embedding_dim: int = 768,\n    patch_size: int = 16,\n    image_size: int = 224,\n    num_layers: int = 12,\n    num_heads: int = 12,\n    norm_eps: float = 1e-06,\n    mlp_ratio: int = 4,\n    num_registers: int = 0,\n    activation: Activation = fl.GeLU(),\n    feedforward_dim: int | None = None,\n    interpolate_antialias: bool = False,\n    interpolate_mode: str = \"bicubic\",\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Vision Transformer (ViT) model.

See [arXiv:2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale for more details.

Parameters:

Name Type Description Default embedding_dim int

The dimension of the embedding.

768 patch_size int

The size of the patches.

16 image_size int

The size of the input image.

224 num_layers int

The number of layers.

12 num_heads int

The number of heads.

12 norm_eps float

The epsilon value for normalization.

1e-06 mlp_ratio int

The ratio for the multi-layer perceptron (MLP).

4 num_registers int

The number of registers.

0 activation Activation

The activation function.

GeLU() feedforward_dim int | None

The dimension of the feedforward layer.

None interpolate_antialias bool

Whether to use antialiasing for interpolation.

False interpolate_mode str

The interpolation mode.

'bicubic' device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/dinov2/vit.py
def __init__(\n    self,\n    embedding_dim: int = 768,\n    patch_size: int = 16,\n    image_size: int = 224,\n    num_layers: int = 12,\n    num_heads: int = 12,\n    norm_eps: float = 1e-6,\n    mlp_ratio: int = 4,\n    num_registers: int = 0,\n    activation: Activation = fl.GeLU(),\n    feedforward_dim: int | None = None,\n    interpolate_antialias: bool = False,\n    interpolate_mode: str = \"bicubic\",\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n) -> None:\n    \"\"\"Initialize a Vision Transformer (ViT) model.\n\n    Args:\n        embedding_dim: The dimension of the embedding.\n        patch_size: The size of the patches.\n        image_size: The size of the input image.\n        num_layers: The number of layers.\n        num_heads: The number of heads.\n        norm_eps: The epsilon value for normalization.\n        mlp_ratio: The ratio for the multi-layer perceptron (MLP).\n        num_registers: The number of registers.\n        activation: The activation function.\n        feedforward_dim: The dimension of the feedforward layer.\n        interpolate_antialias: Whether to use antialiasing for interpolation.\n        interpolate_mode: The interpolation mode.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    num_patches = image_size // patch_size\n    self.embedding_dim = embedding_dim\n    self.patch_size = patch_size\n    self.image_size = image_size\n    self.num_layers = num_layers\n    self.num_heads = num_heads\n    self.norm_eps = norm_eps\n    self.mlp_ratio = mlp_ratio\n    self.num_registers = num_registers\n    self.feedforward_dim = feedforward_dim\n\n    super().__init__(\n        fl.Concatenate(\n            ClassToken(\n                embedding_dim=embedding_dim,\n                device=device,\n                dtype=dtype,\n            ),\n            PatchEncoder(\n                in_channels=3,\n                out_channels=embedding_dim,\n                patch_size=patch_size,\n                device=device,\n                dtype=dtype,\n            ),\n            dim=1,\n        ),\n        PositionalEncoder(\n            PositionalEmbedding(\n                sequence_length=num_patches**2 + 1,\n                embedding_dim=embedding_dim,\n                patch_size=patch_size,\n                device=device,\n                dtype=dtype,\n            ),\n            fl.Chain(\n                fl.Parallel(\n                    fl.Identity(),\n                    fl.UseContext(context=\"dinov2_vit\", key=\"input\"),\n                ),\n                InterpolateEmbedding(\n                    mode=interpolate_mode,\n                    antialias=interpolate_antialias,\n                    patch_size=patch_size,\n                ),\n            ),\n        ),\n        Transformer(\n            TransformerLayer(\n                embedding_dim=embedding_dim,\n                feedforward_dim=feedforward_dim,\n                activation=activation,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                norm_eps=norm_eps,\n                device=device,\n                dtype=dtype,\n            )\n            for _ in range(num_layers)\n        ),\n        fl.LayerNorm(\n            normalized_shape=embedding_dim,\n            eps=norm_eps,\n            device=device,\n            dtype=dtype,\n        ),\n    )\n\n    if self.num_registers > 0:\n        registers = Registers(\n            num_registers=num_registers,\n            embedding_dim=embedding_dim,\n            device=device,\n            dtype=dtype,\n        )\n        self.insert_before_type(Transformer, registers)\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.preprocess","title":"preprocess","text":"
preprocess(img: Image, dim: int = 224) -> Tensor\n

Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation. Note that this only resizes and normalizes the image, there is no center crop.

Parameters:

Name Type Description Default img Image

The image.

required dim int

The square dimension to resize the image. Typically 224 or 518.

224

Returns:

Type Description Tensor

A float32 tensor with shape (3, dim, dim).

Source code in src/refiners/foundationals/dinov2/dinov2.py
def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:\n    \"\"\"\n    Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.\n    Note that this only resizes and normalizes the image, there is no center crop.\n\n    Args:\n        img: The image.\n        dim: The square dimension to resize the image. Typically 224 or 518.\n\n    Returns:\n        A float32 tensor with shape (3, dim, dim).\n    \"\"\"\n    img = img.convert(\"RGB\").resize((dim, dim))  # type: ignore\n    t = image_to_tensor(img).squeeze()\n    return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n
"},{"location":"reference/foundationals/latent_diffusion/","title":" Latent Diffusion","text":""},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.FixedGroupNorm","title":"FixedGroupNorm","text":"
FixedGroupNorm(target: GroupNorm)\n

Bases: Chain, Adapter[GroupNorm]

Adapter for GroupNorm layers to fix the running mean and variance.

This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers are consistent across tiles.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(self, target: fl.GroupNorm) -> None:\n    self.mean = None\n    self.var = None\n    with self.setup_adapter(target):\n        super().__init__(fl.Lambda(self.compute_group_norm))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder","title":"LatentDiffusionAutoencoder","text":"
LatentDiffusionAutoencoder(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Latent diffusion autoencoder model.

Attributes:

Name Type Description encoder_scale

The encoder scale to use.

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n    self,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initializes the model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        Encoder(device=device, dtype=dtype),\n        Decoder(device=device, dtype=dtype),\n    )\n    self._tile_size = None\n    self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.decode","title":"decode","text":"
decode(x: Tensor) -> Tensor\n

Decode a latent tensor.

Parameters:

Name Type Description Default x Tensor

The latent to decode.

required

Returns:

Type Description Tensor

The decoded image tensor.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def decode(self, x: Tensor) -> Tensor:\n    \"\"\"Decode a latent tensor.\n\n    Args:\n        x: The latent to decode.\n\n    Returns:\n        The decoded image tensor.\n    \"\"\"\n    decoder = self[1]\n    x = decoder(x / self.encoder_scale)\n    return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.encode","title":"encode","text":"
encode(x: Tensor) -> Tensor\n

Encode an image.

Parameters:

Name Type Description Default x Tensor

The image tensor to encode.

required

Returns:

Type Description Tensor

The encoded tensor.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def encode(self, x: Tensor) -> Tensor:\n    \"\"\"Encode an image.\n\n    Args:\n        x: The image tensor to encode.\n\n    Returns:\n        The encoded tensor.\n    \"\"\"\n    encoder = self[0]\n    x = self.encoder_scale * encoder(x)\n    return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.image_to_latents","title":"image_to_latents","text":"
image_to_latents(image: Image) -> Tensor\n

Encode an image to latents.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def image_to_latents(self, image: Image.Image) -> Tensor:\n    \"\"\"\n    Encode an image to latents.\n    \"\"\"\n    return self.images_to_latents([image])\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.images_to_latents","title":"images_to_latents","text":"
images_to_latents(images: list[Image]) -> Tensor\n

Convert a list of images to latents.

Parameters:

Name Type Description Default images list[Image]

The list of images to convert.

required

Returns:

Type Description Tensor

A tensor containing the latents associated with the images.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def images_to_latents(self, images: list[Image.Image]) -> Tensor:\n    \"\"\"Convert a list of images to latents.\n\n    Args:\n        images: The list of images to convert.\n\n    Returns:\n        A tensor containing the latents associated with the images.\n    \"\"\"\n    x = images_to_tensor(images, device=self.device, dtype=self.dtype)\n    x = 2 * x - 1\n    return self.encode(x)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.latents_to_image","title":"latents_to_image","text":"
latents_to_image(x: Tensor) -> Image\n

Decode latents to an image.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_image(self, x: Tensor) -> Image.Image:\n    \"\"\"\n    Decode latents to an image.\n    \"\"\"\n    if x.shape[0] != 1:\n        raise ValueError(f\"Expected batch size of 1, got {x.shape[0]}\")\n\n    return self.latents_to_images(x)[0]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.latents_to_images","title":"latents_to_images","text":"
latents_to_images(x: Tensor) -> list[Image]\n

Convert a tensor of latents to images.

Parameters:

Name Type Description Default x Tensor

The tensor of latents to convert.

required

Returns:

Type Description list[Image]

A list of images associated with the latents.

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_images(self, x: Tensor) -> list[Image.Image]:\n    \"\"\"Convert a tensor of latents to images.\n\n    Args:\n        x: The tensor of latents to convert.\n\n    Returns:\n        A list of images associated with the latents.\n    \"\"\"\n    x = self.decode(x)\n    x = (x + 1) / 2\n    return tensor_to_images(x)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_image_to_latents","title":"tiled_image_to_latents","text":"
tiled_image_to_latents(image: Image) -> Tensor\n

Convert an image to latents with gradient blending to smooth tile edges.

You need to activate the tiled inference context manager with the tiled_inference method to use this method.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): latents = lda.tiled_image_to_latents(sample_image)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:\n    \"\"\"\n    Convert an image to latents with gradient blending to smooth tile edges.\n\n    You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.\n\n    ```python\n    with lda.tiled_inference(sample_image, tile_size=(768, 1024)):\n        latents = lda.tiled_image_to_latents(sample_image)\n    \"\"\"\n    if self._tile_size is None:\n        raise ValueError(\"Tiled inference context manager not active. Use `tiled_inference` method to activate.\")\n\n    assert self._tile_size is not None and self._blending is not None\n    image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)\n    image_tensor = 2 * image_tensor - 1\n    return self._tiled_encode(image_tensor, self._tile_size, self._blending)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_inference","title":"tiled_inference","text":"
tiled_inference(\n    image: Image,\n    tile_size: tuple[int, int] = (512, 512),\n    blending: int = 64,\n) -> Generator[None, None, None]\n

Context manager for tiled inference operations to save VRAM for large images.

This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the image.

Be careful not to use the normal image_to_latents and latents_to_image methods while this context manager is active, as this will fail silently and run the operation without tiling.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32): latents = lda.tiled_image_to_latents(sample_image) decoded_image = lda.tiled_latents_to_image(latents)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
@contextmanager\ndef tiled_inference(\n    self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64\n) -> Generator[None, None, None]:\n    \"\"\"\n    Context manager for tiled inference operations to save VRAM for large images.\n\n    This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the\n    autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is\n    consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the\n    image.\n\n    Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is\n    active, as this will fail silently and run the operation without tiling.\n\n    ```python\n    with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):\n        latents = lda.tiled_image_to_latents(sample_image)\n        decoded_image = lda.tiled_latents_to_image(latents)\n    \"\"\"\n    try:\n        self._blending = blending\n        self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])\n        self._add_fixed_group_norm(image, inference_size=self._tile_size)\n        yield\n    finally:\n        self._remove_fixed_group_norm()\n        self._tile_size = None\n        self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_latents_to_image","title":"tiled_latents_to_image","text":"
tiled_latents_to_image(x: Tensor) -> Image\n

Convert latents to an image with gradient blending to smooth tile edges.

You need to activate the tiled inference context manager with the tiled_inference method to use this method.

```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): image = lda.tiled_latents_to_image(latents)

Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:\n    \"\"\"\n    Convert latents to an image with gradient blending to smooth tile edges.\n\n    You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.\n\n    ```python\n    with lda.tiled_inference(sample_image, tile_size=(768, 1024)):\n        image = lda.tiled_latents_to_image(latents)\n    \"\"\"\n    if self._tile_size is None:\n        raise ValueError(\"Tiled inference context manager not active. Use `tiled_inference` method to activate.\")\n\n    assert self._tile_size is not None and self._blending is not None\n    result = self._tiled_decode(x, self._tile_size, self._blending)\n    return tensor_to_image((result + 1) / 2)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel","title":"LatentDiffusionModel","text":"
LatentDiffusionModel(\n    unet: Chain,\n    lda: LatentDiffusionAutoencoder,\n    clip_text_encoder: Chain,\n    solver: Solver,\n    classifier_free_guidance: bool = True,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Module, ABC

Source code in src/refiners/foundationals/latent_diffusion/model.py
def __init__(\n    self,\n    unet: fl.Chain,\n    lda: LatentDiffusionAutoencoder,\n    clip_text_encoder: fl.Chain,\n    solver: Solver,\n    classifier_free_guidance: bool = True,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    super().__init__()\n    self.device: Device = device if isinstance(device, Device) else Device(device=device)\n    self.dtype = dtype\n    self.unet = unet.to(device=self.device, dtype=self.dtype)\n    self.lda = lda.to(device=self.device, dtype=self.dtype)\n    self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)\n    self.solver = solver.to(device=self.device, dtype=self.dtype)\n    self.classifier_free_guidance = classifier_free_guidance\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.init_latents","title":"init_latents","text":"
init_latents(\n    size: tuple[int, int],\n    init_image: Image | None = None,\n    noise: Tensor | None = None,\n) -> Tensor\n

Initialize the latents for the diffusion process.

Parameters:

Name Type Description Default size tuple[int, int]

The size of the latent (in pixel space).

required init_image Image | None

The image to use as initialization for the latents.

None noise Tensor | None

The noise to add to the latents.

None Source code in src/refiners/foundationals/latent_diffusion/model.py
def init_latents(\n    self,\n    size: tuple[int, int],\n    init_image: Image.Image | None = None,\n    noise: Tensor | None = None,\n) -> Tensor:\n    \"\"\"Initialize the latents for the diffusion process.\n\n    Args:\n        size: The size of the latent (in pixel space).\n        init_image: The image to use as initialization for the latents.\n        noise: The noise to add to the latents.\n    \"\"\"\n    height, width = size\n    latent_height = height // 8\n    latent_width = width // 8\n\n    if noise is None:\n        noise = LatentDiffusionModel.sample_noise(\n            size=(1, 4, latent_height, latent_width),\n            device=self.device,\n            dtype=self.dtype,\n        )\n\n    assert list(noise.shape[2:]) == [\n        latent_height,\n        latent_width,\n    ], f\"noise shape is not compatible: {noise.shape}, with size: {size}\"\n\n    if init_image is None:\n        latent = noise\n    else:\n        resized = init_image.resize(size=(width, height))  # type: ignore\n        encoded_image = self.lda.image_to_latents(resized)\n        latent = self.solver.add_noise(\n            x=encoded_image,\n            noise=noise,\n            step=self.solver.first_inference_step,\n        )\n\n    return self.solver.scale_model_input(latent, step=-1)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.sample_noise","title":"sample_noise staticmethod","text":"
sample_noise(\n    size: tuple[int, ...],\n    device: device | None = None,\n    dtype: dtype | None = None,\n    offset_noise: float | None = None,\n) -> Tensor\n

Sample noise from a normal distribution with an optional offset.

Parameters:

Name Type Description Default size tuple[int, ...]

The size of the noise tensor.

required device device | None

The device to put the noise tensor on.

None dtype dtype | None

The data type of the noise tensor.

None offset_noise float | None

The offset of the noise tensor. Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.

None Source code in src/refiners/foundationals/latent_diffusion/model.py
@staticmethod\ndef sample_noise(\n    size: tuple[int, ...],\n    device: Device | None = None,\n    dtype: DType | None = None,\n    offset_noise: float | None = None,\n) -> torch.Tensor:\n    \"\"\"Sample noise from a normal distribution with an optional offset.\n\n    Args:\n        size: The size of the noise tensor.\n        device: The device to put the noise tensor on.\n        dtype: The data type of the noise tensor.\n        offset_noise: The offset of the noise tensor.\n            Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.\n    \"\"\"\n    noise = torch.randn(size=size, device=device, dtype=dtype)\n    if offset_noise is not None:\n        noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)\n    return noise\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.set_inference_steps","title":"set_inference_steps","text":"
set_inference_steps(\n    num_steps: int, first_step: int = 0\n) -> None\n

Set the steps of the diffusion process.

Parameters:

Name Type Description Default num_steps int

The number of inference steps.

required first_step int

The first inference step, used for image-to-image diffusion. You may be used to setting a float in [0, 1] called strength instead, which is an abstraction for this. The first step is round((1 - strength) * (num_steps - 1)).

0 Source code in src/refiners/foundationals/latent_diffusion/model.py
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:\n    \"\"\"Set the steps of the diffusion process.\n\n    Args:\n        num_steps: The number of inference steps.\n        first_step: The first inference step, used for image-to-image diffusion.\n            You may be used to setting a float in `[0, 1]` called `strength` instead,\n            which is an abstraction for this. The first step is\n            `round((1 - strength) * (num_steps - 1))`.\n    \"\"\"\n    self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora","title":"ControlLora","text":"
ControlLora(\n    name: str,\n    unet: SDXLUNet,\n    scale: float = 1.0,\n    condition_channels: int = 3,\n)\n

Bases: Passthrough

ControlLora is a Half-UNet clone of the target UNet, patched with various LoRA layers, ZeroConvolution layers, and a ConditionEncoder.

Like ControlNet, it injects residual tensors into the target UNet. See https://github.com/HighCWu/control-lora-v2 for more details.

Gets context:

Type Description Float[Tensor, 'batch condition_channels width height']

The input image.

Sets context:

Type Description list[Tensor]

The residuals to be added to the target UNet's residuals. (context=\"unet\", key=\"residuals\")

Parameters:

Name Type Description Default name str

The name of the ControlLora.

required unet SDXLUNet

The target UNet.

required scale float

The scale to multiply the residuals by.

1.0 condition_channels int

The number of channels of the input condition tensor.

3 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(\n    self,\n    name: str,\n    unet: SDXLUNet,\n    scale: float = 1.0,\n    condition_channels: int = 3,\n) -> None:\n    \"\"\"Initialize the ControlLora.\n\n    Args:\n        name: The name of the ControlLora.\n        unet: The target UNet.\n        scale: The scale to multiply the residuals by.\n        condition_channels: The number of channels of the input condition tensor.\n    \"\"\"\n    self.name = name\n\n    super().__init__(\n        timestep_encoder := unet.layer(\"TimestepEncoder\", Chain).structural_copy(),\n        downblocks := unet.layer(\"DownBlocks\", Chain).structural_copy(),\n        middle_block := unet.layer(\"MiddleBlock\", Chain).structural_copy(),\n    )\n\n    # modify the context_key of the copied TimestepEncoder to avoid conflicts\n    timestep_encoder.context_key = f\"timestep_embedding_control_lora_{name}\"\n\n    # modify the context_key of each RangeAdapter2d to avoid conflicts\n    for range_adapter in self.layers(RangeAdapter2d):\n        range_adapter.context_key = f\"timestep_embedding_control_lora_{name}\"\n\n    # insert the ConditionEncoder in the first DownBlock\n    first_downblock = downblocks.layer(0, Chain)\n    out_channels = first_downblock.layer(0, Conv2d).out_channels\n    first_downblock.append(\n        Residual(\n            UseContext(f\"control_lora_{name}\", f\"condition\"),\n            ConditionEncoder(\n                in_channels=condition_channels,\n                out_channels=out_channels,\n                device=unet.device,\n                dtype=unet.dtype,\n            ),\n        )\n    )\n\n    # replace each ResidualAccumulator by a ZeroConvolution\n    for residual_accumulator in self.layers(ResidualAccumulator):\n        downblock = self.ensure_find_parent(residual_accumulator)\n\n        first_layer = downblock[0]\n        assert hasattr(first_layer, \"out_channels\"), f\"{first_layer} has no out_channels attribute\"\n\n        block_channels = first_layer.out_channels\n        assert isinstance(block_channels, int)\n\n        downblock.replace(\n            residual_accumulator,\n            ZeroConvolution(\n                scale=scale,\n                residual_index=residual_accumulator.n,\n                in_channels=block_channels,\n                out_channels=block_channels,\n                device=unet.device,\n                dtype=unet.dtype,\n            ),\n        )\n\n    # append a ZeroConvolution to middle_block\n    middle_block_channels = middle_block.layer(0, ResidualBlock).out_channels\n    middle_block.append(\n        ZeroConvolution(\n            scale=scale,\n            residual_index=len(downblocks),\n            in_channels=middle_block_channels,\n            out_channels=middle_block_channels,\n            device=unet.device,\n            dtype=unet.dtype,\n        )\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora.scale","title":"scale property writable","text":"
scale: float\n

The scale of the residuals stored in the context.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter","title":"ControlLoraAdapter","text":"
ControlLoraAdapter(\n    name: str,\n    target: SDXLUNet,\n    scale: float = 1.0,\n    condition_channels: int = 3,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: Chain, Adapter[SDXLUNet]

Adapter for ControlLora.

This adapter simply prepends a ControlLora model inside the target SDXLUNet.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(\n    self,\n    name: str,\n    target: SDXLUNet,\n    scale: float = 1.0,\n    condition_channels: int = 3,\n    weights: dict[str, Tensor] | None = None,\n) -> None:\n    with self.setup_adapter(target):\n        self.name = name\n        self._control_lora = [\n            ControlLora(\n                name=name,\n                unet=target,\n                scale=scale,\n                condition_channels=condition_channels,\n            ),\n        ]\n\n        super().__init__(target)\n\n    if weights:\n        self.load_weights(weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.control_lora","title":"control_lora property","text":"
control_lora: ControlLora\n

The ControlLora model.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.scale","title":"scale property writable","text":"
scale: float\n

The scale of the injected residuals.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_condition_encoder","title":"load_condition_encoder staticmethod","text":"
load_condition_encoder(\n    state_dict: dict[str, Tensor], control_lora: ControlLora\n)\n

Load the ConditionEncoder's layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default state_dict dict[str, Tensor]

The state_dict containing the ConditionEncoder layers to load.

required control_lora ControlLora

The ControlLora to load the ConditionEncoder layers into.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_condition_encoder(\n    state_dict: dict[str, Tensor],\n    control_lora: ControlLora,\n):\n    \"\"\"Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.\n\n    Args:\n        state_dict: The state_dict containing the ConditionEncoder layers to load.\n        control_lora: The ControlLora to load the ConditionEncoder layers into.\n    \"\"\"\n    condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)\n    condition_encoder_state_dict = {\n        key.removeprefix(\"ConditionEncoder.\"): value\n        for key, value in state_dict.items()\n        if \"ConditionEncoder\" in key\n    }\n    condition_encoder_layer.load_state_dict(condition_encoder_state_dict)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_lora_layers","title":"load_lora_layers staticmethod","text":"
load_lora_layers(\n    name: str,\n    state_dict: dict[str, Tensor],\n    control_lora: ControlLora,\n) -> None\n

Load the LoRA layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default name str

The name of the ControlLora.

required state_dict dict[str, Tensor]

The state_dict containing the LoRA layers to load.

required control_lora ControlLora

The ControlLora to load the LoRA layers into.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_lora_layers(\n    name: str,\n    state_dict: dict[str, Tensor],\n    control_lora: ControlLora,\n) -> None:\n    \"\"\"Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.\n\n    Args:\n        name: The name of the ControlLora.\n        state_dict: The state_dict containing the LoRA layers to load.\n        control_lora: The ControlLora to load the LoRA layers into.\n    \"\"\"\n    # filter the LoraAdapters from the state_dict\n    lora_weights = {\n        key.removeprefix(\"ControlLora.\"): value for key, value in state_dict.items() if \"ControlLora\" in key\n    }\n    lora_weights = {f\"{key}.weight\": value for key, value in lora_weights.items()}\n\n    # move the tensors to the device and dtype of the ControlLora\n    lora_weights = {\n        key: value.to(\n            dtype=control_lora.dtype,\n            device=control_lora.device,\n        )\n        for key, value in lora_weights.items()\n    }\n\n    # load every LoRA layers from the filtered state_dict\n    loras = Lora.from_dict(name, state_dict=lora_weights)\n\n    # attach the LoRA layers to the ControlLora\n    adapters: list[LoraAdapter] = []\n    for key, lora in loras.items():\n        target = control_lora.layer(key.split(\".\"), WeightedModule)\n        assert lora.is_compatible(target)\n        adapter = LoraAdapter(target, lora)\n        adapters.append(adapter)\n\n    for adapter in adapters:\n        adapter.inject(control_lora)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_weights","title":"load_weights","text":"
load_weights(state_dict: dict[str, Tensor]) -> None\n

Load the weights from the state_dict into the ControlLora.

Parameters:

Name Type Description Default state_dict dict[str, Tensor]

The state_dict containing the weights to load.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def load_weights(\n    self,\n    state_dict: dict[str, Tensor],\n) -> None:\n    \"\"\"Load the weights from the state_dict into the `ControlLora`.\n\n    Args:\n        state_dict: The state_dict containing the weights to load.\n    \"\"\"\n    ControlLoraAdapter.load_lora_layers(self.name, state_dict, self.control_lora)\n    ControlLoraAdapter.load_zero_convolution_layers(state_dict, self.control_lora)\n    ControlLoraAdapter.load_condition_encoder(state_dict, self.control_lora)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_zero_convolution_layers","title":"load_zero_convolution_layers staticmethod","text":"
load_zero_convolution_layers(\n    state_dict: dict[str, Tensor], control_lora: ControlLora\n)\n

Load the ZeroConvolution layers from the state_dict into the ControlLora.

Parameters:

Name Type Description Default state_dict dict[str, Tensor]

The state_dict containing the ZeroConvolution layers to load.

required control_lora ControlLora

The ControlLora to load the ZeroConvolution layers into.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_zero_convolution_layers(\n    state_dict: dict[str, Tensor],\n    control_lora: ControlLora,\n):\n    \"\"\"Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.\n\n    Args:\n        state_dict: The state_dict containing the ZeroConvolution layers to load.\n        control_lora: The ControlLora to load the ZeroConvolution layers into.\n    \"\"\"\n    zero_convolution_layers = list(control_lora.layers(ZeroConvolution))\n    for i, zero_convolution_layer in enumerate(zero_convolution_layers):\n        zero_convolution_state_dict = {\n            key.removeprefix(f\"ZeroConvolution_{i+1:02d}.\"): value\n            for key, value in state_dict.items()\n            if f\"ZeroConvolution_{i+1:02d}\" in key\n        }\n        zero_convolution_layer.load_state_dict(zero_convolution_state_dict)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLAutoencoder","title":"SDXLAutoencoder","text":"
SDXLAutoencoder(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: LatentDiffusionAutoencoder

Stable Diffusion XL autoencoder model.

Attributes:

Name Type Description encoder_scale float

The encoder scale to use.

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n    self,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initializes the model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        Encoder(device=device, dtype=dtype),\n        Decoder(device=device, dtype=dtype),\n    )\n    self._tile_size = None\n    self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLIPAdapter","title":"SDXLIPAdapter","text":"
SDXLIPAdapter(\n    target: SDXLUNet,\n    clip_image_encoder: CLIPImageEncoderH | None = None,\n    image_proj: (\n        ImageProjection | PerceiverResampler | None\n    ) = None,\n    scale: float = 1.0,\n    fine_grained: bool = False,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: IPAdapter[SDXLUNet]

Image Prompt adapter for the Stable Diffusion XL U-Net model.

Parameters:

Name Type Description Default target SDXLUNet

The SDXLUNet model to adapt.

required clip_image_encoder CLIPImageEncoderH | None

The CLIP image encoder to use.

None image_proj ImageProjection | PerceiverResampler | None

The image projection to use.

None scale float

The scale to use for the image prompt.

1.0 fine_grained bool

Whether to use fine-grained image prompt.

False weights dict[str, Tensor] | None

The weights of the IPAdapter.

None Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py
def __init__(\n    self,\n    target: SDXLUNet,\n    clip_image_encoder: CLIPImageEncoderH | None = None,\n    image_proj: ImageProjection | PerceiverResampler | None = None,\n    scale: float = 1.0,\n    fine_grained: bool = False,\n    weights: dict[str, Tensor] | None = None,\n) -> None:\n    \"\"\"Initialize the adapter.\n\n    Args:\n        target: The SDXLUNet model to adapt.\n        clip_image_encoder: The CLIP image encoder to use.\n        image_proj: The image projection to use.\n        scale: The scale to use for the image prompt.\n        fine_grained: Whether to use fine-grained image prompt.\n        weights: The weights of the IPAdapter.\n    \"\"\"\n    clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)\n\n    if image_proj is None:\n        cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)\n        image_proj = (\n            ImageProjection(\n                clip_image_embedding_dim=clip_image_encoder.output_dim,\n                clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,\n                device=target.device,\n                dtype=target.dtype,\n            )\n            if not fine_grained\n            else PerceiverResampler(\n                latents_dim=1280,  # not `cross_attn_2d.context_embedding_dim` in this case\n                num_attention_layers=4,\n                num_attention_heads=20,\n                head_dim=64,\n                num_tokens=16,\n                input_dim=clip_image_encoder.embedding_dim,  # = dim before final projection\n                output_dim=cross_attn_2d.context_embedding_dim,\n                device=target.device,\n                dtype=target.dtype,\n            )\n        )\n    elif fine_grained:\n        assert isinstance(image_proj, PerceiverResampler)\n\n    super().__init__(\n        target=target,\n        clip_image_encoder=clip_image_encoder,\n        image_proj=image_proj,\n        scale=scale,\n        fine_grained=fine_grained,\n        weights=weights,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLLcmAdapter","title":"SDXLLcmAdapter","text":"
SDXLLcmAdapter(\n    target: SDXLUNet,\n    condition_scale_embedding_dim: int = 256,\n    condition_scale: float = 7.5,\n)\n

Bases: Chain, Adapter[SDXLUNet]

Note that LCM must be used without CFG. You can disable CFG on SD by setting the classifier_free_guidance attribute to False.

Parameters:

Name Type Description Default target SDXLUNet

A SDXL UNet.

required condition_scale_embedding_dim int

LCM uses a condition scale embedding, this is its dimension.

256 condition_scale float

Because of the embedding, the condition scale must be passed to this adapter instead of SD. The condition scale passed to SD will be ignored.

7.5 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py
def __init__(\n    self,\n    target: SDXLUNet,\n    condition_scale_embedding_dim: int = 256,\n    condition_scale: float = 7.5,\n) -> None:\n    \"\"\"Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]\n    for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].\n\n    Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the\n    `classifier_free_guidance` attribute to `False`.\n\n    Args:\n        target: A SDXL UNet.\n        condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.\n        condition_scale: Because of the embedding, the condition scale must be passed to this adapter\n            instead of SD. The condition scale passed to SD will be ignored.\n    \"\"\"\n    assert condition_scale_embedding_dim % 2 == 0\n    self.condition_scale_embedding_dim = condition_scale_embedding_dim\n    self.condition_scale = condition_scale\n    with self.setup_adapter(target):\n        super().__init__(target)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet","title":"SDXLUNet","text":"
SDXLUNet(\n    in_channels: int,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Stable Diffusion XL U-Net.

See [arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis for more details.

Parameters:

Name Type Description Default in_channels int

Number of input channels.

required device device | str | None

Device to use for computation.

None dtype dtype | None

Data type to use for computation.

None Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def __init__(\n    self,\n    in_channels: int,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the U-Net.\n\n    Args:\n        in_channels: Number of input channels.\n        device: Device to use for computation.\n        dtype: Data type to use for computation.\n    \"\"\"\n    self.in_channels = in_channels\n    super().__init__(\n        TimestepEncoder(device=device, dtype=dtype),\n        DownBlocks(in_channels=in_channels, device=device, dtype=dtype),\n        MiddleBlock(device=device, dtype=dtype),\n        fl.Residual(fl.UseContext(context=\"unet\", key=\"residuals\").compose(lambda x: x[-1])),\n        UpBlocks(device=device, dtype=dtype),\n        OutputBlock(device=device, dtype=dtype),\n    )\n    for residual_block in self.layers(ResidualBlock):\n        chain = residual_block.layer(\"Chain\", fl.Chain)\n        RangeAdapter2d(\n            target=chain.layer(\"Conv2d_1\", fl.Conv2d),\n            channels=residual_block.out_channels,\n            embedding_dim=1280,\n            context_key=\"timestep_embedding\",\n            device=device,\n            dtype=dtype,\n        ).inject(chain)\n    for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):\n        block.append(module=ResidualAccumulator(n=n))\n    for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):\n        block.insert(index=0, module=ResidualConcatenator(n=-n - 2))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_clip_text_embedding","title":"set_clip_text_embedding","text":"
set_clip_text_embedding(\n    clip_text_embedding: Tensor,\n) -> None\n

Set the clip text embedding context.

Note

This context is required by the SDXLCrossAttention blocks.

Parameters:

Name Type Description Default clip_text_embedding Tensor

The CLIP text embedding tensor.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:\n    \"\"\"Set the clip text embedding context.\n\n    Note:\n        This context is required by the `SDXLCrossAttention` blocks.\n\n    Args:\n        clip_text_embedding: The CLIP text embedding tensor.\n    \"\"\"\n    self.set_context(context=\"cross_attention_block\", value={\"clip_text_embedding\": clip_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_pooled_text_embedding","title":"set_pooled_text_embedding","text":"
set_pooled_text_embedding(\n    pooled_text_embedding: Tensor,\n) -> None\n

Set the pooled text embedding context.

Note

This is required by TextTimeEmbedding.

Parameters:

Name Type Description Default pooled_text_embedding Tensor

The pooled text embedding tensor.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:\n    \"\"\"Set the pooled text embedding context.\n\n    Note:\n        This is required by `TextTimeEmbedding`.\n\n    Args:\n        pooled_text_embedding: The pooled text embedding tensor.\n    \"\"\"\n    self.set_context(context=\"diffusion\", value={\"pooled_text_embedding\": pooled_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_time_ids","title":"set_time_ids","text":"
set_time_ids(time_ids: Tensor) -> None\n

Set the time IDs context.

Note

This is required by TextTimeEmbedding.

Parameters:

Name Type Description Default time_ids Tensor

The time IDs tensor.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_time_ids(self, time_ids: Tensor) -> None:\n    \"\"\"Set the time IDs context.\n\n    Note:\n        This is required by `TextTimeEmbedding`.\n\n    Args:\n        time_ids: The time IDs tensor.\n    \"\"\"\n    self.set_context(context=\"diffusion\", value={\"time_ids\": time_ids})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_timestep","title":"set_timestep","text":"
set_timestep(timestep: Tensor) -> None\n

Set the timestep context.

Note

This is required by TimestepEncoder.

Parameters:

Name Type Description Default timestep Tensor

The timestep tensor.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_timestep(self, timestep: Tensor) -> None:\n    \"\"\"Set the timestep context.\n\n    Note:\n        This is required by `TimestepEncoder`.\n\n    Args:\n        timestep: The timestep tensor.\n    \"\"\"\n    self.set_context(context=\"diffusion\", value={\"timestep\": timestep})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL","title":"StableDiffusion_XL","text":"
StableDiffusion_XL(\n    unet: SDXLUNet | None = None,\n    lda: SDXLAutoencoder | None = None,\n    clip_text_encoder: DoubleTextEncoder | None = None,\n    solver: Solver | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: LatentDiffusionModel

Stable Diffusion XL model.

Attributes:

Name Type Description unet SDXLUNet

The U-Net model.

clip_text_encoder DoubleTextEncoder

The text encoder.

lda SDXLAutoencoder

The image autoencoder.

Parameters:

Name Type Description Default unet SDXLUNet | None

The SDXLUNet U-Net model to use.

None lda SDXLAutoencoder | None

The SDXLAutoencoder image autoencoder to use.

None clip_text_encoder DoubleTextEncoder | None

The DoubleTextEncoder text encoder to use.

None solver Solver | None

The solver to use.

None device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def __init__(\n    self,\n    unet: SDXLUNet | None = None,\n    lda: SDXLAutoencoder | None = None,\n    clip_text_encoder: DoubleTextEncoder | None = None,\n    solver: Solver | None = None,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    \"\"\"Initializes the model.\n\n    Args:\n        unet: The SDXLUNet U-Net model to use.\n        lda: The SDXLAutoencoder image autoencoder to use.\n        clip_text_encoder: The DoubleTextEncoder text encoder to use.\n        solver: The solver to use.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    unet = unet or SDXLUNet(in_channels=4)\n    lda = lda or SDXLAutoencoder()\n    clip_text_encoder = clip_text_encoder or DoubleTextEncoder()\n    solver = solver or DDIM(num_inference_steps=30)\n\n    super().__init__(\n        unet=unet,\n        lda=lda,\n        clip_text_encoder=clip_text_encoder,\n        solver=solver,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.default_time_ids","title":"default_time_ids property","text":"
default_time_ids: Tensor\n

The default time IDs to use.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.compute_clip_text_embedding","title":"compute_clip_text_embedding","text":"
compute_clip_text_embedding(\n    text: str | list[str],\n    negative_text: str | list[str] = \"\",\n) -> tuple[Tensor, Tensor]\n

Compute the CLIP text embedding associated with the given prompt and negative prompt.

Parameters:

Name Type Description Default text str | list[str]

The prompt to compute the CLIP text embedding of.

required negative_text str | list[str]

The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., \"\").

'' Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_clip_text_embedding(\n    self, text: str | list[str], negative_text: str | list[str] = \"\"\n) -> tuple[Tensor, Tensor]:\n    \"\"\"Compute the CLIP text embedding associated with the given prompt and negative prompt.\n\n    Args:\n        text: The prompt to compute the CLIP text embedding of.\n        negative_text: The negative prompt to compute the CLIP text embedding of.\n            If not provided, the negative prompt is assumed to be empty (i.e., `\"\"`).\n    \"\"\"\n\n    text = [text] if isinstance(text, str) else text\n\n    if not self.classifier_free_guidance:\n        return self.clip_text_encoder(text)\n\n    negative_text = [negative_text] if isinstance(negative_text, str) else negative_text\n    assert len(text) == len(negative_text), \"The length of the text list and negative_text should be the same\"\n\n    conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)\n    negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)\n\n    return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(\n        tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"
compute_self_attention_guidance(\n    x: Tensor,\n    noise: Tensor,\n    step: int,\n    *,\n    clip_text_embedding: Tensor,\n    pooled_text_embedding: Tensor,\n    time_ids: Tensor,\n    **kwargs: Tensor\n) -> Tensor\n

Compute the self-attention guidance.

Parameters:

Name Type Description Default x Tensor

The input tensor.

required noise Tensor

The noise tensor.

required step int

The step to compute the self-attention guidance at.

required clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required pooled_text_embedding Tensor

The pooled CLIP text embedding to compute the self-attention guidance with.

required time_ids Tensor

The time IDs to compute the self-attention guidance with.

required

Returns:

Type Description Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_self_attention_guidance(\n    self,\n    x: Tensor,\n    noise: Tensor,\n    step: int,\n    *,\n    clip_text_embedding: Tensor,\n    pooled_text_embedding: Tensor,\n    time_ids: Tensor,\n    **kwargs: Tensor,\n) -> Tensor:\n    \"\"\"Compute the self-attention guidance.\n\n    Args:\n        x: The input tensor.\n        noise: The noise tensor.\n        step: The step to compute the self-attention guidance at.\n        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n        pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with.\n        time_ids: The time IDs to compute the self-attention guidance with.\n\n    Returns:\n        The computed self-attention guidance.\n    \"\"\"\n    sag = self._find_sag_adapter()\n    assert sag is not None\n\n    degraded_latents = sag.compute_degraded_latents(\n        solver=self.solver,\n        latents=x,\n        noise=noise,\n        step=step,\n        classifier_free_guidance=True,\n    )\n\n    negative_text_embedding, _ = clip_text_embedding.chunk(2)\n    negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)\n    timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n    time_ids, _ = time_ids.chunk(2)\n\n    self.set_unet_context(\n        timestep=timestep,\n        clip_text_embedding=negative_text_embedding,\n        pooled_text_embedding=negative_pooled_embedding,\n        time_ids=time_ids,\n    )\n    if \"ip_adapter\" in self.unet.provider.contexts:\n        # this implementation is a bit hacky, it should be refactored in the future\n        ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n        image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n        ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n        degraded_noise = self.unet(degraded_latents)\n        ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n    else:\n        degraded_noise = self.unet(degraded_latents)\n\n    return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.has_self_attention_guidance","title":"has_self_attention_guidance","text":"
has_self_attention_guidance() -> bool\n

Whether the model has self-attention guidance or not.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def has_self_attention_guidance(self) -> bool:\n    \"\"\"Whether the model has self-attention guidance or not.\"\"\"\n    return self._find_sag_adapter() is not None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.set_self_attention_guidance","title":"set_self_attention_guidance","text":"
set_self_attention_guidance(\n    enable: bool, scale: float = 1.0\n) -> None\n

Sets the self-attention guidance.

See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.

Parameters:

Name Type Description Default enable bool

Whether to enable self-attention guidance or not.

required scale float

The scale to use.

1.0 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:\n    \"\"\"Sets the self-attention guidance.\n\n    See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)\n    for more details.\n\n    Args:\n        enable: Whether to enable self-attention guidance or not.\n        scale: The scale to use.\n    \"\"\"\n    if enable:\n        if sag := self._find_sag_adapter():\n            sag.scale = scale\n        else:\n            SDXLSAGAdapter(target=self.unet, scale=scale).inject()\n    else:\n        if sag := self._find_sag_adapter():\n            sag.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.set_unet_context","title":"set_unet_context","text":"
set_unet_context(\n    *,\n    timestep: Tensor,\n    clip_text_embedding: Tensor,\n    pooled_text_embedding: Tensor,\n    time_ids: Tensor,\n    **_: Tensor\n) -> None\n

Set the various context parameters required by the U-Net model.

Parameters:

Name Type Description Default timestep Tensor

The timestep to set.

required clip_text_embedding Tensor

The CLIP text embedding to set.

required pooled_text_embedding Tensor

The pooled CLIP text embedding to set.

required time_ids Tensor

The time IDs to set.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_unet_context(\n    self,\n    *,\n    timestep: Tensor,\n    clip_text_embedding: Tensor,\n    pooled_text_embedding: Tensor,\n    time_ids: Tensor,\n    **_: Tensor,\n) -> None:\n    \"\"\"Set the various context parameters required by the U-Net model.\n\n    Args:\n        timestep: The timestep to set.\n        clip_text_embedding: The CLIP text embedding to set.\n        pooled_text_embedding: The pooled CLIP text embedding to set.\n        time_ids: The time IDs to set.\n    \"\"\"\n    self.unet.set_timestep(timestep=timestep)\n    self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)\n    self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)\n    self.unet.set_time_ids(time_ids=time_ids)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.add_lcm_lora","title":"add_lcm_lora","text":"
add_lcm_lora(\n    manager: SDLoraManager,\n    tensors: dict[str, Tensor],\n    name: str = \"lcm\",\n    scale: float = 8.0 / 64.0,\n    check_validity: bool = True,\n) -> None\n

Add a LCM-LoRA or a LoRA with similar structure such as SDXL-Lightning to SDXLUNet.

This is a complex LoRA so SDLoraManager.add_loras() is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of auto_attach_loras.

LCM-LoRA can be used with or without CFG in SD. If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.

Parameters:

Name Type Description Default manager SDLoraManager

A SDLoraManager for SDXL.

required tensors dict[str, Tensor]

The state_dict of the LoRA.

required name str

The name of the LoRA.

'lcm' scale float

The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).

8.0 / 64.0 check_validity bool

Perform additional checks, raise an exception if they fail.

True Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py
def add_lcm_lora(\n    manager: SDLoraManager,\n    tensors: dict[str, torch.Tensor],\n    name: str = \"lcm\",\n    scale: float = 8.0 / 64.0,\n    check_validity: bool = True,\n) -> None:\n    \"\"\"Add a [LCM-LoRA](https://arxiv.org/abs/2311.05556) or a LoRA with similar structure\n    such as [SDXL-Lightning](https://arxiv.org/abs/2402.13929) to SDXLUNet.\n\n    This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]\n    is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of\n    [auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].\n\n    LCM-LoRA can be used with or without CFG in SD.\n    If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.\n\n    Args:\n        manager: A SDLoraManager for SDXL.\n        tensors: The `state_dict` of the LoRA.\n        name: The name of the LoRA.\n        scale: The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).\n        check_validity: Perform additional checks, raise an exception if they fail.\n    \"\"\"\n\n    assert isinstance(manager.target, StableDiffusion_XL)\n    unet = manager.target.unet\n\n    loras = Lora.from_dict(name, {k: v.to(unet.device, unet.dtype) for k, v in tensors.items()})\n    assert all(k.startswith(\"lora_unet_\") for k in loras.keys())\n    loras = {k: loras[k] for k in sorted(loras.keys(), key=SDLoraManager.sort_keys)}\n\n    debug_map: list[tuple[str, str]] | None = [] if check_validity else None\n\n    # Projections are in `SDXLCrossAttention` but not in `CrossAttentionBlock`.\n    loras_projs = {k: v for k, v in loras.items() if k.endswith(\"proj_in\") or k.endswith(\"proj_out\")}\n    auto_attach_loras(\n        loras_projs,\n        unet,\n        exclude=[\"CrossAttentionBlock\"],\n        include=[\"SDXLCrossAttention\"],\n        debug_map=debug_map,\n    )\n\n    manager.add_loras_to_unet(\n        {k: v for k, v in loras.items() if k not in loras_projs},\n        debug_map=debug_map,\n    )\n\n    if debug_map is not None:\n        _check_validity(debug_map)\n\n    # LoRAs are finally injected, set the scale with the manager.\n    manager.set_scale(name, scale)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight","title":"ICLight","text":"
ICLight(\n    patch_weights: dict[str, Tensor],\n    unet: SD1UNet,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: StableDiffusion_1

IC-Light is a Stable Diffusion model that can be used to relight a reference image.

At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.

Example
import torch\nfrom huggingface_hub import hf_hub_download\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.clip import CLIPTextEncoderL\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\ndtype = torch.float32\nno_grad().__enter__()\nmanual_seed(42)\n\nsd = ICLight(\n    patch_weights=load_from_safetensors(\n        path=hf_hub_download(\n            repo_id=\"refiners/ic_light.sd1_5.fc\",\n            filename=\"model.safetensors\",\n        ),\n        device=device,\n    ),\n    unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(\n        tensors_path=hf_hub_download(\n            repo_id=\"refiners/realistic_vision.v5_1.sd1_5.unet\",\n            filename=\"model.safetensors\",\n        )\n    ),\n    clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(\n        tensors_path=hf_hub_download(\n            repo_id=\"refiners/realistic_vision.v5_1.sd1_5.text_encoder\",\n            filename=\"model.safetensors\",\n        )\n    ),\n    lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(\n        tensors_path=hf_hub_download(\n            repo_id=\"refiners/realistic_vision.v5_1.sd1_5.autoencoder\",\n            filename=\"model.safetensors\",\n        )\n    ),\n    device=device,\n    dtype=dtype,\n)\n\nprompt = \"soft lighting, high-quality professional image\"\nnegative_prompt = \"lowres, bad anatomy, bad hands, cropped, worst quality\"\nclip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)\n\nimage = Image.open(\"reference-image.png\").resize((512, 512))\nsd.set_ic_light_condition(image)\n\nx = torch.randn(\n    size=(1, 4, 64, 64),\n    device=device,\n    dtype=dtype,\n)\n\nfor step in sd.steps:\n    x = sd(\n        x=x,\n        step=step,\n        clip_text_embedding=clip_text_embedding,\n        condition_scale=1.5,\n    )\npredicted_image = sd.lda.latents_to_image(x)\n\npredicted_image.save(\"ic-light-output.png\")\n
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def __init__(\n    self,\n    patch_weights: dict[str, torch.Tensor],\n    unet: SD1UNet,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: torch.device | str = \"cpu\",\n    dtype: torch.dtype = torch.float32,\n) -> None:\n    super().__init__(\n        unet=unet,\n        lda=lda,\n        clip_text_encoder=clip_text_encoder,\n        solver=solver,\n        device=device,\n        dtype=dtype,\n    )\n    self._extend_conv_in()\n    self._apply_patch(weights=patch_weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight.compute_gray_composite","title":"compute_gray_composite staticmethod","text":"
compute_gray_composite(image: Image, mask: Image) -> Image\n

Compute a grayscale composite of an image and a mask.

IC-Light will recreate the image

Parameters:

Name Type Description Default image Image

The image to composite.

required mask Image

The mask to use for the composite.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
@staticmethod\ndef compute_gray_composite(\n    image: Image.Image,\n    mask: Image.Image,\n) -> Image.Image:\n    \"\"\"Compute a grayscale composite of an image and a mask.\n\n    IC-Light will recreate the image\n\n    Args:\n        image: The image to composite.\n        mask: The mask to use for the composite.\n    \"\"\"\n    assert mask.mode == \"L\", \"Mask must be a grayscale image\"\n    assert image.size == mask.size, \"Image and mask must have the same size\"\n    background = Image.new(\"RGB\", image.size, (127, 127, 127))\n    return Image.composite(image, background, mask)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight.set_ic_light_condition","title":"set_ic_light_condition","text":"
set_ic_light_condition(\n    image: Image, mask: Image | None = None\n) -> None\n

Set the IC light condition.

Parameters:

Name Type Description Default image Image

The reference image.

required mask Image | None

The mask to use for the reference image.

None

If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def set_ic_light_condition(\n    self,\n    image: Image.Image,\n    mask: Image.Image | None = None,\n) -> None:\n    \"\"\"Set the IC light condition.\n\n    Args:\n        image: The reference image.\n        mask: The mask to use for the reference image.\n\n    If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,\n    the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.\n    \"\"\"\n    if mask is not None:\n        image = self.compute_gray_composite(image=image, mask=mask)\n    latents = self.lda.image_to_latents(image)\n    self._ic_light_condition = latents\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1Autoencoder","title":"SD1Autoencoder","text":"
SD1Autoencoder(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: LatentDiffusionAutoencoder

Stable Diffusion 1.5 autoencoder model.

Attributes:

Name Type Description encoder_scale float

The encoder scale to use.

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to use.

None dtype dtype | None

The PyTorch data type to use.

None Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n    self,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initializes the model.\n\n    Args:\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    super().__init__(\n        Encoder(device=device, dtype=dtype),\n        Decoder(device=device, dtype=dtype),\n    )\n    self._tile_size = None\n    self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1ELLAAdapter","title":"SD1ELLAAdapter","text":"
SD1ELLAAdapter(\n    target: SD1UNet,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: ELLAAdapter[SD1UNet]

ELLA adapter for Stable Diffusion 1.5.

Parameters:

Name Type Description Default target SD1UNet

The target model to adapt.

required weights dict[str, Tensor] | None

The weights of the ELLA adapter (see scripts/conversion/convert_ella_adapter.py).

None Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py
def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None:\n    \"\"\"Initialize the adapter.\n\n    Args:\n        target: The target model to adapt.\n        weights: The weights of the ELLA adapter (see `scripts/conversion/convert_ella_adapter.py`).\n    \"\"\"\n    latents_encoder = ELLA(\n        time_channel=320,\n        timestep_embedding_dim=768,\n        width=768,\n        num_layers=6,\n        num_heads=8,\n        num_latents=64,\n        input_dim=2048,\n        device=target.device,\n        dtype=target.dtype,\n    )\n    super().__init__(target=target, latents_encoder=latents_encoder, weights=weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet","title":"SD1UNet","text":"
SD1UNet(\n    in_channels: int,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Chain

Stable Diffusion 1.5 U-Net.

See [arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models for more details.

Parameters:

Name Type Description Default in_channels int

The number of input channels.

required device device | str | None

The PyTorch device to use for computation.

None dtype dtype | None

The PyTorch dtype to use for computation.

None Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def __init__(\n    self,\n    in_channels: int,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    \"\"\"Initialize the U-Net.\n\n    Args:\n        in_channels: The number of input channels.\n        device: The PyTorch device to use for computation.\n        dtype: The PyTorch dtype to use for computation.\n    \"\"\"\n    self.in_channels = in_channels\n    super().__init__(\n        TimestepEncoder(device=device, dtype=dtype),\n        DownBlocks(in_channels=in_channels, device=device, dtype=dtype),\n        fl.Sum(\n            fl.UseContext(context=\"unet\", key=\"residuals\").compose(lambda x: x[-1]),\n            MiddleBlock(device=device, dtype=dtype),\n        ),\n        UpBlocks(device=device, dtype=dtype),\n        fl.Chain(\n            fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),\n            fl.SiLU(),\n            fl.Conv2d(\n                in_channels=320,\n                out_channels=4,\n                kernel_size=3,\n                stride=1,\n                padding=1,\n                device=device,\n                dtype=dtype,\n            ),\n        ),\n    )\n    for residual_block in self.layers(ResidualBlock):\n        chain = residual_block.layer(\"Chain\", fl.Chain)\n        RangeAdapter2d(\n            target=chain.layer(\"Conv2d_1\", fl.Conv2d),\n            channels=residual_block.out_channels,\n            embedding_dim=1280,\n            context_key=\"timestep_embedding\",\n            device=device,\n            dtype=dtype,\n        ).inject(chain)\n    for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):\n        block.append(ResidualAccumulator(n))\n    for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):\n        block.insert(0, ResidualConcatenator(-n - 2))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet.set_clip_text_embedding","title":"set_clip_text_embedding","text":"
set_clip_text_embedding(\n    clip_text_embedding: Tensor,\n) -> None\n

Set the CLIP text embedding.

Note

This context is required by the CLIPLCrossAttention blocks.

Parameters:

Name Type Description Default clip_text_embedding Tensor

The CLIP text embedding.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:\n    \"\"\"Set the CLIP text embedding.\n\n    Note:\n        This context is required by the `CLIPLCrossAttention` blocks.\n\n    Args:\n        clip_text_embedding: The CLIP text embedding.\n    \"\"\"\n    self.set_context(\"cross_attention_block\", {\"clip_text_embedding\": clip_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet.set_timestep","title":"set_timestep","text":"
set_timestep(timestep: Tensor) -> None\n

Set the timestep.

Note

This context is required by TimestepEncoder.

Parameters:

Name Type Description Default timestep Tensor

The timestep.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_timestep(self, timestep: Tensor) -> None:\n    \"\"\"Set the timestep.\n\n    Note:\n        This context is required by `TimestepEncoder`.\n\n    Args:\n        timestep: The timestep.\n    \"\"\"\n    self.set_context(\"diffusion\", {\"timestep\": timestep})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1","title":"StableDiffusion_1","text":"
StableDiffusion_1(\n    unet: SD1UNet | None = None,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: LatentDiffusionModel

Stable Diffusion 1.5 model.

Attributes:

Name Type Description unet SD1UNet

The U-Net model.

clip_text_encoder CLIPTextEncoderL

The text encoder.

lda SD1Autoencoder

The image autoencoder.

Example:

import torch\n\nfrom refiners.fluxion.utils import manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1\n\n# Load SD\nsd15 = StableDiffusion_1(device=\"cuda\", dtype=torch.float16)\n\nsd15.clip_text_encoder.load_from_safetensors(\"sd1_5.text_encoder.safetensors\")\nsd15.unet.load_from_safetensors(\"sd1_5.unet.safetensors\")\nsd15.lda.load_from_safetensors(\"sd1_5.autoencoder.safetensors\")\n\n# Hyperparameters\nprompt = \"a cute cat, best quality, high quality\"\nnegative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\nseed = 42\n\nsd15.set_inference_steps(50)\n\nwith no_grad():  # Disable gradient calculation for memory-efficient inference\n    clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)\n    manual_seed(seed)\n\n    x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)\n\n    # Diffusion process\n    for step in sd15.steps:\n        x = sd15(x, step=step, clip_text_embedding=clip_text_embedding)\n\n    predicted_image = sd15.lda.decode_latents(x)\n    predicted_image.save(\"output.png\")\n

Parameters:

Name Type Description Default unet SD1UNet | None

The SD1UNet U-Net model to use.

None lda SD1Autoencoder | None

The SD1Autoencoder image autoencoder to use.

None clip_text_encoder CLIPTextEncoderL | None

The CLIPTextEncoderL text encoder to use.

None solver Solver | None

The solver to use.

None device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(\n    self,\n    unet: SD1UNet | None = None,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    \"\"\"Initializes the model.\n\n    Args:\n        unet: The SD1UNet U-Net model to use.\n        lda: The SD1Autoencoder image autoencoder to use.\n        clip_text_encoder: The CLIPTextEncoderL text encoder to use.\n        solver: The solver to use.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    unet = unet or SD1UNet(in_channels=4)\n    lda = lda or SD1Autoencoder()\n    clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()\n    solver = solver or DPMSolver(num_inference_steps=30)\n\n    super().__init__(\n        unet=unet,\n        lda=lda,\n        clip_text_encoder=clip_text_encoder,\n        solver=solver,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.compute_clip_text_embedding","title":"compute_clip_text_embedding","text":"
compute_clip_text_embedding(\n    text: str | list[str],\n    negative_text: str | list[str] = \"\",\n) -> Tensor\n

Compute the CLIP text embedding associated with the given prompt and negative prompt.

Parameters:

Name Type Description Default text str | list[str]

The prompt to compute the CLIP text embedding of.

required negative_text str | list[str]

The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., \"\").

'' Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = \"\") -> Tensor:\n    \"\"\"Compute the CLIP text embedding associated with the given prompt and negative prompt.\n\n    Args:\n        text: The prompt to compute the CLIP text embedding of.\n        negative_text: The negative prompt to compute the CLIP text embedding of.\n            If not provided, the negative prompt is assumed to be empty (i.e., `\"\"`).\n    \"\"\"\n    text = [text] if isinstance(text, str) else text\n\n    if not self.classifier_free_guidance:\n        return self.clip_text_encoder(text)\n\n    negative_text = [negative_text] if isinstance(negative_text, str) else negative_text\n    assert len(text) == len(negative_text), \"The length of the text list and negative_text should be the same\"\n\n    conditional_embedding = self.clip_text_encoder(text)\n    negative_embedding = self.clip_text_encoder(negative_text)\n\n    return torch.cat((negative_embedding, conditional_embedding))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"
compute_self_attention_guidance(\n    x: Tensor,\n    noise: Tensor,\n    step: int,\n    *,\n    clip_text_embedding: Tensor,\n    **kwargs: Tensor\n) -> Tensor\n

Compute the self-attention guidance.

Parameters:

Name Type Description Default x Tensor

The input tensor.

required noise Tensor

The noise tensor.

required step int

The step to compute the self-attention guidance at.

required clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required

Returns:

Type Description Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(\n    self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor\n) -> Tensor:\n    \"\"\"Compute the self-attention guidance.\n\n    Args:\n        x: The input tensor.\n        noise: The noise tensor.\n        step: The step to compute the self-attention guidance at.\n        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n\n    Returns:\n        The computed self-attention guidance.\n    \"\"\"\n    sag = self._find_sag_adapter()\n    assert sag is not None\n\n    degraded_latents = sag.compute_degraded_latents(\n        solver=self.solver,\n        latents=x,\n        noise=noise,\n        step=step,\n        classifier_free_guidance=True,\n    )\n\n    timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n    negative_embedding, _ = clip_text_embedding.chunk(2)\n    self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)\n    if \"ip_adapter\" in self.unet.provider.contexts:\n        # this implementation is a bit hacky, it should be refactored in the future\n        ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n        image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n        ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n        degraded_noise = self.unet(degraded_latents)\n        ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n    else:\n        degraded_noise = self.unet(degraded_latents)\n\n    return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.has_self_attention_guidance","title":"has_self_attention_guidance","text":"
has_self_attention_guidance() -> bool\n

Whether the model has self-attention guidance or not.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def has_self_attention_guidance(self) -> bool:\n    \"\"\"Whether the model has self-attention guidance or not.\"\"\"\n    return self._find_sag_adapter() is not None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.set_self_attention_guidance","title":"set_self_attention_guidance","text":"
set_self_attention_guidance(\n    enable: bool, scale: float = 1.0\n) -> None\n

Set whether to enable self-attention guidance.

See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.

Parameters:

Name Type Description Default enable bool

Whether to enable self-attention guidance.

required scale float

The scale to use.

1.0 Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:\n    \"\"\"Set whether to enable self-attention guidance.\n\n    See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)\n    for more details.\n\n    Args:\n        enable: Whether to enable self-attention guidance.\n        scale: The scale to use.\n    \"\"\"\n    if enable:\n        if sag := self._find_sag_adapter():\n            sag.scale = scale\n        else:\n            SD1SAGAdapter(target=self.unet, scale=scale).inject()\n    else:\n        if sag := self._find_sag_adapter():\n            sag.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.set_unet_context","title":"set_unet_context","text":"
set_unet_context(\n    *,\n    timestep: Tensor,\n    clip_text_embedding: Tensor,\n    **_: Tensor\n) -> None\n

Set the various context parameters required by the U-Net model.

Parameters:

Name Type Description Default timestep Tensor

The timestep tensor to use.

required clip_text_embedding Tensor

The CLIP text embedding tensor to use.

required Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:\n    \"\"\"Set the various context parameters required by the U-Net model.\n\n    Args:\n        timestep: The timestep tensor to use.\n        clip_text_embedding: The CLIP text embedding tensor to use.\n    \"\"\"\n    self.unet.set_timestep(timestep=timestep)\n    self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting","title":"StableDiffusion_1_Inpainting","text":"
StableDiffusion_1_Inpainting(\n    unet: SD1UNet | None = None,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: StableDiffusion_1

Stable Diffusion 1.5 inpainting model.

Attributes:

Name Type Description unet

The U-Net model.

clip_text_encoder

The text encoder.

lda

The image autoencoder.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(\n    self,\n    unet: SD1UNet | None = None,\n    lda: SD1Autoencoder | None = None,\n    clip_text_encoder: CLIPTextEncoderL | None = None,\n    solver: Solver | None = None,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    self.mask_latents: Tensor | None = None\n    self.target_image_latents: Tensor | None = None\n    unet = unet or SD1UNet(in_channels=9)\n    super().__init__(\n        unet=unet,\n        lda=lda,\n        clip_text_encoder=clip_text_encoder,\n        solver=solver,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"
compute_self_attention_guidance(\n    x: Tensor,\n    noise: Tensor,\n    step: int,\n    *,\n    clip_text_embedding: Tensor,\n    **kwargs: Tensor\n) -> Tensor\n

Compute the self-attention guidance.

Parameters:

Name Type Description Default x Tensor

The input tensor.

required noise Tensor

The noise tensor.

required step int

The step to compute the self-attention guidance at.

required clip_text_embedding Tensor

The CLIP text embedding to compute the self-attention guidance with.

required

Returns:

Type Description Tensor

The computed self-attention guidance.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(\n    self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor\n) -> Tensor:\n    \"\"\"Compute the self-attention guidance.\n\n    Args:\n        x: The input tensor.\n        noise: The noise tensor.\n        step: The step to compute the self-attention guidance at.\n        clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n\n    Returns:\n        The computed self-attention guidance.\n    \"\"\"\n    sag = self._find_sag_adapter()\n    assert sag is not None\n    assert self.mask_latents is not None\n    assert self.target_image_latents is not None\n\n    degraded_latents = sag.compute_degraded_latents(\n        solver=self.solver,\n        latents=x,\n        noise=noise,\n        step=step,\n        classifier_free_guidance=True,\n    )\n    x = torch.cat(\n        tensors=(degraded_latents, self.mask_latents, self.target_image_latents),\n        dim=1,\n    )\n\n    timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n    negative_embedding, _ = clip_text_embedding.chunk(2)\n    self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)\n\n    if \"ip_adapter\" in self.unet.provider.contexts:\n        # this implementation is a bit hacky, it should be refactored in the future\n        ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n        image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n        ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n        degraded_noise = self.unet(x)\n        ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n    else:\n        degraded_noise = self.unet(x)\n\n    return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting.set_inpainting_conditions","title":"set_inpainting_conditions","text":"
set_inpainting_conditions(\n    target_image: Image,\n    mask: Image,\n    latents_size: tuple[int, int] = (64, 64),\n) -> tuple[Tensor, Tensor]\n

Set the inpainting conditions.

Parameters:

Name Type Description Default target_image Image

The target image to inpaint.

required mask Image

The mask to use for inpainting.

required latents_size tuple[int, int]

The size of the latents to use.

(64, 64)

Returns:

Type Description tuple[Tensor, Tensor]

The mask latents and the target image latents.

Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_inpainting_conditions(\n    self,\n    target_image: Image.Image,\n    mask: Image.Image,\n    latents_size: tuple[int, int] = (64, 64),\n) -> tuple[Tensor, Tensor]:\n    \"\"\"Set the inpainting conditions.\n\n    Args:\n        target_image: The target image to inpaint.\n        mask: The mask to use for inpainting.\n        latents_size: The size of the latents to use.\n\n    Returns:\n        The mask latents and the target image latents.\n    \"\"\"\n    target_image = target_image.convert(mode=\"RGB\")\n    mask = mask.convert(mode=\"L\")\n\n    mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)\n    mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)\n    self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))\n\n    init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1\n    masked_init_image = init_image_tensor * (1 - mask_tensor)\n    self.target_image_latents = self.lda.encode(x=masked_init_image)\n\n    return self.mask_latents, self.target_image_latents\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DDIM","title":"DDIM","text":"
DDIM(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Solver

Denoising Diffusion Implicit Model (DDIM) solver.

See [arXiv:2010.02502] Denoising Diffusion Implicit Models for more details.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/solvers/ddim.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: Device | str = \"cpu\",\n    dtype: Dtype = torch.float32,\n) -> None:\n    \"\"\"Initializes a new DDIM solver.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n        raise NotImplementedError\n    if params and params.sde_variance != 0.0:\n        raise NotImplementedError(\"DDIM does not support sde_variance != 0.0 yet\")\n\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        params=params,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DDPM","title":"DDPM","text":"
DDPM(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: device | str = \"cpu\",\n)\n

Bases: Solver

Denoising Diffusion Probabilistic Model (DDPM) solver.

Warning

Only used for training Latent Diffusion models. Cannot be called.

See [arXiv:2006.11239] Denoising Diffusion Probabilistic Models for more details.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None device device | str

The PyTorch device to use.

'cpu' Source code in src/refiners/foundationals/latent_diffusion/solvers/ddpm.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: Device | str = \"cpu\",\n) -> None:\n    \"\"\"Initializes a new DDPM solver.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        device: The PyTorch device to use.\n    \"\"\"\n\n    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n        raise NotImplementedError\n\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        params=params,\n        device=device,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver","title":"DPMSolver","text":"
DPMSolver(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    last_step_first_order: bool = False,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Solver

Diffusion probabilistic models (DPMs) solver.

See [arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models for more details.

Note

Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts when used with SDXL and few steps. This parameter is a way to mitigate that effect by using a first-order (Euler) update instead of a second-order update for the last step of the diffusion.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None last_step_first_order bool

Use a first-order update for the last step.

False device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    last_step_first_order: bool = False,\n    device: torch.device | str = \"cpu\",\n    dtype: torch.dtype = torch.float32,\n) -> None:\n    \"\"\"Initializes a new DPM solver.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        last_step_first_order: Use a first-order update for the last step.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n        raise NotImplementedError\n    if params and params.sde_variance not in (0.0, 1.0):\n        raise NotImplementedError(\"DPMSolver only supports sde_variance=0.0 or 1.0\")\n\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        params=params,\n        device=device,\n        dtype=dtype,\n    )\n    self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)\n    self.last_step_first_order = last_step_first_order\n    sigmas = self.noise_std / self.cumulative_scale_factors\n    self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)\n    sigma_min = sigmas[0:1]  # corresponds to `final_sigmas_type=\"sigma_min\" in diffusers`\n    self.sigmas = torch.cat([self.sigmas, sigma_min])\n    self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(\n        self.sigmas\n    )\n    self.timesteps = self._timesteps_from_sigmas(sigmas)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.dpm_solver_first_order_update","title":"dpm_solver_first_order_update","text":"
dpm_solver_first_order_update(\n    x: Tensor,\n    noise: Tensor,\n    step: int,\n    sde_noise: Tensor | None = None,\n) -> Tensor\n

Applies a first-order backward Euler update to the input data x.

Parameters:

Name Type Description Default x Tensor

The input data.

required noise Tensor

The predicted noise.

required step int

The current step.

required

Returns:

Type Description Tensor

The denoised version of the input data x.

Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def dpm_solver_first_order_update(\n    self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None\n) -> torch.Tensor:\n    \"\"\"Applies a first-order backward Euler update to the input data `x`.\n\n    Args:\n        x: The input data.\n        noise: The predicted noise.\n        step: The current step.\n\n    Returns:\n        The denoised version of the input data `x`.\n    \"\"\"\n    current_ratio = self.signal_to_noise_ratios[step]\n    next_ratio = self.signal_to_noise_ratios[step + 1]\n\n    next_scale_factor = self.cumulative_scale_factors[step + 1]\n\n    next_noise_std = self.noise_std[step + 1]\n    current_noise_std = self.noise_std[step]\n\n    ratio_delta = current_ratio - next_ratio\n\n    if sde_noise is None:\n        return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise\n\n    factor = 1.0 - torch.exp(2.0 * ratio_delta)\n    return (\n        (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x\n        + next_scale_factor * factor * noise\n        + next_noise_std * safe_sqrt(factor) * sde_noise\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.multistep_dpm_solver_second_order_update","title":"multistep_dpm_solver_second_order_update","text":"
multistep_dpm_solver_second_order_update(\n    x: Tensor, step: int, sde_noise: Tensor | None = None\n) -> Tensor\n

Applies a second-order backward Euler update to the input data x.

Parameters:

Name Type Description Default x Tensor

The input data.

required step int

The current step.

required

Returns:

Type Description Tensor

The denoised version of the input data x.

Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def multistep_dpm_solver_second_order_update(\n    self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None\n) -> torch.Tensor:\n    \"\"\"Applies a second-order backward Euler update to the input data `x`.\n\n    Args:\n        x: The input data.\n        step: The current step.\n\n    Returns:\n        The denoised version of the input data `x`.\n    \"\"\"\n    current_data_estimation = self.estimated_data[-1]\n    previous_data_estimation = self.estimated_data[-2]\n\n    next_ratio = self.signal_to_noise_ratios[step + 1]\n    current_ratio = self.signal_to_noise_ratios[step]\n    previous_ratio = self.signal_to_noise_ratios[step - 1]\n\n    next_scale_factor = self.cumulative_scale_factors[step + 1]\n    next_noise_std = self.noise_std[step + 1]\n    current_noise_std = self.noise_std[step]\n\n    estimation_delta = (current_data_estimation - previous_data_estimation) / (\n        (current_ratio - previous_ratio) / (next_ratio - current_ratio)\n    )\n    ratio_delta = current_ratio - next_ratio\n\n    if sde_noise is None:\n        factor = 1.0 - torch.exp(ratio_delta)\n        return (\n            (next_noise_std / current_noise_std) * x\n            + next_scale_factor * factor * current_data_estimation\n            + 0.5 * next_scale_factor * factor * estimation_delta\n        )\n\n    factor = 1.0 - torch.exp(2.0 * ratio_delta)\n    return (\n        (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x\n        + next_scale_factor * factor * current_data_estimation\n        + 0.5 * next_scale_factor * factor * estimation_delta\n        + next_noise_std * safe_sqrt(factor) * sde_noise\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.rebuild","title":"rebuild","text":"
rebuild(\n    num_inference_steps: int | None,\n    first_inference_step: int | None = None,\n) -> DPMSolver\n

Rebuilds the solver with new parameters.

Parameters:

Name Type Description Default num_inference_steps int | None

The number of inference steps.

required first_inference_step int | None

The first inference step.

None Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def rebuild(\n    self: \"DPMSolver\",\n    num_inference_steps: int | None,\n    first_inference_step: int | None = None,\n) -> \"DPMSolver\":\n    \"\"\"Rebuilds the solver with new parameters.\n\n    Args:\n        num_inference_steps: The number of inference steps.\n        first_inference_step: The first inference step.\n    \"\"\"\n    r = super().rebuild(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n    )\n    r.last_step_first_order = self.last_step_first_order\n    return r\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler","title":"Euler","text":"
Euler(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Solver

Euler solver.

See [arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models for more details.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/solvers/euler.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: Device | str = \"cpu\",\n    dtype: Dtype = torch.float32,\n):\n    \"\"\"Initializes a new Euler solver.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n    if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):\n        raise NotImplementedError\n    if params and params.sde_variance != 0.0:\n        raise NotImplementedError(\"Euler does not support sde_variance != 0.0 yet\")\n\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        params=params,\n        device=device,\n        dtype=dtype,\n    )\n    self.sigmas = self._generate_sigmas()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler.init_noise_sigma","title":"init_noise_sigma property","text":"
init_noise_sigma: Tensor\n

The initial noise sigma.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler.scale_model_input","title":"scale_model_input","text":"
scale_model_input(x: Tensor, step: int) -> Tensor\n

Scales the model input according to the current step.

Parameters:

Name Type Description Default x Tensor

The model input.

required step int

The current step. This method is called with step=-1 in init_latents.

required

Returns:

Type Description Tensor

The scaled model input.

Source code in src/refiners/foundationals/latent_diffusion/solvers/euler.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:\n    \"\"\"Scales the model input according to the current step.\n\n    Args:\n        x: The model input.\n        step: The current step. This method is called with `step=-1` in `init_latents`.\n\n    Returns:\n        The scaled model input.\n    \"\"\"\n\n    if step == -1:\n        return x * self.init_noise_sigma\n\n    sigma = self.sigmas[step]\n    return x / ((sigma**2 + 1) ** 0.5)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.FrankenSolver","title":"FrankenSolver","text":"
FrankenSolver(\n    get_diffusers_scheduler: Callable[[], SchedulerLike],\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n    **kwargs: Any\n)\n

Bases: Solver

Lets you use Diffusers Schedulers as Refiners Solvers.

For instance
from diffusers import EulerDiscreteScheduler\nfrom refiners.foundationals.latent_diffusion.solvers import FrankenSolver\n\nscheduler = EulerDiscreteScheduler(...)\nsolver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)\n
Source code in src/refiners/foundationals/latent_diffusion/solvers/franken.py
def __init__(\n    self,\n    get_diffusers_scheduler: Callable[[], SchedulerLike],\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n    **kwargs: Any,  # for typing, ignored\n) -> None:\n    self.get_diffusers_scheduler = get_diffusers_scheduler\n    self.diffusers_scheduler = self.get_diffusers_scheduler()\n    self.diffusers_scheduler.set_timesteps(num_inference_steps)\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.LCMSolver","title":"LCMSolver","text":"
LCMSolver(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    num_orig_steps: int = 50,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Solver

Latent Consistency Model solver.

This solver is designed for use either with a specific base model or a specific LoRA.

See [arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference for details.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None num_orig_steps int

The number of inference steps of the emulated DPM solver.

50 device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Source code in src/refiners/foundationals/latent_diffusion/solvers/lcm.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    num_orig_steps: int = 50,\n    device: torch.device | str = \"cpu\",\n    dtype: torch.dtype = torch.float32,\n):\n    \"\"\"Initializes a new LCM solver.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        num_orig_steps: The number of inference steps of the emulated DPM solver.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n    \"\"\"\n\n    assert (\n        num_orig_steps >= num_inference_steps\n    ), f\"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})\"\n\n    params = self.resolve_params(params)\n    if params.model_prediction_type != ModelPredictionType.NOISE:\n        raise NotImplementedError\n\n    self._dpm = [\n        DPMSolver(\n            num_inference_steps=num_orig_steps,\n            params=SolverParams(\n                num_train_timesteps=params.num_train_timesteps,\n                timesteps_spacing=params.timesteps_spacing,\n            ),\n            device=device,\n            dtype=dtype,\n        )\n    ]\n    super().__init__(\n        num_inference_steps=num_inference_steps,\n        first_inference_step=first_inference_step,\n        params=params,\n        device=device,\n        dtype=dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.ModelPredictionType","title":"ModelPredictionType","text":"

Bases: str, Enum

An enumeration of possible outputs of the model.

Attributes:

Name Type Description NOISE

The model predicts the noise (epsilon).

SAMPLE

The model predicts the denoised sample (x0).

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.NoiseSchedule","title":"NoiseSchedule","text":"

Bases: str, Enum

An enumeration of schedules used to sample the noise.

Attributes:

Name Type Description UNIFORM

A uniform noise schedule.

QUADRATIC

A quadratic noise schedule. Corresponds to \"Stable Diffusion\" in [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 1.

KARRAS

See [arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver","title":"Solver","text":"
Solver(\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Module, ABC

The base class for creating a diffusion model solver.

Solvers create a sequence of noise and scaling factors used in the diffusion process, which gradually transforms the original data distribution into a Gaussian one.

This process is described using several parameters such as initial and final diffusion rates, and is encapsulated into a __call__ method that applies a step of the diffusion process.

Attributes:

Name Type Description params ResolvedSolverParams

The common parameters for solvers. See SolverParams.

num_inference_steps

The number of inference steps to perform.

first_inference_step

The step to start the inference process from.

scale_factors

The scale factors used to denoise the input. These are called \"betas\" in other implementations, and 1 - scale_factors is called \"alphas\".

cumulative_scale_factors

The cumulative scale factors used to denoise the input. These are called \"alpha_t\" in other implementations.

noise_std

The standard deviation of the noise used to denoise the input. This is called \"sigma_t\" in other implementations.

signal_to_noise_ratios

The signal-to-noise ratios used to denoise the input. This is called \"lambda_t\" in other implementations.

Parameters:

Name Type Description Default num_inference_steps int

The number of inference steps to perform.

required first_inference_step int

The first inference step to perform.

0 params BaseSolverParams | None

The common parameters for solvers.

None device device | str

The PyTorch device to use for the solver's tensors.

'cpu' dtype dtype

The PyTorch data type to use for the solver's tensors.

float32 Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def __init__(\n    self,\n    num_inference_steps: int,\n    first_inference_step: int = 0,\n    params: BaseSolverParams | None = None,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    \"\"\"Initializes a new `Solver` instance.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n        params: The common parameters for solvers.\n        device: The PyTorch device to use for the solver's tensors.\n        dtype: The PyTorch data type to use for the solver's tensors.\n    \"\"\"\n    super().__init__()\n\n    self.num_inference_steps = num_inference_steps\n    self.first_inference_step = first_inference_step\n    self.params = self.resolve_params(params)\n\n    self.scale_factors = self.sample_noise_schedule()\n    self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))\n    self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))\n    self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)\n    self.timesteps = self._generate_timesteps()\n\n    self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.all_steps","title":"all_steps property","text":"
all_steps: list[int]\n

Return a list of all inference steps.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.device","title":"device property writable","text":"
device: device\n

The PyTorch device used for the solver's tensors.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.dtype","title":"dtype property writable","text":"
dtype: dtype\n

The PyTorch data type used for the solver's tensors.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.inference_steps","title":"inference_steps property","text":"
inference_steps: list[int]\n

Return a list of inference steps to perform.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.add_noise","title":"add_noise","text":"
add_noise(\n    x: Tensor, noise: Tensor, step: int | list[int]\n) -> Tensor\n

Add noise to the input tensor using the solver's parameters.

Parameters:

Name Type Description Default x Tensor

The input tensor to add noise to.

required noise Tensor

The noise tensor to add to the input tensor.

required step int | list[int]

The current step(s) of the diffusion process.

required

Returns:

Type Description Tensor

The input tensor with added noise.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def add_noise(\n    self,\n    x: Tensor,\n    noise: Tensor,\n    step: int | list[int],\n) -> Tensor:\n    \"\"\"Add noise to the input tensor using the solver's parameters.\n\n    Args:\n        x: The input tensor to add noise to.\n        noise: The noise tensor to add to the input tensor.\n        step: The current step(s) of the diffusion process.\n\n    Returns:\n        The input tensor with added noise.\n    \"\"\"\n    if isinstance(step, list):\n        assert len(x) == len(noise) == len(step), \"x, noise, and step must have the same length\"\n        return torch.stack(\n            tensors=[\n                self._add_noise(\n                    x=x[i],\n                    noise=noise[i],\n                    step=step[i],\n                )\n                for i in range(x.shape[0])\n            ],\n            dim=0,\n        )\n\n    return self._add_noise(x=x, noise=noise, step=step)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.generate_timesteps","title":"generate_timesteps staticmethod","text":"
generate_timesteps(\n    spacing: TimestepSpacing,\n    num_inference_steps: int,\n    num_train_timesteps: int = 1000,\n    offset: int = 0,\n) -> Tensor\n

Generate a tensor of timesteps according to a given spacing.

Parameters:

Name Type Description Default spacing TimestepSpacing

The spacing to use for the timesteps.

required num_inference_steps int

The number of inference steps to perform.

required num_train_timesteps int

The number of timesteps used to train the diffusion process.

1000 offset int

The offset to use for the timesteps.

0 Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
@staticmethod\ndef generate_timesteps(\n    spacing: TimestepSpacing,\n    num_inference_steps: int,\n    num_train_timesteps: int = 1000,\n    offset: int = 0,\n) -> Tensor:\n    \"\"\"Generate a tensor of timesteps according to a given spacing.\n\n    Args:\n        spacing: The spacing to use for the timesteps.\n        num_inference_steps: The number of inference steps to perform.\n        num_train_timesteps: The number of timesteps used to train the diffusion process.\n        offset: The offset to use for the timesteps.\n    \"\"\"\n    max_timestep = num_train_timesteps - 1 + offset\n    match spacing:\n        case TimestepSpacing.LINSPACE:\n            return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)\n        case TimestepSpacing.LINSPACE_ROUNDED:\n            return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)\n        case TimestepSpacing.LEADING:\n            step_ratio = num_train_timesteps // num_inference_steps\n            return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)\n        case TimestepSpacing.TRAILING:\n            step_ratio = num_train_timesteps // num_inference_steps\n            max_timestep = num_train_timesteps - 1 + offset\n            return torch.arange(max_timestep, offset, -step_ratio)\n        case TimestepSpacing.CUSTOM:\n            raise RuntimeError(\"generate_timesteps called with custom spacing\")\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.rebuild","title":"rebuild","text":"
rebuild(\n    num_inference_steps: int | None,\n    first_inference_step: int | None = None,\n) -> T\n

Rebuild the solver with new parameters.

Parameters:

Name Type Description Default num_inference_steps int | None

The number of inference steps to perform.

required first_inference_step int | None

The first inference step to perform.

None

Returns:

Type Description T

A new solver instance with the specified parameters.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:\n    \"\"\"Rebuild the solver with new parameters.\n\n    Args:\n        num_inference_steps: The number of inference steps to perform.\n        first_inference_step: The first inference step to perform.\n\n    Returns:\n        A new solver instance with the specified parameters.\n    \"\"\"\n    return self.__class__(\n        num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,\n        first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,\n        params=dataclasses.replace(self.params),\n        device=self.device,\n        dtype=self.dtype,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.remove_noise","title":"remove_noise","text":"
remove_noise(x: Tensor, noise: Tensor, step: int) -> Tensor\n

Remove noise from the input tensor using the current step of the diffusion process.

Note

See [arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15 and [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance.

Parameters:

Name Type Description Default x Tensor

The input tensor to remove noise from.

required noise Tensor

The noise tensor to remove from the input tensor.

required step int

The current step of the diffusion process.

required

Returns:

Type Description Tensor

The denoised input tensor.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:\n    \"\"\"Remove noise from the input tensor using the current step of the diffusion process.\n\n    Note:\n        See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15](https://arxiv.org/abs/2006.11239)\n        and [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939).\n\n    Args:\n        x: The input tensor to remove noise from.\n        noise: The noise tensor to remove from the input tensor.\n        step: The current step of the diffusion process.\n\n    Returns:\n        The denoised input tensor.\n    \"\"\"\n    timestep = self.timesteps[step]\n    cumulative_scale_factors = self.cumulative_scale_factors[timestep]\n    noise_stds = self.noise_std[timestep]\n    denoised_x = (x - noise_stds * noise) / cumulative_scale_factors\n    return denoised_x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.sample_noise_schedule","title":"sample_noise_schedule","text":"
sample_noise_schedule() -> Tensor\n

Sample the noise schedule.

Returns:

Type Description Tensor

A tensor representing the noise schedule.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_noise_schedule(self) -> Tensor:\n    \"\"\"Sample the noise schedule.\n\n    Returns:\n        A tensor representing the noise schedule.\n    \"\"\"\n    match self.params.noise_schedule:\n        case NoiseSchedule.UNIFORM:\n            return 1 - self.sample_power_distribution(1)\n        case NoiseSchedule.QUADRATIC:\n            return 1 - self.sample_power_distribution(2)\n        case NoiseSchedule.KARRAS:\n            return 1 - self.sample_power_distribution(7)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.sample_power_distribution","title":"sample_power_distribution","text":"
sample_power_distribution(power: float = 2) -> Tensor\n

Sample a power distribution.

Parameters:

Name Type Description Default power float

The power to use for the distribution.

2

Returns:

Type Description Tensor

A tensor representing the power distribution between the initial and final diffusion rates of the solver.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_power_distribution(self, power: float = 2, /) -> Tensor:\n    \"\"\"Sample a power distribution.\n\n    Args:\n        power: The power to use for the distribution.\n\n    Returns:\n        A tensor representing the power distribution between the initial and final diffusion rates of the solver.\n    \"\"\"\n    return (\n        torch.linspace(\n            start=self.params.initial_diffusion_rate ** (1 / power),\n            end=self.params.final_diffusion_rate ** (1 / power),\n            steps=self.params.num_train_timesteps,\n        )\n        ** power\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.scale_model_input","title":"scale_model_input","text":"
scale_model_input(x: Tensor, step: int) -> Tensor\n

Scale the model's input according to the current timestep.

Note

This method should only be overridden by solvers that need to scale the input according to the current timestep.

By default, this method does not scale the input. (scale=1)

Parameters:

Name Type Description Default x Tensor

The input tensor to scale.

required step int

The current step of the diffusion process.

required

Returns:

Type Description Tensor

The scaled input tensor.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:\n    \"\"\"Scale the model's input according to the current timestep.\n\n    Note:\n        This method should only be overridden by solvers that\n        need to scale the input according to the current timestep.\n\n        By default, this method does not scale the input.\n        (scale=1)\n\n    Args:\n        x: The input tensor to scale.\n        step: The current step of the diffusion process.\n\n    Returns:\n        The scaled input tensor.\n    \"\"\"\n    return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.to","title":"to","text":"
to(\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n) -> Solver\n

Move the solver to the specified device and data type.

Parameters:

Name Type Description Default device device | str | None

The PyTorch device to move the solver to.

None dtype dtype | None

The PyTorch data type to move the solver to.

None

Returns:

Type Description Solver

The solver instance, moved to the specified device and data type.

Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> \"Solver\":\n    \"\"\"Move the solver to the specified device and data type.\n\n    Args:\n        device: The PyTorch device to move the solver to.\n        dtype: The PyTorch data type to move the solver to.\n\n    Returns:\n        The solver instance, moved to the specified device and data type.\n    \"\"\"\n    super().to(device=device, dtype=dtype)\n    for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:\n        match name:\n            case \"timesteps\":\n                setattr(self, name, attr.to(device=device))\n            case _:\n                setattr(self, name, attr.to(device=device, dtype=dtype))\n    return self\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.SolverParams","title":"SolverParams dataclass","text":"
SolverParams(\n    *,\n    num_train_timesteps: int | None = None,\n    timesteps_spacing: TimestepSpacing | None = None,\n    timesteps_offset: int | None = None,\n    initial_diffusion_rate: float | None = None,\n    final_diffusion_rate: float | None = None,\n    noise_schedule: NoiseSchedule | None = None,\n    sigma_schedule: NoiseSchedule | None = None,\n    model_prediction_type: (\n        ModelPredictionType | None\n    ) = None,\n    sde_variance: float = 0.0\n)\n

Bases: BaseSolverParams

Common parameters for solvers.

Parameters:

Name Type Description Default num_train_timesteps int | None

The number of timesteps used to train the diffusion process.

None timesteps_spacing TimestepSpacing | None

The spacing to use for the timesteps.

None timesteps_offset int | None

The offset to use for the timesteps.

None initial_diffusion_rate float | None

The initial diffusion rate used to sample the noise schedule.

None final_diffusion_rate float | None

The final diffusion rate used to sample the noise schedule.

None noise_schedule NoiseSchedule | None

The noise schedule used to sample the noise schedule.

None model_prediction_type ModelPredictionType | None

Defines what the model predicts.

None"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.TimestepSpacing","title":"TimestepSpacing","text":"

Bases: str, Enum

An enumeration of methods to space the timesteps.

See [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 2.

Attributes:

Name Type Description LINSPACE

Sample N steps with linear interpolation, return a floating-point tensor.

LINSPACE_ROUNDED

Same as LINSPACE but return an integer tensor with rounded timesteps.

LEADING

Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM, with a mitigation for that issue.

TRAILING

Sample N+1 steps, do not include the first timestep.

CUSTOM

Use custom timespacing in solver (override _generate_timesteps, see DPM).

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager","title":"SDLoraManager","text":"
SDLoraManager(target: LatentDiffusionModel)\n

Manage LoRAs for a Stable Diffusion model.

Note

In the context of SDLoraManager, a \"LoRA\" is a set of \"LoRA layers\" that can be attached to a target model.

Parameters:

Name Type Description Default target LatentDiffusionModel

The target model to manage the LoRAs for.

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def __init__(\n    self,\n    target: LatentDiffusionModel,\n) -> None:\n    \"\"\"Initialize the LoRA manager.\n\n    Args:\n        target: The target model to manage the LoRAs for.\n    \"\"\"\n    self.target = target\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.clip_text_encoder","title":"clip_text_encoder property","text":"
clip_text_encoder: Chain\n

The Stable Diffusion's text encoder.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.lora_adapters","title":"lora_adapters property","text":"
lora_adapters: list[LoraAdapter]\n

List of all the LoraAdapters managed by the SDLoraManager.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.loras","title":"loras property","text":"
loras: list[Lora[Any]]\n

List of all the LoRA layers managed by the SDLoraManager.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.names","title":"names property","text":"
names: list[str]\n

List of all the LoRA names managed the SDLoraManager

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.scales","title":"scales property","text":"
scales: dict[str, float]\n

The scales of all the LoRAs managed by the SDLoraManager.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.unet","title":"unet property","text":"
unet: Chain\n

The Stable Diffusion's U-Net model.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras","title":"add_loras","text":"
add_loras(\n    name: str,\n    /,\n    tensors: dict[str, Tensor],\n    scale: float = 1.0,\n    unet_inclusions: list[str] | None = None,\n    unet_exclusions: list[str] | None = None,\n    unet_preprocess: dict[str, str] | None = None,\n    text_encoder_inclusions: list[str] | None = None,\n    text_encoder_exclusions: list[str] | None = None,\n) -> None\n

Load a single LoRA from a state_dict.

Warning

This method expects the keys of the state_dict to be in the commonly found formats on CivitAI's hub.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required tensors dict[str, Tensor]

The state_dict of the LoRA to load.

required scale float

The scale to use for the LoRA.

1.0 unet_inclusions list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered when patching the UNet.

None unet_exclusions list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered when patching the UNet. If this is None then it defaults to [\"TimestepEncoder\"].

None unet_preprocess dict[str, str] | None

A map between parts of state dict keys and layer names. This is used to attach some keys to specific parts of the UNet. You should leave it set to None (it has a default value), otherwise read the source code to understand how it works.

None text_encoder_inclusions list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered when patching the text encoder.

None text_encoder_exclusions list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered when patching the text encoder.

None

Raises:

Type Description AssertionError

If the Manager already has a LoRA with the same name.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras(\n    self,\n    name: str,\n    /,\n    tensors: dict[str, Tensor],\n    scale: float = 1.0,\n    unet_inclusions: list[str] | None = None,\n    unet_exclusions: list[str] | None = None,\n    unet_preprocess: dict[str, str] | None = None,\n    text_encoder_inclusions: list[str] | None = None,\n    text_encoder_exclusions: list[str] | None = None,\n) -> None:\n    \"\"\"Load a single LoRA from a `state_dict`.\n\n    Warning:\n        This method expects the keys of the `state_dict` to be in the commonly found formats on CivitAI's hub.\n\n    Args:\n        name: The name of the LoRA.\n        tensors: The `state_dict` of the LoRA to load.\n        scale: The scale to use for the LoRA.\n        unet_inclusions: A list of layer names, only layers with such a layer\n            in their ancestors will be considered when patching the UNet.\n        unet_exclusions: A list of layer names, layers with such a layer in\n            their ancestors will not be considered when patching the UNet.\n            If this is `None` then it defaults to `[\"TimestepEncoder\"]`.\n        unet_preprocess: A map between parts of state dict keys and layer names.\n            This is used to attach some keys to specific parts of the UNet.\n            You should leave it set to `None` (it has a default value),\n            otherwise read the source code to understand how it works.\n        text_encoder_inclusions: A list of layer names, only layers with such a layer\n            in their ancestors will be considered when patching the text encoder.\n        text_encoder_exclusions: A list of layer names, layers with such a layer in\n            their ancestors will not be considered when patching the text encoder.\n\n    Raises:\n        AssertionError: If the Manager already has a LoRA with the same name.\n    \"\"\"\n    assert name not in self.names, f\"LoRA {name} already exists\"\n\n    # load LoRA the state_dict\n    loras = Lora.from_dict(\n        name,\n        state_dict={\n            key: value.to(\n                device=self.target.device,\n                dtype=self.target.dtype,\n            )\n            for key, value in tensors.items()\n        },\n    )\n    # sort all the LoRA's keys using the `sort_keys` method\n    loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}\n\n    # if no key contains \"unet\" or \"text\", assume all keys are for the unet\n    if all(\"unet\" not in key and \"text\" not in key for key in loras.keys()):\n        loras = {f\"unet_{key}\": value for key, value in loras.items()}\n\n    # attach the LoRA to the target\n    self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess)\n    self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions)\n\n    # set the scale of the LoRA\n    self.set_scale(name, scale)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras_to_text_encoder","title":"add_loras_to_text_encoder","text":"
add_loras_to_text_encoder(\n    loras: dict[str, Lora[Any]],\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> None\n

Add multiple LoRAs to the text encoder. See add_loras for details about arguments.

Parameters:

Name Type Description Default loras dict[str, Lora[Any]]

The dictionary of LoRAs to add to the text encoder. (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_text_encoder(\n    self,\n    loras: dict[str, Lora[Any]],\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> None:\n    \"\"\"Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.\n\n    Args:\n        loras: The dictionary of LoRAs to add to the text encoder.\n            (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)\n    \"\"\"\n    text_encoder_loras = {key: loras[key] for key in loras.keys() if \"text\" in key}\n    auto_attach_loras(\n        text_encoder_loras,\n        self.clip_text_encoder,\n        exclude=exclude,\n        include=include,\n        debug_map=debug_map,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras_to_unet","title":"add_loras_to_unet","text":"
add_loras_to_unet(\n    loras: dict[str, Lora[Any]],\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    preprocess: dict[str, str] | None = None,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> None\n

Add multiple LoRAs to the U-Net. See add_loras for details about arguments.

Parameters:

Name Type Description Default loras dict[str, Lora[Any]]

The dictionary of LoRAs to add to the U-Net. (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_unet(\n    self,\n    loras: dict[str, Lora[Any]],\n    /,\n    include: list[str] | None = None,\n    exclude: list[str] | None = None,\n    preprocess: dict[str, str] | None = None,\n    debug_map: list[tuple[str, str]] | None = None,\n) -> None:\n    \"\"\"Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.\n\n    Args:\n        loras: The dictionary of LoRAs to add to the U-Net.\n            (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)\n    \"\"\"\n    unet_loras = {key: loras[key] for key in loras.keys() if \"unet\" in key}\n\n    if exclude is None:\n        exclude = [\"TimestepEncoder\"]\n\n    if preprocess is None:\n        preprocess = {\n            \"res\": \"ResidualBlock\",\n            \"downsample\": \"Downsample\",\n            \"upsample\": \"Upsample\",\n        }\n\n    if include is not None:\n        preprocess = {k: v for k, v in preprocess.items() if v in include}\n\n    preprocess = {k: v for k, v in preprocess.items() if v not in exclude}\n\n    loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())}\n    loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded}\n\n    for exc_k, exc_v in preprocess.items():\n        ls = {k: v for k, v in loras_excluded.items() if exc_k in k}\n        auto_attach_loras(ls, self.unet, include=[exc_v], exclude=exclude, debug_map=debug_map)\n\n    auto_attach_loras(\n        loras_remaining,\n        self.unet,\n        exclude=[*exclude, *preprocess.values()],\n        include=include,\n        debug_map=debug_map,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.get_loras_by_name","title":"get_loras_by_name","text":"
get_loras_by_name(name: str) -> list[Lora[Any]]\n

Get the LoRA layers with the given name.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:\n    \"\"\"Get the LoRA layers with the given name.\n\n    Args:\n        name: The name of the LoRA.\n    \"\"\"\n    return [lora for lora in self.loras if lora.name == name]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.get_scale","title":"get_scale","text":"
get_scale(name: str) -> float\n

Get the scale of the LoRA with the given name.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required

Returns:

Type Description float

The scale of the LoRA layers with the given name.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def get_scale(self, name: str, /) -> float:\n    \"\"\"Get the scale of the LoRA with the given name.\n\n    Args:\n        name: The name of the LoRA.\n\n    Returns:\n        The scale of the LoRA layers with the given name.\n    \"\"\"\n    loras = self.get_loras_by_name(name)\n    assert all([lora.scale == loras[0].scale for lora in loras]), \"lora scales are not all the same\"\n    return loras[0].scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.remove_all","title":"remove_all","text":"
remove_all() -> None\n

Remove all the LoRAs from the target.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
def remove_all(self) -> None:\n    \"\"\"Remove all the LoRAs from the target.\"\"\"\n    for lora_adapter in self.lora_adapters:\n        lora_adapter.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.remove_loras","title":"remove_loras","text":"
remove_loras(*names: str) -> None\n

Remove multiple LoRAs from the target.

Parameters:

Name Type Description Default names str

The names of the LoRAs to remove.

() Source code in src/refiners/foundationals/latent_diffusion/lora.py
def remove_loras(self, *names: str) -> None:\n    \"\"\"Remove multiple LoRAs from the target.\n\n    Args:\n        names: The names of the LoRAs to remove.\n    \"\"\"\n    for lora_adapter in self.lora_adapters:\n        for name in names:\n            lora_adapter.remove_lora(name)\n\n        if len(lora_adapter.loras) == 0:\n            lora_adapter.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.set_scale","title":"set_scale","text":"
set_scale(name: str, scale: float) -> None\n

Set the scale of the LoRA with the given name.

Parameters:

Name Type Description Default name str

The name of the LoRA.

required scale float

The new scale to set.

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def set_scale(self, name: str, scale: float, /) -> None:\n    \"\"\"Set the scale of the LoRA with the given name.\n\n    Args:\n        name: The name of the LoRA.\n        scale: The new scale to set.\n    \"\"\"\n    self.update_scales({name: scale})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.sort_keys","title":"sort_keys staticmethod","text":"
sort_keys(key: str) -> tuple[str, int]\n

Compute the score of a key, relatively to its suffix.

When used by sorted, the keys will only be sorted \"at the suffix level\". The idea is that sometimes closely related keys in the state dict are not in the same order as the one we expect, for instance q -> k -> v or in -> out. This attempts to fix that issue, not cases where distant layers are called in a different order.

Parameters:

Name Type Description Default key str

The key to sort.

required

Returns:

Type Description str

The padded prefix of the key.

int

A score depending on the key's suffix.

Source code in src/refiners/foundationals/latent_diffusion/lora.py
@staticmethod\ndef sort_keys(key: str, /) -> tuple[str, int]:\n    \"\"\"Compute the score of a key, relatively to its suffix.\n\n    When used by [`sorted`][sorted], the keys will only be sorted \"at the suffix level\".\n    The idea is that sometimes closely related keys in the state dict are not in the\n    same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This\n    attempts to fix that issue, not cases where distant layers are called in a different\n    order.\n\n    Args:\n        key: The key to sort.\n\n    Returns:\n        The padded prefix of the key.\n        A score depending on the key's suffix.\n    \"\"\"\n\n    # this dict might not be exhaustive\n    suffix_scores = {\"q\": 1, \"k\": 2, \"v\": 3, \"in\": 3, \"out\": 4, \"out0\": 4, \"out_0\": 4}\n    patterns = [\"_{}\", \"_{}_lora\"]\n\n    # apply patterns to the keys of suffix_scores\n    key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns}\n\n    # get the suffix and score for `key` (default: no suffix, highest score = 5)\n    (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), (\"\", 5))\n\n    padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))\n    return (padded_key_prefix, score)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.update_scales","title":"update_scales","text":"
update_scales(scales: dict[str, float]) -> None\n

Update the scales of multiple LoRAs.

Parameters:

Name Type Description Default scales dict[str, float]

The scales to update. (keys are the names of the LoRAs, values are the new scales to set)

required Source code in src/refiners/foundationals/latent_diffusion/lora.py
def update_scales(self, scales: dict[str, float], /) -> None:\n    \"\"\"Update the scales of multiple LoRAs.\n\n    Args:\n        scales: The scales to update.\n            (keys are the names of the LoRAs, values are the new scales to set)\n    \"\"\"\n    assert all([name in self.names for name in scales]), f\"Scales keys must be a subset of {self.names}\"\n    for name, scale in scales.items():\n        for lora in self.get_loras_by_name(name):\n            lora.scale = scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter","title":"IPAdapter","text":"
IPAdapter(\n    target: T,\n    clip_image_encoder: CLIPImageEncoderH,\n    image_proj: Module,\n    scale: float = 1.0,\n    fine_grained: bool = False,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: Generic[T], Chain, Adapter[T]

Image Prompt adapter for a Stable Diffusion U-Net model.

See [arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models for more details.

Parameters:

Name Type Description Default target T

The target model to adapt.

required clip_image_encoder CLIPImageEncoderH

The CLIP image encoder to use.

required image_proj Module

The image projection to use.

required scale float

The scale to use for the image prompt.

1.0 fine_grained bool

Whether to use fine-grained image prompt.

False weights dict[str, Tensor] | None

The weights of the IPAdapter.

None Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def __init__(\n    self,\n    target: T,\n    clip_image_encoder: CLIPImageEncoderH,\n    image_proj: fl.Module,\n    scale: float = 1.0,\n    fine_grained: bool = False,\n    weights: dict[str, Tensor] | None = None,\n) -> None:\n    \"\"\"Initialize the adapter.\n\n    Args:\n        target: The target model to adapt.\n        clip_image_encoder: The CLIP image encoder to use.\n        image_proj: The image projection to use.\n        scale: The scale to use for the image prompt.\n        fine_grained: Whether to use fine-grained image prompt.\n        weights: The weights of the IPAdapter.\n    \"\"\"\n    with self.setup_adapter(target):\n        super().__init__(target)\n\n    self.fine_grained = fine_grained\n    self._clip_image_encoder = [clip_image_encoder]\n    if fine_grained:\n        self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]\n    self._image_proj = [image_proj]\n\n    self.sub_adapters = [\n        CrossAttentionAdapter(target=cross_attn, scale=scale)\n        for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))\n    ]\n\n    if weights is not None:\n        image_proj_state_dict: dict[str, Tensor] = {\n            k.removeprefix(\"image_proj.\"): v for k, v in weights.items() if k.startswith(\"image_proj.\")\n        }\n        self.image_proj.load_state_dict(image_proj_state_dict)\n\n        for i, cross_attn in enumerate(self.sub_adapters):\n            cross_attention_weights: list[Tensor] = []\n            for k, v in weights.items():\n                prefix = f\"ip_adapter.{i:03d}.\"\n                if not k.startswith(prefix):\n                    continue\n                cross_attention_weights.append(v)\n\n            assert len(cross_attention_weights) == 2\n            cross_attn.load_weights(*cross_attention_weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.clip_image_encoder","title":"clip_image_encoder property","text":"
clip_image_encoder: CLIPImageEncoderH\n

The CLIP image encoder of the adapter.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.scale","title":"scale property writable","text":"
scale: float\n

The scale of the adapter.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.compute_clip_image_embedding","title":"compute_clip_image_embedding","text":"
compute_clip_image_embedding(\n    image_prompt: Tensor | Image | list[Image],\n    weights: list[float] | None = None,\n    concat_batches: bool = True,\n) -> Tensor\n

Compute the CLIP image embedding.

Parameters:

Name Type Description Default image_prompt Tensor | Image | list[Image]

The image prompt to use.

required weights list[float] | None

The scale to use for the image prompt.

None concat_batches bool

Whether to concatenate the batches.

True

Returns:

Type Description Tensor

The CLIP image embedding.

Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def compute_clip_image_embedding(\n    self,\n    image_prompt: Tensor | Image.Image | list[Image.Image],\n    weights: list[float] | None = None,\n    concat_batches: bool = True,\n) -> Tensor:\n    \"\"\"Compute the CLIP image embedding.\n\n    Args:\n        image_prompt: The image prompt to use.\n        weights: The scale to use for the image prompt.\n        concat_batches: Whether to concatenate the batches.\n\n    Returns:\n        The CLIP image embedding.\n    \"\"\"\n    if isinstance(image_prompt, Image.Image):\n        image_prompt = self.preprocess_image(image_prompt)\n    elif isinstance(image_prompt, list):\n        assert all(isinstance(image, Image.Image) for image in image_prompt)\n        image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])\n\n    negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)\n\n    batch_size = image_prompt.shape[0]\n    if weights is not None:\n        assert len(weights) == batch_size, f\"Got {len(weights)} weights for {batch_size} images\"\n        if any(weight != 1.0 for weight in weights):\n            conditional_embedding *= (\n                torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)\n                .unsqueeze(-1)\n                .unsqueeze(-1)\n            )\n\n    if batch_size > 1 and concat_batches:\n        # Create a longer image tokens sequence when a batch of images is given\n        # See https://github.com/tencent-ailab/IP-Adapter/issues/99\n        negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)\n        conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)\n\n    return torch.cat((negative_embedding, conditional_embedding))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.preprocess_image","title":"preprocess_image","text":"
preprocess_image(\n    image: Image,\n    size: tuple[int, int] = (224, 224),\n    mean: list[float] | None = None,\n    std: list[float] | None = None,\n) -> Tensor\n

Preprocess the image.

Note

The default mean and std are parameters from https://github.com/openai/CLIP

Parameters:

Name Type Description Default image Image

The image to preprocess.

required size tuple[int, int]

The size to resize the image to.

(224, 224) mean list[float] | None

The mean to use for normalization.

None std list[float] | None

The standard deviation to use for normalization.

None Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def preprocess_image(\n    self,\n    image: Image.Image,\n    size: tuple[int, int] = (224, 224),\n    mean: list[float] | None = None,\n    std: list[float] | None = None,\n) -> Tensor:\n    \"\"\"Preprocess the image.\n\n    Note:\n        The default mean and std are parameters from\n        https://github.com/openai/CLIP\n\n    Args:\n        image: The image to preprocess.\n        size: The size to resize the image to.\n        mean: The mean to use for normalization.\n        std: The standard deviation to use for normalization.\n    \"\"\"\n    resized = image.resize(size)  # type: ignore\n    return normalize(\n        image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),\n        mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,\n        std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.set_clip_image_embedding","title":"set_clip_image_embedding","text":"
set_clip_image_embedding(image_embedding: Tensor) -> None\n

Set the CLIP image embedding context.

Note

This is required by ImageCrossAttention.

Parameters:

Name Type Description Default image_embedding Tensor

The CLIP image embedding to set.

required Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:\n    \"\"\"Set the CLIP image embedding context.\n\n    Note:\n        This is required by `ImageCrossAttention`.\n\n    Args:\n        image_embedding: The CLIP image embedding to set.\n    \"\"\"\n    self.set_context(\"ip_adapter\", {\"clip_image_embedding\": image_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.AdaIN","title":"AdaIN","text":"
AdaIN(epsilon: float = 1e-08)\n

Bases: Module

Apply Adaptive Instance Normalization (AdaIN) to the target features.

See [arXiv:1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization for more details.

Receives:

Name Type Description reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features.

targets Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The target features.

Returns:

Name Type Description reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features (unchanged).

targets Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The target features, renormalized.

Parameters:

Name Type Description Default epsilon float

A small value to avoid division by zero.

1e-08 Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(self, epsilon: float = 1e-8) -> None:\n    \"\"\"Initialize the AdaIN module.\n\n    Args:\n        epsilon: A small value to avoid division by zero.\n    \"\"\"\n    super().__init__()\n    self.epsilon = epsilon\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.ExtractReferenceFeatures","title":"ExtractReferenceFeatures","text":"
ExtractReferenceFeatures(*args: Any, **kwargs: Any)\n

Bases: Module

Extract the reference features from the input features.

Note

This layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).

The reference features are the first features of the conditional and unconditional input features. They are extracted, and repeated to match the batch size of the input features.

Receives:

Name Type Description features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The input features.

Returns:

Name Type Description reference Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The reference features.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.ScaleReferenceFeatures","title":"ScaleReferenceFeatures","text":"
ScaleReferenceFeatures(scale: float = 1.0)\n

Bases: Module

Scale the reference features.

Note

This layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).

This layer scales the reference features which will later be used (in the attention dot product) with the target features.

Receives:

Name Type Description features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The input reference features.

Returns:

Name Type Description features Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']

The rescaled reference features.

Parameters:

Name Type Description Default scale float

The scaling factor.

1.0 Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n    self,\n    scale: float = 1.0,\n) -> None:\n    \"\"\"Initialize the ScaleReferenceFeatures module.\n\n    Args:\n        scale: The scaling factor.\n    \"\"\"\n    super().__init__()\n    self.scale = scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.SharedSelfAttentionAdapter","title":"SharedSelfAttentionAdapter","text":"
SharedSelfAttentionAdapter(\n    target: SelfAttention, scale: float = 1.0\n)\n

Bases: Chain, Adapter[SelfAttention]

Upgrades a SelfAttention layer into a SharedSelfAttention layer.

This adapter inserts 3 StyleAligned modules right after the original Q, K, V Linear-s (wrapped inside a fl.Distribute).

Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n    self,\n    target: fl.SelfAttention,\n    scale: float = 1.0,\n) -> None:\n    with self.setup_adapter(target):\n        super().__init__(target)\n\n    self._style_aligned_layers = [\n        StyleAligned(  # Query\n            adain=True,\n            concatenate=False,\n            scale=scale,\n        ),\n        StyleAligned(  # Key\n            adain=True,\n            concatenate=True,\n            scale=scale,\n        ),\n        StyleAligned(  # Value\n            adain=False,\n            concatenate=True,\n            scale=scale,\n        ),\n    ]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAligned","title":"StyleAligned","text":"
StyleAligned(\n    adain: bool, concatenate: bool, scale: float = 1.0\n)\n

Bases: Chain

StyleAligned module.

This layer encapsulates the logic of the StyleAligned method, as described in [arXiv:2312.02133] Style Aligned Image Generation via Shared Attention.

See also https://blog.finegrain.ai/posts/implementing-style-aligned/.

Receives:

Name Type Description features Float[Tensor, 'cfg_batch_size sequence_length_in embedding_dim']

The input features.

Returns:

Name Type Description shared_features Float[Tensor, 'cfg_batch_size sequence_length_out embedding_dim']

The transformed features.

Parameters:

Name Type Description Default adain bool

Whether to apply Adaptive Instance Normalization to the target features.

required scale float

The scaling factor for the reference features.

1.0 concatenate bool

Whether to concatenate the reference and target features.

required Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n    self,\n    adain: bool,\n    concatenate: bool,\n    scale: float = 1.0,\n) -> None:\n    \"\"\"Initialize the StyleAligned module.\n\n    Args:\n        adain: Whether to apply Adaptive Instance Normalization to the target features.\n        scale: The scaling factor for the reference features.\n        concatenate: Whether to concatenate the reference and target features.\n    \"\"\"\n    super().__init__(\n        # (features): (cfg_batch_size sequence_length embedding_dim)\n        fl.Parallel(\n            fl.Identity(),\n            ExtractReferenceFeatures(),\n        ),\n        # (targets, reference)\n        AdaIN(),\n        # (targets_renormalized, reference)\n        fl.Distribute(\n            fl.Identity(),\n            ScaleReferenceFeatures(scale=scale),\n        ),\n        # (targets_renormalized, reference_scaled)\n        fl.Concatenate(\n            fl.GetArg(index=0),  # targets\n            fl.GetArg(index=1),  # reference\n            dim=-2,  # sequence_length\n        ),\n        # (features_with_shared_reference)\n    )\n\n    if not adain:\n        adain_module = self.ensure_find(AdaIN)\n        self.remove(adain_module)\n\n    if not concatenate:\n        concatenate_module = self.ensure_find(fl.Concatenate)\n        self.replace(\n            old_module=concatenate_module,\n            new_module=fl.GetArg(index=0),  # targets\n        )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAligned.scale","title":"scale property writable","text":"
scale: float\n

The scaling factor for the reference features.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAlignedAdapter","title":"StyleAlignedAdapter","text":"
StyleAlignedAdapter(target: T, scale: float = 1.0)\n

Bases: Generic[T], Chain, Adapter[T]

Upgrade each SelfAttention layer of a UNet into a SharedSelfAttention layer.

Parameters:

Name Type Description Default target T

The target module.

required scale float

The scaling factor for the reference features.

1.0 Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n    self,\n    target: T,\n    scale: float = 1.0,\n) -> None:\n    \"\"\"Initialize the StyleAlignedAdapter.\n\n    Args:\n        target: The target module.\n        scale: The scaling factor for the reference features.\n    \"\"\"\n    with self.setup_adapter(target):\n        super().__init__(target)\n\n    # create a SharedSelfAttentionAdapter for each SelfAttention module\n    self.shared_self_attention_adapters = tuple(\n        SharedSelfAttentionAdapter(\n            target=self_attention,\n            scale=scale,\n        )\n        for self_attention in self.target.layers(fl.SelfAttention)\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAlignedAdapter.scale","title":"scale property writable","text":"
scale: float\n

The scaling factor for the reference features.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.multi_diffusion.DiffusionTarget","title":"DiffusionTarget dataclass","text":"
DiffusionTarget(\n    *,\n    tile: Tile,\n    solver: Solver,\n    init_latents: Tensor | None = None,\n    opacity_mask: Tensor | None = None,\n    weight: int = 1,\n    start_step: int = 0,\n    end_step: int = MAX_STEPS\n)\n

Represents a target for the tiled diffusion process.

This class encapsulates the parameters and properties needed to define a specific area (target) within a larger diffusion process, allowing for fine-grained control over different regions of the generated image.

Attributes:

Name Type Description tile Tile

The tile defining the area of the target within the latent image.

solver Solver

The solver to use for this target's diffusion process. This is useful because some solvers have an internal state that needs to be updated during the diffusion process. Using the same solver instance for multiple targets would interfere with this internal state.

init_latents Tensor | None

The initial latents for this target. If None, the target will be initialized with noise.

opacity_mask Tensor | None

Mask controlling the target's visibility in the final image. If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent which means the target has no influence.

weight int

The importance of this target in the final image. Higher values increase the target's influence.

start_step int

The diffusion step at which this target begins to influence the process.

end_step int

The diffusion step at which this target stops influencing the process.

size Size

The size of the target area.

offset tuple[int, int]

The top-left offset of the target area within the latent image.

The combination of opacity_mask and weight determines the target's overall contribution to the final generated image. The solver is responsible for the actual diffusion calculations for this target.

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.multi_diffusion.MultiDiffusion","title":"MultiDiffusion","text":"

Bases: ABC, Generic[T]

MultiDiffusion class for performing multi-target diffusion using tiled diffusion.

For more details, refer to the paper: MultiDiffusion

"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.multi_diffusion.MultiDiffusion.generate_latent_tiles","title":"generate_latent_tiles staticmethod","text":"
generate_latent_tiles(\n    size: Size, tile_size: Size, min_overlap: int = 8\n) -> list[Tile]\n

Generate tiles for a latent image with the given size and tile size.

If one dimension of the tile_size is larger than the corresponding dimension of the image size, a single tile is used to cover the entire image - and therefore tile_size is ignored. This algorithm ensures that the tile size is respected as much as possible, while still covering the entire image and respecting the minimum overlap.

Source code in src/refiners/foundationals/latent_diffusion/multi_diffusion.py
@staticmethod\ndef generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:\n    \"\"\"\n    Generate tiles for a latent image with the given size and tile size.\n\n    If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is\n    used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size\n    is respected as much as possible, while still covering the entire image and respecting the minimum overlap.\n    \"\"\"\n    assert (\n        0 <= min_overlap < min(tile_size.height, tile_size.width)\n    ), \"Overlap must be non-negative and less than the tile size\"\n\n    if tile_size.width > size.width or tile_size.height > size.height:\n        return [Tile(top=0, left=0, bottom=size.height, right=size.width)]\n\n    tiles: list[Tile] = []\n\n    def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:\n        if tile_length >= length:\n            return 1, 0\n        num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1\n        overlap = (num_tiles * tile_length - length) // (num_tiles - 1)\n        return num_tiles, overlap\n\n    num_tiles_x, overlap_x = _compute_tiles_and_overlap(\n        length=size.width, tile_length=tile_size.width, min_overlap=min_overlap\n    )\n    num_tiles_y, overlap_y = _compute_tiles_and_overlap(\n        length=size.height, tile_length=tile_size.height, min_overlap=min_overlap\n    )\n\n    for i in range(num_tiles_y):\n        for j in range(num_tiles_x):\n            x = j * (tile_size.width - overlap_x)\n            y = i * (tile_size.height - overlap_y)\n\n            # Adjust x and y coordinates to ensure full-sized tiles\n            if x + tile_size.width > size.width:\n                x = size.width - tile_size.width\n            if y + tile_size.height > size.height:\n                y = size.height - tile_size.height\n\n            tile_right = x + tile_size.width\n            tile_bottom = y + tile_size.height\n            tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))\n\n    return tiles\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.ella_adapter.ELLA","title":"ELLA","text":"
ELLA(\n    time_channel: int,\n    timestep_embedding_dim: int,\n    width: int,\n    num_layers: int,\n    num_heads: int,\n    num_latents: int,\n    input_dim: int | None = None,\n    out_dim: int | None = None,\n    device: device | str | None = None,\n    dtype: dtype | None = None,\n)\n

Bases: Passthrough

ELLA latents encoder.

See [arXiv:2403.05135] ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment for more details.

Source code in src/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(\n    self,\n    time_channel: int,\n    timestep_embedding_dim: int,\n    width: int,\n    num_layers: int,\n    num_heads: int,\n    num_latents: int,\n    input_dim: int | None = None,\n    out_dim: int | None = None,\n    device: Device | str | None = None,\n    dtype: DType | None = None,\n) -> None:\n    super().__init__(\n        TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype),\n        fl.UseContext(\"adapted_cross_attention_block\", \"llm_text_embedding\"),\n        PerceiverResampler(\n            timestep_embedding_dim,\n            width,\n            num_layers,\n            num_heads,\n            num_latents,\n            out_dim,\n            input_dim,\n            device=device,\n            dtype=dtype,\n        ),\n        fl.SetContext(\"ella\", \"latents\"),\n    )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.ella_adapter.ELLAAdapter","title":"ELLAAdapter","text":"
ELLAAdapter(\n    target: T,\n    latents_encoder: ELLA,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: Generic[T], Chain, Adapter[T]

Adapter for ELLA.

Source code in src/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None:\n    if weights is not None:\n        latents_encoder.load_state_dict(weights)\n\n    self._latents_encoder = [latents_encoder]\n    with self.setup_adapter(target):\n        super().__init__(target)\n    self.sub_adapters = [\n        ELLACrossAttentionAdapter(use_context)\n        for cross_attn in target.layers(CrossAttentionBlock)\n        for use_context in cross_attn.layers(fl.UseContext)\n    ]\n
"},{"location":"reference/foundationals/segment_anything/","title":" Segment Anything","text":""},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.HQSAMAdapter","title":"HQSAMAdapter","text":"
HQSAMAdapter(\n    target: SegmentAnything,\n    hq_mask_only: bool = False,\n    weights: dict[str, Tensor] | None = None,\n)\n

Bases: Chain, Adapter[SegmentAnything]

Adapter for SAM introducing HQ features.

See [arXiv:2306.01567] Segment Anything in High Quality for details.

Example
from refiners.fluxion.utils import load_from_safetensors\n\n# Tips: run scripts/prepare_test_weights.py to download the weights\ntensor_path = \"./tests/weights/refiners-sam-hq-vit-h.safetensors\"\nweights = load_from_safetensors(tensor_path)\n\nhq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)\nhq_sam_adapter.inject()  # then use SAM as usual\n

Parameters:

Name Type Description Default target SegmentAnything

The SegmentAnything model to adapt.

required hq_mask_only bool

Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).

False weights dict[str, Tensor] | None

The weights of the HQSAMAdapter.

None Source code in src/refiners/foundationals/segment_anything/hq_sam.py
def __init__(\n    self,\n    target: SegmentAnything,\n    hq_mask_only: bool = False,\n    weights: dict[str, torch.Tensor] | None = None,\n) -> None:\n    \"\"\"Initialize the adapter.\n\n    Args:\n        target: The SegmentAnything model to adapt.\n        hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).\n        weights: The weights of the HQSAMAdapter.\n    \"\"\"\n    self.vit_embedding_dim = target.image_encoder.embedding_dim\n    self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2\n\n    with self.setup_adapter(target):\n        super().__init__(target)\n\n    if target.mask_decoder.multimask_output:\n        raise NotImplementedError(\"Multi-mask mode is not supported in HQSAMAdapter.\")\n\n    mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)\n\n    self._mask_prediction_adapter = [\n        MaskPredictionAdapter(\n            mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype\n        )\n    ]\n    self._register_adapter_module(\"Chain.HQSAMMaskPrediction\", self.mask_prediction_adapter.hq_sam_mask_prediction)\n\n    self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]\n    self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]\n\n    mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)\n    self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]\n    self._register_adapter_module(\"MaskDecoderTokensExtender.hq_token\", self.mask_decoder_tokens_extender.hq_token)\n\n    if weights is not None:\n        self.load_weights(weights)\n\n    self.to(device=target.device, dtype=target.dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything","title":"SegmentAnything","text":"
SegmentAnything(\n    image_encoder: SAMViT,\n    point_encoder: PointEncoder,\n    mask_encoder: MaskEncoder,\n    mask_decoder: MaskDecoder,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: Chain

SegmentAnything model.

See [arXiv:2304.02643] Segment Anything

E.g. see SegmentAnythingH for usage.

Attributes:

Name Type Description mask_threshold float

0.0

Parameters:

Name Type Description Default image_encoder SAMViT

The image encoder to use.

required point_encoder PointEncoder

The point encoder to use.

required mask_encoder MaskEncoder

The mask encoder to use.

required mask_decoder MaskDecoder

The mask decoder to use.

required Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(\n    self,\n    image_encoder: SAMViT,\n    point_encoder: PointEncoder,\n    mask_encoder: MaskEncoder,\n    mask_decoder: MaskDecoder,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    \"\"\"Initialize SegmentAnything model.\n\n    Args:\n        image_encoder: The image encoder to use.\n        point_encoder: The point encoder to use.\n        mask_encoder: The mask encoder to use.\n        mask_decoder: The mask decoder to use.\n    \"\"\"\n    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)\n\n    self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.image_encoder","title":"image_encoder property","text":"
image_encoder: SAMViT\n

The image encoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.image_encoder_resolution","title":"image_encoder_resolution property","text":"
image_encoder_resolution: int\n

The resolution of the image encoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.mask_decoder","title":"mask_decoder property","text":"
mask_decoder: MaskDecoder\n

The mask decoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.mask_encoder","title":"mask_encoder property","text":"
mask_encoder: MaskEncoder\n

The mask encoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.point_encoder","title":"point_encoder property","text":"
point_encoder: PointEncoder\n

The point encoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.compute_image_embedding","title":"compute_image_embedding","text":"
compute_image_embedding(image: Image) -> ImageEmbedding\n

Compute the emmbedding of an image.

Parameters:

Name Type Description Default image Image

The image to compute the embedding of.

required

Returns:

Type Description ImageEmbedding

The computed image embedding.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()\ndef compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:\n    \"\"\"Compute the emmbedding of an image.\n\n    Args:\n        image: The image to compute the embedding of.\n\n    Returns:\n        The computed image embedding.\n    \"\"\"\n    original_size = (image.height, image.width)\n    return ImageEmbedding(\n        features=self.image_encoder(self.preprocess_image(image)),\n        original_image_size=original_size,\n    )\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.normalize","title":"normalize","text":"
normalize(\n    coordinates: Tensor, original_size: tuple[int, int]\n) -> Tensor\n

See normalize_coordinates Args: coordinates: a tensor of coordinates. original_size: (h, w) the original size of the image. Returns: The [0,1] normalized coordinates tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:\n    \"\"\"\n    See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]\n    Args:\n        coordinates: a tensor of coordinates.\n        original_size: (h, w) the original size of the image.\n    Returns:\n        The [0,1] normalized coordinates tensor.\n    \"\"\"\n    return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.postprocess_masks","title":"postprocess_masks","text":"
postprocess_masks(\n    low_res_masks: Tensor, original_size: tuple[int, int]\n) -> Tensor\n

See postprocess_masks Args: low_res_masks: a mask tensor of size (N, 1, 256, 256) original_size: (h, w) the original size of the image. Returns: The mask of shape (N, 1, H, W)

Source code in src/refiners/foundationals/segment_anything/model.py
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:\n    \"\"\"\n    See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]\n    Args:\n        low_res_masks: a mask tensor of size (N, 1, 256, 256)\n        original_size: (h, w) the original size of the image.\n    Returns:\n        The mask of shape (N, 1, H, W)\n    \"\"\"\n    return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.predict","title":"predict","text":"
predict(\n    input: Image | ImageEmbedding,\n    foreground_points: (\n        Sequence[tuple[float, float]] | None\n    ) = None,\n    background_points: (\n        Sequence[tuple[float, float]] | None\n    ) = None,\n    box_points: (\n        Sequence[Sequence[tuple[float, float]]] | None\n    ) = None,\n    low_res_mask: (\n        Float[Tensor, \"1 1 256 256\"] | None\n    ) = None,\n    binarize: bool = True,\n) -> tuple[Tensor, Tensor, Tensor]\n

Predict the masks of the input image.

Parameters:

Name Type Description Default input Image | ImageEmbedding

The input image or its embedding.

required foreground_points Sequence[tuple[float, float]] | None

The points of the foreground.

None background_points Sequence[tuple[float, float]] | None

The points of the background.

None box_points Sequence[Sequence[tuple[float, float]]] | None

The points of the box.

None low_res_mask Float[Tensor, '1 1 256 256'] | None

The low resolution mask.

None binarize bool

Whether to binarize the masks.

True

Returns:

Type Description Tensor

The predicted masks.

Tensor

The IOU prediction.

Tensor

The low resolution masks.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()\ndef predict(\n    self,\n    input: Image.Image | ImageEmbedding,\n    foreground_points: Sequence[tuple[float, float]] | None = None,\n    background_points: Sequence[tuple[float, float]] | None = None,\n    box_points: Sequence[Sequence[tuple[float, float]]] | None = None,\n    low_res_mask: Float[Tensor, \"1 1 256 256\"] | None = None,\n    binarize: bool = True,\n) -> tuple[Tensor, Tensor, Tensor]:\n    \"\"\"Predict the masks of the input image.\n\n    Args:\n        input: The input image or its embedding.\n        foreground_points: The points of the foreground.\n        background_points: The points of the background.\n        box_points: The points of the box.\n        low_res_mask: The low resolution mask.\n        binarize: Whether to binarize the masks.\n\n    Returns:\n        The predicted masks.\n        The IOU prediction.\n        The low resolution masks.\n    \"\"\"\n    if isinstance(input, ImageEmbedding):\n        original_size = input.original_image_size\n        image_embedding = input.features\n    else:\n        original_size = (input.height, input.width)\n        image_embedding = self.image_encoder(self.preprocess_image(input))\n\n    coordinates, type_mask = self.point_encoder.points_to_tensor(\n        foreground_points=foreground_points,\n        background_points=background_points,\n        box_points=box_points,\n    )\n    self.point_encoder.set_type_mask(type_mask=type_mask)\n\n    if low_res_mask is not None:\n        mask_embedding = self.mask_encoder(low_res_mask)\n    else:\n        mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(\n            image_embedding_size=self.image_encoder.image_embedding_size\n        )\n\n    point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))\n    dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(\n        image_embedding_size=self.image_encoder.image_embedding_size\n    )\n\n    self.mask_decoder.set_image_embedding(image_embedding=image_embedding)\n    self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)\n    self.mask_decoder.set_point_embedding(point_embedding=point_embedding)\n    self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)\n\n    low_res_masks, iou_predictions = self.mask_decoder()\n\n    high_res_masks = self.postprocess_masks(low_res_masks, original_size)\n\n    if binarize:\n        high_res_masks = high_res_masks > self.mask_threshold\n\n    return high_res_masks, iou_predictions, low_res_masks\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.preprocess_image","title":"preprocess_image","text":"
preprocess_image(image: Image) -> Tensor\n

See preprocess_image Args: image: The image to preprocess. Returns: The preprocessed tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def preprocess_image(self, image: Image.Image) -> Tensor:\n    \"\"\"\n    See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]\n    Args:\n        image: The image to preprocess.\n    Returns:\n        The preprocessed tensor.\n    \"\"\"\n    return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnythingH","title":"SegmentAnythingH","text":"
SegmentAnythingH(\n    image_encoder: SAMViTH | None = None,\n    point_encoder: PointEncoder | None = None,\n    mask_encoder: MaskEncoder | None = None,\n    mask_decoder: MaskDecoder | None = None,\n    multimask_output: bool | None = None,\n    device: device | str = \"cpu\",\n    dtype: dtype = torch.float32,\n)\n

Bases: SegmentAnything

SegmentAnything huge model.

Parameters:

Name Type Description Default image_encoder SAMViTH | None

The image encoder to use.

None point_encoder PointEncoder | None

The point encoder to use.

None mask_encoder MaskEncoder | None

The mask encoder to use.

None mask_decoder MaskDecoder | None

The mask decoder to use.

None multimask_output bool | None

Whether to use multimask output.

None device device | str

The PyTorch device to use.

'cpu' dtype dtype

The PyTorch data type to use.

float32 Example
device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n# multimask_output=True is recommended for ambiguous prompts such as a single point.\n# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask\nsam_h = SegmentAnythingH(multimask_output=False, device=device)\n\n# Tips: run scripts/prepare_test_weights.py to download the weights\ntensors_path = \"./tests/weights/segment-anything-h.safetensors\"\nsam_h.load_from_safetensors(tensors_path=tensors_path)\n\nfrom PIL import Image\nimage = Image.open(\"image.png\")\n\nmasks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])\n\nassert masks.shape == (1, 1, image.height, image.width)\nassert masks.dtype == torch.bool\n\n# convert it to [0,255] uint8 ndarray of shape (H, W)\nmask = masks[0, 0].cpu().numpy().astype(\"uint8\") * 255\n\nImage.fromarray(mask).save(\"mask_image.png\")\n
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(\n    self,\n    image_encoder: SAMViTH | None = None,\n    point_encoder: PointEncoder | None = None,\n    mask_encoder: MaskEncoder | None = None,\n    mask_decoder: MaskDecoder | None = None,\n    multimask_output: bool | None = None,\n    device: Device | str = \"cpu\",\n    dtype: DType = torch.float32,\n) -> None:\n    \"\"\"Initialize SegmentAnything huge model.\n\n    Args:\n        image_encoder: The image encoder to use.\n        point_encoder: The point encoder to use.\n        mask_encoder: The mask encoder to use.\n        mask_decoder: The mask decoder to use.\n        multimask_output: Whether to use multimask output.\n        device: The PyTorch device to use.\n        dtype: The PyTorch data type to use.\n\n    Example:\n        ```py\n        device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # multimask_output=True is recommended for ambiguous prompts such as a single point.\n        # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask\n        sam_h = SegmentAnythingH(multimask_output=False, device=device)\n\n        # Tips: run scripts/prepare_test_weights.py to download the weights\n        tensors_path = \"./tests/weights/segment-anything-h.safetensors\"\n        sam_h.load_from_safetensors(tensors_path=tensors_path)\n\n        from PIL import Image\n        image = Image.open(\"image.png\")\n\n        masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])\n\n        assert masks.shape == (1, 1, image.height, image.width)\n        assert masks.dtype == torch.bool\n\n        # convert it to [0,255] uint8 ndarray of shape (H, W)\n        mask = masks[0, 0].cpu().numpy().astype(\"uint8\") * 255\n\n        Image.fromarray(mask).save(\"mask_image.png\")\n        ```\n    \"\"\"\n    image_encoder = image_encoder or SAMViTH()\n    point_encoder = point_encoder or PointEncoder()\n    mask_encoder = mask_encoder or MaskEncoder()\n\n    if mask_decoder:\n        assert (\n            multimask_output is None or mask_decoder.multimask_output == multimask_output\n        ), f\"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})\"\n    else:\n        mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()\n\n    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)\n\n    self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnythingH.image_encoder","title":"image_encoder property","text":"
image_encoder: SAMViTH\n

The image encoder.

"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.compute_scaled_size","title":"compute_scaled_size","text":"
compute_scaled_size(\n    size: tuple[int, int], image_encoder_resolution: int\n) -> tuple[int, int]\n

Compute the scaled size as expected by the image encoder. This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.

Parameters:

Name Type Description Default size tuple[int, int]

The size (h, w) of the input image.

required image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description int

The target height.

int

The target width.

Source code in src/refiners/foundationals/segment_anything/utils.py
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:\n    \"\"\"Compute the scaled size as expected by the image encoder.\n    This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.\n\n    Args:\n        size: The size (h, w) of the input image.\n        image_encoder_resolution: Image encoder resolution.\n\n    Returns:\n        The target height.\n        The target width.\n    \"\"\"\n    oldh, oldw = size\n    scale = image_encoder_resolution * 1.0 / max(oldh, oldw)\n    newh, neww = oldh * scale, oldw * scale\n    neww = int(neww + 0.5)\n    newh = int(newh + 0.5)\n    return (newh, neww)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.image_to_scaled_tensor","title":"image_to_scaled_tensor","text":"
image_to_scaled_tensor(\n    image: Image,\n    scaled_size: tuple[int, int],\n    device: device | None = None,\n    dtype: dtype | None = None,\n) -> Tensor\n

Resize the image to scaled_size and convert it to a tensor.

Parameters:

Name Type Description Default image Image

The image.

required scaled_size tuple[int, int]

The target size (h, w).

required device device | None

Tensor device.

None dtype dtype | None

Tensor dtype.

None

Returns: a Tensor of shape (1, c, h, w)

Source code in src/refiners/foundationals/segment_anything/utils.py
def image_to_scaled_tensor(\n    image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None\n) -> Tensor:\n    \"\"\"Resize the image to `scaled_size` and convert it to a tensor.\n\n    Args:\n        image: The image.\n        scaled_size: The target size (h, w).\n        device: Tensor device.\n        dtype: Tensor dtype.\n    Returns:\n        a Tensor of shape (1, c, h, w)\n    \"\"\"\n    h, w = scaled_size\n    resized = image.resize((w, h), resample=Image.Resampling.BILINEAR)  # type: ignore\n\n    return image_to_tensor(resized, device=device, dtype=dtype) * 255.0\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.normalize_coordinates","title":"normalize_coordinates","text":"
normalize_coordinates(\n    coordinates: Tensor,\n    original_size: tuple[int, int],\n    image_encoder_resolution: int,\n) -> Tensor\n

Normalize the coordinates in the [0,1] range

Parameters:

Name Type Description Default coordinates Tensor

The coordinates to normalize.

required original_size tuple[int, int]

The original image size.

required image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description Tensor

The normalized coordinates.

Source code in src/refiners/foundationals/segment_anything/utils.py
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n    \"\"\"Normalize the coordinates in the [0,1] range\n\n    Args:\n        coordinates: The coordinates to normalize.\n        original_size: The original image size.\n        image_encoder_resolution: Image encoder resolution.\n\n    Returns:\n        The normalized coordinates.\n    \"\"\"\n    scaled_size = compute_scaled_size(original_size, image_encoder_resolution)\n    coordinates[:, :, 0] = (\n        (coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5\n    ) / image_encoder_resolution\n    coordinates[:, :, 1] = (\n        (coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5\n    ) / image_encoder_resolution\n    return coordinates\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.pad_image_tensor","title":"pad_image_tensor","text":"
pad_image_tensor(\n    image_tensor: Tensor,\n    scaled_size: tuple[int, int],\n    image_encoder_resolution: int,\n) -> Tensor\n

Pad an image with zeros to make it square.

Parameters:

Name Type Description Default image_tensor Tensor

The image tensor to pad.

required scaled_size tuple[int, int]

The scaled size (h, w).

required image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description Tensor

The padded image.

Source code in src/refiners/foundationals/segment_anything/utils.py
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n    \"\"\"Pad an image with zeros to make it square.\n\n    Args:\n        image_tensor: The image tensor to pad.\n        scaled_size: The scaled size (h, w).\n        image_encoder_resolution: Image encoder resolution.\n\n    Returns:\n        The padded image.\n    \"\"\"\n    assert len(image_tensor.shape) == 4\n    assert image_tensor.shape[2] <= image_encoder_resolution\n    assert image_tensor.shape[3] <= image_encoder_resolution\n\n    h, w = scaled_size\n    padh = image_encoder_resolution - h\n    padw = image_encoder_resolution - w\n    return pad(image_tensor, (0, padw, 0, padh))\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.postprocess_masks","title":"postprocess_masks","text":"
postprocess_masks(\n    low_res_masks: Tensor,\n    original_size: tuple[int, int],\n    image_encoder_resolution: int,\n) -> Tensor\n

Postprocess the masks to fit the original image size and remove zero-padding (if any).

Parameters:

Name Type Description Default low_res_masks Tensor

The masks to postprocess.

required original_size tuple[int, int]

The original size (h, w).

required image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description Tensor

The postprocessed masks.

Source code in src/refiners/foundationals/segment_anything/utils.py
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n    \"\"\"Postprocess the masks to fit the original image size and remove zero-padding (if any).\n\n    Args:\n        low_res_masks: The masks to postprocess.\n        original_size: The original size (h, w).\n        image_encoder_resolution: Image encoder resolution.\n\n    Returns:\n        The postprocessed masks.\n    \"\"\"\n    scaled_size = compute_scaled_size(original_size, image_encoder_resolution)\n    masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode=\"bilinear\")\n    masks = masks[..., : scaled_size[0], : scaled_size[1]]  # remove padding added at `preprocess_image` time\n    masks = interpolate(masks, size=Size(original_size), mode=\"bilinear\")\n    return masks\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.preprocess_image","title":"preprocess_image","text":"
preprocess_image(\n    image: Image,\n    image_encoder_resolution: int,\n    device: device | None = None,\n    dtype: dtype | None = None,\n) -> Tensor\n

Preprocess an image without distorting its aspect ratio.

Parameters:

Name Type Description Default image Image

The image to preprocess before calling the image encoder.

required image_encoder_resolution int

Image encoder resolution.

required device device | None

Tensor device (None by default).

None dtype dtype | None

Tensor dtype (None by default).

None

Returns:

Type Description Tensor

The preprocessed image.

Source code in src/refiners/foundationals/segment_anything/utils.py
def preprocess_image(\n    image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None\n) -> Tensor:\n    \"\"\"Preprocess an image without distorting its aspect ratio.\n\n    Args:\n        image: The image to preprocess before calling the image encoder.\n        image_encoder_resolution: Image encoder resolution.\n        device: Tensor device (None by default).\n        dtype: Tensor dtype (None by default).\n\n    Returns:\n        The preprocessed image.\n    \"\"\"\n\n    scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)\n\n    image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype)\n\n    return pad_image_tensor(\n        normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),\n        scaled_size,\n        image_encoder_resolution,\n    )\n
"},{"location":"reference/foundationals/swin/","title":" Swin Transformers","text":""},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.swin_transformer.SwinTransformer","title":"SwinTransformer","text":"
SwinTransformer(\n    patch_size: tuple[int, int] = (4, 4),\n    in_chans: int = 3,\n    embedding_dim: int = 96,\n    depths: list[int] | None = None,\n    num_heads: list[int] | None = None,\n    window_size: int = 7,\n    mlp_ratio: float = 4.0,\n    device: device | None = None,\n)\n

Bases: Chain

Swin Transformer (arXiv:2103.14030)

Currently specific to MVANet, only supports square inputs.

Source code in src/refiners/foundationals/swin/swin_transformer.py
def __init__(\n    self,\n    patch_size: tuple[int, int] = (4, 4),\n    in_chans: int = 3,\n    embedding_dim: int = 96,\n    depths: list[int] | None = None,\n    num_heads: list[int] | None = None,\n    window_size: int = 7,  # image size is 32 * this\n    mlp_ratio: float = 4.0,\n    device: Device | None = None,\n) -> None:\n    if depths is None:\n        depths = [2, 2, 6, 2]\n\n    if num_heads is None:\n        num_heads = [3, 6, 12, 24]\n\n    self.num_layers = len(depths)\n    assert len(num_heads) == self.num_layers\n\n    super().__init__(\n        PatchEmbedding(\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embedding_dim=embedding_dim,\n            device=device,\n        ),\n        fl.Passthrough(\n            fl.Transpose(1, 2),\n            SquareUnflatten(2),\n            fl.SetContext(\"swin\", \"outputs\", callback=lambda t, x: t.append(x)),\n        ),\n        *(\n            fl.Chain(\n                BasicLayer(\n                    dim=int(embedding_dim * 2**i),\n                    depth=depths[i],\n                    num_heads=num_heads[i],\n                    window_size=window_size,\n                    mlp_ratio=mlp_ratio,\n                    device=device,\n                ),\n                fl.Passthrough(\n                    fl.LayerNorm(int(embedding_dim * 2**i), device=device),\n                    fl.Transpose(1, 2),\n                    SquareUnflatten(2),\n                    fl.SetContext(\"swin\", \"outputs\", callback=lambda t, x: t.insert(0, x)),\n                ),\n                PatchMerging(dim=int(embedding_dim * 2**i), device=device)\n                if i < self.num_layers - 1\n                else fl.UseContext(\"swin\", \"outputs\").compose(lambda t: tuple(t)),\n            )\n            for i in range(self.num_layers)\n        ),\n    )\n
"},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.swin_transformer.WindowAttention","title":"WindowAttention","text":"
WindowAttention(\n    dim: int,\n    window_size: int,\n    num_heads: int,\n    shift: bool = False,\n    device: device | None = None,\n)\n

Bases: Chain

Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA).

It has a trainable relative position bias (RelativePositionBias).

The input projection is stored as a single Linear for q, k and v.

Source code in src/refiners/foundationals/swin/swin_transformer.py
def __init__(\n    self,\n    dim: int,\n    window_size: int,\n    num_heads: int,\n    shift: bool = False,\n    device: Device | None = None,\n) -> None:\n    super().__init__(\n        fl.Linear(dim, dim * 3, bias=True, device=device),\n        WindowSDPA(window_size, num_heads, shift, device=device),\n        fl.Linear(dim, dim, device=device),\n    )\n
"},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.mvanet.MVANet","title":"MVANet","text":"
MVANet(\n    embedding_dim: int = 128,\n    n_logits: int = 1,\n    depths: list[int] | None = None,\n    num_heads: list[int] | None = None,\n    window_size: int = 12,\n    device: device | None = None,\n)\n

Bases: Chain

Multi-view Aggregation Network for Dichotomous Image Segmentation

See [arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation for more details.

Parameters:

Name Type Description Default embedding_dim int

embedding dimension

128 n_logits int

the number of output logits (default to 1) 1 logit is used for alpha matting/foreground-background segmentation/sod segmentation

1 depths list[int]

see SwinTransformer

None num_heads list[int]

see SwinTransformer

None window_size int

default to 12, see SwinTransformer

12 device device | None

the device to use

None Source code in src/refiners/foundationals/swin/mvanet/mvanet.py
def __init__(\n    self,\n    embedding_dim: int = 128,\n    n_logits: int = 1,\n    depths: list[int] | None = None,\n    num_heads: list[int] | None = None,\n    window_size: int = 12,\n    device: Device | None = None,\n) -> None:\n    if depths is None:\n        depths = [2, 2, 18, 2]\n    if num_heads is None:\n        num_heads = [4, 8, 16, 32]\n\n    super().__init__(\n        ComputeShallow(embedding_dim=embedding_dim, device=device),\n        SplitMultiView(),\n        fl.Flatten(0, 1),\n        SwinTransformer(\n            embedding_dim=embedding_dim,\n            depths=depths,\n            num_heads=num_heads,\n            window_size=window_size,\n            device=device,\n        ),\n        fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),\n        Pyramid(embedding_dim=embedding_dim, device=device),\n        RearrangeMultiView(embedding_dim=embedding_dim, device=device),\n        ShallowUpscaler(embedding_dim, device=device),\n        fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),\n    )\n
"}]}