{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"A PyTorch microframework for foundation model adaptation","text":"
The simplest way to train and run adapters on top of foundation models.
At the era of foundation models, adaptation is quickly rising at the method of choice for bridging the last mile quality gap. We couldn't find a framework with first class citizen APIs for foundation model adaptation, so we created one. It's called Refiners, and we're building it on top of PyTorch, in the open, under the MIT License. Read our manifesto.
"},{"location":"concepts/chain/","title":"Chain","text":"When we say models are implemented in a declarative way in Refiners, what this means in practice is they are implemented as Chains. Chain
is a Python class to implement trees of modules. It is a subclass of Refiners' Module
, which is in turn a subclass of PyTorch's Module
. All inner nodes of a Chain are subclasses of Chain
, and leaf nodes are subclasses of Refiners' Module
.
To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables.
img_res = 28\nchannels = 128\nkernel_size = 3\nhidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels\nhidden_layer_out = 200\noutput_size = 10\n
Now, here is the model in PyTorch:
class BasicModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv = nn.Conv2d(1, channels, kernel_size)\n self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out)\n self.maxpool = nn.MaxPool2d(2)\n self.linear_2 = nn.Linear(hidden_layer_out, output_size)\n\n def forward(self, x):\n x = self.conv(x)\n x = nn.functional.relu(x)\n x = self.maxpool(x)\n x = x.flatten(start_dim=1)\n x = self.linear_1(x)\n x = nn.functional.relu(x)\n x = self.linear_2(x)\n return nn.functional.softmax(x, dim=0)\n
And here is how we could implement the same model in Refiners:
class BasicModel(fl.Chain):\n def __init__(self):\n super().__init__(\n fl.Conv2d(1, channels, kernel_size),\n fl.ReLU(),\n fl.MaxPool2d(2),\n fl.Flatten(start_dim=1),\n fl.Linear(hidden_layer_in, hidden_layer_out),\n fl.ReLU(),\n fl.Linear(hidden_layer_out, output_size),\n fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)),\n )\n
Note
We often use the namespace fl
which means fluxion
, which is the name of the part of Refiners that implements basic layers.
As of writing, Refiners does not include a Softmax
layer by default, but as you can see you can easily call arbitrary code using fl.Lambda
. Alternatively, if you just wanted to write Softmax()
, you could implement it like this:
class Softmax(fl.Module):\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return torch.nn.functional.softmax(x, dim=0)\n
Note
Notice the type hints here. All of Refiners' codebase is typed, which makes it a pleasure to use if your downstream code is typed too.
"},{"location":"concepts/chain/#inspecting-and-manipulating","title":"Inspecting and manipulating","text":"Let us instantiate the BasicModel
we just defined and inspect its representation in a Python REPL:
>>> m = BasicModel()\n>>> m\n(CHAIN) BasicModel()\n \u251c\u2500\u2500 Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n \u251c\u2500\u2500 ReLU() #1\n \u251c\u2500\u2500 MaxPool2d(kernel_size=2, stride=2)\n \u251c\u2500\u2500 Flatten(start_dim=1)\n \u251c\u2500\u2500 Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1\n \u251c\u2500\u2500 ReLU() #2\n \u251c\u2500\u2500 Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2\n \u2514\u2500\u2500 Softmax()\n
The children of a Chain
are stored in a dictionary and can be accessed by name or index. When layers of the same type appear in the Chain, distinct suffixed keys are automatically generated.
>>> m[0]\nConv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n>>> m.Conv2d\nConv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n>>> m[6]\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n>>> m.Linear_2\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it:
class ConvLayer(fl.Chain):\n pass\n\nclass HiddenLayer(fl.Chain):\n pass\n\nclass OutputLayer(fl.Chain):\n pass\n\nm.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0)))\nm.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1)))\nm.append(OutputLayer(m.pop(2), m.pop(2)))\n
Did it work? Let's see:
>>> m\n(CHAIN) BasicModel()\n \u251c\u2500\u2500 (CHAIN) ConvLayer()\n \u2502 \u251c\u2500\u2500 Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\n \u2502 \u251c\u2500\u2500 ReLU()\n \u2502 \u2514\u2500\u2500 MaxPool2d(kernel_size=2, stride=2)\n \u251c\u2500\u2500 (CHAIN) HiddenLayer()\n \u2502 \u251c\u2500\u2500 Flatten(start_dim=1)\n \u2502 \u251c\u2500\u2500 Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)\n \u2502 \u2514\u2500\u2500 ReLU()\n \u2514\u2500\u2500 (CHAIN) OutputLayer()\n \u251c\u2500\u2500 Linear(in_features=200, out_features=10, device=cpu, dtype=float32)\n \u2514\u2500\u2500 Softmax()\n
Note
Organizing models like this is actually a good idea, it makes them easier to understand and adapt.
"},{"location":"concepts/chain/#accessing-and-iterating","title":"Accessing and iterating","text":"There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named walk
. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the Conv2d
and the Linear
s), we can just do:
for x in m.layers(fl.WeightedModule):\n print(x)\n
It prints:
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)\nLinear(in_features=21632, out_features=200, device=cpu, dtype=float32)\nLinear(in_features=200, out_features=10, device=cpu, dtype=float32)\n
"},{"location":"concepts/context/","title":"Context","text":""},{"location":"concepts/context/#motivation-avoiding-props-drilling","title":"Motivation: avoiding \"props drilling\"","text":"Chains are a powerful tool to represent computational graphs, but they are not always convenient.
Many adapters add extra input to the model. For instance, ControlNet and T2i-Adapter require a guide (condition image), inpainting adapters require a mask, Latent Consistency Models use a condition scale embedding, other adapters may leverage time or context embeddings... Those inputs are often passed by the user in a high-level format (numbers, text...) and converted to embeddings by the model before being consumed in downstream layers.
Managing this extra input is inconvenient. Typically, you would add them to the inputs and outputs of each layer somehow. But if you add them as channels or concatenate them you get composability issues, and if you try to pass them as extra arguments you end up needing to deal with them in layers that should not be concerned with their presence.
The same kind of having to pass extra contextual information up and down a tree exists in other fields, and in particular in JavaScript frameworks that deal with a Virtual DOM such as React, where it is called \"props drilling\". To make it easier to manage, the Context API was introduced, and we went with a similar idea in Refiners.
"},{"location":"concepts/context/#a-simple-example","title":"A simple example","text":"Here is an example of how contexts work:
from refiners.fluxion.context import Contexts\n\nclass MyProvider(fl.Chain):\n def init_context(self) -> Contexts:\n return {\"my context\": {\"my key\": None}}\n\nm = MyProvider(\n fl.Chain(\n fl.Sum(\n fl.UseContext(\"my context\", \"my key\"),\n fl.Lambda(lambda: 2),\n ),\n fl.SetContext(\"my context\", \"my key\"),\n ),\n fl.Chain(\n fl.UseContext(\"my context\", \"my key\"),\n fl.Lambda(print),\n ),\n)\n\nm.set_context(\"my context\", {\"my key\": 4})\nm() # prints 6\n
As you can see, to use the context, you define it by subclassing any Chain
and defining init_context
. You can set the context with the set_context
method or the SetContext
layer, and you can access it anywhere down the provider's tree with UseContext
.
Another use of the context is simplifying complex models, in particular those with long-range nested skip connections.
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:
square = lambda: fl.Lambda(lambda x: x ** 2)\nsqrt = lambda: fl.Lambda(lambda x: x ** 0.5)\n\nm1 = fl.Chain(\n fl.Residual(\n square(),\n fl.Residual(\n square(),\n fl.Residual(\n square(),\n ),\n sqrt(),\n ),\n sqrt(),\n ),\n sqrt(),\n)\n
You can see two problems here:
Let us solve those two issues using the context:
from refiners.fluxion.context import Contexts\n\nclass MyModel(fl.Chain):\n def init_context(self) -> Contexts:\n return {\"mymodel\": {\"residuals\": []}}\n\ndef push_residual():\n return fl.SetContext(\n \"mymodel\",\n \"residuals\",\n callback=lambda l, x: l.append(x),\n )\n\nclass ApplyResidual(fl.Sum):\n def __init__(self):\n super().__init__(\n fl.Identity(),\n fl.UseContext(\"mymodel\", \"residuals\").compose(lambda x: x.pop()),\n )\n\nsquares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))\nsqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))\nm2 = MyModel(squares, sqrts)\n
As you can see, despite squares
and sqrts
being completely independent chains, they can access the same context due to being nested under the same provider.
Does it work?
>>> m1(2.0)\n2.5547711633552384\n>>> m2(2.0)\n2.5547711633552384\n
Yes!\u2728
"},{"location":"concepts/adapter/","title":"Adapter","text":"Adapters are the final and most high-level abstraction in Refiners. They are the concept of adaptation turned into code.
An Adapter is generally a Chain that replaces a Module (the target) in another Chain (the parent). Typically the target will become a child of the adapter.
In code terms, Adapter
is a generic mixin. Adapters subclass type(parent)
and Adapter[type(target)]
. For instance, if you adapt a Conv2d
in a Sum
, the definition of the Adapter could look like:
class MyAdapter(fl.Sum, fl.Adapter[fl.Conv2d]):\n ...\n
"},{"location":"concepts/adapter/#a-simple-example-adapting-a-linear","title":"A simple example: adapting a Linear","text":"Let us take a simple example to see how this works. Consider this model:
In code, it could look like this:
my_model = MyModel(fl.Chain(fl.Linear(), fl.Chain(...)))\n
Suppose we want to adapt the Linear to sum its output with the result of another chain. We can define and initialize an adapter like this:
class MyAdapter(fl.Sum, fl.Adapter[fl.Linear]):\n def __init__(self, target: fl.Linear) -> None:\n with self.setup_adapter(target):\n super().__init__(fl.Chain(...), target)\n\n# Find the target and its parent in the chain.\n# For simplicity let us assume it is the only Linear.\nfor target, parent in my_model.walk(fl.Linear):\n break\n\nadapter = MyAdapter(target)\n
The result is now this:
Note that the original chain is unmodified. You can still run inference on it as if the adapter did not exist. To use the adapter, you must inject it into the chain:
adapter.inject(parent)\n
The result will be:
Now if you run inference it will go through the Adapter. You can go back to the previous situation by calling adapter.eject()
.
We are not limited to adapting base modules, we can also adapt Chains.
Starting from the same model as earlier, let us assume we want to:
This Adapter that will perform a structural_copy
of part of its target, which means it will duplicate all Chain nodes but keep pointers to the same WeightedModule
s, and hence not use extra GPU memory.
class MyAdapter(fl.Chain, fl.Adapter[fl.Chain]):\n def __init__(self, target: fl.Linear) -> None:\n with self.setup_adapter(target):\n new_b = fl.Chain(target, target.Chain.Chain_2.structural_copy())\n super().__init__(new_b, target.Linear)\n\nadapter = MyAdapter(my_model.Chain_1) # Chain A in the diagram\n
We end up with this:
We can now inject it into the original graph. It is not even needed to pass the parent this time, since Chains know their parents.
adapter.inject()\n
We obtain this:
Note that the Linear is in the Chain twice now, but that does not matter as long as you really want it to be the same Linear layer with the same weights.
As before, we can call eject the adapter to go back to the original model.
"},{"location":"concepts/adapter/#a-real-world-example-loraadapter","title":"A real-world example: LoraAdapter","text":"A popular example of adaptation is LoRA. You can check out how we implement it in Refiners.
"},{"location":"concepts/adapter/#higher-level-adapters","title":"Higher-level adapters","text":"If you use Refiners, you will find Adapters that go beyond the simple definition given at the top of this page. Some adapters inject multiple smaller adapters in models, others implement helper methods to be used by their caller...
From a bird's eye view, you can just consider Adapters as things you inject into models to adapt them, and that can be ejected to return the model to its original state. You will get a better feel for what is an adapter and how to leverage them by actually using the framework.
"},{"location":"getting-started/advanced/","title":"Advanced usage","text":""},{"location":"getting-started/advanced/#using-other-package-managers-pip-poetry","title":"Using other package managers (pip, Poetry...)","text":"We use Rye to maintain and release Refiners but it conforms to the standard Python packaging guidelines and can be used with other package managers. Please refer to their respective documentation to figure out how to install a package from Git if you intend to use the development branch, as well as how to specify features.
"},{"location":"getting-started/advanced/#using-stable-releases-from-pypi","title":"Using stable releases from PyPI","text":"Although we recommend using our development branch, we do publish more stable releases to PyPI and you are welcome to use them in your project. They are also available directly on the GitHub releases page. However, beware that the format of weights can be different from the current state of the development branch.
"},{"location":"getting-started/recommended/","title":"Recommended usage","text":"Refiners is still a young project and development is active, so to use the latest and greatest version of the framework we recommend you use the main
branch from our development repository.
Moreover, we recommend using Rye which simplifies several things related to Python package management, so start by following the instructions to install it on your system.
"},{"location":"getting-started/recommended/#installing","title":"Installing","text":"To try Refiners, clone the GitHub repository and install it with all optional features:
git clone git@github.com:finegrain-ai/refiners.git\ncd refiners\nrye sync --all-features\n
"},{"location":"getting-started/recommended/#converting-weights","title":"Converting weights","text":"The format of state dicts used by Refiners is custom, so to use pretrained models you will need to convert weights. We provide conversion tools and pre-converted weights on our HuggingFace organization for popular models.
For instance, to use the autoencoder from Stable Diffusion 1.5:
"},{"location":"getting-started/recommended/#use-pre-converted-weights","title":"Use pre-converted weights","text":"from huggingface_hub import hf_hub_download\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder\n\n# download the pre-converted weights from the hub\nsafetensors_path = hf_hub_download(\n repo_id=\"refiners/sd15.autoencoder\",\n filename=\"model.safetensors\",\n revision=\"9ce6af42e21fce64d74b1cab57a65aea82fd40ea\", # optional\n)\n\n# initialize the model\nmodel = SD1Autoencoder()\n\n# load the pre-converted weights\nmodel.load_from_safetensors(safetensors_path)\n
"},{"location":"getting-started/recommended/#convert-the-weights-yourself","title":"Convert the weights yourself","text":"If you want to convert the weights yourself, you can use the conversion tools we provide.
from refiners.conversion import autoencoder_sd15\n\n# This function will:\n# - download the original weights from the internet, and save them to disk at a known location\n# (e.g. tests/weights/stable-diffusion-v1-5/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors)\n# - convert them to the refiners format, and save them to disk at a known location\n# (e.g. tests/weights/refiners/sd15.autoencoder/model.safetensors)\nautoencoder_sd15.runwayml.convert()\n\n# get the path to the converted weights\nsafetensors_path = autoencoder_sd15.runwayml.converted.local_path\n\n# initialize the model\nmodel = SD1Autoencoder()\n\n# load the converted weights\nmodel.load_from_safetensors(safetensors_path)\n
Note
If you need to convert more model weights or all of them, check out the refiners.conversion
module.
Warning
Converting all the weights requires a lot of disk space and CPU time, so be prepared. Currently downloading all the original weights takes around ~100GB of disk space, and converting them all takes around ~70GB of disk space.
Warning
Some conversion scripts may also require quite a bit of RAM, since they load the entire weights in memory, ~16GB of RAM should be enough for most models, but some models may require more.
"},{"location":"getting-started/recommended/#testing-the-conversion","title":"Testing the conversion","text":"To quickly check that the weights you got from the hub or converted yourself are correct, you can run the following snippet:
from PIL import Image\nfrom refiners.fluxion.utils import no_grad\n\nimage = Image.open(\"input.png\")\n\nwith no_grad():\n latents = model.image_to_latents(image)\n decoded = model.latents_to_image(latents)\n\ndecoded.save(\"output.png\")\n
Inspect output.png
, if the converted weights are correct, it should be similar to input.png
(but have a few differences).
So far you used Refiners as a standalone package, but if you want to create your own project using it as a dependency here is how you can proceed:
rye init --py \"3.11\" myproject\ncd myproject\nrye add refiners@git+https://github.com/finegrain-ai/refiners\nrye sync\n
If you intend to use Refiners for training, you can install the training
feature:
rye add refiners[training]@git+https://github.com/finegrain-ai/refiners\n
Similarly, if you need to use the conversion tools we provide, you install the conversion
feature:
rye add refiners[conversion]@git+https://github.com/finegrain-ai/refiners\n
Note
You can install multiple features at once by separating them with a comma:
rye add refiners[training,conversion]@git+https://github.com/finegrain-ai/refiners\n
"},{"location":"getting-started/recommended/#whats-next","title":"What's next?","text":"We suggest you check out the guides section to dive into the usage of Refiners, of the Key Concepts section for a better understanding of how the framework works.
"},{"location":"guides/adapting_sdxl/","title":"Adapting Stable Diffusion XL","text":"Stable Diffusion XL (SDXL) is a very popular text-to-image open source foundation model. This guide will show you how to boost its capabilities with Refiners, using iconic adapters the framework supports out-of-the-box (i.e. without the need for tedious prompt engineering). We'll follow a step by step approach, progressively increasing the number of adapters involved to showcase how simple adapter composition is using Refiners. Our use case will be the generation of an image with \"a futuristic castle surrounded by a forest, mountains in the background\".
"},{"location":"guides/adapting_sdxl/#baseline","title":"Baseline","text":"Make sure that Refiners is installed in your local environment (see Getting started), and that you have access to a decent GPU (~24 GB VRAM should be enough).
Before diving into the adapters themselves, let's establish a baseline by simply prompting SDXL with Refiners.
Reminder
A StableDiffusion model is composed of three modules:
Start by instantiating a StableDiffusion_XL
model and load the weights.
import torch\nfrom huggingface_hub import hf_hub_download\n\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n
Then, define the inference parameters by setting the appropriate prompt, seed and number of inference steps:
# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n
Finally, define and run the inference process:
from refiners.fluxion.utils import manual_seed, no_grad\nfrom tqdm import tqdm\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024, here we use a higher resolution\n x = sdxl.init_latents((2048, 2048))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n predicted_image = sdxl.lda.latents_to_image(x)\n\npredicted_image.save(\"vanilla_sdxl.png\")\n
Expand to see the entire end-to-end code import torch\nfrom huggingface_hub import hf_hub_download\nfrom tqdm import tqdm\n\nfrom refiners.fluxion.utils import manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n\n# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024, here we use a higher resolution\n x = sdxl.init_latents((2048, 2048))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n predicted_image = sdxl.lda.latents_to_image(x)\n\npredicted_image.save(\"vanilla_sdxl.png\")\n
The resulting image should look like this:
Generated image of a castle using default SDXL weights.It is not really what we prompted the model for, unfortunately. To get a more futuristic-looking castle, you can either go for tedious prompt engineering, or use a pretrainered LoRA tailored to our use case, like the Sci-fi Environments LoRA available on Civitai. We'll now show you how the LoRA option works with Refiners.
"},{"location":"guides/adapting_sdxl/#single-lora","title":"Single LoRA","text":"Let's use the Sci-fi Environments LoRA. LoRas don't need to be converted, all you have to do is download the safetensors file from the internet.
You can easily download the LoRA by doing:
curl -L -o scifi.safetensors 'https://civitai.com/api/download/models/140624?type=Model&format=SafeTensor'\n
Inject the LoRA into SDXL using SDLoraManager
right after instantiating StableDiffusion_XL
:
from refiners.fluxion.utils import load_from_safetensors\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\n\n# Load LoRA weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"scifi.safetensors\")\nmanager.add_loras(\"scifi\", tensors=scifi_lora_weights)\n
Expand to see the entire end-to-end code import torch\nfrom huggingface_hub import hf_hub_download\nfrom tqdm import tqdm\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n\n# add Sci-Fi LoRA\nmanager = SDLoraManager(sdxl)\nscifi_lora_weights = load_from_safetensors(\"scifi.safetensors\")\nmanager.add_loras(\"scifi\", tensors=scifi_lora_weights)\n\n# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024, here we use a higher resolution\n x = sdxl.init_latents((2048, 2048))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n\n # decode the latents to an image\n predicted_image = sdxl.lda.decode_latents(x)\n\npredicted_image.save(\"scifi_sdxl.png\")\n
You should get something like this - pretty neat, isn't it?
Generated image of a castle in sci-fi style."},{"location":"guides/adapting_sdxl/#multiple-loras","title":"Multiple LoRAs","text":"Continuing with our futuristic castle example, we might want to turn it, for instance, into a pixel art.
Again, we could either try some tedious prompt engineering, or instead use another LoRA found on the web, such as Pixel Art LoRA, found on Civitai.
You can easily download the LoRA by doing:
curl -L -o pixelart.safetensors 'https://civitai.com/api/download/models/135931?type=Model&format=SafeTensor'\n
Injecting a second LoRA into the current SDXL model is dead simple, as SDLoraManager
allows loading multiple LoRAs:
# load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"scifi.safetensors\"))\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixelart.safetensors\"))\n
Adapters such as LoRAs also have a scale (roughly) quantifying the effect of this Adapter. Refiners allows setting different scales for each Adapter, allowing the user to balance the effect of each Adapter:
# load LoRAs weights from disk and inject them into target\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"scifi.safetensors\"), scale=1.0)\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixelart.safetensors\"), scale=1.4)\n
Expand to see the entire end-to-end code import torch\nfrom huggingface_hub import hf_hub_download\nfrom tqdm import tqdm\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n\n# add Sci-Fi and Pixel-Art LoRAs\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"scifi.safetensors\"), scale=1.0)\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixelart.safetensors\"), scale=1.4)\n\n# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024, here we use a higher resolution\n x = sdxl.init_latents((2048, 2048))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n predicted_image = sdxl.lda.latents_to_image(x)\n\npredicted_image.save(\"scifi_pixel_sdxl.png\")\n
The results are looking great:
Generated image of a castle in sci-fi, pixel art style."},{"location":"guides/adapting_sdxl/#multiple-loras-ip-adapter","title":"Multiple LoRAs + IP-Adapter","text":"Refiners really shines when it comes to composing different Adapters to fully exploit the possibilities of foundation models.
For instance, IP-Adapter (covered in a previous blog post) is a common choice for practictioners wanting to guide the diffusion process towards a specific prompt image.
In our example, we would like to guide the diffusion process to align with this image of the Neuschwanstein Castle:
Credits: Bayerische Schl\u00f6sserverwaltung, Anton BrandlYou can easily download the above image by doing:
curl -O https://refine.rs/guides/adapting_sdxl/german-castle.jpg\n
Instantiate a SDXLIPAdapter
targeting our sdxl.unet
, and inject it using a simple .inject()
call:
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter\n\n# load IP-Adapter\nip_adapter = SDXLIPAdapter(\n target=sdxl.unet,\n weights=load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.ip_adapter.plus\",\n filename=\"model.safetensors\",\n ),\n ),\n scale=1.0,\n fine_grained=True, # Use fine-grained IP-Adapter (i.e IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sd21.unclip.image_encoder\",\n filename=\"model.safetensors\",\n )\n)\nip_adapter.inject()\n
Then, at runtime, we simply compute the embedding of the image prompt through the ip_adapter
object, and set its embedding calling .set_clip_image_embedding()
:
from PIL import Image\nimage_prompt = Image.open(\"german-castle.jpg\")\n\nwith torch.no_grad():\n clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n# And start the diffusion process\n
Note
Be wary that composing Adapters (especially ones of different natures, such as LoRAs and IP-Adapter) can be tricky, as their respective effects can be adversarial. This is visible in our example below. In the code below, we tuned the LoRAs scales respectively to 1.5
and 1.55
. We invite you to try and test different seeds and scales to find the perfect combination! Furthermore, the order in which you inject adapters can also have an impact on the final result.
import torch\nfrom huggingface_hub import hf_hub_download\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n\n# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n\n# add Sci-Fi and Pixel-Art LoRAs\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"scifi.safetensors\"), scale=1.5)\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixelart.safetensors\"), scale=1.55)\n\n# Instantiate the IP-Adapter\nip_adapter = SDXLIPAdapter(\n target=sdxl.unet,\n weights=load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.ip_adapter.plus\",\n filename=\"model.safetensors\",\n ),\n ),\n scale=1.0,\n fine_grained=True, # Use fine-grained IP-Adapter (i.e IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sd21.unclip.image_encoder\",\n filename=\"model.safetensors\",\n )\n)\nip_adapter.inject()\n\n# load image prompt\nimage_prompt = Image.open(\"german-castle.jpg\")\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # compute image prompt embeddings\n clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024\n x = sdxl.init_latents((1024, 1024))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n predicted_image = sdxl.lda.latents_to_image(x)\n\npredicted_image.save(\"scifi_pixel_IP_sdxl.png\")\n
The result looks convincing: we do get a pixel-art, futuristic-looking Neuschwanstein castle!
Generated image in sci-fi, pixel art style, using IP-Adapter."},{"location":"guides/adapting_sdxl/#multiple-loras-ip-adapter-t2i-adapter","title":"Multiple LoRAs + IP-Adapter + T2I-Adapter","text":"T2I-Adapters are a powerful class of Adapters aiming at controlling the Text-to-Image (T2I) diffusion process with external control signals, such as canny edges or pose estimations inputs. In this section, we will compose our previous example with the Depth-Zoe Adapter, providing a depth condition to the diffusion process using the following depth map as input signal:
Input depth map of the initial castle image.You can easily download the above image by doing:
curl -O https://refine.rs/guides/adapting_sdxl/zoe-depth-map-german-castle.png\n
Then, just inject it as usual:
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter\n\n# Load T2I-Adapter\nt2i_adapter = SDXLT2IAdapter(\n target=sdxl.unet,\n name=\"zoe-depth\",\n weights=load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.t2i_adapter.depth.zoe\",\n filename=\"model.safetensors\",\n ),\n ),\n scale=0.72,\n).inject()\n
Finally, at runtime, compute the embedding of the input condition through the t2i_adapter
object, and set its embedding calling .set_condition_features()
:
from refiners.fluxion.utils import image_to_tensor, interpolate\n\nimage_depth_condition = Image.open(\"zoe-depth-map-german-castle.png\")\n\nwith torch.no_grad():\n condition = image_to_tensor(image_depth_condition.convert(\"RGB\"), device=sdxl.device, dtype=sdxl.dtype)\n # Spatial dimensions should be divisible by default downscale factor (=16 for T2IAdapter ConditionEncoder)\n condition = interpolate(condition, torch.Size((1024, 1024)))\n t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))\n
Expand to see the entire end-to-end code import torch\nfrom huggingface_hub import hf_hub_download\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom refiners.fluxion.utils import image_to_tensor, interpolate, load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.lora import SDLoraManager\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter\n\n# instantiate SDXL model\nsdxl = StableDiffusion_XL(\n device=\"cuda\", # use GPU\n dtype=torch.float16 # use half-precision for memory efficiency\n)\n\n# Load the weights\nsdxl.clip_text_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.text_encoder\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.unet.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.unet\",\n filename=\"model.safetensors\",\n )\n)\nsdxl.lda.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.autoencoder_fp16fix\",\n filename=\"model.safetensors\",\n )\n)\n\n# hyperparameters\nseed = 42\nnum_inference_steps = 50\nprompt = \"a futuristic castle surrounded by a forest, mountains in the background\"\nsdxl.set_inference_steps(num_inference_steps, first_step=0)\n\n# enable self-attention guidance to enhance the quality of the generated images\nsag_scale = 0.75\nsdxl.set_self_attention_guidance(enable=True, scale=sag_scale)\n\n# add Sci-Fi and Pixel-Art LoRAs\nmanager = SDLoraManager(sdxl)\nmanager.add_loras(\"scifi-lora\", load_from_safetensors(\"scifi.safetensors\"), scale=1.5)\nmanager.add_loras(\"pixel-art-lora\", load_from_safetensors(\"pixelart.safetensors\"), scale=1.55)\n\n# Instantiate the IP-Adapter\nip_adapter = SDXLIPAdapter(\n target=sdxl.unet,\n weights=load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.ip_adapter.plus\",\n filename=\"model.safetensors\",\n ),\n ),\n scale=1.0,\n fine_grained=True, # Use fine-grained IP-Adapter (i.e IP-Adapter Plus)\n)\nip_adapter.clip_image_encoder.load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sd21.unclip.image_encoder\",\n filename=\"model.safetensors\",\n )\n)\nip_adapter.inject()\n\n# Load T2I-Adapter\nt2i_adapter = SDXLT2IAdapter(\n target=sdxl.unet,\n name=\"zoe-depth\",\n weights=load_from_safetensors(\n hf_hub_download(\n repo_id=\"refiners/sdxl.t2i_adapter.depth.zoe\",\n filename=\"model.safetensors\",\n ),\n ),\n scale=0.72,\n).inject()\n\n# load image prompt and image depth condition\nimage_prompt = Image.open(\"german-castle.jpg\")\nimage_depth_condition = Image.open(\"zoe-depth-map-german-castle.png\")\n\nwith no_grad(): # disable gradient calculation for memory-efficient inference\n # encode the text prompts to embeddings, and get the time_ids\n clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(\n text=prompt + \", best quality, high quality\",\n negative_text=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n )\n time_ids = sdxl.default_time_ids\n\n # compute and set image prompt embeddings\n clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))\n ip_adapter.set_clip_image_embedding(clip_image_embedding)\n\n # compute and set the T2I features\n condition = image_to_tensor(image_depth_condition.convert(\"RGB\"), device=sdxl.device, dtype=sdxl.dtype)\n condition = interpolate(condition, torch.Size((1024, 1024)))\n t2i_features = t2i_adapter.compute_condition_features(condition)\n t2i_adapter.set_condition_features(features=t2i_features)\n\n # seed the random number generator, for reproducibility\n manual_seed(seed)\n\n # SDXL typically generates 1024x1024\n x = sdxl.init_latents((1024, 1024))\n\n # diffusion denoising process\n for step in tqdm(sdxl.steps):\n x = sdxl(\n x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n pooled_text_embedding=pooled_text_embedding,\n time_ids=time_ids,\n )\n predicted_image = sdxl.lda.latents_to_image(x)\n\npredicted_image.save(\"scifi_pixel_IP_T2I_sdxl.png\")\n
The results look convincing: the depth and proportions of the initial castle are more faithful, while preserving our futuristic, pixel-art style!
Generated image in sci-fi, pixel art style, using IP and T2I Adapters."},{"location":"guides/adapting_sdxl/#wrap-up","title":"Wrap up","text":"As you can see in this guide, composing Adapters on top of foundation models is pretty seamless in Refiners, allowing practitioners to quickly test out different combinations of Adapters for their needs. We encourage you to try out different ones, and even train some yourselves!
"},{"location":"guides/training_101/","title":"Training 101","text":"This guide will walk you through training a model using Refiners. We built the training_utils
module to provide a simple, flexible, statically type-safe interface.
We will use a simple model and a toy dataset for demonstration purposes. The model will be a simple autoencoder, and the dataset will be a synthetic dataset of rectangles of different sizes.
"},{"location":"guides/training_101/#pre-requisites","title":"Pre-requisites","text":"We recommend installing Refiners targeting a specific commit hash to avoid unexpected changes in the API. You also get the benefit of having a perfectly reproducible environment.
with rye (recommended):
rye add refiners[training] --git=https://github.com/finegrain-ai/refiners.git --branch=<insert-latest-commit-hash>\n
with pip:
pip install \"git+https://github.com/finegrain-ai/refiners.git@<insert-latest-commit-hash>#egg=refiners[training]\"\n
Let's start by building our autoencoder using Refiners.
Expand to see the autoencoder model.from refiners.fluxion import layers as fl\n\n\n class ConvBlock(fl.Chain):\n def __init__(self, in_channels: int, out_channels: int) -> None:\n super().__init__(\n fl.Conv2d(\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=3,\n padding=1,\n groups=min(in_channels, out_channels)\n ),\n fl.LayerNorm2d(out_channels),\n fl.SiLU(),\n fl.Conv2d(\n in_channels=out_channels,\n out_channels=out_channels,\n kernel_size=1,\n padding=0,\n ),\n fl.LayerNorm2d(out_channels),\n fl.SiLU(),\n )\n\n\nclass ResidualBlock(fl.Sum):\n def __init__(self, in_channels: int, out_channels: int) -> None:\n super().__init__(\n ConvBlock(in_channels=in_channels, out_channels=out_channels),\n fl.Conv2d(\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=3,\n padding=1,\n ),\n )\n\n\nclass Encoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n ResidualBlock(in_channels=1, out_channels=8),\n fl.Downsample(channels=8, scale_factor=2, register_shape=False),\n ResidualBlock(in_channels=8, out_channels=16),\n fl.Downsample(channels=16, scale_factor=2, register_shape=False),\n ResidualBlock(in_channels=16, out_channels=32),\n fl.Downsample(channels=32, scale_factor=2, register_shape=False),\n fl.Reshape(2048),\n fl.Linear(in_features=2048, out_features=256),\n fl.SiLU(),\n fl.Linear(in_features=256, out_features=256),\n )\n\n\nclass Decoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n fl.Linear(in_features=256, out_features=256),\n fl.SiLU(),\n fl.Linear(in_features=256, out_features=2048),\n fl.Reshape(32, 8, 8),\n ResidualBlock(in_channels=32, out_channels=32),\n ResidualBlock(in_channels=32, out_channels=32),\n fl.Upsample(channels=32, upsample_factor=2),\n ResidualBlock(in_channels=32, out_channels=16),\n ResidualBlock(in_channels=16, out_channels=16),\n fl.Upsample(channels=16, upsample_factor=2),\n ResidualBlock(in_channels=16, out_channels=8),\n ResidualBlock(in_channels=8, out_channels=8),\n fl.Upsample(channels=8, upsample_factor=2),\n ResidualBlock(in_channels=8, out_channels=8),\n ResidualBlock(in_channels=8, out_channels=1),\n fl.Sigmoid(),\n )\n\n\nclass Autoencoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n Encoder(),\n Decoder(),\n )\n\n @property\n def encoder(self) -> Encoder:\n return self.ensure_find(Encoder)\n\n @property\n def decoder(self) -> Decoder:\n return self.ensure_find(Decoder)\n
We now have a fully functional autoencoder that takes an image with one channel of size 64x64 and compresses it to a vector of size 256 (x16 compression). The decoder then takes this vector and reconstructs the original image.
import torch\n\nautoencoder = Autoencoder()\n\nx = torch.randn(2, 1, 64, 64) # batch of 2 images\n\nz = autoencoder.encoder(x) # [2, 256]\n\nx_reconstructed = autoencoder.decoder(z) # [2, 1, 64, 64]\n
"},{"location":"guides/training_101/#dataset","title":"Dataset","text":"We will use a synthetic dataset of rectangles of different sizes. The dataset will be generated on the fly using this simple function:
import random\nfrom typing import Generator\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import image_to_tensor\n\ndef generate_mask(size: int, seed: int | None = None) -> Generator[torch.Tensor, None, None]:\n \"\"\"Generate a tensor of a binary mask of size `size` using random rectangles.\"\"\"\n if seed is None:\n seed = random.randint(0, 2**32 - 1)\n random.seed(seed)\n\n while True:\n rectangle = Image.new(\n \"L\", (random.randint(1, size), random.randint(1, size)), color=255\n )\n mask = Image.new(\"L\", (size, size))\n mask.paste(\n rectangle,\n (\n random.randint(0, size - rectangle.width),\n random.randint(0, size - rectangle.height),\n ),\n )\n tensor = image_to_tensor(mask)\n\n if random.random() > 0.5:\n tensor = 1 - tensor\n\n yield tensor\n
To generate a mask, do:
from refiners.fluxion.utils import tensor_to_image\n\nmask = next(generate_mask(64, seed=42))\ntensor_to_image(mask).save(\"mask.png\")\n
Here are a two examples of generated masks:
"},{"location":"guides/training_101/#trainer","title":"Trainer","text":"We will now create a Trainer class to handle the training loop. This class will manage the model, the optimizer, the loss function, and the dataset. It will also orchestrate the training loop and the evaluation loop.
But first, we need to define the batch type that will be used to represent a batch for the forward and backward pass and the configuration associated with the trainer.
"},{"location":"guides/training_101/#batch","title":"Batch","text":"Our batches are composed of a single tensor representing the images. We will define a simple Batch
type to implement this.
from dataclasses import dataclass\n\n@dataclass\nclass Batch:\n image: torch.Tensor\n
"},{"location":"guides/training_101/#config","title":"Config","text":"We will now define the configuration for the autoencoder. It holds the configuration for the training loop, the optimizer, and the learning rate scheduler. It should inherit refiners.training_utils.BaseConfig
and has the following mandatory attributes:
TrainingConfig
: The configuration for the training loop, including the duration of the training, the batch size, device, dtype, etc.OptimizerConfig
: The configuration for the optimizer, including the learning rate, weight decay, etc.LRSchedulerConfig
: The configuration for the learning rate scheduler, including the scheduler type, parameters, etc.Example:
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch\n\nclass AutoencoderConfig(BaseConfig):\n ...\n\ntraining = TrainingConfig(\n duration=Epoch(1000),\n device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n dtype=\"float32\"\n)\n\noptimizer = OptimizerConfig(\n optimizer=Optimizers.AdamW,\n learning_rate=1e-4,\n)\n\nlr_scheduler = LRSchedulerConfig(\n type=LRSchedulerType.ConstantLR\n)\n\nconfig = AutoencoderConfig(\n training=training,\n optimizer=optimizer,\n lr_scheduler=lr_scheduler,\n)\n
"},{"location":"guides/training_101/#subclass","title":"Subclass","text":"We can now define the Trainer subclass. It should inherit from refiners.training_utils.Trainer
and implement the following methods:
create_data_iterable
: The Trainer
will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the compute_loss
method. Every time the iterable is exhausted, an epoch ends.compute_loss
: This method should take a Batch and return the loss tensor.Here is a simple implementation of the create_data_iterable
method. For this toy example, we will generate a simple list of Batch
objects containing random masks. Later you can replace this with torch.utils.data.DataLoader
or any other data loader with more complex features that support shuffling, parallel loading, etc.
from functools import cached_property\nfrom refiners.training_utils import Trainer\n\n\nclass AutoencoderConfig(BaseConfig):\n num_images: int = 2048\n batch_size: int = 32\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n def create_data_iterable(self) -> list[Batch]:\n dataset: list[Batch] = []\n generator = generate_mask(size=64)\n\n for _ in range(self.config.num_images // self.config.batch_size):\n masks = [next(generator) for _ in range(self.config.batch_size)]\n dataset.append(Batch(image=torch.cat(masks, dim=0)))\n\n return dataset\n\n def compute_loss(self, batch: Batch) -> torch.Tensor:\n raise NotImplementedError(\"We'll implement this later\")\n\n\ntrainer = AutoencoderTrainer(config)\n
"},{"location":"guides/training_101/#model-registration","title":"Model registration","text":"For the Trainer to be able to handle the model, we need to register it.
We need two things to do so:
refiners.training_utils.ModelConfig
attribute to the Config named autoencoder
.@register_model
decorator. This method should take the ModelConfig
as an argument. The Trainer's __init__
will register the models and add any parameters to the optimizer that have requires_grad
enabled.After registering the model, the self.autoencoder
attribute will be available in the Trainer.
from refiners.training_utils import ModelConfig, register_model\n\n\nclass AutoencoderModelConfig(ModelConfig):\n pass\n\n\nclass AutoencoderConfig(BaseConfig):\n num_images: int = 2048\n batch_size: int = 32\n autoencoder: AutoencoderModelConfig\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n # ... other methods\n\n @register_model()\n def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:\n return Autoencoder()\n\n def compute_loss(self, batch: Batch) -> torch.Tensor:\n batch.image = batch.image.to(self.device, self.dtype)\n x_reconstructed = self.autoencoder.decoder(\n self.autoencoder.encoder(batch.image)\n )\n return F.binary_cross_entropy(x_reconstructed, batch.image)\n
We now have a fully functional Trainer that can train our autoencoder. We can now call the train
method to start the training loop.
trainer.train()\n
"},{"location":"guides/training_101/#logging","title":"Logging","text":"Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from refiners.training_utils.Callback
and implement any of the following methods:
on_init_begin
on_init_end
on_train_begin
on_train_end
on_epoch_begin
on_epoch_end
on_step_begin
on_step_end
on_backward_begin
on_backward_end
on_optimizer_step_begin
on_optimizer_step_end
on_compute_loss_begin
on_compute_loss_end
on_evaluate_begin
on_evaluate_end
on_lr_scheduler_step_begin
on_lr_scheduler_step_end
We will implement the on_epoch_end
method to log the loss and the reconstructed images and the on_compute_loss_end
method to store the loss in a list.
from refiners.training_utils import Callback\nfrom loguru import logger\nfrom typing import Any\n\n\nclass LoggingCallback(Callback[Any]):\n losses: list[float] = []\n\n def on_compute_loss_end(self, loss: torch.Tensor) -> None:\n self.losses.append(loss.item())\n\n def on_epoch_end(self, epoch: int) -> None:\n mean_loss = sum(self.losses) / len(self.losses)\n logger.info(f\"Mean loss: {mean_loss}, epoch: {epoch}\")\n self.losses = []\n
Exactly like models, we need to register the callback to the Trainer. We can do so by adding a CallbackConfig
attribute to the config named logging
and adding a method to the Trainer class that returns the callback decorated with @register_callback
decorator.
from refiners.training_utils import CallbackConfig, register_callback\n\nclass AutoencoderConfig(BaseConfig):\n # ... other properties\n logging: CallbackConfig = CallbackConfig()\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n # ... other methods\n\n @register_callback()\n def logging(self, config: CallbackConfig) -> LoggingCallback:\n return LoggingCallback()\n
"},{"location":"guides/training_101/#evaluation","title":"Evaluation","text":"Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a compute_evaluation
method, then we register a callback to call this method at regular intervals.
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n # ... other methods\n\n def compute_evaluation(self) -> None:\n generator = generate_mask(size=64, seed=0)\n\n grid: list[tuple[Image.Image, Image.Image]] = []\n for _ in range(4):\n mask = next(generator).to(self.device, self.dtype)\n x_reconstructed = self.autoencoder.decoder(\n self.autoencoder.encoder(mask)\n )\n loss = F.mse_loss(x_reconstructed, mask)\n logger.info(f\"Validation loss: {loss.detach().cpu().item()}\")\n grid.append(\n (tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))\n )\n\n import matplotlib.pyplot as plt\n\n _, axes = plt.subplots(4, 2, figsize=(8, 16))\n\n for i, (mask, reconstructed) in enumerate(grid):\n axes[i, 0].imshow(mask, cmap='gray')\n axes[i, 0].axis('off')\n axes[i, 0].set_title('Mask')\n\n axes[i, 1].imshow(reconstructed, cmap='gray')\n axes[i, 1].axis('off')\n axes[i, 1].set_title('Reconstructed')\n\n plt.tight_layout()\n plt.savefig(f\"result_{trainer.clock.epoch}.png\")\n plt.close()\n
We starting by implementing an EvaluationConfig
that controls the evaluation interval and the seed for the random generator.
from refiners.training_utils.config import TimeValueField\n\nclass EvaluationConfig(CallbackConfig):\n interval: TimeValueField\n seed: int\n
The TimeValueField
is a custom field that allow Pydantic to parse a string representing a time value (e.g., \"50:epochs\"
) into a TimeValue
object. This is useful to specify the evaluation interval in the configuration file.
from refiners.training_utils import scoped_seed, Callback\n\nclass EvaluationCallback(Callback[Any]):\n def __init__(self, config: EvaluationConfig) -> None:\n self.config = config\n\n def on_epoch_end(self, trainer: Trainer) -> None:\n # The `is_due` method checks if the current epoch is a multiple of the interval.\n if not trainer.clock.is_due(self.config.interval):\n return\n\n # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the \n # evaluation.\n with scoped_seed(self.config.seed):\n trainer.compute_evaluation()\n
We can now register the callback to the Trainer.
class AutoencoderConfig(BaseConfig):\n # ... other properties\n evaluation: EvaluationConfig\n
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n # ... other methods\n\n @register_callback()\n def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:\n return EvaluationCallback(config) \n
We can now train the model and see the results in the result_{epoch}.png
files.
You can train this toy model using the code below:
Expand to see the full code.import random\nfrom dataclasses import dataclass\nfrom typing import Any, Generator\n\nimport torch\nfrom loguru import logger\nfrom PIL import Image\nfrom torch.nn import functional as F\n\nfrom refiners.fluxion import layers as fl\nfrom refiners.fluxion.utils import image_to_tensor, tensor_to_image\nfrom refiners.training_utils import (\n BaseConfig,\n Callback,\n CallbackConfig,\n ClockConfig,\n Epoch,\n LRSchedulerConfig,\n LRSchedulerType,\n ModelConfig,\n OptimizerConfig,\n Optimizers,\n Trainer,\n TrainingConfig,\n register_callback,\n register_model,\n)\nfrom refiners.training_utils.common import scoped_seed\nfrom refiners.training_utils.config import TimeValueField\n\n\nclass ConvBlock(fl.Chain):\n def __init__(self, in_channels: int, out_channels: int) -> None:\n super().__init__(\n fl.Conv2d(\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=3,\n padding=1,\n groups=min(in_channels, out_channels),\n ),\n fl.LayerNorm2d(out_channels),\n fl.SiLU(),\n fl.Conv2d(\n in_channels=out_channels,\n out_channels=out_channels,\n kernel_size=1,\n padding=0,\n ),\n fl.LayerNorm2d(out_channels),\n fl.SiLU(),\n )\n\n\nclass ResidualBlock(fl.Sum):\n def __init__(self, in_channels: int, out_channels: int) -> None:\n super().__init__(\n ConvBlock(in_channels=in_channels, out_channels=out_channels),\n fl.Conv2d(\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=3,\n padding=1,\n ),\n )\n\n\nclass Encoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n ResidualBlock(in_channels=1, out_channels=8),\n fl.Downsample(channels=8, scale_factor=2, register_shape=False),\n ResidualBlock(in_channels=8, out_channels=16),\n fl.Downsample(channels=16, scale_factor=2, register_shape=False),\n ResidualBlock(in_channels=16, out_channels=32),\n fl.Downsample(channels=32, scale_factor=2, register_shape=False),\n fl.Reshape(2048),\n fl.Linear(in_features=2048, out_features=256),\n fl.SiLU(),\n fl.Linear(in_features=256, out_features=256),\n )\n\n\nclass Decoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n fl.Linear(in_features=256, out_features=256),\n fl.SiLU(),\n fl.Linear(in_features=256, out_features=2048),\n fl.Reshape(32, 8, 8),\n ResidualBlock(in_channels=32, out_channels=32),\n ResidualBlock(in_channels=32, out_channels=32),\n fl.Upsample(channels=32, upsample_factor=2),\n ResidualBlock(in_channels=32, out_channels=16),\n ResidualBlock(in_channels=16, out_channels=16),\n fl.Upsample(channels=16, upsample_factor=2),\n ResidualBlock(in_channels=16, out_channels=8),\n ResidualBlock(in_channels=8, out_channels=8),\n fl.Upsample(channels=8, upsample_factor=2),\n ResidualBlock(in_channels=8, out_channels=8),\n ResidualBlock(in_channels=8, out_channels=1),\n fl.Sigmoid(),\n )\n\n\nclass Autoencoder(fl.Chain):\n def __init__(self) -> None:\n super().__init__(\n Encoder(),\n Decoder(),\n )\n\n @property\n def encoder(self) -> Encoder:\n return self.ensure_find(Encoder)\n\n @property\n def decoder(self) -> Decoder:\n return self.ensure_find(Decoder)\n\n\ndef generate_mask(size: int, seed: int | None = None) -> Generator[torch.Tensor, None, None]:\n \"\"\"Generate a tensor of a binary mask of size `size` using random rectangles.\"\"\"\n if seed is None:\n seed = random.randint(0, 2**32 - 1)\n random.seed(seed)\n\n while True:\n rectangle = Image.new(\"L\", (random.randint(1, size), random.randint(1, size)), color=255)\n mask = Image.new(\"L\", (size, size))\n mask.paste(\n rectangle,\n (\n random.randint(0, size - rectangle.width),\n random.randint(0, size - rectangle.height),\n ),\n )\n tensor = image_to_tensor(mask)\n\n if random.random() > 0.5:\n tensor = 1 - tensor\n\n yield tensor\n\n\n@dataclass\nclass Batch:\n image: torch.Tensor\n\n\nclass AutoencoderModelConfig(ModelConfig):\n pass\n\n\nclass LoggingCallback(Callback[Trainer[Any, Any]]):\n losses: list[float] = []\n\n def on_compute_loss_end(self, trainer: Trainer[Any, Any]) -> None:\n self.losses.append(trainer.loss.detach().cpu().item())\n\n def on_epoch_end(self, trainer: Trainer[Any, Any]) -> None:\n mean_loss = sum(self.losses) / len(self.losses)\n logger.info(f\"Mean loss: {mean_loss}, epoch: {trainer.clock.epoch}\")\n self.losses = []\n\n\nclass EvaluationConfig(CallbackConfig):\n interval: TimeValueField\n seed: int\n\n\nclass EvaluationCallback(Callback[\"AutoencoderTrainer\"]):\n def __init__(self, config: EvaluationConfig) -> None:\n self.config = config\n\n def on_epoch_end(self, trainer: \"AutoencoderTrainer\") -> None:\n # The `is_due` method checks if the current epoch is a multiple of the interval.\n if not trainer.clock.is_due(self.config.interval):\n return\n\n # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the\n # evaluation.\n with scoped_seed(self.config.seed):\n trainer.compute_evaluation()\n\n\nclass AutoencoderConfig(BaseConfig):\n num_images: int = 2048\n batch_size: int = 32\n autoencoder: AutoencoderModelConfig\n evaluation: EvaluationConfig\n logging: CallbackConfig = CallbackConfig()\n\n\nautoencoder_config = AutoencoderModelConfig(\n requires_grad=True, # set during registration to set the requires_grad attribute of the model.\n)\n\ntraining = TrainingConfig(\n duration=Epoch(200),\n device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n dtype=\"float32\",\n)\n\noptimizer = OptimizerConfig(\n optimizer=Optimizers.AdamW,\n learning_rate=1e-4,\n)\n\nlr_scheduler = LRSchedulerConfig(type=LRSchedulerType.CONSTANT_LR)\n\nconfig = AutoencoderConfig(\n training=training,\n optimizer=optimizer,\n lr_scheduler=lr_scheduler,\n autoencoder=autoencoder_config,\n evaluation=EvaluationConfig(interval=Epoch(50), seed=0),\n clock=ClockConfig(verbose=False), # to disable the default clock logging\n)\n\n\nclass AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):\n def create_data_iterable(self) -> list[Batch]:\n dataset: list[Batch] = []\n generator = generate_mask(size=64)\n\n for _ in range(self.config.num_images // self.config.batch_size):\n masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)]\n dataset.append(Batch(image=torch.cat(masks, dim=0)))\n\n return dataset\n\n @register_model()\n def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:\n return Autoencoder()\n\n def compute_loss(self, batch: Batch) -> torch.Tensor:\n batch.image = batch.image.to(self.device, self.dtype)\n x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))\n return F.binary_cross_entropy(x_reconstructed, batch.image)\n\n def compute_evaluation(self) -> None:\n generator = generate_mask(size=64, seed=0)\n\n grid: list[tuple[Image.Image, Image.Image]] = []\n validation_losses: list[float] = []\n for _ in range(4):\n mask = next(generator).to(self.device, self.dtype)\n x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(mask))\n loss = F.mse_loss(x_reconstructed, mask)\n validation_losses.append(loss.detach().cpu().item())\n grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed > 0.5).float())))\n\n mean_loss = sum(validation_losses) / len(validation_losses)\n logger.info(f\"Mean validation loss: {mean_loss}, epoch: {self.clock.epoch}\")\n\n import matplotlib.pyplot as plt\n\n _, axes = plt.subplots(4, 2, figsize=(8, 16)) # type: ignore\n\n for i, (mask, reconstructed) in enumerate(grid):\n axes[i, 0].imshow(mask, cmap=\"gray\")\n axes[i, 0].axis(\"off\")\n axes[i, 0].set_title(\"Mask\")\n\n axes[i, 1].imshow(reconstructed, cmap=\"gray\")\n axes[i, 1].axis(\"off\")\n axes[i, 1].set_title(\"Reconstructed\")\n\n plt.tight_layout() # type: ignore\n plt.savefig(f\"result_{trainer.clock.epoch}.png\") # type: ignore\n plt.close() # type: ignore\n\n @register_callback()\n def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:\n return EvaluationCallback(config)\n\n @register_callback()\n def logging(self, config: CallbackConfig) -> LoggingCallback:\n return LoggingCallback()\n\n\ntrainer = AutoencoderTrainer(config)\n\ntrainer.train()\n
"},{"location":"home/why/","title":"Why Refiners?","text":""},{"location":"home/why/#pytorch-an-imperative-framework","title":"PyTorch: an imperative framework","text":"PyTorch is a great framework to implement deep learning models, widely adopted in academia and industry around the globe. A core design principle of PyTorch is that users write imperative Python code that manipulates Tensors1. This code can be organized in Modules, which are just Python classes whose constructors typically initialize parameters and load weights, and which implement a forward
method that computes the forward pass. Dealing with reconstructing an inference graph, backpropagation and so on are left to the framework.
This approach works very well in general, as demonstrated by the popularity of PyTorch. However, the growing importance of the Adaptation pattern is challenging it.
"},{"location":"home/why/#adaptation-patching-foundation-models","title":"Adaptation: patching foundation models","text":"Adaptation is the idea of patching existing powerful models to implement new capabilities. Those models are called foundation models; they are typically trained from scratch on amounts of data inaccessible to most individuals, small companies or research labs, and exhibit emergent properties. Examples of such models are LLMs (GPT, LLaMa, Mistral), image generation models (Stable Diffusion, Muse), vision models (BLIP-2, LLaVA 1.5, Fuyu-8B) but also models trained on more specific tasks such as embedding extraction (CLIP, DINOv2) or image segmentation (SAM).
Adaptation of foundation models can take many forms. One of the simplest but most powerful derives from fine-tuning: re-training a subset of the weights of the model on a specific task, then distributing only those weights. Add to this a trick to significantly reduce the size of the fine-tuned weights and you get LoRA2, which is probably the most well-known adaptation method. However, adaptation can go beyond that and change the shape of the model or its inputs.
"},{"location":"home/why/#imperative-code-is-hard-to-patch-cleanly","title":"Imperative code is hard to patch cleanly","text":"There are several approaches to patch the code of a foundation model implemented in typical PyTorch imperative style to support adaptation, including:
As believers in adaptation, none of those approaches was appealing to us, so we designed Refiners as a better option. Refiners is a micro-framework built on top of PyTorch which does away with its imperative style. In Refiners, models are implemented in a declarative way instead, which makes them by nature easier to manipulate and patch.
"},{"location":"home/why/#whats-next","title":"What's next?","text":"Now you know why we wrote a declarative framework, you can check out how. It's not that complicated, we promise!
Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.\u00a0\u21a9
Hu et al., 2022. LoRA: Low-Rank Adaptation of Large Language Models.\u00a0\u21a9
Adapters","text":""},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter","title":"Adapter","text":" Bases: Generic[T]
Base class for adapters.
An Adapter modifies the structure of a Module
(typically by adding, removing or replacing layers), to adapt it to a new task.
property
","text":"target: T\n
The target of the adapter.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.eject","title":"eject","text":"eject() -> None\n
Eject the adapter.
This method is the inverse of inject
, and should leave the target in the same state as before the injection.
src/refiners/fluxion/adapters/adapter.py
def eject(self) -> None:\n \"\"\"Eject the adapter.\n\n This method is the inverse of [`inject`][refiners.fluxion.adapters.Adapter.inject],\n and should leave the target in the same state as before the injection.\n \"\"\"\n assert isinstance(self, fl.Chain)\n\n # In general, the \"actual target\" is the target.\n # Here we deal with the edge case where the target\n # is part of the replacement block and has been adapted by\n # another adapter after this one. For instance, this is the\n # case when stacking Controlnets.\n actual_target = lookup_top_adapter(self, self.target)\n\n if (parent := self.parent) is None:\n if isinstance(actual_target, fl.ContextModule):\n actual_target._set_parent(None) # type: ignore[reportPrivateUsage]\n else:\n parent.replace(old_module=self, new_module=actual_target)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.inject","title":"inject","text":"inject(parent: Chain | None = None) -> TAdapter\n
Inject the adapter.
This method replaces the target of the adapter by the adapter inside the parent of the target.
Parameters:
Name Type Description Defaultparent
Chain | None
The parent to inject the adapter into, if the target doesn't have a parent.
None
Source code in src/refiners/fluxion/adapters/adapter.py
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:\n \"\"\"Inject the adapter.\n\n This method replaces the target of the adapter by the adapter inside the parent of the target.\n\n Args:\n parent: The parent to inject the adapter into, if the target doesn't have a parent.\n \"\"\"\n assert isinstance(self, fl.Chain)\n\n if (parent is None) and isinstance(self.target, fl.ContextModule):\n parent = self.target.parent\n if parent is not None:\n assert isinstance(parent, fl.Chain), f\"{self.target} has invalid parent {parent}\"\n\n target_parent = self.find_parent(self.target)\n\n if parent is None:\n if isinstance(self.target, fl.ContextModule):\n self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]\n return self\n\n # In general, `true_parent` is `parent`. We do this to support multiple adaptation,\n # i.e. initializing two adapters before injecting them.\n true_parent = parent.ensure_find_parent(self.target)\n true_parent.replace(\n old_module=self.target,\n new_module=self,\n old_module_parent=target_parent,\n )\n return self\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Adapter.setup_adapter","title":"setup_adapter","text":"setup_adapter(target: T) -> Iterator[None]\n
Setup the adapter.
This method should be called by the constructor of the adapter. It sets the target of the adapter and ensures that the adapter is not a submodule of the target.
Parameters:
Name Type Description Defaulttarget
T
The target of the adapter.
required Source code insrc/refiners/fluxion/adapters/adapter.py
@contextlib.contextmanager\ndef setup_adapter(self, target: T) -> Iterator[None]:\n \"\"\"Setup the adapter.\n\n This method should be called by the constructor of the adapter.\n It sets the target of the adapter and ensures that the adapter\n is not a submodule of the target.\n\n Args:\n target: The target of the adapter.\n \"\"\"\n assert isinstance(self, fl.Chain)\n assert (not hasattr(self, \"_modules\")) or (\n len(self) == 0\n ), \"Call the Chain constructor in the setup_adapter context.\"\n self._target = [target]\n\n if isinstance(target, fl.ContextModule):\n assert isinstance(target, fl.ContextModule)\n with target.no_parent_refresh():\n yield\n else:\n yield\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Conv2dLora","title":"Conv2dLora","text":"Conv2dLora(\n name: str,\n /,\n in_channels: int,\n out_channels: int,\n rank: int = 16,\n scale: float = 1.0,\n kernel_size: tuple[int, int] = (1, 3),\n stride: tuple[int, int] = (1, 1),\n padding: tuple[int, int] = (0, 1),\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Lora[Conv2d]
Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two Conv2d
layers as its down and up layers.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredin_channels
int
The number of input channels.
requiredout_channels
int
The number of output channels.
requiredrank
int
The rank of the LoRA.
16
scale
float
The scale of the LoRA.
1.0
kernel_size
tuple[int, int]
The kernel size of the LoRA.
(1, 3)
stride
tuple[int, int]
The stride of the LoRA.
(1, 1)
padding
tuple[int, int]
The padding of the LoRA.
(0, 1)
device
device | str | None
The device of the LoRA weights.
None
dtype
dtype | None
The dtype of the LoRA weights.
None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n self,\n name: str,\n /,\n in_channels: int,\n out_channels: int,\n rank: int = 16,\n scale: float = 1.0,\n kernel_size: tuple[int, int] = (1, 3),\n stride: tuple[int, int] = (1, 1),\n padding: tuple[int, int] = (0, 1),\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the LoRA layer.\n\n Args:\n name: The name of the LoRA.\n in_channels: The number of input channels.\n out_channels: The number of output channels.\n rank: The rank of the LoRA.\n scale: The scale of the LoRA.\n kernel_size: The kernel size of the LoRA.\n stride: The stride of the LoRA.\n padding: The padding of the LoRA.\n device: The device of the LoRA weights.\n dtype: The dtype of the LoRA weights.\n \"\"\"\n self.in_channels = in_channels\n self.out_channels = out_channels\n self.kernel_size = kernel_size\n self.stride = stride\n self.padding = padding\n\n super().__init__(\n name,\n rank=rank,\n scale=scale,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LinearLora","title":"LinearLora","text":"LinearLora(\n name: str,\n /,\n in_features: int,\n out_features: int,\n rank: int = 16,\n scale: float = 1.0,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Lora[Linear]
Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two Linear
layers as its down and up layers.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredin_features
int
The number of input features.
requiredout_features
int
The number of output features.
requiredrank
int
The rank of the LoRA.
16
scale
float
The scale of the LoRA.
1.0
device
device | str | None
The device of the LoRA weights.
None
dtype
dtype | None
The dtype of the LoRA weights.
None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n self,\n name: str,\n /,\n in_features: int,\n out_features: int,\n rank: int = 16,\n scale: float = 1.0,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the LoRA layer.\n\n Args:\n name: The name of the LoRA.\n in_features: The number of input features.\n out_features: The number of output features.\n rank: The rank of the LoRA.\n scale: The scale of the LoRA.\n device: The device of the LoRA weights.\n dtype: The dtype of the LoRA weights.\n \"\"\"\n self.in_features = in_features\n self.out_features = out_features\n\n super().__init__(\n name,\n rank=rank,\n scale=scale,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora","title":"Lora","text":"Lora(\n name: str,\n /,\n rank: int = 16,\n scale: float = 1.0,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Generic[T]
, Chain
, ABC
Low-Rank Adaptation (LoRA) layer.
This layer's purpose is to approximate a given layer by two smaller layers: the down
layer (aka A) and the up
layer (aka B). See [ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models for more details.
This layer is not meant to be used directly. Instead, use one of its subclasses:
LinearLora
Conv2dLora
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredrank
int
The rank of the LoRA.
16
scale
float
The scale of the LoRA.
1.0
device
device | str | None
The device of the LoRA weights.
None
dtype
dtype | None
The dtype of the LoRA weights.
None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(\n self,\n name: str,\n /,\n rank: int = 16,\n scale: float = 1.0,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the LoRA layer.\n\n Args:\n name: The name of the LoRA.\n rank: The rank of the LoRA.\n scale: The scale of the LoRA.\n device: The device of the LoRA weights.\n dtype: The dtype of the LoRA weights.\n \"\"\"\n self.name = name\n self._rank = rank\n self._scale = scale\n\n super().__init__(\n *self.lora_layers(device=device, dtype=dtype),\n fl.Multiply(scale),\n )\n self.reset_parameters()\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.down","title":"down property
","text":"down: T\n
The down layer.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.rank","title":"rankproperty
","text":"rank: int\n
The rank of the low-rank approximation.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.scale","title":"scaleproperty
writable
","text":"scale: float\n
The scale of the low-rank approximation.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.up","title":"upproperty
","text":"up: T\n
The up layer.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.from_dict","title":"from_dictclassmethod
","text":"from_dict(\n name: str, /, state_dict: dict[str, Tensor]\n) -> dict[str, Lora[Any]]\n
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
Source code insrc/refiners/fluxion/adapters/lora.py
@classmethod\ndef from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, \"Lora[Any]\"]:\n \"\"\"\n Create a dictionary of LoRA layers from a state dict.\n\n Expects the state dict to be a succession of down and up weights.\n \"\"\"\n state_dict = {k: v for k, v in state_dict.items() if \".weight\" in k}\n loras: dict[str, Lora[Any]] = {}\n for down_key, down_tensor, up_tensor in zip(\n list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]\n ):\n key = \".\".join(down_key.split(\".\")[:-2])\n loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)\n return loras\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.load_weights","title":"load_weights","text":"load_weights(\n down_weight: Tensor, up_weight: Tensor\n) -> None\n
Load the (pre-trained) weights of the LoRA.
Parameters:
Name Type Description Defaultdown_weight
Tensor
The down weight.
requiredup_weight
Tensor
The up weight.
required Source code insrc/refiners/fluxion/adapters/lora.py
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:\n \"\"\"Load the (pre-trained) weights of the LoRA.\n\n Args:\n down_weight: The down weight.\n up_weight: The up weight.\n \"\"\"\n assert down_weight.shape == self.down.weight.shape\n assert up_weight.shape == self.up.weight.shape\n self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))\n self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.lora_layers","title":"lora_layers abstractmethod
","text":"lora_layers(\n device: device | str | None = None,\n dtype: dtype | None = None,\n) -> tuple[T, T]\n
Create the down and up layers of the LoRA.
Parameters:
Name Type Description Defaultdevice
device | str | None
The device of the LoRA weights.
None
dtype
dtype | None
The dtype of the LoRA weights.
None
Source code in src/refiners/fluxion/adapters/lora.py
@abstractmethod\ndef lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:\n \"\"\"Create the down and up layers of the LoRA.\n\n Args:\n device: The device of the LoRA weights.\n dtype: The dtype of the LoRA weights.\n \"\"\"\n ...\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.Lora.reset_parameters","title":"reset_parameters","text":"reset_parameters() -> None\n
Reset the parameters of up and down layers.
Source code insrc/refiners/fluxion/adapters/lora.py
def reset_parameters(self) -> None:\n \"\"\"Reset the parameters of up and down layers.\"\"\"\n normal_(tensor=self.down.weight, std=1 / self.rank)\n zeros_(tensor=self.up.weight)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter","title":"LoraAdapter","text":"LoraAdapter(target: WeightedModule, /, *loras: Lora[Any])\n
Bases: Sum
, Adapter[WeightedModule]
Adapter for LoRA layers.
This adapter simply sums the target layer with the given LoRA layers.
Parameters:
Name Type Description Defaulttarget
WeightedModule
The target layer.
requiredloras
Lora[Any]
The LoRA layers.
()
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:\n \"\"\"Initialize the adapter.\n\n Args:\n target: The target layer.\n loras: The LoRA layers.\n \"\"\"\n with self.setup_adapter(target):\n super().__init__(target, *loras)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.lora_layers","title":"lora_layers property
","text":"lora_layers: Iterator[Lora[Any]]\n
The LoRA layers.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.loras","title":"lorasproperty
","text":"loras: dict[str, Lora[Any]]\n
The LoRA layers indexed by name.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.names","title":"namesproperty
","text":"names: list[str]\n
The names of the LoRA layers.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.scales","title":"scalesproperty
","text":"scales: dict[str, float]\n
The scales of the LoRA layers indexed by names.
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.add_lora","title":"add_lora","text":"add_lora(lora: Lora[Any]) -> None\n
Add a LoRA layer to the adapter.
Raises:
Type DescriptionAssertionError
If the adapter already contains a LoRA layer with the same name.
Parameters:
Name Type Description Defaultlora
Lora[Any]
The LoRA layer to add.
required Source code insrc/refiners/fluxion/adapters/lora.py
def add_lora(self, lora: Lora[Any], /) -> None:\n \"\"\"Add a LoRA layer to the adapter.\n\n Raises:\n AssertionError: If the adapter already contains a LoRA layer with the same name.\n\n Args:\n lora: The LoRA layer to add.\n \"\"\"\n assert lora.name not in self.names, f\"LoRA layer with name {lora.name} already exists\"\n self.append(lora)\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.LoraAdapter.remove_lora","title":"remove_lora","text":"remove_lora(name: str) -> Lora[Any] | None\n
Remove a LoRA layer from the adapter.
NoteIf the adapter doesn't contain a LoRA layer with the given name, nothing happens and None
is returned.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA layer to remove.
required Source code insrc/refiners/fluxion/adapters/lora.py
def remove_lora(self, name: str, /) -> Lora[Any] | None:\n \"\"\"Remove a LoRA layer from the adapter.\n\n Note:\n If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.\n\n Args:\n name: The name of the LoRA layer to remove.\n \"\"\"\n if name in self.names:\n lora = self.loras[name]\n self.remove(lora)\n return lora\n
"},{"location":"reference/fluxion/adapters/#refiners.fluxion.adapters.auto_attach_loras","title":"auto_attach_loras","text":"auto_attach_loras(\n loras: dict[str, Lora[Any]],\n target: Chain,\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n sanity_check: bool = True,\n debug_map: list[tuple[str, str]] | None = None,\n) -> list[str]\n
Auto-attach several LoRA layers to a Chain.
Parameters:
Name Type Description Defaultloras
dict[str, Lora[Any]]
A dictionary of LoRA layers associated to their respective key. The keys are typically derived from the state dict and only used for debug_map
and the return value.
target
Chain
The target Chain.
requiredinclude
list[str] | None
A list of layer names, only layers with such a layer in their ancestors will be considered.
None
exclude
list[str] | None
A list of layer names, layers with such a layer in their ancestors will not be considered.
None
sanity_check
bool
Check that LoRAs passed are correctly attached.
True
debug_map
list[tuple[str, str]] | None
Pass a list to get a debug mapping of key - path pairs of attached points.
None
Returns: A list of keys of LoRA layers which failed to attach.
Source code insrc/refiners/fluxion/adapters/lora.py
def auto_attach_loras(\n loras: dict[str, Lora[Any]],\n target: fl.Chain,\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n sanity_check: bool = True,\n debug_map: list[tuple[str, str]] | None = None,\n) -> list[str]:\n \"\"\"Auto-attach several LoRA layers to a Chain.\n\n Args:\n loras: A dictionary of LoRA layers associated to their respective key. The keys are typically\n derived from the state dict and only used for `debug_map` and the return value.\n target: The target Chain.\n include: A list of layer names, only layers with such a layer in their ancestors will be considered.\n exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.\n sanity_check: Check that LoRAs passed are correctly attached.\n debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.\n Returns:\n A list of keys of LoRA layers which failed to attach.\n \"\"\"\n\n if not sanity_check:\n return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)\n\n loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}\n debug_map_1: list[tuple[str, str]] = []\n failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)\n if debug_map is not None:\n debug_map += debug_map_1\n if len(debug_map_1) != len(loras) or failed_keys_1:\n raise ValueError(\n f\"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed\"\n )\n\n # Extra sanity check: if we re-run the attach, all layers should fail.\n debug_map_2: list[tuple[str, str]] = []\n failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)\n if debug_map_2 or len(failed_keys_2) != len(loras):\n raise ValueError(\n f\"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped\"\n )\n\n return failed_keys_1\n
"},{"location":"reference/fluxion/context/","title":"
Context","text":""},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider","title":"ContextProvider","text":"ContextProvider()\n
A class that provides a context store.
Source code insrc/refiners/fluxion/context.py
def __init__(self) -> None:\n \"\"\"Initializes the ContextProvider.\"\"\"\n self.contexts: Contexts = {}\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.create","title":"create staticmethod
","text":"create(contexts: Contexts) -> ContextProvider\n
Create a ContextProvider from a dict of contexts.
Parameters:
Name Type Description Defaultcontexts
Contexts
The contexts.
requiredReturns:
Type DescriptionContextProvider
A ContextProvider with the contexts.
Source code insrc/refiners/fluxion/context.py
@staticmethod\ndef create(contexts: Contexts) -> \"ContextProvider\":\n \"\"\"Create a ContextProvider from a dict of contexts.\n\n Args:\n contexts: The contexts.\n\n Returns:\n A ContextProvider with the contexts.\n \"\"\"\n provider = ContextProvider()\n provider.update_contexts(contexts)\n return provider\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.get_context","title":"get_context","text":"get_context(key: str) -> Any\n
Retrieve a value from the context.
Parameters:
Name Type Description Defaultkey
str
The key of the context.
requiredReturns:
Type DescriptionAny
The context value.
Source code insrc/refiners/fluxion/context.py
def get_context(self, key: str) -> Any:\n \"\"\"Retrieve a value from the context.\n\n Args:\n key: The key of the context.\n\n Returns:\n The context value.\n \"\"\"\n return self.contexts.get(key)\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.set_context","title":"set_context","text":"set_context(key: str, value: Context) -> None\n
Store a value in the context.
Parameters:
Name Type Description Defaultkey
str
The key of the context.
requiredvalue
Context
The context.
required Source code insrc/refiners/fluxion/context.py
def set_context(self, key: str, value: Context) -> None:\n \"\"\"Store a value in the context.\n\n Args:\n key: The key of the context.\n value: The context.\n \"\"\"\n self.contexts[key] = value\n
"},{"location":"reference/fluxion/context/#refiners.fluxion.context.ContextProvider.update_contexts","title":"update_contexts","text":"update_contexts(new_contexts: Contexts) -> None\n
Update or set the contexts with new contexts.
Parameters:
Name Type Description Defaultnew_contexts
Contexts
The new contexts.
required Source code insrc/refiners/fluxion/context.py
def update_contexts(self, new_contexts: Contexts) -> None:\n \"\"\"Update or set the contexts with new contexts.\n\n Args:\n new_contexts: The new contexts.\n \"\"\"\n for key, value in new_contexts.items():\n if key not in self.contexts:\n self.contexts[key] = value\n else:\n self.contexts[key].update(value)\n
"},{"location":"reference/fluxion/layers/","title":"
Layers","text":""},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Activation","title":"Activation","text":"Activation()\n
Bases: Module
, ABC
Base class for activation layers.
Activation layers are layers that apply a (non-linear) function to their input.
Receives:
Name Type Descriptionx
Tensor
Returns:
Type DescriptionTensor
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Attention","title":"Attention","text":"Attention(\n embedding_dim: int,\n num_heads: int = 1,\n key_embedding_dim: int | None = None,\n value_embedding_dim: int | None = None,\n inner_dim: int | None = None,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Multi-Head Attention layer.
See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details This layer simply chainsDistribute
layer, containing 3 Linear
layers, which transforms the 3 inputs into Query, Key and ValueScaledDotProductAttention
layerLinear
layer, which projects the output of the ScaledDotProductAttention
layerReceives:
Name Type DescriptionQuery
Float[Tensor, 'batch sequence_length embedding_dim']
Key
Float[Tensor, 'batch sequence_length embedding_dim']
Value
Float[Tensor, 'batch sequence_length embedding_dim']
Returns:
Type DescriptionFloat[Tensor, 'batch sequence_length embedding_dim']
Example attention = fl.Attention(num_heads=8, embedding_dim=128)\n\ntensor = torch.randn(2, 10, 128)\noutput = attention(tensor, tensor, tensor)\n\nassert output.shape == (2, 10, 128)\n
Parameters:
Name Type Description Defaultembedding_dim
int
The embedding dimension of the input and output tensors.
requirednum_heads
int
The number of heads to use.
1
key_embedding_dim
int | None
The embedding dimension of the key tensor.
None
value_embedding_dim
int | None
The embedding dimension of the value tensor.
None
inner_dim
int | None
The inner dimension of the linear layers.
None
use_bias
bool
Whether to use bias in the linear layers.
True
is_causal
bool
Whether to use causal attention.
False
is_optimized
bool
Whether to use optimized attention.
True
device
device | str | None
The device to use.
None
dtype
dtype | None
The dtype to use.
None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n self,\n embedding_dim: int,\n num_heads: int = 1,\n key_embedding_dim: int | None = None,\n value_embedding_dim: int | None = None,\n inner_dim: int | None = None,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the Attention layer.\n\n Args:\n embedding_dim: The embedding dimension of the input and output tensors.\n num_heads: The number of heads to use.\n key_embedding_dim: The embedding dimension of the key tensor.\n value_embedding_dim: The embedding dimension of the value tensor.\n inner_dim: The inner dimension of the linear layers.\n use_bias: Whether to use bias in the linear layers.\n is_causal: Whether to use causal attention.\n is_optimized: Whether to use optimized attention.\n device: The device to use.\n dtype: The dtype to use.\n \"\"\"\n assert (\n embedding_dim % num_heads == 0\n ), f\"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}\"\n self.embedding_dim = embedding_dim\n self.num_heads = num_heads\n self.heads_dim = embedding_dim // num_heads\n self.key_embedding_dim = key_embedding_dim or embedding_dim\n self.value_embedding_dim = value_embedding_dim or embedding_dim\n self.inner_dim = inner_dim or embedding_dim\n self.use_bias = use_bias\n self.is_causal = is_causal\n self.is_optimized = is_optimized\n\n super().__init__(\n Distribute(\n Linear( # Query projection\n in_features=self.embedding_dim,\n out_features=self.inner_dim,\n bias=self.use_bias,\n device=device,\n dtype=dtype,\n ),\n Linear( # Key projection\n in_features=self.key_embedding_dim,\n out_features=self.inner_dim,\n bias=self.use_bias,\n device=device,\n dtype=dtype,\n ),\n Linear( # Value projection\n in_features=self.value_embedding_dim,\n out_features=self.inner_dim,\n bias=self.use_bias,\n device=device,\n dtype=dtype,\n ),\n ),\n ScaledDotProductAttention(\n num_heads=num_heads,\n is_causal=is_causal,\n is_optimized=is_optimized,\n ),\n Linear( # Output projection\n in_features=self.inner_dim,\n out_features=self.embedding_dim,\n bias=True,\n device=device,\n dtype=dtype,\n ),\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Breakpoint","title":"Breakpoint","text":"Breakpoint(vscode: bool = True)\n
Bases: ContextModule
Breakpoint layer.
This layer pauses the execution when encountered, and opens a debugger.
Source code insrc/refiners/fluxion/layers/chain.py
def __init__(self, vscode: bool = True):\n super().__init__()\n self.vscode = vscode\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain","title":"Chain","text":"Chain(*args: Module | Iterable[Module])\n
Bases: ContextModule
Chain layer.
This layer is the main building block of Fluxion. It is used to compose other layers in a sequential manner. Similarly to torch.nn.Sequential
, it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer. However, it also provides additional methods to manipulate its sub-layers and their context.
chain = fl.Chain(\n fl.Linear(32, 64),\n fl.ReLU(),\n fl.Linear(64, 128),\n)\n\ntensor = torch.randn(2, 32)\noutput = chain(tensor)\n\nassert output.shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.device","title":"device property
","text":"device: device | None\n
The PyTorch device of the Chain's parameters.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.dtype","title":"dtypeproperty
","text":"dtype: dtype | None\n
The PyTorch dtype of the Chain's parameters.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.provider","title":"providerproperty
","text":"provider: ContextProvider\n
The ContextProvider
of the Chain.
append(module: Module) -> None\n
Append a new module to the chain.
Parameters:
Name Type Description Defaultmodule
Module
The module to append.
required Source code insrc/refiners/fluxion/layers/chain.py
def append(self, module: Module) -> None:\n \"\"\"Append a new module to the chain.\n\n Args:\n module: The module to append.\n \"\"\"\n self.insert(-1, module)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.ensure_find","title":"ensure_find","text":"ensure_find(layer_type: type[T]) -> T\n
Walk the Chain's sub-module tree and return the first layer of the given type.
Parameters:
Name Type Description Defaultlayer_type
type[T]
The type of layer to find.
requiredReturns:
Type DescriptionT
The first module of the given layer_type.
Raises:
Type DescriptionAssertionError
If the module doesn't exist.
Source code insrc/refiners/fluxion/layers/chain.py
def ensure_find(self, layer_type: type[T]) -> T:\n \"\"\"Walk the Chain's sub-module tree and return the first layer of the given type.\n\n Args:\n layer_type: The type of layer to find.\n\n Returns:\n The first module of the given layer_type.\n\n Raises:\n AssertionError: If the module doesn't exist.\n \"\"\"\n r = self.find(layer_type)\n assert r is not None, f\"could not find {layer_type} in {self}\"\n return r\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.ensure_find_parent","title":"ensure_find_parent","text":"ensure_find_parent(module: Module) -> Chain\n
Walk the Chain's sub-module tree and return the parent of the given module.
Parameters:
Name Type Description Defaultmodule
Module
The module whose parent to find.
requiredReturns:
Type DescriptionChain
The parent of the given module.
Raises:
Type DescriptionAssertionError
If the module doesn't exist.
Source code insrc/refiners/fluxion/layers/chain.py
def ensure_find_parent(self, module: Module) -> \"Chain\":\n \"\"\"Walk the Chain's sub-module tree and return the parent of the given module.\n\n Args:\n module: The module whose parent to find.\n\n Returns:\n The parent of the given module.\n\n Raises:\n AssertionError: If the module doesn't exist.\n \"\"\"\n r = self.find_parent(module)\n assert r is not None, f\"could not find {module} in {self}\"\n return r\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.find","title":"find","text":"find(layer_type: type[T]) -> T | None\n
Walk the Chain's sub-module tree and return the first layer of the given type.
Parameters:
Name Type Description Defaultlayer_type
type[T]
The type of layer to find.
requiredReturns:
Type DescriptionT | None
The first module of the given layer_type, or None if it doesn't exist.
Source code insrc/refiners/fluxion/layers/chain.py
def find(self, layer_type: type[T]) -> T | None:\n \"\"\"Walk the Chain's sub-module tree and return the first layer of the given type.\n\n Args:\n layer_type: The type of layer to find.\n\n Returns:\n The first module of the given layer_type, or None if it doesn't exist.\n \"\"\"\n return next(self.layers(layer_type=layer_type), None)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.find_parent","title":"find_parent","text":"find_parent(module: Module) -> Chain | None\n
Walk the Chain's sub-module tree and return the parent of the given module.
Parameters:
Name Type Description Defaultmodule
Module
The module whose parent to find.
requiredReturns:
Type DescriptionChain | None
The parent of the given module, or None if it doesn't exist.
Source code insrc/refiners/fluxion/layers/chain.py
def find_parent(self, module: Module) -> \"Chain | None\":\n \"\"\"Walk the Chain's sub-module tree and return the parent of the given module.\n\n Args:\n module: The module whose parent to find.\n\n Returns:\n The parent of the given module, or None if it doesn't exist.\n \"\"\"\n if module in self: # avoid DFS-crawling the whole tree\n return self\n for _, parent in self.walk(lambda m, _: m == module):\n return parent\n return None\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.init_context","title":"init_context","text":"init_context() -> Contexts\n
Initialize the context provider with some default values.
This method is called when the Chain is created, and when it is reset. This method may be overridden by subclasses to provide default values for the context provider.
Source code insrc/refiners/fluxion/layers/chain.py
def init_context(self) -> Contexts:\n \"\"\"Initialize the context provider with some default values.\n\n This method is called when the Chain is created, and when it is reset.\n This method may be overridden by subclasses to provide default values for the context provider.\n \"\"\"\n return {}\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert","title":"insert","text":"insert(index: int, module: Module) -> None\n
Insert a new module in the chain.
Parameters:
Name Type Description Defaultindex
int
The index at which to insert the module.
requiredmodule
Module
The module to insert.
requiredRaises:
Type DescriptionIndexError
If the index is out of range.
Source code insrc/refiners/fluxion/layers/chain.py
def insert(self, index: int, module: Module) -> None:\n \"\"\"Insert a new module in the chain.\n\n Args:\n index: The index at which to insert the module.\n module: The module to insert.\n\n Raises:\n IndexError: If the index is out of range.\n \"\"\"\n if index < 0:\n index = max(0, len(self._modules) + index + 1)\n modules = list(self)\n modules.insert(index, module)\n self._regenerate_keys(modules)\n if isinstance(module, ContextModule):\n module._set_parent(self)\n self._register_provider()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert_after_type","title":"insert_after_type","text":"insert_after_type(\n module_type: type[Module], new_module: Module\n) -> None\n
Insert a new module in the chain, right after the first module of the given type.
Parameters:
Name Type Description Defaultmodule_type
type[Module]
The type of module to insert after.
requirednew_module
Module
The module to insert.
requiredRaises:
Type DescriptionValueError
If no module of the given type exists in the chain.
Source code insrc/refiners/fluxion/layers/chain.py
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:\n \"\"\"Insert a new module in the chain, right after the first module of the given type.\n\n Args:\n module_type: The type of module to insert after.\n new_module: The module to insert.\n\n Raises:\n ValueError: If no module of the given type exists in the chain.\n \"\"\"\n for i, module in enumerate(self):\n if isinstance(module, module_type):\n self.insert(i + 1, new_module)\n return\n raise ValueError(f\"No module of type {module_type.__name__} found in the chain.\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.insert_before_type","title":"insert_before_type","text":"insert_before_type(\n module_type: type[Module], new_module: Module\n) -> None\n
Insert a new module in the chain, right before the first module of the given type.
Parameters:
Name Type Description Defaultmodule_type
type[Module]
The type of module to insert before.
requirednew_module
Module
The module to insert.
requiredRaises:
Type DescriptionValueError
If no module of the given type exists in the chain.
Source code insrc/refiners/fluxion/layers/chain.py
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:\n \"\"\"Insert a new module in the chain, right before the first module of the given type.\n\n Args:\n module_type: The type of module to insert before.\n new_module: The module to insert.\n\n Raises:\n ValueError: If no module of the given type exists in the chain.\n \"\"\"\n for i, module in enumerate(self):\n if isinstance(module, module_type):\n self.insert(i, new_module)\n return\n raise ValueError(f\"No module of type {module_type.__name__} found in the chain.\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.layer","title":"layer","text":"layer(\n key: str | int | Sequence[str | int],\n layer_type: type[T] = Module,\n) -> T\n
Access a layer of the Chain given its type.
Example# same as my_chain[\"Linear_2\"], asserts it is a Linear\nmy_chain.layer(\"Linear_2\", fl.Linear)\n\n\n# same as my_chain[3], asserts it is a Linear\nmy_chain.layer(3, fl.Linear)\n\n# probably won't work\nmy_chain.layer(\"Conv2d\", fl.Linear)\n\n\n# same as my_chain[\"foo\"][42][\"bar\"],\n# assuming bar is a MyType and all parents are Chains\nmy_chain.layer((\"foo\", 42, \"bar\"), fl.MyType)\n
Parameters:
Name Type Description Defaultkey
str | int | Sequence[str | int]
The key or path of the layer.
requiredlayer_type
type[T]
The type of the layer.
Module
Yields:
Type DescriptionT
The layer.
Raises:
Type DescriptionAssertionError
If the layer doesn't exist or the type is invalid.
Source code insrc/refiners/fluxion/layers/chain.py
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:\n \"\"\"Access a layer of the Chain given its type.\n\n Example:\n ```py\n # same as my_chain[\"Linear_2\"], asserts it is a Linear\n my_chain.layer(\"Linear_2\", fl.Linear)\n\n\n # same as my_chain[3], asserts it is a Linear\n my_chain.layer(3, fl.Linear)\n\n # probably won't work\n my_chain.layer(\"Conv2d\", fl.Linear)\n\n\n # same as my_chain[\"foo\"][42][\"bar\"],\n # assuming bar is a MyType and all parents are Chains\n my_chain.layer((\"foo\", 42, \"bar\"), fl.MyType)\n ```\n\n Args:\n key: The key or path of the layer.\n layer_type: The type of the layer.\n\n Yields:\n The layer.\n\n Raises:\n AssertionError: If the layer doesn't exist or the type is invalid.\n \"\"\"\n if isinstance(key, (str, int)):\n r = self[key]\n assert isinstance(r, layer_type), f\"layer {key} is {type(r)}, not {layer_type}\"\n return r\n if len(key) == 0:\n assert isinstance(self, layer_type), f\"layer is {type(self)}, not {layer_type}\"\n return self\n if len(key) == 1:\n return self.layer(key[0], layer_type)\n return self.layer(key[0], Chain).layer(key[1:], layer_type)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.layers","title":"layers","text":"layers(\n layer_type: type[T], recurse: bool = False\n) -> Iterator[T]\n
Walk the Chain's sub-module tree and yield each layer of the given type.
Parameters:
Name Type Description Defaultlayer_type
type[T]
The type of layer to yield.
requiredrecurse
bool
Whether to recurse into sub-Chains.
False
Yields:
Type DescriptionT
Each module of the given layer_type.
Source code insrc/refiners/fluxion/layers/chain.py
def layers(\n self,\n layer_type: type[T],\n recurse: bool = False,\n) -> Iterator[T]:\n \"\"\"Walk the Chain's sub-module tree and yield each layer of the given type.\n\n Args:\n layer_type: The type of layer to yield.\n recurse: Whether to recurse into sub-Chains.\n\n Yields:\n Each module of the given layer_type.\n \"\"\"\n for module, _ in self.walk(layer_type, recurse):\n yield module\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.pop","title":"pop","text":"pop(index: int = -1) -> Module\n
Pop a module from the chain at the given index.
Parameters:
Name Type Description Defaultindex
int
The index of the module to pop.
-1
Returns:
Type DescriptionModule
The popped module.
Raises:
Type DescriptionIndexError
If the index is out of range.
Source code insrc/refiners/fluxion/layers/chain.py
def pop(self, index: int = -1) -> Module:\n \"\"\"Pop a module from the chain at the given index.\n\n Args:\n index: The index of the module to pop.\n\n Returns:\n The popped module.\n\n Raises:\n IndexError: If the index is out of range.\n \"\"\"\n modules = list(self)\n if index < 0:\n index = len(modules) + index\n if index < 0 or index >= len(modules):\n raise IndexError(\"Index out of range.\")\n removed_module = modules.pop(index)\n if isinstance(removed_module, ContextModule):\n removed_module._set_parent(None)\n self._regenerate_keys(modules)\n return removed_module\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.remove","title":"remove","text":"remove(module: Module) -> None\n
Remove a module from the chain.
Parameters:
Name Type Description Defaultmodule
Module
The module to remove.
requiredRaises:
Type DescriptionValueError
If the module is not in the chain.
Source code insrc/refiners/fluxion/layers/chain.py
def remove(self, module: Module) -> None:\n \"\"\"Remove a module from the chain.\n\n Args:\n module: The module to remove.\n\n Raises:\n ValueError: If the module is not in the chain.\n \"\"\"\n modules = list(self)\n try:\n modules.remove(module)\n except ValueError:\n raise ValueError(f\"{module} is not in {self}\")\n self._regenerate_keys(modules)\n if isinstance(module, ContextModule):\n module._set_parent(None)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.replace","title":"replace","text":"replace(\n old_module: Module,\n new_module: Module,\n old_module_parent: Chain | None = None,\n) -> None\n
Replace a module in the chain with a new module.
Parameters:
Name Type Description Defaultold_module
Module
The module to replace.
requirednew_module
Module
The module to replace with.
requiredold_module_parent
Chain | None
The parent of the old module. If None, the old module is orphanized.
None
Raises:
Type DescriptionValueError
If the module is not in the chain.
Source code insrc/refiners/fluxion/layers/chain.py
def replace(\n self,\n old_module: Module,\n new_module: Module,\n old_module_parent: \"Chain | None\" = None,\n) -> None:\n \"\"\"Replace a module in the chain with a new module.\n\n Args:\n old_module: The module to replace.\n new_module: The module to replace with.\n old_module_parent: The parent of the old module.\n If None, the old module is orphanized.\n\n Raises:\n ValueError: If the module is not in the chain.\n \"\"\"\n modules = list(self)\n try:\n modules[modules.index(old_module)] = new_module\n except ValueError:\n raise ValueError(f\"{old_module} is not in {self}\")\n self._regenerate_keys(modules)\n if isinstance(new_module, ContextModule):\n new_module._set_parent(self)\n if isinstance(old_module, ContextModule):\n old_module._set_parent(old_module_parent)\n self._register_provider()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.set_context","title":"set_context","text":"set_context(context: str, value: Any) -> None\n
Set a value in the context provider.
Parameters:
Name Type Description Defaultcontext
str
The context to update.
requiredvalue
Any
The value to set.
required Source code insrc/refiners/fluxion/layers/chain.py
def set_context(self, context: str, value: Any) -> None:\n \"\"\"Set a value in the context provider.\n\n Args:\n context: The context to update.\n value: The value to set.\n \"\"\"\n self._provider.set_context(context, value)\n self._register_provider()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.structural_copy","title":"structural_copy","text":"structural_copy() -> TChain\n
Copy the structure of the Chain tree.
This method returns a recursive copy of the Chain tree where all inner nodes (instances of Chain and its subclasses) are duplicated and all leaves (regular Modules) are not.
Such copies can be adapted without disrupting the base model, but do not require extra GPU memory since the weights are in the leaves and hence not copied.
Source code insrc/refiners/fluxion/layers/chain.py
def structural_copy(self: TChain) -> TChain:\n \"\"\"Copy the structure of the Chain tree.\n\n This method returns a recursive copy of the Chain tree where all inner nodes\n (instances of Chain and its subclasses) are duplicated and all leaves\n (regular Modules) are not.\n\n Such copies can be adapted without disrupting the base model, but do not\n require extra GPU memory since the weights are in the leaves and hence not copied.\n \"\"\"\n if hasattr(self, \"_pre_structural_copy\"):\n assert callable(self._pre_structural_copy)\n self._pre_structural_copy()\n\n modules = [structural_copy(m) for m in self]\n clone = super().structural_copy()\n clone._provider = ContextProvider.create(clone.init_context())\n\n for module in modules:\n clone.append(module=module)\n\n if hasattr(clone, \"_post_structural_copy\"):\n assert callable(clone._post_structural_copy)\n clone._post_structural_copy(self)\n\n return clone\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Chain.walk","title":"walk","text":"walk(\n predicate: (\n Callable[[Module, Chain], bool] | None\n ) = None,\n recurse: bool = False,\n) -> Iterator[tuple[Module, Chain]]\n
walk(\n predicate: type[T], recurse: bool = False\n) -> Iterator[tuple[T, Chain]]\n
walk(\n predicate: (\n type[T] | Callable[[Module, Chain], bool] | None\n ) = None,\n recurse: bool = False,\n) -> (\n Iterator[tuple[T, Chain]]\n | Iterator[tuple[Module, Chain]]\n)\n
Walk the Chain's sub-module tree and yield each module that matches the predicate.
Parameters:
Name Type Description Defaultpredicate
type[T] | Callable[[Module, Chain], bool] | None
The predicate to match.
None
recurse
bool
Whether to recurse into sub-Chains.
False
Yields:
Type DescriptionIterator[tuple[T, Chain]] | Iterator[tuple[Module, Chain]]
Each module that matches the predicate.
Source code insrc/refiners/fluxion/layers/chain.py
def walk(\n self,\n predicate: type[T] | Callable[[Module, \"Chain\"], bool] | None = None,\n recurse: bool = False,\n) -> Iterator[tuple[T, \"Chain\"]] | Iterator[tuple[Module, \"Chain\"]]:\n \"\"\"Walk the Chain's sub-module tree and yield each module that matches the predicate.\n\n Args:\n predicate: The predicate to match.\n recurse: Whether to recurse into sub-Chains.\n\n Yields:\n Each module that matches the predicate.\n \"\"\"\n\n if get_origin(predicate) is not None:\n raise ValueError(f\"subscripted generics cannot be used as predicates\")\n\n if isinstance(predicate, type):\n # if the predicate is a Module type\n # build a predicate function that matches the type\n return self._walk(\n predicate=lambda m, _: isinstance(m, predicate),\n recurse=recurse,\n )\n else:\n return self._walk(\n predicate=predicate,\n recurse=recurse,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Concatenate","title":"Concatenate","text":"Concatenate(*modules: Module, dim: int = 0)\n
Bases: Chain
Concatenation layer.
This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.
Exampleconcatenate = fl.Concatenate(\n fl.Linear(32, 128),\n fl.Linear(32, 128),\n dim=1,\n)\n\ntensor = torch.randn(2, 32)\noutput = concatenate(tensor)\n\nassert output.shape == (2, 256)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *modules: Module, dim: int = 0) -> None:\n super().__init__(*modules)\n self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule","title":"ContextModule","text":"ContextModule(*args: Any, **kwargs: Any)\n
Bases: Module
A module containing a ContextProvider
.
src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs)\n self._parent = []\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.ensure_parent","title":"ensure_parent property
","text":"ensure_parent: Chain\n
Return the module's parent, or raise an error if module is an orphan.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.parent","title":"parentproperty
","text":"parent: Chain | None\n
Return the module's parent, or None if module is an orphan.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.provider","title":"providerproperty
","text":"provider: ContextProvider\n
Return the module's context provider.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.get_parents","title":"get_parents","text":"get_parents() -> list[Chain]\n
Recursively retrieve the module's parents.
Source code insrc/refiners/fluxion/layers/module.py
def get_parents(self) -> \"list[Chain]\":\n \"\"\"Recursively retrieve the module's parents.\"\"\"\n return self._parent + self._parent[0].get_parents() if self._parent else []\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.get_path","title":"get_path","text":"get_path(\n parent: Chain | None = None, top: Module | None = None\n) -> str\n
Get the path of the module in the chain.
Parameters:
Name Type Description Defaultparent
Chain | None
The parent of the module in the chain.
None
top
Module | None
The top module of the chain. If None, the path will be relative to the root of the chain.
None
Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: \"Chain | None\" = None, top: \"Module | None\" = None) -> str:\n \"\"\"Get the path of the module in the chain.\n\n Args:\n parent: The parent of the module in the chain.\n top: The top module of the chain.\n If None, the path will be relative to the root of the chain.\n \"\"\"\n\n return super().get_path(parent=parent or self.parent, top=top)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ContextModule.use_context","title":"use_context","text":"use_context(context_name: str) -> Context\n
Retrieve the context object from the module's context provider.
Source code insrc/refiners/fluxion/layers/module.py
def use_context(self, context_name: str) -> Context:\n \"\"\"Retrieve the context object from the module's context provider.\"\"\"\n context = self.provider.get_context(context_name)\n assert context is not None, f\"Context {context_name} not found.\"\n return context\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Conv2d","title":"Conv2d","text":"Conv2d(\n in_channels: int,\n out_channels: int,\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] = (1, 1),\n padding: int | tuple[int, int] | str = (0, 0),\n groups: int = 1,\n use_bias: bool = True,\n dilation: int | tuple[int, int] = (1, 1),\n padding_mode: str = \"zeros\",\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Conv2d
, WeightedModule
2D Convolutional layer.
This layer wraps torch.nn.Conv2d
.
Receives:
Type DescriptionReal[Tensor, 'batch in_channels in_height in_width']
Returns:
Type DescriptionReal[Tensor, 'batch out_channels out_height out_width']
Example conv2d = fl.Conv2d(\n in_channels=3,\n out_channels=32,\n kernel_size=3,\n stride=1,\n padding=1,\n)\n\ntensor = torch.randn(2, 3, 128, 128)\noutput = conv2d(tensor)\n\nassert output.shape == (2, 32, 128, 128)\n
Source code in src/refiners/fluxion/layers/conv.py
def __init__(\n self,\n in_channels: int,\n out_channels: int,\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] = (1, 1),\n padding: int | tuple[int, int] | str = (0, 0),\n groups: int = 1,\n use_bias: bool = True,\n dilation: int | tuple[int, int] = (1, 1),\n padding_mode: str = \"zeros\",\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__( # type: ignore\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n dilation=dilation,\n groups=groups,\n bias=use_bias,\n padding_mode=padding_mode,\n device=device,\n dtype=dtype,\n )\n self.use_bias = use_bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ConvTranspose2d","title":"ConvTranspose2d","text":"ConvTranspose2d(\n in_channels: int,\n out_channels: int,\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] = 1,\n padding: int | tuple[int, int] = 0,\n output_padding: int | tuple[int, int] = 0,\n groups: int = 1,\n use_bias: bool = True,\n dilation: int | tuple[int, int] = 1,\n padding_mode: str = \"zeros\",\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ConvTranspose2d
, WeightedModule
2D Transposed Convolutional layer.
This layer wraps torch.nn.ConvTranspose2d
.
Receives:
Type DescriptionReal[Tensor, 'batch in_channels in_height in_width']
Returns:
Type DescriptionReal[Tensor, 'batch out_channels out_height out_width']
Example conv2d = fl.ConvTranspose2d(\n in_channels=3,\n out_channels=32,\n kernel_size=3,\n stride=1,\n padding=1,\n)\n\ntensor = torch.randn(2, 3, 128, 128)\noutput = conv2d(tensor)\n\nassert output.shape == (2, 32, 128, 128)\n
Source code in src/refiners/fluxion/layers/conv.py
def __init__(\n self,\n in_channels: int,\n out_channels: int,\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] = 1,\n padding: int | tuple[int, int] = 0,\n output_padding: int | tuple[int, int] = 0,\n groups: int = 1,\n use_bias: bool = True,\n dilation: int | tuple[int, int] = 1,\n padding_mode: str = \"zeros\",\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__( # type: ignore\n in_channels=in_channels,\n out_channels=out_channels,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n output_padding=output_padding,\n dilation=dilation,\n groups=groups,\n bias=use_bias,\n padding_mode=padding_mode,\n device=device,\n dtype=dtype,\n )\n self.use_bias = use_bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Converter","title":"Converter","text":"Converter(set_device: bool = True, set_dtype: bool = True)\n
Bases: ContextModule
A Converter class that adjusts tensor properties based on a parent module's settings.
This class inherits from ContextModule
and provides functionality to adjust the device and dtype of input tensor(s) to match the parent module's attributes.
Ensure the parent module has device
and dtype
attributes if set_device
or set_dtype
are set to True.
Parameters:
Name Type Description Defaultset_device
bool
If True, matches the device of the input tensor(s) to the parent's device.
True
set_dtype
bool
If True, matches the dtype of the input tensor(s) to the parent's dtype.
True
Source code in src/refiners/fluxion/layers/converter.py
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:\n \"\"\"Initializes the Converter layer.\n\n Args:\n set_device: If True, matches the device of the input tensor(s) to the parent's device.\n set_dtype: If True, matches the dtype of the input tensor(s) to the parent's dtype.\n \"\"\"\n super().__init__()\n self.set_device = set_device\n self.set_dtype = set_dtype\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Cos","title":"Cos","text":"Cos(*args: Any, **kwargs: Any)\n
Bases: Module
Cosine operator layer.
This layer applies the cosine function to the input tensor. See also torch.cos
.
cos = fl.Cos()\n\ntensor = torch.tensor([0, torch.pi])\noutput = cos(tensor)\n\nexpected_output = torch.tensor([1.0, -1.0])\nassert torch.allclose(output, expected_output, atol=1e-6)\n
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Distribute","title":"Distribute","text":"Distribute(*args: Module | Iterable[Module])\n
Bases: Chain
Distribute layer.
This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.
Exampledistribute = fl.Distribute(\n fl.Linear(32, 128),\n fl.Linear(64, 256),\n)\n\ntensor1 = torch.randn(2, 32)\ntensor2 = torch.randn(4, 64)\noutputs = distribute(tensor1, tensor2)\n\nassert len(outputs) == 2\nassert outputs[0].shape == (2, 128)\nassert outputs[1].shape == (4, 256)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Downsample","title":"Downsample","text":"Downsample(\n channels: int,\n scale_factor: int,\n padding: int = 0,\n register_shape: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Downsample layer.
This layer downsamples the input by the given scale factor.
Raises:
Type DescriptionRuntimeError
If the context sampling is not set or if the context does not contain a list.
Parameters:
Name Type Description Defaultchannels
int
The number of input and output channels.
requiredscale_factor
int
The factor by which to downsample the input.
requiredpadding
int
The amount of zero-padding added to both sides of the input.
0
register_shape
bool
If True, registers the input shape in the context.
True
device
device | str | None
The device to use for the convolutional layer.
None
dtype
dtype | None
The dtype to use for the convolutional layer.
None
Source code in src/refiners/fluxion/layers/sampling.py
def __init__(\n self,\n channels: int,\n scale_factor: int,\n padding: int = 0,\n register_shape: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n):\n \"\"\"Initializes the Downsample layer.\n\n Args:\n channels: The number of input and output channels.\n scale_factor: The factor by which to downsample the input.\n padding: The amount of zero-padding added to both sides of the input.\n register_shape: If True, registers the input shape in the context.\n device: The device to use for the convolutional layer.\n dtype: The dtype to use for the convolutional layer.\n \"\"\"\n self.channels = channels\n self.in_channels = channels\n self.out_channels = channels\n self.scale_factor = scale_factor\n self.padding = padding\n\n super().__init__(\n Conv2d(\n in_channels=channels,\n out_channels=channels,\n kernel_size=3,\n stride=scale_factor,\n padding=padding,\n device=device,\n dtype=dtype,\n ),\n )\n\n if padding == 0:\n zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))\n self.insert(\n index=0,\n module=Lambda(func=zero_pad),\n )\n\n if register_shape:\n self.insert(\n index=0,\n module=SetContext(\n context=\"sampling\",\n key=\"shapes\",\n callback=self.register_shape,\n ),\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Embedding","title":"Embedding","text":"Embedding(\n num_embeddings: int,\n embedding_dim: int,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Embedding
, WeightedModule
Embedding layer.
This layer wraps torch.nn.Embedding
.
Receives:
Type DescriptionInt[Tensor, 'batch length']
Returns:
Type DescriptionFloat[Tensor, 'batch length embedding_dim']
Example embedding = fl.Embedding(\n num_embeddings=10,\n embedding_dim=128\n)\n\ntensor = torch.randint(0, 10, (2, 10))\noutput = embedding(tensor)\n\nassert output.shape == (2, 10, 128)\n
Parameters:
Name Type Description Defaultnum_embeddings
int
The number of embeddings.
requiredembedding_dim
int
The dimension of the embeddings.
requireddevice
device | str | None
The device to use for the embedding layer.
None
dtype
dtype | None
The dtype to use for the embedding layer.
None
Source code in src/refiners/fluxion/layers/embedding.py
def __init__(\n self,\n num_embeddings: int,\n embedding_dim: int,\n device: Device | str | None = None,\n dtype: DType | None = None,\n):\n \"\"\"Initializes the Embedding layer.\n\n Args:\n num_embeddings: The number of embeddings.\n embedding_dim: The dimension of the embeddings.\n device: The device to use for the embedding layer.\n dtype: The dtype to use for the embedding layer.\n \"\"\"\n _Embedding.__init__( # type: ignore\n self,\n num_embeddings=num_embeddings,\n embedding_dim=embedding_dim,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Flatten","title":"Flatten","text":"Flatten(start_dim: int = 0, end_dim: int = -1)\n
Bases: Module
Flatten operation layer.
This layer flattens the input tensor between the given dimensions. See also torch.flatten
.
flatten = fl.Flatten(start_dim=1)\n\ntensor = torch.randn(10, 10, 10)\noutput = flatten(tensor)\n\nassert output.shape == (10, 100)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n self,\n start_dim: int = 0,\n end_dim: int = -1,\n) -> None:\n super().__init__()\n self.start_dim = start_dim\n self.end_dim = end_dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GLU","title":"GLU","text":"GLU(activation: Activation)\n
Bases: Activation
Gated Linear Unit activation function.
See [arXiv:2002.05202] GLU Variants Improve Transformer for more details.
Exampleglu = fl.GLU(fl.ReLU())\ntensor = torch.tensor([[1.0, 0.0, -1.0, 1.0]])\noutput = glu(tensor)\nassert torch.allclose(output, torch.tensor([0.0, 0.0]))\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self, activation: Activation) -> None:\n super().__init__()\n self.activation = activation\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GeLU","title":"GeLU","text":"GeLU(approximation: GeLUApproximation = NONE)\n
Bases: Activation
Gaussian Error Linear Unit activation function.
This activation can be quite expensive to compute, a few approximations are available, see GeLUApproximation
.
See [arXiv:1606.08415] Gaussian Error Linear Units for more details.
Examplegelu = fl.GeLU()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = gelu(tensor)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(\n self,\n approximation: GeLUApproximation = GeLUApproximation.NONE,\n) -> None:\n super().__init__()\n self.approximation = approximation\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GeLUApproximation","title":"GeLUApproximation","text":" Bases: Enum
Approximation methods for the Gaussian Error Linear Unit activation function.
Attributes:
Name Type DescriptionNONE
No approximation, use the original formula.
TANH
Use the tanh approximation.
SIGMOID
Use the sigmoid approximation.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GetArg","title":"GetArg","text":"GetArg(index: int)\n
Bases: Module
GetArg operation layer.
This layer returns the nth tensor of the input arguments.
Exampleget_arg = fl.GetArg(1)\n\ninputs = (\n torch.randn(10, 10),\n torch.randn(20, 20),\n torch.randn(30, 30),\n)\noutput = get_arg(*inputs)\n\nassert id(inputs[1]) == id(output)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, index: int) -> None:\n super().__init__()\n self.index = index\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.GroupNorm","title":"GroupNorm","text":"GroupNorm(\n channels: int,\n num_groups: int,\n eps: float = 1e-05,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: GroupNorm
, WeightedModule
Group Normalization layer.
This layer wraps torch.nn.GroupNorm
.
Receives:
Type DescriptionFloat[Tensor, 'batch channels *normalized_shape']
Returns:
Type DescriptionFloat[Tensor, 'batch channels *normalized_shape']
Example groupnorm = fl.GroupNorm(channels=128, num_groups=8)\n\ntensor = torch.randn(2, 128, 8)\noutput = groupnorm(tensor)\n\nassert output.shape == (2, 128, 8)\n
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n self,\n channels: int,\n num_groups: int,\n eps: float = 1e-5,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__( # type: ignore\n num_groups=num_groups,\n num_channels=channels,\n eps=eps,\n affine=True, # otherwise not a WeightedModule\n device=device,\n dtype=dtype,\n )\n self.channels = channels\n self.num_groups = num_groups\n self.eps = eps\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Identity","title":"Identity","text":"Identity()\n
Bases: Module
Identity operator layer.
This layer simply returns the input tensor.
Exampleidentity = fl.Identity()\n\ntensor = torch.randn(10, 10)\noutput = identity(tensor)\n\nassert torch.equal(tensor, output)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self) -> None:\n super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.InstanceNorm2d","title":"InstanceNorm2d","text":"InstanceNorm2d(\n num_features: int,\n eps: float = 1e-05,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: InstanceNorm2d
, Module
Instance Normalization layer.
This layer wraps torch.nn.InstanceNorm2d
.
Receives:
Type DescriptionFloat[Tensor, 'batch channels height width']
Returns:
Type DescriptionFloat[Tensor, 'batch channels height width']
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n self,\n num_features: int,\n eps: float = 1e-05,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__( # type: ignore\n num_features=num_features,\n eps=eps,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Interpolate","title":"Interpolate","text":"Interpolate(mode: str = 'nearest', antialias: bool = False)\n
Bases: Module
Interpolate layer.
This layer wraps torch.nn.functional.interpolate
.
src/refiners/fluxion/layers/sampling.py
def __init__(\n self,\n mode: str = \"nearest\",\n antialias: bool = False,\n) -> None:\n super().__init__()\n self.mode = mode\n self.antialias = antialias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Lambda","title":"Lambda","text":"Lambda(func: Callable[..., Any])\n
Bases: Module
Lambda layer.
This layer wraps a Callable
.
Callable
with the given argumentsCallable
)lambda_layer = fl.Lambda(lambda x: x + 1)\n\ntensor = torch.tensor([1, 2, 3])\noutput = lambda_layer(tensor)\n\nexpected_output = torch.tensor([2, 3, 4])\nassert torch.allclose(output, expected_output)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, func: Callable[..., Any]) -> None:\n super().__init__()\n self.func = func\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.LayerNorm","title":"LayerNorm","text":"LayerNorm(\n normalized_shape: int | list[int],\n eps: float = 1e-05,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: LayerNorm
, WeightedModule
Layer Normalization layer.
This layer wraps torch.nn.LayerNorm
.
Receives:
Type DescriptionFloat[Tensor, batch * normalized_shape]
Returns:
Type DescriptionFloat[Tensor, batch * normalized_shape]
Example layernorm = fl.LayerNorm(normalized_shape=128)\n\ntensor = torch.randn(2, 128)\noutput = layernorm(tensor)\n\nassert output.shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n self,\n normalized_shape: int | list[int],\n eps: float = 0.00001,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__( # type: ignore\n normalized_shape=normalized_shape,\n eps=eps,\n elementwise_affine=True, # otherwise not a WeightedModule\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.LayerNorm2d","title":"LayerNorm2d","text":"LayerNorm2d(\n channels: int,\n eps: float = 1e-06,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: WeightedModule
2D Layer Normalization layer.
This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.
Receives:
Type DescriptionFloat[Tensor, 'batch channels height width']
Returns:
Type DescriptionFloat[Tensor, 'batch channels height width']
Source code in src/refiners/fluxion/layers/norm.py
def __init__(\n self,\n channels: int,\n eps: float = 1e-6,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__()\n self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))\n self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))\n self.eps = eps\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Linear","title":"Linear","text":"Linear(\n in_features: int,\n out_features: int,\n bias: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Linear
, WeightedModule
Linear layer.
This layer wraps torch.nn.Linear
.
Receives:
Name Type DescriptionInput
Float[Tensor, 'batch in_features']
Returns:
Name Type DescriptionOutput
Float[Tensor, 'batch out_features']
Example linear = fl.Linear(in_features=32, out_features=128)\n\ntensor = torch.randn(2, 32)\noutput = linear(tensor)\n\nassert output.shape == (2, 128)\n
Parameters:
Name Type Description Defaultin_features
int
The number of input features.
requiredout_features
int
The number of output features.
requiredbias
bool
If True, adds a learnable bias to the output.
True
device
device | str | None
The device to use for the linear layer.
None
dtype
dtype | None
The dtype to use for the linear layer.
None
Source code in src/refiners/fluxion/layers/linear.py
def __init__(\n self,\n in_features: int,\n out_features: int,\n bias: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initializes the Linear layer.\n\n Args:\n in_features: The number of input features.\n out_features: The number of output features.\n bias: If True, adds a learnable bias to the output.\n device: The device to use for the linear layer.\n dtype: The dtype to use for the linear layer.\n \"\"\"\n self.in_features = in_features\n self.out_features = out_features\n super().__init__( # type: ignore\n in_features=in_features,\n out_features=out_features,\n bias=bias,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Matmul","title":"Matmul","text":"Matmul(input: Module, other: Module)\n
Bases: Chain
Matrix multiplication layer.
This layer returns the matrix multiplication of the outputs of its two sub-modules.
Examplematmul = fl.Matmul(\n fl.Identity(),\n fl.Multiply(scale=2),\n)\n\ntensor = torch.randn(10, 10)\noutput = matmul(tensor)\n\nexpected_output = tensor @ (2 * tensor)\nassert torch.allclose(output, expected_output)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, input: Module, other: Module) -> None:\n super().__init__(\n input,\n other,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MaxPool1d","title":"MaxPool1d","text":"MaxPool1d(\n kernel_size: int,\n stride: int | None = None,\n padding: int = 0,\n dilation: int = 1,\n return_indices: bool = False,\n ceil_mode: bool = False,\n)\n
Bases: MaxPool1d
, Module
MaxPool1d layer.
This layer wraps torch.nn.MaxPool1d
.
Receives:
Type DescriptionFloat[Tensor, 'batch channels in_length']
Returns:
Type DescriptionFloat[Tensor, 'batch channels out_length']
Parameters:
Name Type Description Defaultkernel_size
int
The size of the sliding window.
requiredstride
int | None
The stride of the sliding window.
None
padding
int
The amount of zero-padding added to both sides of the input.
0
dilation
int
The spacing between kernel elements.
1
return_indices
bool
If True, returns the max indices along with the outputs.
False
ceil_mode
bool
If True, uses ceil instead of floor to compute the output shape.
False
Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(\n self,\n kernel_size: int,\n stride: int | None = None,\n padding: int = 0,\n dilation: int = 1,\n return_indices: bool = False,\n ceil_mode: bool = False,\n) -> None:\n \"\"\"Initializes the MaxPool1d layer.\n\n Args:\n kernel_size: The size of the sliding window.\n stride: The stride of the sliding window.\n padding: The amount of zero-padding added to both sides of the input.\n dilation: The spacing between kernel elements.\n return_indices: If True, returns the max indices along with the outputs.\n ceil_mode: If True, uses ceil instead of floor to compute the output shape.\n \"\"\"\n super().__init__(\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n dilation=dilation,\n return_indices=return_indices,\n ceil_mode=ceil_mode,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MaxPool2d","title":"MaxPool2d","text":"MaxPool2d(\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] | None = None,\n padding: int | tuple[int, int] = (0, 0),\n dilation: int | tuple[int, int] = (1, 1),\n return_indices: bool = False,\n ceil_mode: bool = False,\n)\n
Bases: MaxPool2d
, Module
MaxPool2d layer.
This layer wraps torch.nn.MaxPool2d
.
Receives:
Type DescriptionFloat[Tensor, 'batch channels in_height in_width']
Returns:
Type DescriptionFloat[Tensor, 'batch channels out_height out_width']
Parameters:
Name Type Description Defaultkernel_size
int | tuple[int, int]
The size of the sliding window.
requiredstride
int | tuple[int, int] | None
The stride of the sliding window.
None
padding
int | tuple[int, int]
The amount of zero-padding added to both sides of the input.
(0, 0)
dilation
int | tuple[int, int]
The spacing between kernel elements.
(1, 1)
return_indices
bool
If True, returns the max indices along with the outputs.
False
ceil_mode
bool
If True, uses ceil instead of floor to compute the output shape.
False
Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(\n self,\n kernel_size: int | tuple[int, int],\n stride: int | tuple[int, int] | None = None,\n padding: int | tuple[int, int] = (0, 0),\n dilation: int | tuple[int, int] = (1, 1),\n return_indices: bool = False,\n ceil_mode: bool = False,\n) -> None:\n \"\"\"Initializes the MaxPool2d layer.\n\n Args:\n kernel_size: The size of the sliding window.\n stride: The stride of the sliding window.\n padding: The amount of zero-padding added to both sides of the input.\n dilation: The spacing between kernel elements.\n return_indices: If True, returns the max indices along with the outputs.\n ceil_mode: If True, uses ceil instead of floor to compute the output shape.\n \"\"\"\n super().__init__(\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n dilation=dilation,\n return_indices=return_indices,\n ceil_mode=ceil_mode,\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module","title":"Module","text":"Module(*args: Any, **kwargs: Any)\n
Bases: Module
A wrapper around torch.nn.Module
.
src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.basic_attributes","title":"basic_attributes","text":"basic_attributes(\n init_attrs_only: bool = False,\n) -> dict[str, BasicType | Sequence[BasicType]]\n
Return a dictionary of basic attributes of the module.
Parameters:
Name Type Description Defaultinit_attrs_only
bool
Whether to only return attributes that are passed to the module's constructor.
False
Source code in src/refiners/fluxion/layers/module.py
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:\n \"\"\"Return a dictionary of basic attributes of the module.\n\n Args:\n init_attrs_only: Whether to only return attributes that are passed to the module's constructor.\n \"\"\"\n sig = signature(obj=self.__init__)\n init_params = set(sig.parameters.keys()) - {\"self\"}\n default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}\n\n def is_basic_attribute(key: str, value: Any) -> bool:\n if key.startswith(\"_\"):\n return False\n\n if isinstance(value, BasicType):\n return True\n\n if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):\n return True\n\n return False\n\n return {\n key: value\n for key, value in self.__dict__.items()\n if is_basic_attribute(key=key, value=value)\n and (not init_attrs_only or (key in init_params and value != default_values.get(key)))\n }\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.get_path","title":"get_path","text":"get_path(\n parent: Chain | None = None, top: Module | None = None\n) -> str\n
Get the path of the module in the chain.
Parameters:
Name Type Description Defaultparent
Chain | None
The parent of the module in the chain.
None
top
Module | None
The top module of the chain. If None, the path will be relative to the root of the chain.
None
Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: \"Chain | None\" = None, top: \"Module | None\" = None) -> str:\n \"\"\"Get the path of the module in the chain.\n\n Args:\n parent: The parent of the module in the chain.\n top: The top module of the chain.\n If None, the path will be relative to the root of the chain.\n \"\"\"\n if (parent is None) or (self == top):\n return self.__class__.__name__\n for k, m in parent._modules.items(): # type: ignore\n if m is self:\n return parent.get_path(parent=parent.parent, top=top) + \".\" + k\n raise ValueError(f\"{self} not found in {parent}\")\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.load_from_safetensors","title":"load_from_safetensors","text":"load_from_safetensors(\n tensors_path: str | Path, strict: bool = True\n) -> T\n
Load the module's state from a SafeTensors file.
Parameters:
Name Type Description Defaulttensors_path
str | Path
The path to the SafeTensors file.
requiredstrict
bool
Whether to raise an error if the SafeTensors's content doesn't map perfectly to the module's state.
True
Returns:
Type DescriptionT
The module, with its state loaded from the SafeTensors file.
Source code insrc/refiners/fluxion/layers/module.py
def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T:\n \"\"\"Load the module's state from a SafeTensors file.\n\n Args:\n tensors_path: The path to the SafeTensors file.\n strict: Whether to raise an error if the SafeTensors's\n content doesn't map perfectly to the module's state.\n\n Returns:\n The module, with its state loaded from the SafeTensors file.\n \"\"\"\n state_dict = load_from_safetensors(tensors_path)\n self.load_state_dict(state_dict, strict=strict)\n return self\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.named_modules","title":"named_modules","text":"named_modules(\n *args: Any, **kwargs: Any\n) -> Generator[tuple[str, Module], None, None]\n
Get all the sub-modules of the module.
Returns:
Type DescriptionNone
An iterator over all the sub-modules of the module.
Source code insrc/refiners/fluxion/layers/module.py
def named_modules(self, *args: Any, **kwargs: Any) -> \"Generator[tuple[str, Module], None, None]\": # type: ignore\n \"\"\"Get all the sub-modules of the module.\n\n Returns:\n An iterator over all the sub-modules of the module.\n \"\"\"\n return super().named_modules(*args) # type: ignore\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.pretty_print","title":"pretty_print","text":"pretty_print(depth: int = -1) -> None\n
Print the module in a tree-like format.
Parameters:
Name Type Description Defaultdepth
int
The maximum depth of the tree to print. If negative, the whole tree is printed.
-1
Source code in src/refiners/fluxion/layers/module.py
def pretty_print(self, depth: int = -1) -> None:\n \"\"\"Print the module in a tree-like format.\n\n Args:\n depth: The maximum depth of the tree to print.\n If negative, the whole tree is printed.\n \"\"\"\n tree = ModuleTree(module=self)\n print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Module.to","title":"to","text":"to(\n device: device | str | None = None,\n dtype: dtype | None = None,\n) -> T\n
Move the module to the given device and cast its parameters to the given dtype.
Parameters:
Name Type Description Defaultdevice
device | str | None
The device to move the module to.
None
dtype
dtype | None
The dtype to cast the module's parameters to.
None
Returns:
Type DescriptionT
The module, moved to the given device and cast to the given dtype.
Source code insrc/refiners/fluxion/layers/module.py
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore\n \"\"\"Move the module to the given device and cast its parameters to the given dtype.\n\n Args:\n device: The device to move the module to.\n dtype: The dtype to cast the module's parameters to.\n\n Returns:\n The module, moved to the given device and cast to the given dtype.\n \"\"\"\n return super().to(device=device, dtype=dtype) # type: ignore\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.MultiLinear","title":"MultiLinear","text":"MultiLinear(\n input_dim: int,\n output_dim: int,\n inner_dim: int,\n num_layers: int,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Multi-layer linear network.
This layer wraps multiple torch.nn.Linear
layers, with an Activation
layer in between.
Receives:
Name Type DescriptionInput
Float[Tensor, 'batch input_dim']
Returns:
Name Type DescriptionOutput
Float[Tensor, 'batch output_dim']
Example linear = fl.MultiLinear(\n input_dim=32,\n output_dim=128,\n inner_dim=64,\n num_layers=3,\n)\n\ntensor = torch.randn(2, 32)\noutput = linear(tensor)\n\nassert output.shape == (2, 128)\n
Parameters:
Name Type Description Defaultinput_dim
int
The input dimension of the first linear layer.
requiredoutput_dim
int
The output dimension of the last linear layer.
requiredinner_dim
int
The output dimension of the inner linear layers.
requirednum_layers
int
The number of linear layers.
requireddevice
device | str | None
The device to use for the linear layers.
None
dtype
dtype | None
The dtype to use for the linear layers.
None
Source code in src/refiners/fluxion/layers/linear.py
def __init__(\n self,\n input_dim: int,\n output_dim: int,\n inner_dim: int,\n num_layers: int,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initializes the MultiLinear layer.\n\n Args:\n input_dim: The input dimension of the first linear layer.\n output_dim: The output dimension of the last linear layer.\n inner_dim: The output dimension of the inner linear layers.\n num_layers: The number of linear layers.\n device: The device to use for the linear layers.\n dtype: The dtype to use for the linear layers.\n \"\"\"\n layers: list[Module] = []\n for i in range(num_layers - 1):\n layers.append(\n Linear(\n in_features=input_dim if i == 0 else inner_dim,\n out_features=inner_dim,\n device=device,\n dtype=dtype,\n )\n )\n layers.append(\n ReLU(),\n )\n layers.append(\n Linear(\n in_features=inner_dim,\n out_features=output_dim,\n device=device,\n dtype=dtype,\n )\n )\n\n super().__init__(layers)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Multiply","title":"Multiply","text":"Multiply(scale: float = 1.0, bias: float = 0.0)\n
Bases: Module
Multiply operator layer.
This layer scales and shifts the input tensor by the given scale and bias.
Examplemultiply = fl.Multiply(scale=2, bias=1)\n\ntensor = torch.ones(1)\noutput = multiply(tensor)\n\nassert torch.allclose(output, torch.tensor([3.0]))\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n self,\n scale: float = 1.0,\n bias: float = 0.0,\n) -> None:\n super().__init__()\n self.scale = scale\n self.bias = bias\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Parallel","title":"Parallel","text":"Parallel(*args: Module | Iterable[Module])\n
Bases: Chain
Parallel layer.
This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.
Exampleparallel = fl.Parallel(\n fl.Linear(32, 64),\n fl.Identity(),\n fl.Linear(32, 128),\n)\n\ntensor = torch.randn(2, 32)\noutputs = parallel(tensor)\n\nassert len(outputs) == 3\nassert outputs[0].shape == (2, 64)\nassert torch.allclose(outputs[1], tensor)\nassert outputs[2].shape == (2, 128)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Parameter","title":"Parameter","text":"Parameter(\n *dims: int,\n requires_grad: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None\n)\n
Bases: WeightedModule
Parameter layer.
This layer simple wraps a PyTorch Parameter
. When called, it simply returns the Parameter
Tensor.
Attributes:
Name Type Descriptionweight
Parameter
The parameter Tensor.
Source code insrc/refiners/fluxion/layers/basics.py
def __init__(\n self,\n *dims: int,\n requires_grad: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__()\n self.dims = dims\n self.weight = TorchParameter(\n requires_grad=requires_grad,\n data=torch.randn(\n *dims,\n device=device,\n dtype=dtype,\n ),\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Passthrough","title":"Passthrough","text":"Passthrough(*args: Module | Iterable[Module])\n
Bases: Chain
Passthrough layer.
This layer call its sub-modules sequentially, and returns its original inputs, like an Identity
layer.
passthrough = fl.Passthrough(\n fl.Linear(32, 128),\n fl.ReLU(),\n fl.Linear(128, 128),\n)\n\ntensor = torch.randn(2, 32)\noutput = passthrough(tensor)\n\nassert torch.allclose(output, tensor)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Permute","title":"Permute","text":"Permute(*dims: int)\n
Bases: Module
Permute operation layer.
This layer permutes the input tensor according to the given dimensions. See also torch.permute
.
permute = fl.Permute(2, 0, 1)\n\ntensor = torch.randn(10, 20, 30)\noutput = permute(tensor)\n\nassert output.shape == (30, 10, 20)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *dims: int) -> None:\n super().__init__()\n self.dims = dims\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.PixelUnshuffle","title":"PixelUnshuffle","text":"PixelUnshuffle(downscale_factor: int)\n
Bases: PixelUnshuffle
, Module
Pixel Unshuffle layer.
This layer wraps torch.nn.PixelUnshuffle
.
Receives:
Type DescriptionFloat[Tensor, 'batch in_channels in_height in_width']
Returns:
Type DescriptionFloat[Tensor, 'batch out_channels out_height out_width']
Source code in src/refiners/fluxion/layers/pixelshuffle.py
def __init__(self, downscale_factor: int):\n _PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ReLU","title":"ReLU","text":"ReLU()\n
Bases: Activation
Rectified Linear Unit activation function.
See Rectified Linear Units Improve Restricted Boltzmann Machines and Cognitron: A self-organizing multilayered neural network
Examplerelu = fl.ReLU()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = relu(tensor)\n\nexpected_output = torch.tensor([[0.0, 0.0, 1.0]])\nassert torch.equal(output, expected_output)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ReflectionPad2d","title":"ReflectionPad2d","text":"ReflectionPad2d(padding: int)\n
Bases: ReflectionPad2d
, Module
Reflection padding layer.
This layer wraps torch.nn.ReflectionPad2d
.
Receives:
Type DescriptionFloat[Tensor, 'batch channels in_height in_width']
Returns:
Type DescriptionFloat[Tensor, 'batch channels out_height out_width']
Source code in src/refiners/fluxion/layers/padding.py
def __init__(self, padding: int) -> None:\n super().__init__(padding=padding)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Reshape","title":"Reshape","text":"Reshape(*shape: int)\n
Bases: Module
Reshape operation layer.
This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape). See also torch.reshape.
WarningThe first dimension (batch dimension) is forcefully preserved.
Examplereshape = fl.Reshape(5, 2)\n\ntensor = torch.randn(2, 10, 1)\noutput = reshape(tensor)\n\nassert output.shape == (2, 5, 2)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *shape: int) -> None:\n super().__init__()\n self.shape = shape\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Residual","title":"Residual","text":"Residual(*args: Module | Iterable[Module])\n
Bases: Chain
Residual layer.
This layer calls its sub-modules sequentially, and adds the original input to the output.
Exampleresidual = fl.Residual(\n fl.Multiply(scale=10),\n)\n\ntensor = torch.ones(2, 32)\noutput = residual(tensor)\n\nassert output.shape == (2, 32)\nassert torch.allclose(output, 10 * tensor + tensor)\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Return","title":"Return","text":"Return(*args: Any, **kwargs: Any)\n
Bases: Module
Return layer.
This layer stops the execution of a Chain when encountered.
Source code insrc/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.ScaledDotProductAttention","title":"ScaledDotProductAttention","text":"ScaledDotProductAttention(\n num_heads: int = 1,\n is_causal: bool = False,\n is_optimized: bool = True,\n slice_size: int | None = None,\n)\n
Bases: Module
Scaled Dot Product Attention.
See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details NoteThis layer simply wraps scaled_dot_product_attention
inside an fl.Module
.
Receives:
Name Type DescriptionQuery
Float[Tensor, 'batch num_queries embedding_dim']
Key
Float[Tensor, 'batch num_keys embedding_dim']
Value
Float[Tensor, 'batch num_values embedding_dim']
Returns:
Type DescriptionFloat[Tensor, 'batch num_queries embedding_dim']
Example attention = fl.ScaledDotProductAttention(num_heads=8)\n\nquery = torch.randn(2, 10, 128)\nkey = torch.randn(2, 10, 128)\nvalue = torch.randn(2, 10, 128)\noutput = attention(query, key, value)\n\nassert output.shape == (2, 10, 128)\n
Parameters:
Name Type Description Defaultnum_heads
int
The number of heads to use.
1
is_causal
bool
Whether to use causal attention.
False
is_optimized
bool
Whether to use optimized attention.
True
slice_size
int | None
The slice size to use for the optimized attention.
None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n self,\n num_heads: int = 1,\n is_causal: bool = False,\n is_optimized: bool = True,\n slice_size: int | None = None,\n) -> None:\n \"\"\"Initialize the Scaled Dot Product Attention layer.\n\n Args:\n num_heads: The number of heads to use.\n is_causal: Whether to use causal attention.\n is_optimized: Whether to use optimized attention.\n slice_size: The slice size to use for the optimized attention.\n \"\"\"\n super().__init__()\n self.num_heads = num_heads\n self.is_causal = is_causal\n self.is_optimized = is_optimized\n self.slice_size = slice_size\n self.dot_product = (\n scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SelfAttention","title":"SelfAttention","text":"SelfAttention(\n embedding_dim: int,\n inner_dim: int | None = None,\n num_heads: int = 1,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Attention
Multi-Head Self-Attention layer.
This layer simply chainsParallel
layer, which duplicates the input Tensor (for each Linear layer in the Attention
layer)Attention
layerReceives:
Type DescriptionFloat[Tensor, 'batch sequence_length embedding_dim']
Returns:
Type DescriptionFloat[Tensor, 'batch sequence_length embedding_dim']
Example self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)\n\ntensor = torch.randn(2, 10, 128)\noutput = self_attention(tensor)\n\nassert output.shape == (2, 10, 128)\n
Parameters:
Name Type Description Defaultembedding_dim
int
The embedding dimension of the input and output tensors.
requiredinner_dim
int | None
The inner dimension of the linear layers.
None
num_heads
int
The number of heads to use.
1
use_bias
bool
Whether to use bias in the linear layers.
True
is_causal
bool
Whether to use causal attention.
False
is_optimized
bool
Whether to use optimized attention.
True
device
device | str | None
The device to use.
None
dtype
dtype | None
The dtype to use.
None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n self,\n embedding_dim: int,\n inner_dim: int | None = None,\n num_heads: int = 1,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the Self-Attention layer.\n\n Args:\n embedding_dim: The embedding dimension of the input and output tensors.\n inner_dim: The inner dimension of the linear layers.\n num_heads: The number of heads to use.\n use_bias: Whether to use bias in the linear layers.\n is_causal: Whether to use causal attention.\n is_optimized: Whether to use optimized attention.\n device: The device to use.\n dtype: The dtype to use.\n \"\"\"\n super().__init__(\n embedding_dim=embedding_dim,\n inner_dim=inner_dim,\n num_heads=num_heads,\n use_bias=use_bias,\n is_causal=is_causal,\n is_optimized=is_optimized,\n device=device,\n dtype=dtype,\n )\n self.insert(\n index=0,\n module=Parallel(\n Identity(), # Query projection's input\n Identity(), # Key projection's input\n Identity(), # Value projection's input\n ),\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SelfAttention2d","title":"SelfAttention2d","text":"SelfAttention2d(\n channels: int,\n num_heads: int = 1,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: SelfAttention
Multi-Head 2D Self-Attention layer.
This Module simply chainsLambda
layer, which transforms the input Tensor into a sequenceSelfAttention
layerLambda
layer, which transforms the output sequence into a 2D TensorReceives:
Type DescriptionFloat[Tensor, 'batch channels height width']
Returns:
Type DescriptionFloat[Tensor, 'batch channels height width']
Example self_attention = fl.SelfAttention2d(channels=128, num_heads=8)\n\ntensor = torch.randn(2, 128, 64, 64)\noutput = self_attention(tensor)\n\nassert output.shape == (2, 128, 64, 64)\n
Parameters:
Name Type Description Defaultchannels
int
The number of channels of the input and output tensors.
requirednum_heads
int
The number of heads to use.
1
use_bias
bool
Whether to use bias in the linear layers.
True
is_causal
bool
Whether to use causal attention.
False
is_optimized
bool
Whether to use optimized attention.
True
device
device | str | None
The device to use.
None
dtype
dtype | None
The dtype to use.
None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(\n self,\n channels: int,\n num_heads: int = 1,\n use_bias: bool = True,\n is_causal: bool = False,\n is_optimized: bool = True,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the 2D Self-Attention layer.\n\n Args:\n channels: The number of channels of the input and output tensors.\n num_heads: The number of heads to use.\n use_bias: Whether to use bias in the linear layers.\n is_causal: Whether to use causal attention.\n is_optimized: Whether to use optimized attention.\n device: The device to use.\n dtype: The dtype to use.\n \"\"\"\n assert channels % num_heads == 0, f\"channels {channels} must be divisible by num_heads {num_heads}\"\n self.channels = channels\n\n super().__init__(\n embedding_dim=channels,\n num_heads=num_heads,\n use_bias=use_bias,\n is_causal=is_causal,\n is_optimized=is_optimized,\n device=device,\n dtype=dtype,\n )\n\n self.insert(0, Lambda(self._tensor_2d_to_sequence))\n self.append(Lambda(self._sequence_to_tensor_2d))\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SetContext","title":"SetContext","text":"SetContext(\n context: str,\n key: str,\n callback: Callable[[Any, Any], Any] | None = None,\n)\n
Bases: ContextModule
SetContext layer.
This layer writes to the ContextProvider
of its parent Chain
.
The context needs to already exist in the ContextProvider
src/refiners/fluxion/layers/chain.py
def __init__(\n self,\n context: str,\n key: str,\n callback: Callable[[Any, Any], Any] | None = None,\n) -> None:\n super().__init__()\n self.context = context\n self.key = key\n self.callback = callback\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.SiLU","title":"SiLU","text":"SiLU()\n
Bases: Activation
Sigmoid Linear Unit activation function.
See [arXiv:1702.03118] Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning for more details.
Source code insrc/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sigmoid","title":"Sigmoid","text":"Sigmoid()\n
Bases: Activation
Sigmoid activation function.
Examplesigmoid = fl.Sigmoid()\n\ntensor = torch.tensor([[-1.0, 0.0, 1.0]])\noutput = sigmoid(tensor)\n
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:\n super().__init__()\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sin","title":"Sin","text":"Sin(*args: Any, **kwargs: Any)\n
Bases: Module
Sine operator layer.
This layer applies the sine function to the input tensor. See also torch.sin
.
sin = fl.Sin()\n\ntensor = torch.tensor([0, torch.pi])\noutput = sin(tensor)\n\nexpected_output = torch.tensor([0.0, 0.0])\nassert torch.allclose(output, expected_output, atol=1e-6)\n
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Slicing","title":"Slicing","text":"Slicing(\n dim: int = 0,\n start: int = 0,\n end: int | None = None,\n step: int = 1,\n)\n
Bases: Module
Slicing operation layer.
This layer slices the input tensor at the given dimension between the given start and end indices. See also torch.index_select
.
slicing = fl.Slicing(dim=1, start=50)\n\ntensor = torch.randn(10, 100)\noutput = slicing(tensor)\n\nassert output.shape == (10, 50)\nassert torch.allclose(output, tensor[:, 50:])\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(\n self,\n dim: int = 0,\n start: int = 0,\n end: int | None = None,\n step: int = 1,\n) -> None:\n super().__init__()\n self.dim = dim\n self.start = start\n self.end = end\n self.step = step\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Squeeze","title":"Squeeze","text":"Squeeze(dim: int)\n
Bases: Module
Squeeze operation layer.
This layer squeezes the input tensor at the given dimension. See also torch.squeeze
.
squeeze = fl.Squeeze(dim=1)\n\ntensor = torch.randn(10, 1, 10)\noutput = squeeze(tensor)\n\nassert output.shape == (10, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n super().__init__()\n self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Sum","title":"Sum","text":"Sum(*args: Module | Iterable[Module])\n
Bases: Chain
Summation layer.
This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.
Examplesummation = fl.Sum(\n fl.Multiply(scale=2, bias=1),\n fl.Multiply(scale=3, bias=0),\n)\n\ntensor = torch.ones(1)\noutput = summation(tensor)\n\nassert torch.allclose(output, torch.tensor([6.0]))\n
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:\n super().__init__()\n self._provider = ContextProvider()\n modules = cast(\n tuple[Module],\n (\n tuple(args[0])\n if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)\n else tuple(args)\n ),\n )\n\n for module in modules:\n # Violating this would mean a ContextModule ends up in two chains,\n # with a single one correctly set as its parent.\n assert (\n (not isinstance(module, ContextModule))\n or (not module._can_refresh_parent)\n or (module.parent is None)\n or (module.parent == self)\n ), f\"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}\"\n\n self._regenerate_keys(modules)\n self._reset_context()\n\n for module in self:\n if isinstance(module, ContextModule) and module.parent != self:\n module._set_parent(self)\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Transpose","title":"Transpose","text":"Transpose(dim0: int, dim1: int)\n
Bases: Module
Transpose operation layer.
This layer transposes the input tensor between the two given dimensions. See also torch.transpose
.
transpose = fl.Transpose(dim0=1, dim1=2)\n\ntensor = torch.randn(10, 20, 30)\noutput = transpose(tensor)\n\nassert output.shape == (10, 30, 20)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim0: int, dim1: int) -> None:\n super().__init__()\n self.dim0 = dim0\n self.dim1 = dim1\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Unflatten","title":"Unflatten","text":"Unflatten(dim: int)\n
Bases: Module
Unflatten operation layer.
This layer unflattens the input tensor at the given dimension with the given sizes. See also torch.unflatten
.
unflatten = fl.Unflatten(dim=1)\n\ntensor = torch.randn(10, 100)\noutput = unflatten(tensor, sizes=(10, 10))\n\nassert output_unflatten.shape == (10, 10, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n super().__init__()\n self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Unsqueeze","title":"Unsqueeze","text":"Unsqueeze(dim: int)\n
Bases: Module
Unsqueeze operation layer.
This layer unsqueezes the input tensor at the given dimension. See also torch.unsqueeze
.
unsqueeze = fl.Unsqueeze(dim=1)\n\ntensor = torch.randn(10, 10)\noutput = unsqueeze(tensor)\n\nassert output.shape == (10, 1, 10)\n
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:\n super().__init__()\n self.dim = dim\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.Upsample","title":"Upsample","text":"Upsample(\n channels: int,\n upsample_factor: int | None = None,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Upsample layer.
This layer upsamples the input by the given scale factor.
Raises:
Type DescriptionRuntimeError
If the context sampling is not set or if the context is empty.
Parameters:
Name Type Description Defaultchannels
int
The number of input and output channels.
requiredupsample_factor
int | None
The factor by which to upsample the input. If None, the input shape is taken from the context.
None
device
device | str | None
The device to use for the convolutional layer.
None
dtype
dtype | None
The dtype to use for the convolutional layer.
None
Source code in src/refiners/fluxion/layers/sampling.py
def __init__(\n self,\n channels: int,\n upsample_factor: int | None = None,\n device: Device | str | None = None,\n dtype: DType | None = None,\n):\n \"\"\"Initializes the Upsample layer.\n\n Args:\n channels: The number of input and output channels.\n upsample_factor: The factor by which to upsample the input.\n If None, the input shape is taken from the context.\n device: The device to use for the convolutional layer.\n dtype: The dtype to use for the convolutional layer.\n \"\"\"\n self.channels = channels\n self.upsample_factor = upsample_factor\n super().__init__(\n Parallel(\n Identity(),\n (\n Lambda(self._get_static_shape)\n if upsample_factor is not None\n else UseContext(context=\"sampling\", key=\"shapes\").compose(lambda x: x.pop())\n ),\n ),\n Interpolate(),\n Conv2d(\n in_channels=channels,\n out_channels=channels,\n kernel_size=3,\n padding=1,\n device=device,\n dtype=dtype,\n ),\n )\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.UseContext","title":"UseContext","text":"UseContext(context: str, key: str)\n
Bases: ContextModule
UseContext layer.
This layer reads from the ContextProvider
of its parent Chain
.
src/refiners/fluxion/layers/chain.py
def __init__(self, context: str, key: str) -> None:\n super().__init__()\n self.context = context\n self.key = key\n self.func: Callable[[Any], Any] = lambda x: x\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule","title":"WeightedModule","text":"WeightedModule(*args: Any, **kwargs: Any)\n
Bases: Module
A module with a weight (Tensor) attribute.
Source code insrc/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule.device","title":"device property
","text":"device: device\n
Return the device of the module's weight.
"},{"location":"reference/fluxion/layers/#refiners.fluxion.layers.WeightedModule.dtype","title":"dtypeproperty
","text":"dtype: dtype\n
Return the dtype of the module's weight.
"},{"location":"reference/fluxion/utils/","title":"
Utils","text":""},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.image_to_tensor","title":"image_to_tensor","text":"image_to_tensor(\n image: Image,\n device: device | str | None = None,\n dtype: dtype | None = None,\n) -> Tensor\n
Convert a PIL Image to a Tensor.
Parameters:
Name Type Description Defaultimage
Image
The image to convert.
requireddevice
device | str | None
The device to use for the tensor.
None
dtype
dtype | None
The dtype to use for the tensor.
None
Returns:
Type DescriptionTensor
The converted tensor.
NoteIf the image is in mode RGB
the tensor will have shape [3, H, W]
, otherwise [1, H, W]
for mode L
(grayscale) or [4, H, W]
for mode RGBA
.
Values are normalized to the range [0, 1]
.
src/refiners/fluxion/utils.py
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:\n \"\"\"Convert a PIL Image to a Tensor.\n\n Args:\n image: The image to convert.\n device: The device to use for the tensor.\n dtype: The dtype to use for the tensor.\n\n Returns:\n The converted tensor.\n\n Note:\n If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,\n otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.\n\n Values are normalized to the range `[0, 1]`.\n \"\"\"\n image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)\n\n assert isinstance(image.mode, str) # type: ignore\n match image.mode:\n case \"L\":\n image_tensor = image_tensor.unsqueeze(0)\n case \"RGBA\" | \"RGB\":\n image_tensor = image_tensor.permute(2, 0, 1)\n case _:\n raise ValueError(f\"Unsupported image mode: {image.mode}\")\n\n return image_tensor.unsqueeze(0)\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.load_from_safetensors","title":"load_from_safetensors","text":"load_from_safetensors(\n path: Path | str, device: device | str = \"cpu\"\n) -> dict[str, Tensor]\n
Load tensors from a SafeTensor file from disk.
Parameters:
Name Type Description Defaultpath
Path | str
The path to the file.
requireddevice
device | str
The device to use for the tensors.
'cpu'
Returns:
Type Descriptiondict[str, Tensor]
The loaded tensors.
Source code insrc/refiners/fluxion/utils.py
def load_from_safetensors(path: Path | str, device: Device | str = \"cpu\") -> dict[str, Tensor]:\n \"\"\"Load tensors from a SafeTensor file from disk.\n\n Args:\n path: The path to the file.\n device: The device to use for the tensors.\n\n Returns:\n The loaded tensors.\n \"\"\"\n return _load_file(path, str(device))\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.load_tensors","title":"load_tensors","text":"load_tensors(\n path: Path | str, /, device: device | str = \"cpu\"\n) -> dict[str, Tensor]\n
Load tensors from a file saved with torch.save
from disk.
This function uses the weights_only
mode of torch.load
for additional safety.
Still, only load data you trust and favor using load_from_safetensors
instead.
src/refiners/fluxion/utils.py
def load_tensors(path: Path | str, /, device: Device | str = \"cpu\") -> dict[str, Tensor]:\n \"\"\"Load tensors from a file saved with `torch.save` from disk.\n\n Note:\n This function uses the `weights_only` mode of `torch.load` for additional safety.\n\n Warning:\n Still, **only load data you trust** and favor using\n [`load_from_safetensors`][refiners.fluxion.utils.load_from_safetensors] instead.\n \"\"\"\n # see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560\n with warnings.catch_warnings():\n warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"TypedStorage is deprecated\")\n tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore\n\n assert isinstance(tensors, dict) and all(\n isinstance(key, str) and isinstance(value, Tensor)\n for key, value in tensors.items() # type: ignore\n ), \"Invalid tensor file, expected a dict[str, Tensor]\"\n\n return cast(dict[str, Tensor], tensors)\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.save_to_safetensors","title":"save_to_safetensors","text":"save_to_safetensors(\n path: Path | str,\n tensors: dict[str, Tensor],\n metadata: dict[str, str] | None = None,\n) -> None\n
Save tensors to a SafeTensor file on disk.
Parameters:
Name Type Description Defaultpath
Path | str
The path to the file.
requiredtensors
dict[str, Tensor]
The tensors to save.
requiredmetadata
dict[str, str] | None
The metadata to save.
None
Source code in src/refiners/fluxion/utils.py
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:\n \"\"\"Save tensors to a SafeTensor file on disk.\n\n Args:\n path: The path to the file.\n tensors: The tensors to save.\n metadata: The metadata to save.\n \"\"\"\n _save_file(tensors, path, metadata) # type: ignore\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.str_to_dtype","title":"str_to_dtype","text":"str_to_dtype(dtype: str) -> dtype\n
Converts a string dtype to a torch.dtype.
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
Source code insrc/refiners/fluxion/utils.py
def str_to_dtype(dtype: str) -> torch.dtype:\n \"\"\"Converts a string dtype to a torch.dtype.\n\n See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype\n \"\"\"\n match dtype.lower():\n case \"float32\" | \"float\":\n return torch.float32\n case \"float64\" | \"double\":\n return torch.float64\n case \"complex64\" | \"cfloat\":\n return torch.complex64\n case \"complex128\" | \"cdouble\":\n return torch.complex128\n case \"float16\" | \"half\":\n return torch.float16\n case \"bfloat16\":\n return torch.bfloat16\n case \"uint8\":\n return torch.uint8\n case \"int8\":\n return torch.int8\n case \"int16\" | \"short\":\n return torch.int16\n case \"int32\" | \"int\":\n return torch.int32\n case \"int64\" | \"long\":\n return torch.int64\n case \"bool\":\n return torch.bool\n case _:\n raise ValueError(f\"Unknown dtype: {dtype}\")\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.summarize_tensor","title":"summarize_tensor","text":"summarize_tensor(tensor: Tensor) -> str\n
Summarize a tensor.
This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.
Parameters:
Name Type Description Defaulttensor
Tensor
The tensor to summarize.
requiredReturns:
Type Descriptionstr
The summary string.
Source code insrc/refiners/fluxion/utils.py
def summarize_tensor(tensor: torch.Tensor, /) -> str:\n \"\"\"Summarize a tensor.\n\n This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.\n\n Args:\n tensor: The tensor to summarize.\n\n Returns:\n The summary string.\n \"\"\"\n info_list = [\n f\"shape=({', '.join(map(str, tensor.shape))})\",\n f\"dtype={str(object=tensor.dtype).removeprefix('torch.')}\",\n f\"device={tensor.device}\",\n ]\n if tensor.is_complex():\n tensor_f = tensor.real.float()\n else:\n if tensor.numel() > 0:\n info_list.extend(\n [\n f\"min={tensor.min():.2f}\", # type: ignore\n f\"max={tensor.max():.2f}\", # type: ignore\n ]\n )\n tensor_f = tensor.float()\n\n info_list.extend(\n [\n f\"mean={tensor_f.mean():.2f}\",\n f\"std={tensor_f.std():.2f}\",\n f\"norm={norm(x=tensor_f):.2f}\",\n f\"grad={tensor.requires_grad}\",\n ]\n )\n\n return \"Tensor(\" + \", \".join(info_list) + \")\"\n
"},{"location":"reference/fluxion/utils/#refiners.fluxion.utils.tensor_to_image","title":"tensor_to_image","text":"tensor_to_image(tensor: Tensor) -> Image\n
Convert a Tensor to a PIL Image.
Parameters:
Name Type Description Defaulttensor
Tensor
The tensor to convert.
requiredReturns:
Type DescriptionImage
The converted image.
NoteThe tensor must have shape [1, channels, height, width]
where the number of channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
Expected values are in the range [0, 1]
and are clamped to this range.
src/refiners/fluxion/utils.py
def tensor_to_image(tensor: Tensor) -> Image.Image:\n \"\"\"Convert a Tensor to a PIL Image.\n\n Args:\n tensor: The tensor to convert.\n\n Returns:\n The converted image.\n\n Note:\n The tensor must have shape `[1, channels, height, width]` where the number of\n channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).\n\n Expected values are in the range `[0, 1]` and are clamped to this range.\n \"\"\"\n assert tensor.ndim == 4 and tensor.shape[0] == 1, f\"Unsupported tensor shape: {tensor.shape}\"\n num_channels = tensor.shape[1]\n tensor = tensor.clamp(0, 1).squeeze(0)\n tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16\n\n match num_channels:\n case 1:\n tensor = tensor.squeeze(0)\n case 3 | 4:\n tensor = tensor.permute(1, 2, 0)\n case _:\n raise ValueError(f\"Unsupported number of channels: {num_channels}\")\n\n return Image.fromarray((tensor.cpu().numpy() * 255).astype(\"uint8\")) # type: ignore[reportUnknownType]\n
"},{"location":"reference/foundationals/clip/","title":"
CLIP","text":""},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoder","title":"CLIPImageEncoder","text":"CLIPImageEncoder(\n image_size: int = 224,\n embedding_dim: int = 768,\n output_dim: int = 512,\n patch_size: int = 32,\n num_layers: int = 12,\n num_attention_heads: int = 12,\n feedforward_dim: int = 3072,\n layer_norm_eps: float = 1e-05,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Contrastive Language-Image Pretraining (CLIP) image encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Parameters:
Name Type Description Defaultimage_size
int
The size of the input image.
224
embedding_dim
int
The dimension of the embedding.
768
output_dim
int
The dimension of the output.
512
patch_size
int
The size of the patches.
32
num_layers
int
The number of layers.
12
num_attention_heads
int
The number of attention heads.
12
feedforward_dim
int
The dimension of the feedforward layer.
3072
layer_norm_eps
float
The epsilon value for normalization.
1e-05
device
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(\n self,\n image_size: int = 224,\n embedding_dim: int = 768,\n output_dim: int = 512,\n patch_size: int = 32,\n num_layers: int = 12,\n num_attention_heads: int = 12,\n feedforward_dim: int = 3072,\n layer_norm_eps: float = 1e-5,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize a CLIP image encoder.\n\n Args:\n image_size: The size of the input image.\n embedding_dim: The dimension of the embedding.\n output_dim: The dimension of the output.\n patch_size: The size of the patches.\n num_layers: The number of layers.\n num_attention_heads: The number of attention heads.\n feedforward_dim: The dimension of the feedforward layer.\n layer_norm_eps: The epsilon value for normalization.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n self.image_size = image_size\n self.embedding_dim = embedding_dim\n self.output_dim = output_dim\n self.patch_size = patch_size\n self.num_layers = num_layers\n self.num_attention_heads = num_attention_heads\n self.feedforward_dim = feedforward_dim\n cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]\n super().__init__(\n ViTEmbeddings(\n image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype\n ),\n fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n fl.Chain(\n TransformerLayer(\n embedding_dim=embedding_dim,\n feedforward_dim=feedforward_dim,\n num_attention_heads=num_attention_heads,\n layer_norm_eps=layer_norm_eps,\n device=device,\n dtype=dtype,\n )\n for _ in range(num_layers)\n ),\n fl.Lambda(func=cls_token_pooling),\n fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoderG","title":"CLIPImageEncoderG","text":"CLIPImageEncoderG(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: CLIPImageEncoder
CLIP giant image encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1664
output_dim
int
1280
patch_size
int
14
num_layers
int
48
num_attention_heads
int
16
feedforward_dim
int
8192
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n \"\"\"Initialize CLIP giant image encoder.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1664,\n output_dim=1280,\n patch_size=14,\n num_layers=48,\n num_attention_heads=16,\n feedforward_dim=8192,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPImageEncoderH","title":"CLIPImageEncoderH","text":"CLIPImageEncoderH(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: CLIPImageEncoder
CLIP huge image encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1280
output_dim
int
1024
patch_size
int
14
num_layers
int
32
num_attention_heads
int
16
feedforward_dim
int
5120
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/image_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n \"\"\"Initialize CLIP huge image encoder.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1280,\n output_dim=1024,\n patch_size=14,\n num_layers=32,\n num_attention_heads=16,\n feedforward_dim=5120,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoder","title":"CLIPTextEncoder","text":"CLIPTextEncoder(\n embedding_dim: int = 768,\n max_sequence_length: int = 77,\n vocabulary_size: int = 49408,\n num_layers: int = 12,\n num_attention_heads: int = 12,\n feedforward_dim: int = 3072,\n layer_norm_eps: float = 1e-05,\n use_quick_gelu: bool = False,\n tokenizer: CLIPTokenizer | None = None,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Contrastive Language-Image Pretraining (CLIP) text encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Parameters:
Name Type Description Defaultembedding_dim
int
The embedding dimension.
768
max_sequence_length
int
The maximum sequence length.
77
vocabulary_size
int
The vocabulary size.
49408
num_layers
int
The number of layers.
12
num_attention_heads
int
The number of attention heads.
12
feedforward_dim
int
The feedforward dimension.
3072
layer_norm_eps
float
The epsilon value for layer normalization.
1e-05
use_quick_gelu
bool
Whether to use the quick GeLU activation function.
False
tokenizer
CLIPTokenizer | None
The tokenizer.
None
device
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(\n self,\n embedding_dim: int = 768,\n max_sequence_length: int = 77,\n vocabulary_size: int = 49408,\n num_layers: int = 12,\n num_attention_heads: int = 12,\n feedforward_dim: int = 3072,\n layer_norm_eps: float = 1e-5,\n use_quick_gelu: bool = False,\n tokenizer: CLIPTokenizer | None = None,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize CLIP text encoder.\n\n Args:\n embedding_dim: The embedding dimension.\n max_sequence_length: The maximum sequence length.\n vocabulary_size: The vocabulary size.\n num_layers: The number of layers.\n num_attention_heads: The number of attention heads.\n feedforward_dim: The feedforward dimension.\n layer_norm_eps: The epsilon value for layer normalization.\n use_quick_gelu: Whether to use the quick GeLU activation function.\n tokenizer: The tokenizer.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n self.embedding_dim = embedding_dim\n self.max_sequence_length = max_sequence_length\n self.vocabulary_size = vocabulary_size\n self.num_layers = num_layers\n self.num_attention_heads = num_attention_heads\n self.feedforward_dim = feedforward_dim\n self.layer_norm_eps = layer_norm_eps\n self.use_quick_gelu = use_quick_gelu\n super().__init__(\n tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),\n fl.Converter(set_dtype=False),\n fl.Sum(\n TokenEncoder(\n vocabulary_size=vocabulary_size,\n embedding_dim=embedding_dim,\n device=device,\n dtype=dtype,\n ),\n PositionalEncoder(\n max_sequence_length=max_sequence_length,\n embedding_dim=embedding_dim,\n device=device,\n dtype=dtype,\n ),\n ),\n *(\n TransformerLayer(\n embedding_dim=embedding_dim,\n num_attention_heads=num_attention_heads,\n feedforward_dim=feedforward_dim,\n layer_norm_eps=layer_norm_eps,\n device=device,\n dtype=dtype,\n )\n for _ in range(num_layers)\n ),\n fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),\n )\n if use_quick_gelu:\n for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):\n parent.replace(\n old_module=gelu,\n new_module=fl.GeLU(approximation=fl.GeLUApproximation.SIGMOID),\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderG","title":"CLIPTextEncoderG","text":"CLIPTextEncoderG(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: CLIPTextEncoder
CLIP giant text encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1280
num_layers
int
32
num_attention_heads
int
20
feedforward_dim
int
5120
tokenizer
CLIPTokenizer
CLIPTokenizer(pad_token_id=0)
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n \"\"\"Initialize CLIP giant text encoder.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n tokenizer = CLIPTokenizer(pad_token_id=0)\n super().__init__(\n embedding_dim=1280,\n num_layers=32,\n num_attention_heads=20,\n feedforward_dim=5120,\n tokenizer=tokenizer,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderH","title":"CLIPTextEncoderH","text":"CLIPTextEncoderH(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: CLIPTextEncoder
CLIP huge text encoder.
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1024
num_layers
int
23
num_attention_heads
int
16
feedforward_dim
int
4096
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n \"\"\"Initialize CLIP huge text encoder.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1024,\n num_layers=23,\n num_attention_heads=16,\n feedforward_dim=4096,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/clip/#refiners.foundationals.clip.CLIPTextEncoderL","title":"CLIPTextEncoderL","text":"CLIPTextEncoderL(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: CLIPTextEncoder
CLIP large text encoder.
NoteWe replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation of OpenAI (https://github.com/openai/CLIP/blob/a1d0717/clip/model.py#L166)
See [arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
768
num_layers
int
12
num_attention_heads
int
12
feedforward_dim
int
3072
use_quick_gelu
bool
True
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/clip/text_encoder.py
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:\n \"\"\"Initialize CLIP large text encoder.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=768,\n num_layers=12,\n num_attention_heads=12,\n feedforward_dim=3072,\n use_quick_gelu=True,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/","title":"
DINOv2","text":""},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_base","title":"DINOv2_base","text":"DINOv2_base(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 base model.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
768
patch_size
int
14
image_size
int
518
num_layers
int
12
num_heads
int
12
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 base model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=768,\n patch_size=14,\n image_size=518,\n num_layers=12,\n num_heads=12,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_base_reg","title":"DINOv2_base_reg","text":"DINOv2_base_reg(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 base model with register.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.
Attributes:
Name Type Descriptionembedding_dim
int
768
patch_size
int
14
image_size
int
518
num_layers
int
12
num_heads
int
12
num_registers
int
4
interpolate_antialias
bool
True
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 base model with register.\n\n Args:\n device (torch.device | str | None): The PyTorch device to use.\n dtype (torch.dtype | None): The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=768,\n patch_size=14,\n image_size=518,\n num_layers=12,\n num_heads=12,\n num_registers=4,\n interpolate_antialias=True,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_giant","title":"DINOv2_giant","text":"DINOv2_giant(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 giant model.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1536
feedforward_dim
int
4096
patch_size
int
14
image_size
int
518
num_layers
int
40
num_heads
int
24
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 giant model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1536,\n feedforward_dim=4096,\n patch_size=14,\n image_size=518,\n num_layers=40,\n num_heads=24,\n activation=GLU(SiLU()),\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_giant_reg","title":"DINOv2_giant_reg","text":"DINOv2_giant_reg(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 giant model with register.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers
Attributes:
Name Type Descriptionembedding_dim
int
1536
feedforward_dim
int
4096
patch_size
int
14
image_size
int
518
num_layers
int
40
num_heads
int
24
num_registers
int
4
interpolate_antialias
bool
True
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 giant model with register.\n\n Args:\n device (torch.device | str | None): The PyTorch device to use.\n dtype (torch.dtype | None): The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1536,\n feedforward_dim=4096,\n patch_size=14,\n image_size=518,\n num_layers=40,\n num_heads=24,\n num_registers=4,\n interpolate_antialias=True,\n activation=GLU(SiLU()),\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_large","title":"DINOv2_large","text":"DINOv2_large(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 large model.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1024
patch_size
int
14
image_size
int
518
num_layers
int
24
num_heads
int
16
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 large model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1024,\n patch_size=14,\n image_size=518,\n num_layers=24,\n num_heads=16,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_large_reg","title":"DINOv2_large_reg","text":"DINOv2_large_reg(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 large model with register.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.
Attributes:
Name Type Descriptionembedding_dim
int
1024
patch_size
int
14
image_size
int
518
num_layers
int
24
num_heads
int
16
num_registers
int
4
interpolate_antialias
bool
True
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 large model with register.\n\n Args:\n device (torch.device | str | None): The PyTorch device to use.\n dtype (torch.dtype | None): The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=1024,\n patch_size=14,\n image_size=518,\n num_layers=24,\n num_heads=16,\n num_registers=4,\n interpolate_antialias=True,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_small","title":"DINOv2_small","text":"DINOv2_small(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 small model.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision for more details.
Attributes:
Name Type Descriptionembedding_dim
int
384
patch_size
int
14
image_size
int
518
num_layers
int
12
num_heads
int
6
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 small model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=384,\n patch_size=14,\n image_size=518,\n num_layers=12,\n num_heads=6,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.DINOv2_small_reg","title":"DINOv2_small_reg","text":"DINOv2_small_reg(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: ViT
DINOv2 small model with register.
See [arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision and [arXiv:2309.16588] Vision Transformers Need Registers for more details.
Attributes:
Name Type Descriptionembedding_dim
int
384
patch_size
int
14
image_size
int
518
num_layers
int
12
num_heads
int
6
num_registers
int
4
interpolate_antialias
bool
True
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/dinov2.py
def __init__(\n self,\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize DINOv2 small model with register.\n\n Args:\n device (torch.device | str | None): The PyTorch device to use.\n dtype (torch.dtype | None): The PyTorch data type to use.\n \"\"\"\n super().__init__(\n embedding_dim=384,\n patch_size=14,\n image_size=518,\n num_layers=12,\n num_heads=6,\n num_registers=4,\n interpolate_antialias=True,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.ViT","title":"ViT","text":"ViT(\n embedding_dim: int = 768,\n patch_size: int = 16,\n image_size: int = 224,\n num_layers: int = 12,\n num_heads: int = 12,\n norm_eps: float = 1e-06,\n mlp_ratio: int = 4,\n num_registers: int = 0,\n activation: Activation = GeLU(),\n feedforward_dim: int | None = None,\n interpolate_antialias: bool = False,\n interpolate_mode: str = \"bicubic\",\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Vision Transformer (ViT) model.
See [arXiv:2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale for more details.
Parameters:
Name Type Description Defaultembedding_dim
int
The dimension of the embedding.
768
patch_size
int
The size of the patches.
16
image_size
int
The size of the input image.
224
num_layers
int
The number of layers.
12
num_heads
int
The number of heads.
12
norm_eps
float
The epsilon value for normalization.
1e-06
mlp_ratio
int
The ratio for the multi-layer perceptron (MLP).
4
num_registers
int
The number of registers.
0
activation
Activation
The activation function.
GeLU()
feedforward_dim
int | None
The dimension of the feedforward layer.
None
interpolate_antialias
bool
Whether to use antialiasing for interpolation.
False
interpolate_mode
str
The interpolation mode.
'bicubic'
device
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/dinov2/vit.py
def __init__(\n self,\n embedding_dim: int = 768,\n patch_size: int = 16,\n image_size: int = 224,\n num_layers: int = 12,\n num_heads: int = 12,\n norm_eps: float = 1e-6,\n mlp_ratio: int = 4,\n num_registers: int = 0,\n activation: Activation = fl.GeLU(),\n feedforward_dim: int | None = None,\n interpolate_antialias: bool = False,\n interpolate_mode: str = \"bicubic\",\n device: torch.device | str | None = None,\n dtype: torch.dtype | None = None,\n) -> None:\n \"\"\"Initialize a Vision Transformer (ViT) model.\n\n Args:\n embedding_dim: The dimension of the embedding.\n patch_size: The size of the patches.\n image_size: The size of the input image.\n num_layers: The number of layers.\n num_heads: The number of heads.\n norm_eps: The epsilon value for normalization.\n mlp_ratio: The ratio for the multi-layer perceptron (MLP).\n num_registers: The number of registers.\n activation: The activation function.\n feedforward_dim: The dimension of the feedforward layer.\n interpolate_antialias: Whether to use antialiasing for interpolation.\n interpolate_mode: The interpolation mode.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n num_patches = image_size // patch_size\n self.embedding_dim = embedding_dim\n self.patch_size = patch_size\n self.image_size = image_size\n self.num_layers = num_layers\n self.num_heads = num_heads\n self.norm_eps = norm_eps\n self.mlp_ratio = mlp_ratio\n self.num_registers = num_registers\n self.feedforward_dim = feedforward_dim\n\n super().__init__(\n fl.Concatenate(\n ClassToken(\n embedding_dim=embedding_dim,\n device=device,\n dtype=dtype,\n ),\n PatchEncoder(\n in_channels=3,\n out_channels=embedding_dim,\n patch_size=patch_size,\n device=device,\n dtype=dtype,\n ),\n dim=1,\n ),\n PositionalEncoder(\n PositionalEmbedding(\n sequence_length=num_patches**2 + 1,\n embedding_dim=embedding_dim,\n patch_size=patch_size,\n device=device,\n dtype=dtype,\n ),\n fl.Chain(\n fl.Parallel(\n fl.Identity(),\n fl.UseContext(context=\"dinov2_vit\", key=\"input\"),\n ),\n InterpolateEmbedding(\n mode=interpolate_mode,\n antialias=interpolate_antialias,\n patch_size=patch_size,\n ),\n ),\n ),\n Transformer(\n TransformerLayer(\n embedding_dim=embedding_dim,\n feedforward_dim=feedforward_dim,\n activation=activation,\n num_heads=num_heads,\n mlp_ratio=mlp_ratio,\n norm_eps=norm_eps,\n device=device,\n dtype=dtype,\n )\n for _ in range(num_layers)\n ),\n fl.LayerNorm(\n normalized_shape=embedding_dim,\n eps=norm_eps,\n device=device,\n dtype=dtype,\n ),\n )\n\n if self.num_registers > 0:\n registers = Registers(\n num_registers=num_registers,\n embedding_dim=embedding_dim,\n device=device,\n dtype=dtype,\n )\n self.insert_before_type(Transformer, registers)\n
"},{"location":"reference/foundationals/dinov2/#refiners.foundationals.dinov2.preprocess","title":"preprocess","text":"preprocess(img: Image, dim: int = 224) -> Tensor\n
Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation. Note that this only resizes and normalizes the image, there is no center crop.
Parameters:
Name Type Description Defaultimg
Image
The image.
requireddim
int
The square dimension to resize the image. Typically 224 or 518.
224
Returns:
Type DescriptionTensor
A float32 tensor with shape (3, dim, dim).
Source code insrc/refiners/foundationals/dinov2/dinov2.py
def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:\n \"\"\"\n Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.\n Note that this only resizes and normalizes the image, there is no center crop.\n\n Args:\n img: The image.\n dim: The square dimension to resize the image. Typically 224 or 518.\n\n Returns:\n A float32 tensor with shape (3, dim, dim).\n \"\"\"\n img = img.convert(\"RGB\").resize((dim, dim)) # type: ignore\n t = image_to_tensor(img).squeeze()\n return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n
"},{"location":"reference/foundationals/latent_diffusion/","title":"
Latent Diffusion","text":""},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.FixedGroupNorm","title":"FixedGroupNorm","text":"FixedGroupNorm(target: GroupNorm)\n
Bases: Chain
, Adapter[GroupNorm]
Adapter for GroupNorm layers to fix the running mean and variance.
This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers are consistent across tiles.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(self, target: fl.GroupNorm) -> None:\n self.mean = None\n self.var = None\n with self.setup_adapter(target):\n super().__init__(fl.Lambda(self.compute_group_norm))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder","title":"LatentDiffusionAutoencoder","text":"LatentDiffusionAutoencoder(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Latent diffusion autoencoder model.
Attributes:
Name Type Descriptionencoder_scale
The encoder scale to use.
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n self,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initializes the model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n Encoder(device=device, dtype=dtype),\n Decoder(device=device, dtype=dtype),\n )\n self._tile_size = None\n self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.decode","title":"decode","text":"decode(x: Tensor) -> Tensor\n
Decode a latent tensor.
Parameters:
Name Type Description Defaultx
Tensor
The latent to decode.
requiredReturns:
Type DescriptionTensor
The decoded image tensor.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def decode(self, x: Tensor) -> Tensor:\n \"\"\"Decode a latent tensor.\n\n Args:\n x: The latent to decode.\n\n Returns:\n The decoded image tensor.\n \"\"\"\n decoder = self[1]\n x = decoder(x / self.encoder_scale)\n return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.encode","title":"encode","text":"encode(x: Tensor) -> Tensor\n
Encode an image.
Parameters:
Name Type Description Defaultx
Tensor
The image tensor to encode.
requiredReturns:
Type DescriptionTensor
The encoded tensor.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def encode(self, x: Tensor) -> Tensor:\n \"\"\"Encode an image.\n\n Args:\n x: The image tensor to encode.\n\n Returns:\n The encoded tensor.\n \"\"\"\n encoder = self[0]\n x = self.encoder_scale * encoder(x)\n return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.image_to_latents","title":"image_to_latents","text":"image_to_latents(image: Image) -> Tensor\n
Encode an image to latents.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def image_to_latents(self, image: Image.Image) -> Tensor:\n \"\"\"\n Encode an image to latents.\n \"\"\"\n return self.images_to_latents([image])\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.images_to_latents","title":"images_to_latents","text":"images_to_latents(images: list[Image]) -> Tensor\n
Convert a list of images to latents.
Parameters:
Name Type Description Defaultimages
list[Image]
The list of images to convert.
requiredReturns:
Type DescriptionTensor
A tensor containing the latents associated with the images.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def images_to_latents(self, images: list[Image.Image]) -> Tensor:\n \"\"\"Convert a list of images to latents.\n\n Args:\n images: The list of images to convert.\n\n Returns:\n A tensor containing the latents associated with the images.\n \"\"\"\n x = images_to_tensor(images, device=self.device, dtype=self.dtype)\n x = 2 * x - 1\n return self.encode(x)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.latents_to_image","title":"latents_to_image","text":"latents_to_image(x: Tensor) -> Image\n
Decode latents to an image.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_image(self, x: Tensor) -> Image.Image:\n \"\"\"\n Decode latents to an image.\n \"\"\"\n if x.shape[0] != 1:\n raise ValueError(f\"Expected batch size of 1, got {x.shape[0]}\")\n\n return self.latents_to_images(x)[0]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.latents_to_images","title":"latents_to_images","text":"latents_to_images(x: Tensor) -> list[Image]\n
Convert a tensor of latents to images.
Parameters:
Name Type Description Defaultx
Tensor
The tensor of latents to convert.
requiredReturns:
Type Descriptionlist[Image]
A list of images associated with the latents.
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def latents_to_images(self, x: Tensor) -> list[Image.Image]:\n \"\"\"Convert a tensor of latents to images.\n\n Args:\n x: The tensor of latents to convert.\n\n Returns:\n A list of images associated with the latents.\n \"\"\"\n x = self.decode(x)\n x = (x + 1) / 2\n return tensor_to_images(x)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_image_to_latents","title":"tiled_image_to_latents","text":"tiled_image_to_latents(image: Image) -> Tensor\n
Convert an image to latents with gradient blending to smooth tile edges.
You need to activate the tiled inference context manager with the tiled_inference
method to use this method.
```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): latents = lda.tiled_image_to_latents(sample_image)
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:\n \"\"\"\n Convert an image to latents with gradient blending to smooth tile edges.\n\n You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.\n\n ```python\n with lda.tiled_inference(sample_image, tile_size=(768, 1024)):\n latents = lda.tiled_image_to_latents(sample_image)\n \"\"\"\n if self._tile_size is None:\n raise ValueError(\"Tiled inference context manager not active. Use `tiled_inference` method to activate.\")\n\n assert self._tile_size is not None and self._blending is not None\n image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)\n image_tensor = 2 * image_tensor - 1\n return self._tiled_encode(image_tensor, self._tile_size, self._blending)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_inference","title":"tiled_inference","text":"tiled_inference(\n image: Image,\n tile_size: tuple[int, int] = (512, 512),\n blending: int = 64,\n) -> Generator[None, None, None]\n
Context manager for tiled inference operations to save VRAM for large images.
This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the image.
Be careful not to use the normal image_to_latents
and latents_to_image
methods while this context manager is active, as this will fail silently and run the operation without tiling.
```python with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32): latents = lda.tiled_image_to_latents(sample_image) decoded_image = lda.tiled_latents_to_image(latents)
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
@contextmanager\ndef tiled_inference(\n self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64\n) -> Generator[None, None, None]:\n \"\"\"\n Context manager for tiled inference operations to save VRAM for large images.\n\n This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the\n autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is\n consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the\n image.\n\n Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is\n active, as this will fail silently and run the operation without tiling.\n\n ```python\n with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):\n latents = lda.tiled_image_to_latents(sample_image)\n decoded_image = lda.tiled_latents_to_image(latents)\n \"\"\"\n try:\n self._blending = blending\n self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])\n self._add_fixed_group_norm(image, inference_size=self._tile_size)\n yield\n finally:\n self._remove_fixed_group_norm()\n self._tile_size = None\n self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.auto_encoder.LatentDiffusionAutoencoder.tiled_latents_to_image","title":"tiled_latents_to_image","text":"tiled_latents_to_image(x: Tensor) -> Image\n
Convert latents to an image with gradient blending to smooth tile edges.
You need to activate the tiled inference context manager with the tiled_inference
method to use this method.
```python with lda.tiled_inference(sample_image, tile_size=(768, 1024)): image = lda.tiled_latents_to_image(latents)
Source code insrc/refiners/foundationals/latent_diffusion/auto_encoder.py
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:\n \"\"\"\n Convert latents to an image with gradient blending to smooth tile edges.\n\n You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.\n\n ```python\n with lda.tiled_inference(sample_image, tile_size=(768, 1024)):\n image = lda.tiled_latents_to_image(latents)\n \"\"\"\n if self._tile_size is None:\n raise ValueError(\"Tiled inference context manager not active. Use `tiled_inference` method to activate.\")\n\n assert self._tile_size is not None and self._blending is not None\n result = self._tiled_decode(x, self._tile_size, self._blending)\n return tensor_to_image((result + 1) / 2)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel","title":"LatentDiffusionModel","text":"LatentDiffusionModel(\n unet: Chain,\n lda: LatentDiffusionAutoencoder,\n clip_text_encoder: Chain,\n solver: Solver,\n classifier_free_guidance: bool = True,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Module
, ABC
src/refiners/foundationals/latent_diffusion/model.py
def __init__(\n self,\n unet: fl.Chain,\n lda: LatentDiffusionAutoencoder,\n clip_text_encoder: fl.Chain,\n solver: Solver,\n classifier_free_guidance: bool = True,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n super().__init__()\n self.device: Device = device if isinstance(device, Device) else Device(device=device)\n self.dtype = dtype\n self.unet = unet.to(device=self.device, dtype=self.dtype)\n self.lda = lda.to(device=self.device, dtype=self.dtype)\n self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)\n self.solver = solver.to(device=self.device, dtype=self.dtype)\n self.classifier_free_guidance = classifier_free_guidance\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.init_latents","title":"init_latents","text":"init_latents(\n size: tuple[int, int],\n init_image: Image | None = None,\n noise: Tensor | None = None,\n) -> Tensor\n
Initialize the latents for the diffusion process.
Parameters:
Name Type Description Defaultsize
tuple[int, int]
The size of the latent (in pixel space).
requiredinit_image
Image | None
The image to use as initialization for the latents.
None
noise
Tensor | None
The noise to add to the latents.
None
Source code in src/refiners/foundationals/latent_diffusion/model.py
def init_latents(\n self,\n size: tuple[int, int],\n init_image: Image.Image | None = None,\n noise: Tensor | None = None,\n) -> Tensor:\n \"\"\"Initialize the latents for the diffusion process.\n\n Args:\n size: The size of the latent (in pixel space).\n init_image: The image to use as initialization for the latents.\n noise: The noise to add to the latents.\n \"\"\"\n height, width = size\n latent_height = height // 8\n latent_width = width // 8\n\n if noise is None:\n noise = LatentDiffusionModel.sample_noise(\n size=(1, 4, latent_height, latent_width),\n device=self.device,\n dtype=self.dtype,\n )\n\n assert list(noise.shape[2:]) == [\n latent_height,\n latent_width,\n ], f\"noise shape is not compatible: {noise.shape}, with size: {size}\"\n\n if init_image is None:\n latent = noise\n else:\n resized = init_image.resize(size=(width, height)) # type: ignore\n encoded_image = self.lda.image_to_latents(resized)\n latent = self.solver.add_noise(\n x=encoded_image,\n noise=noise,\n step=self.solver.first_inference_step,\n )\n\n return self.solver.scale_model_input(latent, step=-1)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.sample_noise","title":"sample_noise staticmethod
","text":"sample_noise(\n size: tuple[int, ...],\n device: device | None = None,\n dtype: dtype | None = None,\n offset_noise: float | None = None,\n) -> Tensor\n
Sample noise from a normal distribution with an optional offset.
Parameters:
Name Type Description Defaultsize
tuple[int, ...]
The size of the noise tensor.
requireddevice
device | None
The device to put the noise tensor on.
None
dtype
dtype | None
The data type of the noise tensor.
None
offset_noise
float | None
The offset of the noise tensor. Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.
None
Source code in src/refiners/foundationals/latent_diffusion/model.py
@staticmethod\ndef sample_noise(\n size: tuple[int, ...],\n device: Device | None = None,\n dtype: DType | None = None,\n offset_noise: float | None = None,\n) -> torch.Tensor:\n \"\"\"Sample noise from a normal distribution with an optional offset.\n\n Args:\n size: The size of the noise tensor.\n device: The device to put the noise tensor on.\n dtype: The data type of the noise tensor.\n offset_noise: The offset of the noise tensor.\n Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.\n \"\"\"\n noise = torch.randn(size=size, device=device, dtype=dtype)\n if offset_noise is not None:\n noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)\n return noise\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.model.LatentDiffusionModel.set_inference_steps","title":"set_inference_steps","text":"set_inference_steps(\n num_steps: int, first_step: int = 0\n) -> None\n
Set the steps of the diffusion process.
Parameters:
Name Type Description Defaultnum_steps
int
The number of inference steps.
requiredfirst_step
int
The first inference step, used for image-to-image diffusion. You may be used to setting a float in [0, 1]
called strength
instead, which is an abstraction for this. The first step is round((1 - strength) * (num_steps - 1))
.
0
Source code in src/refiners/foundationals/latent_diffusion/model.py
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:\n \"\"\"Set the steps of the diffusion process.\n\n Args:\n num_steps: The number of inference steps.\n first_step: The first inference step, used for image-to-image diffusion.\n You may be used to setting a float in `[0, 1]` called `strength` instead,\n which is an abstraction for this. The first step is\n `round((1 - strength) * (num_steps - 1))`.\n \"\"\"\n self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora","title":"ControlLora","text":"ControlLora(\n name: str,\n unet: SDXLUNet,\n scale: float = 1.0,\n condition_channels: int = 3,\n)\n
Bases: Passthrough
ControlLora is a Half-UNet clone of the target UNet, patched with various LoRA
layers, ZeroConvolution
layers, and a ConditionEncoder
.
Like ControlNet, it injects residual tensors into the target UNet. See https://github.com/HighCWu/control-lora-v2 for more details.
Gets context:
Type DescriptionFloat[Tensor, 'batch condition_channels width height']
The input image.
Sets context:
Type Descriptionlist[Tensor]
The residuals to be added to the target UNet's residuals. (context=\"unet\", key=\"residuals\")
Parameters:
Name Type Description Defaultname
str
The name of the ControlLora.
requiredunet
SDXLUNet
The target UNet.
requiredscale
float
The scale to multiply the residuals by.
1.0
condition_channels
int
The number of channels of the input condition tensor.
3
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(\n self,\n name: str,\n unet: SDXLUNet,\n scale: float = 1.0,\n condition_channels: int = 3,\n) -> None:\n \"\"\"Initialize the ControlLora.\n\n Args:\n name: The name of the ControlLora.\n unet: The target UNet.\n scale: The scale to multiply the residuals by.\n condition_channels: The number of channels of the input condition tensor.\n \"\"\"\n self.name = name\n\n super().__init__(\n timestep_encoder := unet.layer(\"TimestepEncoder\", Chain).structural_copy(),\n downblocks := unet.layer(\"DownBlocks\", Chain).structural_copy(),\n middle_block := unet.layer(\"MiddleBlock\", Chain).structural_copy(),\n )\n\n # modify the context_key of the copied TimestepEncoder to avoid conflicts\n timestep_encoder.context_key = f\"timestep_embedding_control_lora_{name}\"\n\n # modify the context_key of each RangeAdapter2d to avoid conflicts\n for range_adapter in self.layers(RangeAdapter2d):\n range_adapter.context_key = f\"timestep_embedding_control_lora_{name}\"\n\n # insert the ConditionEncoder in the first DownBlock\n first_downblock = downblocks.layer(0, Chain)\n out_channels = first_downblock.layer(0, Conv2d).out_channels\n first_downblock.append(\n Residual(\n UseContext(f\"control_lora_{name}\", f\"condition\"),\n ConditionEncoder(\n in_channels=condition_channels,\n out_channels=out_channels,\n device=unet.device,\n dtype=unet.dtype,\n ),\n )\n )\n\n # replace each ResidualAccumulator by a ZeroConvolution\n for residual_accumulator in self.layers(ResidualAccumulator):\n downblock = self.ensure_find_parent(residual_accumulator)\n\n first_layer = downblock[0]\n assert hasattr(first_layer, \"out_channels\"), f\"{first_layer} has no out_channels attribute\"\n\n block_channels = first_layer.out_channels\n assert isinstance(block_channels, int)\n\n downblock.replace(\n residual_accumulator,\n ZeroConvolution(\n scale=scale,\n residual_index=residual_accumulator.n,\n in_channels=block_channels,\n out_channels=block_channels,\n device=unet.device,\n dtype=unet.dtype,\n ),\n )\n\n # append a ZeroConvolution to middle_block\n middle_block_channels = middle_block.layer(0, ResidualBlock).out_channels\n middle_block.append(\n ZeroConvolution(\n scale=scale,\n residual_index=len(downblocks),\n in_channels=middle_block_channels,\n out_channels=middle_block_channels,\n device=unet.device,\n dtype=unet.dtype,\n )\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLora.scale","title":"scale property
writable
","text":"scale: float\n
The scale of the residuals stored in the context.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter","title":"ControlLoraAdapter","text":"ControlLoraAdapter(\n name: str,\n target: SDXLUNet,\n scale: float = 1.0,\n condition_channels: int = 3,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: Chain
, Adapter[SDXLUNet]
Adapter for ControlLora
.
This adapter simply prepends a ControlLora
model inside the target SDXLUNet
.
src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def __init__(\n self,\n name: str,\n target: SDXLUNet,\n scale: float = 1.0,\n condition_channels: int = 3,\n weights: dict[str, Tensor] | None = None,\n) -> None:\n with self.setup_adapter(target):\n self.name = name\n self._control_lora = [\n ControlLora(\n name=name,\n unet=target,\n scale=scale,\n condition_channels=condition_channels,\n ),\n ]\n\n super().__init__(target)\n\n if weights:\n self.load_weights(weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.control_lora","title":"control_lora property
","text":"control_lora: ControlLora\n
The ControlLora model.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.scale","title":"scaleproperty
writable
","text":"scale: float\n
The scale of the injected residuals.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_condition_encoder","title":"load_condition_encoderstaticmethod
","text":"load_condition_encoder(\n state_dict: dict[str, Tensor], control_lora: ControlLora\n)\n
Load the ConditionEncoder
's layers from the state_dict into the ControlLora
.
Parameters:
Name Type Description Defaultstate_dict
dict[str, Tensor]
The state_dict containing the ConditionEncoder layers to load.
requiredcontrol_lora
ControlLora
The ControlLora to load the ConditionEncoder layers into.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_condition_encoder(\n state_dict: dict[str, Tensor],\n control_lora: ControlLora,\n):\n \"\"\"Load the `ConditionEncoder`'s layers from the state_dict into the `ControlLora`.\n\n Args:\n state_dict: The state_dict containing the ConditionEncoder layers to load.\n control_lora: The ControlLora to load the ConditionEncoder layers into.\n \"\"\"\n condition_encoder_layer = control_lora.ensure_find(ConditionEncoder)\n condition_encoder_state_dict = {\n key.removeprefix(\"ConditionEncoder.\"): value\n for key, value in state_dict.items()\n if \"ConditionEncoder\" in key\n }\n condition_encoder_layer.load_state_dict(condition_encoder_state_dict)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_lora_layers","title":"load_lora_layers staticmethod
","text":"load_lora_layers(\n name: str,\n state_dict: dict[str, Tensor],\n control_lora: ControlLora,\n) -> None\n
Load the LoRA
layers from the state_dict into the ControlLora
.
Parameters:
Name Type Description Defaultname
str
The name of the ControlLora.
requiredstate_dict
dict[str, Tensor]
The state_dict containing the LoRA layers to load.
requiredcontrol_lora
ControlLora
The ControlLora to load the LoRA layers into.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_lora_layers(\n name: str,\n state_dict: dict[str, Tensor],\n control_lora: ControlLora,\n) -> None:\n \"\"\"Load the [`LoRA`][refiners.fluxion.adapters.lora.Lora] layers from the state_dict into the `ControlLora`.\n\n Args:\n name: The name of the ControlLora.\n state_dict: The state_dict containing the LoRA layers to load.\n control_lora: The ControlLora to load the LoRA layers into.\n \"\"\"\n # filter the LoraAdapters from the state_dict\n lora_weights = {\n key.removeprefix(\"ControlLora.\"): value for key, value in state_dict.items() if \"ControlLora\" in key\n }\n lora_weights = {f\"{key}.weight\": value for key, value in lora_weights.items()}\n\n # move the tensors to the device and dtype of the ControlLora\n lora_weights = {\n key: value.to(\n dtype=control_lora.dtype,\n device=control_lora.device,\n )\n for key, value in lora_weights.items()\n }\n\n # load every LoRA layers from the filtered state_dict\n loras = Lora.from_dict(name, state_dict=lora_weights)\n\n # attach the LoRA layers to the ControlLora\n adapters: list[LoraAdapter] = []\n for key, lora in loras.items():\n target = control_lora.layer(key.split(\".\"), WeightedModule)\n assert lora.is_compatible(target)\n adapter = LoraAdapter(target, lora)\n adapters.append(adapter)\n\n for adapter in adapters:\n adapter.inject(control_lora)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_weights","title":"load_weights","text":"load_weights(state_dict: dict[str, Tensor]) -> None\n
Load the weights from the state_dict into the ControlLora
.
Parameters:
Name Type Description Defaultstate_dict
dict[str, Tensor]
The state_dict containing the weights to load.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
def load_weights(\n self,\n state_dict: dict[str, Tensor],\n) -> None:\n \"\"\"Load the weights from the state_dict into the `ControlLora`.\n\n Args:\n state_dict: The state_dict containing the weights to load.\n \"\"\"\n ControlLoraAdapter.load_lora_layers(self.name, state_dict, self.control_lora)\n ControlLoraAdapter.load_zero_convolution_layers(state_dict, self.control_lora)\n ControlLoraAdapter.load_condition_encoder(state_dict, self.control_lora)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter.load_zero_convolution_layers","title":"load_zero_convolution_layers staticmethod
","text":"load_zero_convolution_layers(\n state_dict: dict[str, Tensor], control_lora: ControlLora\n)\n
Load the ZeroConvolution
layers from the state_dict into the ControlLora
.
Parameters:
Name Type Description Defaultstate_dict
dict[str, Tensor]
The state_dict containing the ZeroConvolution layers to load.
requiredcontrol_lora
ControlLora
The ControlLora to load the ZeroConvolution layers into.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py
@staticmethod\ndef load_zero_convolution_layers(\n state_dict: dict[str, Tensor],\n control_lora: ControlLora,\n):\n \"\"\"Load the `ZeroConvolution` layers from the state_dict into the `ControlLora`.\n\n Args:\n state_dict: The state_dict containing the ZeroConvolution layers to load.\n control_lora: The ControlLora to load the ZeroConvolution layers into.\n \"\"\"\n zero_convolution_layers = list(control_lora.layers(ZeroConvolution))\n for i, zero_convolution_layer in enumerate(zero_convolution_layers):\n zero_convolution_state_dict = {\n key.removeprefix(f\"ZeroConvolution_{i+1:02d}.\"): value\n for key, value in state_dict.items()\n if f\"ZeroConvolution_{i+1:02d}\" in key\n }\n zero_convolution_layer.load_state_dict(zero_convolution_state_dict)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLAutoencoder","title":"SDXLAutoencoder","text":"SDXLAutoencoder(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: LatentDiffusionAutoencoder
Stable Diffusion XL autoencoder model.
Attributes:
Name Type Descriptionencoder_scale
float
The encoder scale to use.
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n self,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initializes the model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n Encoder(device=device, dtype=dtype),\n Decoder(device=device, dtype=dtype),\n )\n self._tile_size = None\n self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLIPAdapter","title":"SDXLIPAdapter","text":"SDXLIPAdapter(\n target: SDXLUNet,\n clip_image_encoder: CLIPImageEncoderH | None = None,\n image_proj: (\n ImageProjection | PerceiverResampler | None\n ) = None,\n scale: float = 1.0,\n fine_grained: bool = False,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: IPAdapter[SDXLUNet]
Image Prompt adapter for the Stable Diffusion XL U-Net model.
Parameters:
Name Type Description Defaulttarget
SDXLUNet
The SDXLUNet model to adapt.
requiredclip_image_encoder
CLIPImageEncoderH | None
The CLIP image encoder to use.
None
image_proj
ImageProjection | PerceiverResampler | None
The image projection to use.
None
scale
float
The scale to use for the image prompt.
1.0
fine_grained
bool
Whether to use fine-grained image prompt.
False
weights
dict[str, Tensor] | None
The weights of the IPAdapter.
None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py
def __init__(\n self,\n target: SDXLUNet,\n clip_image_encoder: CLIPImageEncoderH | None = None,\n image_proj: ImageProjection | PerceiverResampler | None = None,\n scale: float = 1.0,\n fine_grained: bool = False,\n weights: dict[str, Tensor] | None = None,\n) -> None:\n \"\"\"Initialize the adapter.\n\n Args:\n target: The SDXLUNet model to adapt.\n clip_image_encoder: The CLIP image encoder to use.\n image_proj: The image projection to use.\n scale: The scale to use for the image prompt.\n fine_grained: Whether to use fine-grained image prompt.\n weights: The weights of the IPAdapter.\n \"\"\"\n clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)\n\n if image_proj is None:\n cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)\n image_proj = (\n ImageProjection(\n clip_image_embedding_dim=clip_image_encoder.output_dim,\n clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,\n device=target.device,\n dtype=target.dtype,\n )\n if not fine_grained\n else PerceiverResampler(\n latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case\n num_attention_layers=4,\n num_attention_heads=20,\n head_dim=64,\n num_tokens=16,\n input_dim=clip_image_encoder.embedding_dim, # = dim before final projection\n output_dim=cross_attn_2d.context_embedding_dim,\n device=target.device,\n dtype=target.dtype,\n )\n )\n elif fine_grained:\n assert isinstance(image_proj, PerceiverResampler)\n\n super().__init__(\n target=target,\n clip_image_encoder=clip_image_encoder,\n image_proj=image_proj,\n scale=scale,\n fine_grained=fine_grained,\n weights=weights,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLLcmAdapter","title":"SDXLLcmAdapter","text":"SDXLLcmAdapter(\n target: SDXLUNet,\n condition_scale_embedding_dim: int = 256,\n condition_scale: float = 7.5,\n)\n
Bases: Chain
, Adapter[SDXLUNet]
Note that LCM must be used without CFG. You can disable CFG on SD by setting the classifier_free_guidance
attribute to False
.
Parameters:
Name Type Description Defaulttarget
SDXLUNet
A SDXL UNet.
requiredcondition_scale_embedding_dim
int
LCM uses a condition scale embedding, this is its dimension.
256
condition_scale
float
Because of the embedding, the condition scale must be passed to this adapter instead of SD. The condition scale passed to SD will be ignored.
7.5
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py
def __init__(\n self,\n target: SDXLUNet,\n condition_scale_embedding_dim: int = 256,\n condition_scale: float = 7.5,\n) -> None:\n \"\"\"Adapt [the SDXl UNet][refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet.SDXLUNet]\n for use with [LCMSolver][refiners.foundationals.latent_diffusion.solvers.lcm.LCMSolver].\n\n Note that LCM must be used *without* CFG. You can disable CFG on SD by setting the\n `classifier_free_guidance` attribute to `False`.\n\n Args:\n target: A SDXL UNet.\n condition_scale_embedding_dim: LCM uses a condition scale embedding, this is its dimension.\n condition_scale: Because of the embedding, the condition scale must be passed to this adapter\n instead of SD. The condition scale passed to SD will be ignored.\n \"\"\"\n assert condition_scale_embedding_dim % 2 == 0\n self.condition_scale_embedding_dim = condition_scale_embedding_dim\n self.condition_scale = condition_scale\n with self.setup_adapter(target):\n super().__init__(target)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet","title":"SDXLUNet","text":"SDXLUNet(\n in_channels: int,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Stable Diffusion XL U-Net.
See [arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis for more details.
Parameters:
Name Type Description Defaultin_channels
int
Number of input channels.
requireddevice
device | str | None
Device to use for computation.
None
dtype
dtype | None
Data type to use for computation.
None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def __init__(\n self,\n in_channels: int,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the U-Net.\n\n Args:\n in_channels: Number of input channels.\n device: Device to use for computation.\n dtype: Data type to use for computation.\n \"\"\"\n self.in_channels = in_channels\n super().__init__(\n TimestepEncoder(device=device, dtype=dtype),\n DownBlocks(in_channels=in_channels, device=device, dtype=dtype),\n MiddleBlock(device=device, dtype=dtype),\n fl.Residual(fl.UseContext(context=\"unet\", key=\"residuals\").compose(lambda x: x[-1])),\n UpBlocks(device=device, dtype=dtype),\n OutputBlock(device=device, dtype=dtype),\n )\n for residual_block in self.layers(ResidualBlock):\n chain = residual_block.layer(\"Chain\", fl.Chain)\n RangeAdapter2d(\n target=chain.layer(\"Conv2d_1\", fl.Conv2d),\n channels=residual_block.out_channels,\n embedding_dim=1280,\n context_key=\"timestep_embedding\",\n device=device,\n dtype=dtype,\n ).inject(chain)\n for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):\n block.append(module=ResidualAccumulator(n=n))\n for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):\n block.insert(index=0, module=ResidualConcatenator(n=-n - 2))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_clip_text_embedding","title":"set_clip_text_embedding","text":"set_clip_text_embedding(\n clip_text_embedding: Tensor,\n) -> None\n
Set the clip text embedding context.
NoteThis context is required by the SDXLCrossAttention
blocks.
Parameters:
Name Type Description Defaultclip_text_embedding
Tensor
The CLIP text embedding tensor.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:\n \"\"\"Set the clip text embedding context.\n\n Note:\n This context is required by the `SDXLCrossAttention` blocks.\n\n Args:\n clip_text_embedding: The CLIP text embedding tensor.\n \"\"\"\n self.set_context(context=\"cross_attention_block\", value={\"clip_text_embedding\": clip_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_pooled_text_embedding","title":"set_pooled_text_embedding","text":"set_pooled_text_embedding(\n pooled_text_embedding: Tensor,\n) -> None\n
Set the pooled text embedding context.
NoteThis is required by TextTimeEmbedding
.
Parameters:
Name Type Description Defaultpooled_text_embedding
Tensor
The pooled text embedding tensor.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:\n \"\"\"Set the pooled text embedding context.\n\n Note:\n This is required by `TextTimeEmbedding`.\n\n Args:\n pooled_text_embedding: The pooled text embedding tensor.\n \"\"\"\n self.set_context(context=\"diffusion\", value={\"pooled_text_embedding\": pooled_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_time_ids","title":"set_time_ids","text":"set_time_ids(time_ids: Tensor) -> None\n
Set the time IDs context.
NoteThis is required by TextTimeEmbedding
.
Parameters:
Name Type Description Defaulttime_ids
Tensor
The time IDs tensor.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_time_ids(self, time_ids: Tensor) -> None:\n \"\"\"Set the time IDs context.\n\n Note:\n This is required by `TextTimeEmbedding`.\n\n Args:\n time_ids: The time IDs tensor.\n \"\"\"\n self.set_context(context=\"diffusion\", value={\"time_ids\": time_ids})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.SDXLUNet.set_timestep","title":"set_timestep","text":"set_timestep(timestep: Tensor) -> None\n
Set the timestep context.
NoteThis is required by TimestepEncoder
.
Parameters:
Name Type Description Defaulttimestep
Tensor
The timestep tensor.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py
def set_timestep(self, timestep: Tensor) -> None:\n \"\"\"Set the timestep context.\n\n Note:\n This is required by `TimestepEncoder`.\n\n Args:\n timestep: The timestep tensor.\n \"\"\"\n self.set_context(context=\"diffusion\", value={\"timestep\": timestep})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL","title":"StableDiffusion_XL","text":"StableDiffusion_XL(\n unet: SDXLUNet | None = None,\n lda: SDXLAutoencoder | None = None,\n clip_text_encoder: DoubleTextEncoder | None = None,\n solver: Solver | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: LatentDiffusionModel
Stable Diffusion XL model.
Attributes:
Name Type Descriptionunet
SDXLUNet
The U-Net model.
clip_text_encoder
DoubleTextEncoder
The text encoder.
lda
SDXLAutoencoder
The image autoencoder.
Parameters:
Name Type Description Defaultunet
SDXLUNet | None
The SDXLUNet U-Net model to use.
None
lda
SDXLAutoencoder | None
The SDXLAutoencoder image autoencoder to use.
None
clip_text_encoder
DoubleTextEncoder | None
The DoubleTextEncoder text encoder to use.
None
solver
Solver | None
The solver to use.
None
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def __init__(\n self,\n unet: SDXLUNet | None = None,\n lda: SDXLAutoencoder | None = None,\n clip_text_encoder: DoubleTextEncoder | None = None,\n solver: Solver | None = None,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n \"\"\"Initializes the model.\n\n Args:\n unet: The SDXLUNet U-Net model to use.\n lda: The SDXLAutoencoder image autoencoder to use.\n clip_text_encoder: The DoubleTextEncoder text encoder to use.\n solver: The solver to use.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n unet = unet or SDXLUNet(in_channels=4)\n lda = lda or SDXLAutoencoder()\n clip_text_encoder = clip_text_encoder or DoubleTextEncoder()\n solver = solver or DDIM(num_inference_steps=30)\n\n super().__init__(\n unet=unet,\n lda=lda,\n clip_text_encoder=clip_text_encoder,\n solver=solver,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.default_time_ids","title":"default_time_ids property
","text":"default_time_ids: Tensor\n
The default time IDs to use.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.compute_clip_text_embedding","title":"compute_clip_text_embedding","text":"compute_clip_text_embedding(\n text: str | list[str],\n negative_text: str | list[str] = \"\",\n) -> tuple[Tensor, Tensor]\n
Compute the CLIP text embedding associated with the given prompt and negative prompt.
Parameters:
Name Type Description Defaulttext
str | list[str]
The prompt to compute the CLIP text embedding of.
requirednegative_text
str | list[str]
The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., \"\"
).
''
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_clip_text_embedding(\n self, text: str | list[str], negative_text: str | list[str] = \"\"\n) -> tuple[Tensor, Tensor]:\n \"\"\"Compute the CLIP text embedding associated with the given prompt and negative prompt.\n\n Args:\n text: The prompt to compute the CLIP text embedding of.\n negative_text: The negative prompt to compute the CLIP text embedding of.\n If not provided, the negative prompt is assumed to be empty (i.e., `\"\"`).\n \"\"\"\n\n text = [text] if isinstance(text, str) else text\n\n if not self.classifier_free_guidance:\n return self.clip_text_encoder(text)\n\n negative_text = [negative_text] if isinstance(negative_text, str) else negative_text\n assert len(text) == len(negative_text), \"The length of the text list and negative_text should be the same\"\n\n conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)\n negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)\n\n return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(\n tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"compute_self_attention_guidance(\n x: Tensor,\n noise: Tensor,\n step: int,\n *,\n clip_text_embedding: Tensor,\n pooled_text_embedding: Tensor,\n time_ids: Tensor,\n **kwargs: Tensor\n) -> Tensor\n
Compute the self-attention guidance.
Parameters:
Name Type Description Defaultx
Tensor
The input tensor.
requirednoise
Tensor
The noise tensor.
requiredstep
int
The step to compute the self-attention guidance at.
requiredclip_text_embedding
Tensor
The CLIP text embedding to compute the self-attention guidance with.
requiredpooled_text_embedding
Tensor
The pooled CLIP text embedding to compute the self-attention guidance with.
requiredtime_ids
Tensor
The time IDs to compute the self-attention guidance with.
requiredReturns:
Type DescriptionTensor
The computed self-attention guidance.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def compute_self_attention_guidance(\n self,\n x: Tensor,\n noise: Tensor,\n step: int,\n *,\n clip_text_embedding: Tensor,\n pooled_text_embedding: Tensor,\n time_ids: Tensor,\n **kwargs: Tensor,\n) -> Tensor:\n \"\"\"Compute the self-attention guidance.\n\n Args:\n x: The input tensor.\n noise: The noise tensor.\n step: The step to compute the self-attention guidance at.\n clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with.\n time_ids: The time IDs to compute the self-attention guidance with.\n\n Returns:\n The computed self-attention guidance.\n \"\"\"\n sag = self._find_sag_adapter()\n assert sag is not None\n\n degraded_latents = sag.compute_degraded_latents(\n solver=self.solver,\n latents=x,\n noise=noise,\n step=step,\n classifier_free_guidance=True,\n )\n\n negative_text_embedding, _ = clip_text_embedding.chunk(2)\n negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)\n timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n time_ids, _ = time_ids.chunk(2)\n\n self.set_unet_context(\n timestep=timestep,\n clip_text_embedding=negative_text_embedding,\n pooled_text_embedding=negative_pooled_embedding,\n time_ids=time_ids,\n )\n if \"ip_adapter\" in self.unet.provider.contexts:\n # this implementation is a bit hacky, it should be refactored in the future\n ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n degraded_noise = self.unet(degraded_latents)\n ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n else:\n degraded_noise = self.unet(degraded_latents)\n\n return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.has_self_attention_guidance","title":"has_self_attention_guidance","text":"has_self_attention_guidance() -> bool\n
Whether the model has self-attention guidance or not.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def has_self_attention_guidance(self) -> bool:\n \"\"\"Whether the model has self-attention guidance or not.\"\"\"\n return self._find_sag_adapter() is not None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.set_self_attention_guidance","title":"set_self_attention_guidance","text":"set_self_attention_guidance(\n enable: bool, scale: float = 1.0\n) -> None\n
Sets the self-attention guidance.
See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.
Parameters:
Name Type Description Defaultenable
bool
Whether to enable self-attention guidance or not.
requiredscale
float
The scale to use.
1.0
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:\n \"\"\"Sets the self-attention guidance.\n\n See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)\n for more details.\n\n Args:\n enable: Whether to enable self-attention guidance or not.\n scale: The scale to use.\n \"\"\"\n if enable:\n if sag := self._find_sag_adapter():\n sag.scale = scale\n else:\n SDXLSAGAdapter(target=self.unet, scale=scale).inject()\n else:\n if sag := self._find_sag_adapter():\n sag.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.StableDiffusion_XL.set_unet_context","title":"set_unet_context","text":"set_unet_context(\n *,\n timestep: Tensor,\n clip_text_embedding: Tensor,\n pooled_text_embedding: Tensor,\n time_ids: Tensor,\n **_: Tensor\n) -> None\n
Set the various context parameters required by the U-Net model.
Parameters:
Name Type Description Defaulttimestep
Tensor
The timestep to set.
requiredclip_text_embedding
Tensor
The CLIP text embedding to set.
requiredpooled_text_embedding
Tensor
The pooled CLIP text embedding to set.
requiredtime_ids
Tensor
The time IDs to set.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py
def set_unet_context(\n self,\n *,\n timestep: Tensor,\n clip_text_embedding: Tensor,\n pooled_text_embedding: Tensor,\n time_ids: Tensor,\n **_: Tensor,\n) -> None:\n \"\"\"Set the various context parameters required by the U-Net model.\n\n Args:\n timestep: The timestep to set.\n clip_text_embedding: The CLIP text embedding to set.\n pooled_text_embedding: The pooled CLIP text embedding to set.\n time_ids: The time IDs to set.\n \"\"\"\n self.unet.set_timestep(timestep=timestep)\n self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)\n self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)\n self.unet.set_time_ids(time_ids=time_ids)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.add_lcm_lora","title":"add_lcm_lora","text":"add_lcm_lora(\n manager: SDLoraManager,\n tensors: dict[str, Tensor],\n name: str = \"lcm\",\n scale: float = 8.0 / 64.0,\n check_validity: bool = True,\n) -> None\n
Add a LCM-LoRA or a LoRA with similar structure such as SDXL-Lightning to SDXLUNet.
This is a complex LoRA so SDLoraManager.add_loras() is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of auto_attach_loras.
LCM-LoRA can be used with or without CFG in SD. If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.
Parameters:
Name Type Description Defaultmanager
SDLoraManager
A SDLoraManager for SDXL.
requiredtensors
dict[str, Tensor]
The state_dict
of the LoRA.
name
str
The name of the LoRA.
'lcm'
scale
float
The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).
8.0 / 64.0
check_validity
bool
Perform additional checks, raise an exception if they fail.
True
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py
def add_lcm_lora(\n manager: SDLoraManager,\n tensors: dict[str, torch.Tensor],\n name: str = \"lcm\",\n scale: float = 8.0 / 64.0,\n check_validity: bool = True,\n) -> None:\n \"\"\"Add a [LCM-LoRA](https://arxiv.org/abs/2311.05556) or a LoRA with similar structure\n such as [SDXL-Lightning](https://arxiv.org/abs/2402.13929) to SDXLUNet.\n\n This is a complex LoRA so [SDLoraManager.add_loras()][refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras]\n is not enough. Instead, we add the LoRAs to the UNet in several iterations, using the filtering mechanism of\n [auto_attach_loras][refiners.fluxion.adapters.lora.auto_attach_loras].\n\n LCM-LoRA can be used with or without CFG in SD.\n If you use CFG, typical values range from 1.0 (same as no CFG) to 2.0.\n\n Args:\n manager: A SDLoraManager for SDXL.\n tensors: The `state_dict` of the LoRA.\n name: The name of the LoRA.\n scale: The scale to use for the LoRA (should generally not be changed, those LoRAs must use alpha / rank).\n check_validity: Perform additional checks, raise an exception if they fail.\n \"\"\"\n\n assert isinstance(manager.target, StableDiffusion_XL)\n unet = manager.target.unet\n\n loras = Lora.from_dict(name, {k: v.to(unet.device, unet.dtype) for k, v in tensors.items()})\n assert all(k.startswith(\"lora_unet_\") for k in loras.keys())\n loras = {k: loras[k] for k in sorted(loras.keys(), key=SDLoraManager.sort_keys)}\n\n debug_map: list[tuple[str, str]] | None = [] if check_validity else None\n\n # Projections are in `SDXLCrossAttention` but not in `CrossAttentionBlock`.\n loras_projs = {k: v for k, v in loras.items() if k.endswith(\"proj_in\") or k.endswith(\"proj_out\")}\n auto_attach_loras(\n loras_projs,\n unet,\n exclude=[\"CrossAttentionBlock\"],\n include=[\"SDXLCrossAttention\"],\n debug_map=debug_map,\n )\n\n manager.add_loras_to_unet(\n {k: v for k, v in loras.items() if k not in loras_projs},\n debug_map=debug_map,\n )\n\n if debug_map is not None:\n _check_validity(debug_map)\n\n # LoRAs are finally injected, set the scale with the manager.\n manager.set_scale(name, scale)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight","title":"ICLight","text":"ICLight(\n patch_weights: dict[str, Tensor],\n unet: SD1UNet,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: StableDiffusion_1
IC-Light is a Stable Diffusion model that can be used to relight a reference image.
At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.
Exampleimport torch\nfrom huggingface_hub import hf_hub_download\nfrom PIL import Image\n\nfrom refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad\nfrom refiners.foundationals.clip import CLIPTextEncoderL\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\ndtype = torch.float32\nno_grad().__enter__()\nmanual_seed(42)\n\nsd = ICLight(\n patch_weights=load_from_safetensors(\n path=hf_hub_download(\n repo_id=\"refiners/ic_light.sd1_5.fc\",\n filename=\"model.safetensors\",\n ),\n device=device,\n ),\n unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(\n tensors_path=hf_hub_download(\n repo_id=\"refiners/realistic_vision.v5_1.sd1_5.unet\",\n filename=\"model.safetensors\",\n )\n ),\n clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(\n tensors_path=hf_hub_download(\n repo_id=\"refiners/realistic_vision.v5_1.sd1_5.text_encoder\",\n filename=\"model.safetensors\",\n )\n ),\n lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(\n tensors_path=hf_hub_download(\n repo_id=\"refiners/realistic_vision.v5_1.sd1_5.autoencoder\",\n filename=\"model.safetensors\",\n )\n ),\n device=device,\n dtype=dtype,\n)\n\nprompt = \"soft lighting, high-quality professional image\"\nnegative_prompt = \"lowres, bad anatomy, bad hands, cropped, worst quality\"\nclip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)\n\nimage = Image.open(\"reference-image.png\").resize((512, 512))\nsd.set_ic_light_condition(image)\n\nx = torch.randn(\n size=(1, 4, 64, 64),\n device=device,\n dtype=dtype,\n)\n\nfor step in sd.steps:\n x = sd(\n x=x,\n step=step,\n clip_text_embedding=clip_text_embedding,\n condition_scale=1.5,\n )\npredicted_image = sd.lda.latents_to_image(x)\n\npredicted_image.save(\"ic-light-output.png\")\n
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def __init__(\n self,\n patch_weights: dict[str, torch.Tensor],\n unet: SD1UNet,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: torch.device | str = \"cpu\",\n dtype: torch.dtype = torch.float32,\n) -> None:\n super().__init__(\n unet=unet,\n lda=lda,\n clip_text_encoder=clip_text_encoder,\n solver=solver,\n device=device,\n dtype=dtype,\n )\n self._extend_conv_in()\n self._apply_patch(weights=patch_weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight.compute_gray_composite","title":"compute_gray_composite staticmethod
","text":"compute_gray_composite(image: Image, mask: Image) -> Image\n
Compute a grayscale composite of an image and a mask.
IC-Light will recreate the image
Parameters:
Name Type Description Defaultimage
Image
The image to composite.
requiredmask
Image
The mask to use for the composite.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
@staticmethod\ndef compute_gray_composite(\n image: Image.Image,\n mask: Image.Image,\n) -> Image.Image:\n \"\"\"Compute a grayscale composite of an image and a mask.\n\n IC-Light will recreate the image\n\n Args:\n image: The image to composite.\n mask: The mask to use for the composite.\n \"\"\"\n assert mask.mode == \"L\", \"Mask must be a grayscale image\"\n assert image.size == mask.size, \"Image and mask must have the same size\"\n background = Image.new(\"RGB\", image.size, (127, 127, 127))\n return Image.composite(image, background, mask)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.ICLight.set_ic_light_condition","title":"set_ic_light_condition","text":"set_ic_light_condition(\n image: Image, mask: Image | None = None\n) -> None\n
Set the IC light condition.
Parameters:
Name Type Description Defaultimage
Image
The reference image.
requiredmask
Image | None
The mask to use for the reference image.
None
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py
def set_ic_light_condition(\n self,\n image: Image.Image,\n mask: Image.Image | None = None,\n) -> None:\n \"\"\"Set the IC light condition.\n\n Args:\n image: The reference image.\n mask: The mask to use for the reference image.\n\n If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,\n the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.\n \"\"\"\n if mask is not None:\n image = self.compute_gray_composite(image=image, mask=mask)\n latents = self.lda.image_to_latents(image)\n self._ic_light_condition = latents\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1Autoencoder","title":"SD1Autoencoder","text":"SD1Autoencoder(\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: LatentDiffusionAutoencoder
Stable Diffusion 1.5 autoencoder model.
Attributes:
Name Type Descriptionencoder_scale
float
The encoder scale to use.
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to use.
None
dtype
dtype | None
The PyTorch data type to use.
None
Source code in src/refiners/foundationals/latent_diffusion/auto_encoder.py
def __init__(\n self,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initializes the model.\n\n Args:\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n super().__init__(\n Encoder(device=device, dtype=dtype),\n Decoder(device=device, dtype=dtype),\n )\n self._tile_size = None\n self._blending = None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1ELLAAdapter","title":"SD1ELLAAdapter","text":"SD1ELLAAdapter(\n target: SD1UNet,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: ELLAAdapter[SD1UNet]
ELLA
adapter for Stable Diffusion 1.5.
Parameters:
Name Type Description Defaulttarget
SD1UNet
The target model to adapt.
requiredweights
dict[str, Tensor] | None
The weights of the ELLA adapter (see scripts/conversion/convert_ella_adapter.py
).
None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py
def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None:\n \"\"\"Initialize the adapter.\n\n Args:\n target: The target model to adapt.\n weights: The weights of the ELLA adapter (see `scripts/conversion/convert_ella_adapter.py`).\n \"\"\"\n latents_encoder = ELLA(\n time_channel=320,\n timestep_embedding_dim=768,\n width=768,\n num_layers=6,\n num_heads=8,\n num_latents=64,\n input_dim=2048,\n device=target.device,\n dtype=target.dtype,\n )\n super().__init__(target=target, latents_encoder=latents_encoder, weights=weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet","title":"SD1UNet","text":"SD1UNet(\n in_channels: int,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Chain
Stable Diffusion 1.5 U-Net.
See [arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models for more details.
Parameters:
Name Type Description Defaultin_channels
int
The number of input channels.
requireddevice
device | str | None
The PyTorch device to use for computation.
None
dtype
dtype | None
The PyTorch dtype to use for computation.
None
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def __init__(\n self,\n in_channels: int,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n \"\"\"Initialize the U-Net.\n\n Args:\n in_channels: The number of input channels.\n device: The PyTorch device to use for computation.\n dtype: The PyTorch dtype to use for computation.\n \"\"\"\n self.in_channels = in_channels\n super().__init__(\n TimestepEncoder(device=device, dtype=dtype),\n DownBlocks(in_channels=in_channels, device=device, dtype=dtype),\n fl.Sum(\n fl.UseContext(context=\"unet\", key=\"residuals\").compose(lambda x: x[-1]),\n MiddleBlock(device=device, dtype=dtype),\n ),\n UpBlocks(device=device, dtype=dtype),\n fl.Chain(\n fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),\n fl.SiLU(),\n fl.Conv2d(\n in_channels=320,\n out_channels=4,\n kernel_size=3,\n stride=1,\n padding=1,\n device=device,\n dtype=dtype,\n ),\n ),\n )\n for residual_block in self.layers(ResidualBlock):\n chain = residual_block.layer(\"Chain\", fl.Chain)\n RangeAdapter2d(\n target=chain.layer(\"Conv2d_1\", fl.Conv2d),\n channels=residual_block.out_channels,\n embedding_dim=1280,\n context_key=\"timestep_embedding\",\n device=device,\n dtype=dtype,\n ).inject(chain)\n for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):\n block.append(ResidualAccumulator(n))\n for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):\n block.insert(0, ResidualConcatenator(-n - 2))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet.set_clip_text_embedding","title":"set_clip_text_embedding","text":"set_clip_text_embedding(\n clip_text_embedding: Tensor,\n) -> None\n
Set the CLIP text embedding.
NoteThis context is required by the CLIPLCrossAttention
blocks.
Parameters:
Name Type Description Defaultclip_text_embedding
Tensor
The CLIP text embedding.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:\n \"\"\"Set the CLIP text embedding.\n\n Note:\n This context is required by the `CLIPLCrossAttention` blocks.\n\n Args:\n clip_text_embedding: The CLIP text embedding.\n \"\"\"\n self.set_context(\"cross_attention_block\", {\"clip_text_embedding\": clip_text_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.SD1UNet.set_timestep","title":"set_timestep","text":"set_timestep(timestep: Tensor) -> None\n
Set the timestep.
NoteThis context is required by TimestepEncoder
.
Parameters:
Name Type Description Defaulttimestep
Tensor
The timestep.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py
def set_timestep(self, timestep: Tensor) -> None:\n \"\"\"Set the timestep.\n\n Note:\n This context is required by `TimestepEncoder`.\n\n Args:\n timestep: The timestep.\n \"\"\"\n self.set_context(\"diffusion\", {\"timestep\": timestep})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1","title":"StableDiffusion_1","text":"StableDiffusion_1(\n unet: SD1UNet | None = None,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: LatentDiffusionModel
Stable Diffusion 1.5 model.
Attributes:
Name Type Descriptionunet
SD1UNet
The U-Net model.
clip_text_encoder
CLIPTextEncoderL
The text encoder.
lda
SD1Autoencoder
The image autoencoder.
Example:
import torch\n\nfrom refiners.fluxion.utils import manual_seed, no_grad\nfrom refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1\n\n# Load SD\nsd15 = StableDiffusion_1(device=\"cuda\", dtype=torch.float16)\n\nsd15.clip_text_encoder.load_from_safetensors(\"sd1_5.text_encoder.safetensors\")\nsd15.unet.load_from_safetensors(\"sd1_5.unet.safetensors\")\nsd15.lda.load_from_safetensors(\"sd1_5.autoencoder.safetensors\")\n\n# Hyperparameters\nprompt = \"a cute cat, best quality, high quality\"\nnegative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\nseed = 42\n\nsd15.set_inference_steps(50)\n\nwith no_grad(): # Disable gradient calculation for memory-efficient inference\n clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)\n manual_seed(seed)\n\n x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)\n\n # Diffusion process\n for step in sd15.steps:\n x = sd15(x, step=step, clip_text_embedding=clip_text_embedding)\n\n predicted_image = sd15.lda.latents_to_image(x)\n predicted_image.save(\"output.png\")\n
Parameters:
Name Type Description Defaultunet
SD1UNet | None
The SD1UNet U-Net model to use.
None
lda
SD1Autoencoder | None
The SD1Autoencoder image autoencoder to use.
None
clip_text_encoder
CLIPTextEncoderL | None
The CLIPTextEncoderL text encoder to use.
None
solver
Solver | None
The solver to use.
None
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(\n self,\n unet: SD1UNet | None = None,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n \"\"\"Initializes the model.\n\n Args:\n unet: The SD1UNet U-Net model to use.\n lda: The SD1Autoencoder image autoencoder to use.\n clip_text_encoder: The CLIPTextEncoderL text encoder to use.\n solver: The solver to use.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n unet = unet or SD1UNet(in_channels=4)\n lda = lda or SD1Autoencoder()\n clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()\n solver = solver or DPMSolver(num_inference_steps=30)\n\n super().__init__(\n unet=unet,\n lda=lda,\n clip_text_encoder=clip_text_encoder,\n solver=solver,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.compute_clip_text_embedding","title":"compute_clip_text_embedding","text":"compute_clip_text_embedding(\n text: str | list[str],\n negative_text: str | list[str] = \"\",\n) -> Tensor\n
Compute the CLIP text embedding associated with the given prompt and negative prompt.
Parameters:
Name Type Description Defaulttext
str | list[str]
The prompt to compute the CLIP text embedding of.
requirednegative_text
str | list[str]
The negative prompt to compute the CLIP text embedding of. If not provided, the negative prompt is assumed to be empty (i.e., \"\"
).
''
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = \"\") -> Tensor:\n \"\"\"Compute the CLIP text embedding associated with the given prompt and negative prompt.\n\n Args:\n text: The prompt to compute the CLIP text embedding of.\n negative_text: The negative prompt to compute the CLIP text embedding of.\n If not provided, the negative prompt is assumed to be empty (i.e., `\"\"`).\n \"\"\"\n text = [text] if isinstance(text, str) else text\n\n if not self.classifier_free_guidance:\n return self.clip_text_encoder(text)\n\n negative_text = [negative_text] if isinstance(negative_text, str) else negative_text\n assert len(text) == len(negative_text), \"The length of the text list and negative_text should be the same\"\n\n conditional_embedding = self.clip_text_encoder(text)\n negative_embedding = self.clip_text_encoder(negative_text)\n\n return torch.cat((negative_embedding, conditional_embedding))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"compute_self_attention_guidance(\n x: Tensor,\n noise: Tensor,\n step: int,\n *,\n clip_text_embedding: Tensor,\n **kwargs: Tensor\n) -> Tensor\n
Compute the self-attention guidance.
Parameters:
Name Type Description Defaultx
Tensor
The input tensor.
requirednoise
Tensor
The noise tensor.
requiredstep
int
The step to compute the self-attention guidance at.
requiredclip_text_embedding
Tensor
The CLIP text embedding to compute the self-attention guidance with.
requiredReturns:
Type DescriptionTensor
The computed self-attention guidance.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(\n self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor\n) -> Tensor:\n \"\"\"Compute the self-attention guidance.\n\n Args:\n x: The input tensor.\n noise: The noise tensor.\n step: The step to compute the self-attention guidance at.\n clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n\n Returns:\n The computed self-attention guidance.\n \"\"\"\n sag = self._find_sag_adapter()\n assert sag is not None\n\n degraded_latents = sag.compute_degraded_latents(\n solver=self.solver,\n latents=x,\n noise=noise,\n step=step,\n classifier_free_guidance=True,\n )\n\n timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n negative_embedding, _ = clip_text_embedding.chunk(2)\n self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)\n if \"ip_adapter\" in self.unet.provider.contexts:\n # this implementation is a bit hacky, it should be refactored in the future\n ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n degraded_noise = self.unet(degraded_latents)\n ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n else:\n degraded_noise = self.unet(degraded_latents)\n\n return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.has_self_attention_guidance","title":"has_self_attention_guidance","text":"has_self_attention_guidance() -> bool\n
Whether the model has self-attention guidance or not.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def has_self_attention_guidance(self) -> bool:\n \"\"\"Whether the model has self-attention guidance or not.\"\"\"\n return self._find_sag_adapter() is not None\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.set_self_attention_guidance","title":"set_self_attention_guidance","text":"set_self_attention_guidance(\n enable: bool, scale: float = 1.0\n) -> None\n
Set whether to enable self-attention guidance.
See [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance for more details.
Parameters:
Name Type Description Defaultenable
bool
Whether to enable self-attention guidance.
requiredscale
float
The scale to use.
1.0
Source code in src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:\n \"\"\"Set whether to enable self-attention guidance.\n\n See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)\n for more details.\n\n Args:\n enable: Whether to enable self-attention guidance.\n scale: The scale to use.\n \"\"\"\n if enable:\n if sag := self._find_sag_adapter():\n sag.scale = scale\n else:\n SD1SAGAdapter(target=self.unet, scale=scale).inject()\n else:\n if sag := self._find_sag_adapter():\n sag.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1.set_unet_context","title":"set_unet_context","text":"set_unet_context(\n *,\n timestep: Tensor,\n clip_text_embedding: Tensor,\n **_: Tensor\n) -> None\n
Set the various context parameters required by the U-Net model.
Parameters:
Name Type Description Defaulttimestep
Tensor
The timestep tensor to use.
requiredclip_text_embedding
Tensor
The CLIP text embedding tensor to use.
required Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:\n \"\"\"Set the various context parameters required by the U-Net model.\n\n Args:\n timestep: The timestep tensor to use.\n clip_text_embedding: The CLIP text embedding tensor to use.\n \"\"\"\n self.unet.set_timestep(timestep=timestep)\n self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting","title":"StableDiffusion_1_Inpainting","text":"StableDiffusion_1_Inpainting(\n unet: SD1UNet | None = None,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: StableDiffusion_1
Stable Diffusion 1.5 inpainting model.
Attributes:
Name Type Descriptionunet
The U-Net model.
clip_text_encoder
The text encoder.
lda
The image autoencoder.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def __init__(\n self,\n unet: SD1UNet | None = None,\n lda: SD1Autoencoder | None = None,\n clip_text_encoder: CLIPTextEncoderL | None = None,\n solver: Solver | None = None,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n self.mask_latents: Tensor | None = None\n self.target_image_latents: Tensor | None = None\n unet = unet or SD1UNet(in_channels=9)\n super().__init__(\n unet=unet,\n lda=lda,\n clip_text_encoder=clip_text_encoder,\n solver=solver,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting.compute_self_attention_guidance","title":"compute_self_attention_guidance","text":"compute_self_attention_guidance(\n x: Tensor,\n noise: Tensor,\n step: int,\n *,\n clip_text_embedding: Tensor,\n **kwargs: Tensor\n) -> Tensor\n
Compute the self-attention guidance.
Parameters:
Name Type Description Defaultx
Tensor
The input tensor.
requirednoise
Tensor
The noise tensor.
requiredstep
int
The step to compute the self-attention guidance at.
requiredclip_text_embedding
Tensor
The CLIP text embedding to compute the self-attention guidance with.
requiredReturns:
Type DescriptionTensor
The computed self-attention guidance.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def compute_self_attention_guidance(\n self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor\n) -> Tensor:\n \"\"\"Compute the self-attention guidance.\n\n Args:\n x: The input tensor.\n noise: The noise tensor.\n step: The step to compute the self-attention guidance at.\n clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.\n\n Returns:\n The computed self-attention guidance.\n \"\"\"\n sag = self._find_sag_adapter()\n assert sag is not None\n assert self.mask_latents is not None\n assert self.target_image_latents is not None\n\n degraded_latents = sag.compute_degraded_latents(\n solver=self.solver,\n latents=x,\n noise=noise,\n step=step,\n classifier_free_guidance=True,\n )\n x = torch.cat(\n tensors=(degraded_latents, self.mask_latents, self.target_image_latents),\n dim=1,\n )\n\n timestep = self.solver.timesteps[step].unsqueeze(dim=0)\n negative_embedding, _ = clip_text_embedding.chunk(2)\n self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)\n\n if \"ip_adapter\" in self.unet.provider.contexts:\n # this implementation is a bit hacky, it should be refactored in the future\n ip_adapter_context = self.unet.use_context(\"ip_adapter\")\n image_embedding_copy = ip_adapter_context[\"clip_image_embedding\"].clone()\n ip_adapter_context[\"clip_image_embedding\"], _ = ip_adapter_context[\"clip_image_embedding\"].chunk(2)\n degraded_noise = self.unet(x)\n ip_adapter_context[\"clip_image_embedding\"] = image_embedding_copy\n else:\n degraded_noise = self.unet(x)\n\n return sag.scale * (noise - degraded_noise)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_1.StableDiffusion_1_Inpainting.set_inpainting_conditions","title":"set_inpainting_conditions","text":"set_inpainting_conditions(\n target_image: Image,\n mask: Image,\n latents_size: tuple[int, int] = (64, 64),\n) -> tuple[Tensor, Tensor]\n
Set the inpainting conditions.
Parameters:
Name Type Description Defaulttarget_image
Image
The target image to inpaint.
requiredmask
Image
The mask to use for inpainting.
requiredlatents_size
tuple[int, int]
The size of the latents to use.
(64, 64)
Returns:
Type Descriptiontuple[Tensor, Tensor]
The mask latents and the target image latents.
Source code insrc/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py
def set_inpainting_conditions(\n self,\n target_image: Image.Image,\n mask: Image.Image,\n latents_size: tuple[int, int] = (64, 64),\n) -> tuple[Tensor, Tensor]:\n \"\"\"Set the inpainting conditions.\n\n Args:\n target_image: The target image to inpaint.\n mask: The mask to use for inpainting.\n latents_size: The size of the latents to use.\n\n Returns:\n The mask latents and the target image latents.\n \"\"\"\n target_image = target_image.convert(mode=\"RGB\")\n mask = mask.convert(mode=\"L\")\n\n mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)\n mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)\n self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))\n\n init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1\n masked_init_image = init_image_tensor * (1 - mask_tensor)\n self.target_image_latents = self.lda.encode(x=masked_init_image)\n\n return self.mask_latents, self.target_image_latents\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DDIM","title":"DDIM","text":"DDIM(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Solver
Denoising Diffusion Implicit Model (DDIM) solver.
See [arXiv:2010.02502] Denoising Diffusion Implicit Models for more details.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/ddim.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: Device | str = \"cpu\",\n dtype: Dtype = torch.float32,\n) -> None:\n \"\"\"Initializes a new DDIM solver.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n raise NotImplementedError\n if params and params.sde_variance != 0.0:\n raise NotImplementedError(\"DDIM does not support sde_variance != 0.0 yet\")\n\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n params=params,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DDPM","title":"DDPM","text":"DDPM(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: device | str = \"cpu\",\n)\n
Bases: Solver
Denoising Diffusion Probabilistic Model (DDPM) solver.
WarningOnly used for training Latent Diffusion models. Cannot be called.
See [arXiv:2006.11239] Denoising Diffusion Probabilistic Models for more details.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
device
device | str
The PyTorch device to use.
'cpu'
Source code in src/refiners/foundationals/latent_diffusion/solvers/ddpm.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: Device | str = \"cpu\",\n) -> None:\n \"\"\"Initializes a new DDPM solver.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n device: The PyTorch device to use.\n \"\"\"\n\n if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n raise NotImplementedError\n\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n params=params,\n device=device,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver","title":"DPMSolver","text":"DPMSolver(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n last_step_first_order: bool = False,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Solver
Diffusion probabilistic models (DPMs) solver.
See [arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models for more details.
NoteRegarding last_step_first_order: DPM-Solver++ is known to introduce artifacts when used with SDXL and few steps. This parameter is a way to mitigate that effect by using a first-order (Euler) update instead of a second-order update for the last step of the diffusion.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
last_step_first_order
bool
Use a first-order update for the last step.
False
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n last_step_first_order: bool = False,\n device: torch.device | str = \"cpu\",\n dtype: torch.dtype = torch.float32,\n) -> None:\n \"\"\"Initializes a new DPM solver.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n last_step_first_order: Use a first-order update for the last step.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):\n raise NotImplementedError\n if params and params.sde_variance not in (0.0, 1.0):\n raise NotImplementedError(\"DPMSolver only supports sde_variance=0.0 or 1.0\")\n\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n params=params,\n device=device,\n dtype=torch.float64, # compute constants precisely\n )\n self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)\n self.last_step_first_order = last_step_first_order\n sigmas = self.noise_std / self.cumulative_scale_factors\n self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)\n sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type=\"sigma_min\" in diffusers`\n self.sigmas = torch.cat([self.sigmas, sigma_min])\n self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(\n self.sigmas\n )\n self.timesteps = self._timesteps_from_sigmas(sigmas)\n self.to(dtype=dtype)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.dpm_solver_first_order_update","title":"dpm_solver_first_order_update","text":"dpm_solver_first_order_update(\n x: Tensor,\n noise: Tensor,\n step: int,\n sde_noise: Tensor | None = None,\n) -> Tensor\n
Applies a first-order backward Euler update to the input data x
.
Parameters:
Name Type Description Defaultx
Tensor
The input data.
requirednoise
Tensor
The predicted noise.
requiredstep
int
The current step.
requiredReturns:
Type DescriptionTensor
The denoised version of the input data x
.
src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def dpm_solver_first_order_update(\n self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None\n) -> torch.Tensor:\n \"\"\"Applies a first-order backward Euler update to the input data `x`.\n\n Args:\n x: The input data.\n noise: The predicted noise.\n step: The current step.\n\n Returns:\n The denoised version of the input data `x`.\n \"\"\"\n current_ratio = self.signal_to_noise_ratios[step]\n next_ratio = self.signal_to_noise_ratios[step + 1]\n\n next_scale_factor = self.cumulative_scale_factors[step + 1]\n\n next_noise_std = self.noise_std[step + 1]\n current_noise_std = self.noise_std[step]\n\n ratio_delta = current_ratio - next_ratio\n\n if sde_noise is None:\n return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise\n\n factor = 1.0 - torch.exp(2.0 * ratio_delta)\n return (\n (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x\n + next_scale_factor * factor * noise\n + next_noise_std * safe_sqrt(factor) * sde_noise\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.multistep_dpm_solver_second_order_update","title":"multistep_dpm_solver_second_order_update","text":"multistep_dpm_solver_second_order_update(\n x: Tensor, step: int, sde_noise: Tensor | None = None\n) -> Tensor\n
Applies a second-order backward Euler update to the input data x
.
Parameters:
Name Type Description Defaultx
Tensor
The input data.
requiredstep
int
The current step.
requiredReturns:
Type DescriptionTensor
The denoised version of the input data x
.
src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def multistep_dpm_solver_second_order_update(\n self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None\n) -> torch.Tensor:\n \"\"\"Applies a second-order backward Euler update to the input data `x`.\n\n Args:\n x: The input data.\n step: The current step.\n\n Returns:\n The denoised version of the input data `x`.\n \"\"\"\n current_data_estimation = self.estimated_data[-1]\n previous_data_estimation = self.estimated_data[-2]\n\n next_ratio = self.signal_to_noise_ratios[step + 1]\n current_ratio = self.signal_to_noise_ratios[step]\n previous_ratio = self.signal_to_noise_ratios[step - 1]\n\n next_scale_factor = self.cumulative_scale_factors[step + 1]\n next_noise_std = self.noise_std[step + 1]\n current_noise_std = self.noise_std[step]\n\n estimation_delta = (current_data_estimation - previous_data_estimation) / (\n (current_ratio - previous_ratio) / (next_ratio - current_ratio)\n )\n ratio_delta = current_ratio - next_ratio\n\n if sde_noise is None:\n factor = 1.0 - torch.exp(ratio_delta)\n return (\n (next_noise_std / current_noise_std) * x\n + next_scale_factor * factor * current_data_estimation\n + 0.5 * next_scale_factor * factor * estimation_delta\n )\n\n factor = 1.0 - torch.exp(2.0 * ratio_delta)\n return (\n (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x\n + next_scale_factor * factor * current_data_estimation\n + 0.5 * next_scale_factor * factor * estimation_delta\n + next_noise_std * safe_sqrt(factor) * sde_noise\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.rebuild","title":"rebuild","text":"rebuild(\n num_inference_steps: int | None,\n first_inference_step: int | None = None,\n) -> DPMSolver\n
Rebuilds the solver with new parameters.
Parameters:
Name Type Description Defaultnum_inference_steps
int | None
The number of inference steps.
requiredfirst_inference_step
int | None
The first inference step.
None
Source code in src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def rebuild(\n self: \"DPMSolver\",\n num_inference_steps: int | None,\n first_inference_step: int | None = None,\n) -> \"DPMSolver\":\n \"\"\"Rebuilds the solver with new parameters.\n\n Args:\n num_inference_steps: The number of inference steps.\n first_inference_step: The first inference step.\n \"\"\"\n r = super().rebuild(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n )\n r.last_step_first_order = self.last_step_first_order\n return r\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.DPMSolver.remove_noise","title":"remove_noise","text":"remove_noise(x: Tensor, noise: Tensor, step: int) -> Tensor\n
Remove noise from the input tensor using the current step of the diffusion process.
See Solver.remove_noise
for more details.
src/refiners/foundationals/latent_diffusion/solvers/dpm.py
def remove_noise(self, x: torch.Tensor, noise: torch.Tensor, step: int) -> torch.Tensor:\n \"\"\"Remove noise from the input tensor using the current step of the diffusion process.\n\n See [`Solver.remove_noise`][refiners.foundationals.latent_diffusion.solvers.solver.Solver.remove_noise] for more details.\n \"\"\"\n cumulative_scale_factors = self.cumulative_scale_factors[step]\n noise_stds = self.noise_std[step]\n denoised_x = (x - noise_stds * noise) / cumulative_scale_factors\n return denoised_x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler","title":"Euler","text":"Euler(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Solver
Euler solver.
See [arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models for more details.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/euler.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: Device | str = \"cpu\",\n dtype: Dtype = torch.float32,\n):\n \"\"\"Initializes a new Euler solver.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):\n raise NotImplementedError\n if params and params.sde_variance != 0.0:\n raise NotImplementedError(\"Euler does not support sde_variance != 0.0 yet\")\n\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n params=params,\n device=device,\n dtype=dtype,\n )\n self.sigmas = self._generate_sigmas()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler.init_noise_sigma","title":"init_noise_sigma property
","text":"init_noise_sigma: Tensor\n
The initial noise sigma.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Euler.scale_model_input","title":"scale_model_input","text":"scale_model_input(x: Tensor, step: int) -> Tensor\n
Scales the model input according to the current step.
Parameters:
Name Type Description Defaultx
Tensor
The model input.
requiredstep
int
The current step. This method is called with step=-1
in init_latents
.
Returns:
Type DescriptionTensor
The scaled model input.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/euler.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:\n \"\"\"Scales the model input according to the current step.\n\n Args:\n x: The model input.\n step: The current step. This method is called with `step=-1` in `init_latents`.\n\n Returns:\n The scaled model input.\n \"\"\"\n\n if step == -1:\n return x * self.init_noise_sigma\n\n sigma = self.sigmas[step]\n return x / ((sigma**2 + 1) ** 0.5)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.FrankenSolver","title":"FrankenSolver","text":"FrankenSolver(\n get_diffusers_scheduler: Callable[[], SchedulerLike],\n num_inference_steps: int,\n first_inference_step: int = 0,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n **kwargs: Any\n)\n
Bases: Solver
Lets you use Diffusers Schedulers as Refiners Solvers.
For instancefrom diffusers import EulerDiscreteScheduler\nfrom refiners.foundationals.latent_diffusion.solvers import FrankenSolver\n\nscheduler = EulerDiscreteScheduler(...)\nsolver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)\n
Source code in src/refiners/foundationals/latent_diffusion/solvers/franken.py
def __init__(\n self,\n get_diffusers_scheduler: Callable[[], SchedulerLike],\n num_inference_steps: int,\n first_inference_step: int = 0,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n **kwargs: Any, # for typing, ignored\n) -> None:\n self.get_diffusers_scheduler = get_diffusers_scheduler\n self.diffusers_scheduler = self.get_diffusers_scheduler()\n self.diffusers_scheduler.set_timesteps(num_inference_steps)\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.LCMSolver","title":"LCMSolver","text":"LCMSolver(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n num_orig_steps: int = 50,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Solver
Latent Consistency Model solver.
This solver is designed for use either with a specific base model or a specific LoRA.
See [arXiv:2310.04378] Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference for details.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
num_orig_steps
int
The number of inference steps of the emulated DPM solver.
50
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/lcm.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n num_orig_steps: int = 50,\n device: torch.device | str = \"cpu\",\n dtype: torch.dtype = torch.float32,\n):\n \"\"\"Initializes a new LCM solver.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n num_orig_steps: The number of inference steps of the emulated DPM solver.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n \"\"\"\n\n assert (\n num_orig_steps >= num_inference_steps\n ), f\"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})\"\n\n params = self.resolve_params(params)\n if params.model_prediction_type != ModelPredictionType.NOISE:\n raise NotImplementedError\n\n self._dpm = [\n DPMSolver(\n num_inference_steps=num_orig_steps,\n params=SolverParams(\n num_train_timesteps=params.num_train_timesteps,\n timesteps_spacing=params.timesteps_spacing,\n ),\n device=device,\n dtype=dtype,\n )\n ]\n super().__init__(\n num_inference_steps=num_inference_steps,\n first_inference_step=first_inference_step,\n params=params,\n device=device,\n dtype=dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.ModelPredictionType","title":"ModelPredictionType","text":" Bases: str
, Enum
An enumeration of possible outputs of the model.
Attributes:
Name Type DescriptionNOISE
The model predicts the noise (epsilon).
SAMPLE
The model predicts the denoised sample (x0).
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.NoiseSchedule","title":"NoiseSchedule","text":" Bases: str
, Enum
An enumeration of schedules used to sample the noise.
Attributes:
Name Type DescriptionUNIFORM
A uniform noise schedule.
QUADRATIC
A quadratic noise schedule. Corresponds to \"Stable Diffusion\" in [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 1.
KARRAS
See [arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver","title":"Solver","text":"Solver(\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Module
, ABC
The base class for creating a diffusion model solver.
Solvers create a sequence of noise and scaling factors used in the diffusion process, which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates, and is encapsulated into a __call__
method that applies a step of the diffusion process.
Attributes:
Name Type Descriptionparams
ResolvedSolverParams
The common parameters for solvers. See SolverParams
.
num_inference_steps
The number of inference steps to perform.
first_inference_step
The step to start the inference process from.
scale_factors
The scale factors used to denoise the input. These are called \"betas\" in other implementations, and 1 - scale_factors
is called \"alphas\".
cumulative_scale_factors
The cumulative scale factors used to denoise the input. These are called \"alpha_t\" in other implementations.
noise_std
The standard deviation of the noise used to denoise the input. This is called \"sigma_t\" in other implementations.
signal_to_noise_ratios
The signal-to-noise ratios used to denoise the input. This is called \"lambda_t\" in other implementations.
Parameters:
Name Type Description Defaultnum_inference_steps
int
The number of inference steps to perform.
requiredfirst_inference_step
int
The first inference step to perform.
0
params
BaseSolverParams | None
The common parameters for solvers.
None
device
device | str
The PyTorch device to use for the solver's tensors.
'cpu'
dtype
dtype
The PyTorch data type to use for the solver's tensors.
float32
Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
def __init__(\n self,\n num_inference_steps: int,\n first_inference_step: int = 0,\n params: BaseSolverParams | None = None,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n \"\"\"Initializes a new `Solver` instance.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n params: The common parameters for solvers.\n device: The PyTorch device to use for the solver's tensors.\n dtype: The PyTorch data type to use for the solver's tensors.\n \"\"\"\n super().__init__()\n\n self.num_inference_steps = num_inference_steps\n self.first_inference_step = first_inference_step\n self.params = self.resolve_params(params)\n\n self.scale_factors = self.sample_noise_schedule()\n self.cumulative_scale_factors = torch.sqrt(self.scale_factors.cumprod(dim=0))\n self.noise_std = torch.sqrt(1.0 - self.scale_factors.cumprod(dim=0))\n self.signal_to_noise_ratios = torch.log(self.cumulative_scale_factors) - torch.log(self.noise_std)\n self.timesteps = self._generate_timesteps()\n\n self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.all_steps","title":"all_steps property
","text":"all_steps: list[int]\n
Return a list of all inference steps.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.device","title":"deviceproperty
writable
","text":"device: device\n
The PyTorch device used for the solver's tensors.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.dtype","title":"dtypeproperty
writable
","text":"dtype: dtype\n
The PyTorch data type used for the solver's tensors.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.inference_steps","title":"inference_stepsproperty
","text":"inference_steps: list[int]\n
Return a list of inference steps to perform.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.add_noise","title":"add_noise","text":"add_noise(\n x: Tensor, noise: Tensor, step: int | list[int]\n) -> Tensor\n
Add noise to the input tensor using the solver's parameters.
Parameters:
Name Type Description Defaultx
Tensor
The input tensor to add noise to.
requirednoise
Tensor
The noise tensor to add to the input tensor.
requiredstep
int | list[int]
The current step(s) of the diffusion process.
requiredReturns:
Type DescriptionTensor
The input tensor with added noise.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def add_noise(\n self,\n x: Tensor,\n noise: Tensor,\n step: int | list[int],\n) -> Tensor:\n \"\"\"Add noise to the input tensor using the solver's parameters.\n\n Args:\n x: The input tensor to add noise to.\n noise: The noise tensor to add to the input tensor.\n step: The current step(s) of the diffusion process.\n\n Returns:\n The input tensor with added noise.\n \"\"\"\n if isinstance(step, list):\n assert len(x) == len(noise) == len(step), \"x, noise, and step must have the same length\"\n return torch.stack(\n tensors=[\n self._add_noise(\n x=x[i],\n noise=noise[i],\n step=step[i],\n )\n for i in range(x.shape[0])\n ],\n dim=0,\n )\n\n return self._add_noise(x=x, noise=noise, step=step)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.generate_timesteps","title":"generate_timesteps staticmethod
","text":"generate_timesteps(\n spacing: TimestepSpacing,\n num_inference_steps: int,\n num_train_timesteps: int = 1000,\n offset: int = 0,\n) -> Tensor\n
Generate a tensor of timesteps according to a given spacing.
Parameters:
Name Type Description Defaultspacing
TimestepSpacing
The spacing to use for the timesteps.
requirednum_inference_steps
int
The number of inference steps to perform.
requirednum_train_timesteps
int
The number of timesteps used to train the diffusion process.
1000
offset
int
The offset to use for the timesteps.
0
Source code in src/refiners/foundationals/latent_diffusion/solvers/solver.py
@staticmethod\ndef generate_timesteps(\n spacing: TimestepSpacing,\n num_inference_steps: int,\n num_train_timesteps: int = 1000,\n offset: int = 0,\n) -> Tensor:\n \"\"\"Generate a tensor of timesteps according to a given spacing.\n\n Args:\n spacing: The spacing to use for the timesteps.\n num_inference_steps: The number of inference steps to perform.\n num_train_timesteps: The number of timesteps used to train the diffusion process.\n offset: The offset to use for the timesteps.\n \"\"\"\n max_timestep = num_train_timesteps - 1 + offset\n match spacing:\n case TimestepSpacing.LINSPACE:\n return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=torch.float32).flip(0)\n case TimestepSpacing.LINSPACE_ROUNDED:\n return torch.tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)\n case TimestepSpacing.LEADING:\n step_ratio = num_train_timesteps // num_inference_steps\n return (torch.arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)\n case TimestepSpacing.TRAILING:\n step_ratio = num_train_timesteps // num_inference_steps\n max_timestep = num_train_timesteps - 1 + offset\n return torch.arange(max_timestep, offset, -step_ratio)\n case TimestepSpacing.CUSTOM:\n raise RuntimeError(\"generate_timesteps called with custom spacing\")\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.rebuild","title":"rebuild","text":"rebuild(\n num_inference_steps: int | None,\n first_inference_step: int | None = None,\n) -> T\n
Rebuild the solver with new parameters.
Parameters:
Name Type Description Defaultnum_inference_steps
int | None
The number of inference steps to perform.
requiredfirst_inference_step
int | None
The first inference step to perform.
None
Returns:
Type DescriptionT
A new solver instance with the specified parameters.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:\n \"\"\"Rebuild the solver with new parameters.\n\n Args:\n num_inference_steps: The number of inference steps to perform.\n first_inference_step: The first inference step to perform.\n\n Returns:\n A new solver instance with the specified parameters.\n \"\"\"\n return self.__class__(\n num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,\n first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,\n params=dataclasses.replace(self.params),\n device=self.device,\n dtype=self.dtype,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.remove_noise","title":"remove_noise","text":"remove_noise(x: Tensor, noise: Tensor, step: int) -> Tensor\n
Remove noise from the input tensor using the current step of the diffusion process.
NoteSee [arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15 and [arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance.
Parameters:
Name Type Description Defaultx
Tensor
The input tensor to remove noise from.
requirednoise
Tensor
The noise tensor to remove from the input tensor.
requiredstep
int
The current step of the diffusion process.
requiredReturns:
Type DescriptionTensor
The denoised input tensor.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:\n \"\"\"Remove noise from the input tensor using the current step of the diffusion process.\n\n Note:\n See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15](https://arxiv.org/abs/2006.11239)\n and [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939).\n\n Args:\n x: The input tensor to remove noise from.\n noise: The noise tensor to remove from the input tensor.\n step: The current step of the diffusion process.\n\n Returns:\n The denoised input tensor.\n \"\"\"\n timestep = self.timesteps[step]\n cumulative_scale_factors = self.cumulative_scale_factors[timestep]\n noise_stds = self.noise_std[timestep]\n denoised_x = (x - noise_stds * noise) / cumulative_scale_factors\n return denoised_x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.sample_noise_schedule","title":"sample_noise_schedule","text":"sample_noise_schedule() -> Tensor\n
Sample the noise schedule.
Returns:
Type DescriptionTensor
A tensor representing the noise schedule.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_noise_schedule(self) -> Tensor:\n \"\"\"Sample the noise schedule.\n\n Returns:\n A tensor representing the noise schedule.\n \"\"\"\n match self.params.noise_schedule:\n case NoiseSchedule.UNIFORM:\n return 1 - self.sample_power_distribution(1)\n case NoiseSchedule.QUADRATIC:\n return 1 - self.sample_power_distribution(2)\n case NoiseSchedule.KARRAS:\n return 1 - self.sample_power_distribution(7)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.sample_power_distribution","title":"sample_power_distribution","text":"sample_power_distribution(power: float = 2) -> Tensor\n
Sample a power distribution.
Parameters:
Name Type Description Defaultpower
float
The power to use for the distribution.
2
Returns:
Type DescriptionTensor
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def sample_power_distribution(self, power: float = 2, /) -> Tensor:\n \"\"\"Sample a power distribution.\n\n Args:\n power: The power to use for the distribution.\n\n Returns:\n A tensor representing the power distribution between the initial and final diffusion rates of the solver.\n \"\"\"\n return (\n torch.linspace(\n start=self.params.initial_diffusion_rate ** (1 / power),\n end=self.params.final_diffusion_rate ** (1 / power),\n steps=self.params.num_train_timesteps,\n )\n ** power\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.scale_model_input","title":"scale_model_input","text":"scale_model_input(x: Tensor, step: int) -> Tensor\n
Scale the model's input according to the current timestep.
NoteThis method should only be overridden by solvers that need to scale the input according to the current timestep.
By default, this method does not scale the input. (scale=1)
Parameters:
Name Type Description Defaultx
Tensor
The input tensor to scale.
requiredstep
int
The current step of the diffusion process.
requiredReturns:
Type DescriptionTensor
The scaled input tensor.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def scale_model_input(self, x: Tensor, step: int) -> Tensor:\n \"\"\"Scale the model's input according to the current timestep.\n\n Note:\n This method should only be overridden by solvers that\n need to scale the input according to the current timestep.\n\n By default, this method does not scale the input.\n (scale=1)\n\n Args:\n x: The input tensor to scale.\n step: The current step of the diffusion process.\n\n Returns:\n The scaled input tensor.\n \"\"\"\n return x\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.Solver.to","title":"to","text":"to(\n device: device | str | None = None,\n dtype: dtype | None = None,\n) -> Solver\n
Move the solver to the specified device and data type.
Parameters:
Name Type Description Defaultdevice
device | str | None
The PyTorch device to move the solver to.
None
dtype
dtype | None
The PyTorch data type to move the solver to.
None
Returns:
Type DescriptionSolver
The solver instance, moved to the specified device and data type.
Source code insrc/refiners/foundationals/latent_diffusion/solvers/solver.py
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> \"Solver\":\n \"\"\"Move the solver to the specified device and data type.\n\n Args:\n device: The PyTorch device to move the solver to.\n dtype: The PyTorch data type to move the solver to.\n\n Returns:\n The solver instance, moved to the specified device and data type.\n \"\"\"\n super().to(device=device, dtype=dtype)\n for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:\n match name:\n case \"timesteps\":\n setattr(self, name, attr.to(device=device))\n case _:\n setattr(self, name, attr.to(device=device, dtype=dtype))\n return self\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.SolverParams","title":"SolverParams dataclass
","text":"SolverParams(\n *,\n num_train_timesteps: int | None = None,\n timesteps_spacing: TimestepSpacing | None = None,\n timesteps_offset: int | None = None,\n initial_diffusion_rate: float | None = None,\n final_diffusion_rate: float | None = None,\n noise_schedule: NoiseSchedule | None = None,\n sigma_schedule: NoiseSchedule | None = None,\n model_prediction_type: (\n ModelPredictionType | None\n ) = None,\n sde_variance: float = 0.0\n)\n
Bases: BaseSolverParams
Common parameters for solvers.
Parameters:
Name Type Description Defaultnum_train_timesteps
int | None
The number of timesteps used to train the diffusion process.
None
timesteps_spacing
TimestepSpacing | None
The spacing to use for the timesteps.
None
timesteps_offset
int | None
The offset to use for the timesteps.
None
initial_diffusion_rate
float | None
The initial diffusion rate used to sample the noise schedule.
None
final_diffusion_rate
float | None
The final diffusion rate used to sample the noise schedule.
None
noise_schedule
NoiseSchedule | None
The noise schedule used to sample the noise schedule.
None
model_prediction_type
ModelPredictionType | None
Defines what the model predicts.
None
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.solvers.TimestepSpacing","title":"TimestepSpacing","text":" Bases: str
, Enum
An enumeration of methods to space the timesteps.
See [arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed table 2.
Attributes:
Name Type DescriptionLINSPACE
Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_ROUNDED
Same as LINSPACE but return an integer tensor with rounded timesteps.
LEADING
Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM, with a mitigation for that issue.
TRAILING
Sample N+1 steps, do not include the first timestep.
CUSTOM
Use custom timespacing in solver (override _generate_timesteps
, see DPM).
SDLoraManager(target: LatentDiffusionModel)\n
Manage LoRAs for a Stable Diffusion model.
NoteIn the context of SDLoraManager, a \"LoRA\" is a set of \"LoRA layers\" that can be attached to a target model.
Parameters:
Name Type Description Defaulttarget
LatentDiffusionModel
The target model to manage the LoRAs for.
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def __init__(\n self,\n target: LatentDiffusionModel,\n) -> None:\n \"\"\"Initialize the LoRA manager.\n\n Args:\n target: The target model to manage the LoRAs for.\n \"\"\"\n self.target = target\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.clip_text_encoder","title":"clip_text_encoder property
","text":"clip_text_encoder: Chain\n
The Stable Diffusion's text encoder.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.lora_adapters","title":"lora_adaptersproperty
","text":"lora_adapters: list[LoraAdapter]\n
List of all the LoraAdapters managed by the SDLoraManager.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.loras","title":"lorasproperty
","text":"loras: list[Lora[Any]]\n
List of all the LoRA layers managed by the SDLoraManager.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.names","title":"namesproperty
","text":"names: list[str]\n
List of all the LoRA names managed the SDLoraManager
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.scales","title":"scalesproperty
","text":"scales: dict[str, float]\n
The scales of all the LoRAs managed by the SDLoraManager.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.unet","title":"unetproperty
","text":"unet: Chain\n
The Stable Diffusion's U-Net model.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras","title":"add_loras","text":"add_loras(\n name: str,\n /,\n tensors: dict[str, Tensor],\n scale: float = 1.0,\n unet_inclusions: list[str] | None = None,\n unet_exclusions: list[str] | None = None,\n unet_preprocess: dict[str, str] | None = None,\n text_encoder_inclusions: list[str] | None = None,\n text_encoder_exclusions: list[str] | None = None,\n) -> None\n
Load a single LoRA from a state_dict
.
This method expects the keys of the state_dict
to be in the commonly found formats on CivitAI's hub.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredtensors
dict[str, Tensor]
The state_dict
of the LoRA to load.
scale
float
The scale to use for the LoRA.
1.0
unet_inclusions
list[str] | None
A list of layer names, only layers with such a layer in their ancestors will be considered when patching the UNet.
None
unet_exclusions
list[str] | None
A list of layer names, layers with such a layer in their ancestors will not be considered when patching the UNet. If this is None
then it defaults to [\"TimestepEncoder\"]
.
None
unet_preprocess
dict[str, str] | None
A map between parts of state dict keys and layer names. This is used to attach some keys to specific parts of the UNet. You should leave it set to None
(it has a default value), otherwise read the source code to understand how it works.
None
text_encoder_inclusions
list[str] | None
A list of layer names, only layers with such a layer in their ancestors will be considered when patching the text encoder.
None
text_encoder_exclusions
list[str] | None
A list of layer names, layers with such a layer in their ancestors will not be considered when patching the text encoder.
None
Raises:
Type DescriptionAssertionError
If the Manager already has a LoRA with the same name.
Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def add_loras(\n self,\n name: str,\n /,\n tensors: dict[str, Tensor],\n scale: float = 1.0,\n unet_inclusions: list[str] | None = None,\n unet_exclusions: list[str] | None = None,\n unet_preprocess: dict[str, str] | None = None,\n text_encoder_inclusions: list[str] | None = None,\n text_encoder_exclusions: list[str] | None = None,\n) -> None:\n \"\"\"Load a single LoRA from a `state_dict`.\n\n Warning:\n This method expects the keys of the `state_dict` to be in the commonly found formats on CivitAI's hub.\n\n Args:\n name: The name of the LoRA.\n tensors: The `state_dict` of the LoRA to load.\n scale: The scale to use for the LoRA.\n unet_inclusions: A list of layer names, only layers with such a layer\n in their ancestors will be considered when patching the UNet.\n unet_exclusions: A list of layer names, layers with such a layer in\n their ancestors will not be considered when patching the UNet.\n If this is `None` then it defaults to `[\"TimestepEncoder\"]`.\n unet_preprocess: A map between parts of state dict keys and layer names.\n This is used to attach some keys to specific parts of the UNet.\n You should leave it set to `None` (it has a default value),\n otherwise read the source code to understand how it works.\n text_encoder_inclusions: A list of layer names, only layers with such a layer\n in their ancestors will be considered when patching the text encoder.\n text_encoder_exclusions: A list of layer names, layers with such a layer in\n their ancestors will not be considered when patching the text encoder.\n\n Raises:\n AssertionError: If the Manager already has a LoRA with the same name.\n \"\"\"\n assert name not in self.names, f\"LoRA {name} already exists\"\n\n # load LoRA the state_dict\n loras = Lora.from_dict(\n name,\n state_dict={\n key: value.to(\n device=self.target.device,\n dtype=self.target.dtype,\n )\n for key, value in tensors.items()\n },\n )\n # sort all the LoRA's keys using the `sort_keys` method\n loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}\n\n # if no key contains \"unet\" or \"text\", assume all keys are for the unet\n if all(\"unet\" not in key and \"text\" not in key for key in loras.keys()):\n loras = {f\"unet_{key}\": value for key, value in loras.items()}\n\n # attach the LoRA to the target\n self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess)\n self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions)\n\n # set the scale of the LoRA\n self.set_scale(name, scale)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras_to_text_encoder","title":"add_loras_to_text_encoder","text":"add_loras_to_text_encoder(\n loras: dict[str, Lora[Any]],\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n debug_map: list[tuple[str, str]] | None = None,\n) -> None\n
Add multiple LoRAs to the text encoder. See add_loras
for details about arguments.
Parameters:
Name Type Description Defaultloras
dict[str, Lora[Any]]
The dictionary of LoRAs to add to the text encoder. (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_text_encoder(\n self,\n loras: dict[str, Lora[Any]],\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n debug_map: list[tuple[str, str]] | None = None,\n) -> None:\n \"\"\"Add multiple LoRAs to the text encoder. See `add_loras` for details about arguments.\n\n Args:\n loras: The dictionary of LoRAs to add to the text encoder.\n (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)\n \"\"\"\n text_encoder_loras = {key: loras[key] for key in loras.keys() if \"text\" in key}\n auto_attach_loras(\n text_encoder_loras,\n self.clip_text_encoder,\n exclude=exclude,\n include=include,\n debug_map=debug_map,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.add_loras_to_unet","title":"add_loras_to_unet","text":"add_loras_to_unet(\n loras: dict[str, Lora[Any]],\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n preprocess: dict[str, str] | None = None,\n debug_map: list[tuple[str, str]] | None = None,\n) -> None\n
Add multiple LoRAs to the U-Net. See add_loras
for details about arguments.
Parameters:
Name Type Description Defaultloras
dict[str, Lora[Any]]
The dictionary of LoRAs to add to the U-Net. (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def add_loras_to_unet(\n self,\n loras: dict[str, Lora[Any]],\n /,\n include: list[str] | None = None,\n exclude: list[str] | None = None,\n preprocess: dict[str, str] | None = None,\n debug_map: list[tuple[str, str]] | None = None,\n) -> None:\n \"\"\"Add multiple LoRAs to the U-Net. See `add_loras` for details about arguments.\n\n Args:\n loras: The dictionary of LoRAs to add to the U-Net.\n (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)\n \"\"\"\n unet_loras = {key: loras[key] for key in loras.keys() if \"unet\" in key}\n\n if exclude is None:\n exclude = [\"TimestepEncoder\"]\n\n if preprocess is None:\n preprocess = {\n \"res\": \"ResidualBlock\",\n \"downsample\": \"Downsample\",\n \"upsample\": \"Upsample\",\n }\n\n if include is not None:\n preprocess = {k: v for k, v in preprocess.items() if v in include}\n\n preprocess = {k: v for k, v in preprocess.items() if v not in exclude}\n\n loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())}\n loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded}\n\n for exc_k, exc_v in preprocess.items():\n ls = {k: v for k, v in loras_excluded.items() if exc_k in k}\n auto_attach_loras(ls, self.unet, include=[exc_v], exclude=exclude, debug_map=debug_map)\n\n auto_attach_loras(\n loras_remaining,\n self.unet,\n exclude=[*exclude, *preprocess.values()],\n include=include,\n debug_map=debug_map,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.get_loras_by_name","title":"get_loras_by_name","text":"get_loras_by_name(name: str) -> list[Lora[Any]]\n
Get the LoRA layers with the given name.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:\n \"\"\"Get the LoRA layers with the given name.\n\n Args:\n name: The name of the LoRA.\n \"\"\"\n return [lora for lora in self.loras if lora.name == name]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.get_scale","title":"get_scale","text":"get_scale(name: str) -> float\n
Get the scale of the LoRA with the given name.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredReturns:
Type Descriptionfloat
The scale of the LoRA layers with the given name.
Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def get_scale(self, name: str, /) -> float:\n \"\"\"Get the scale of the LoRA with the given name.\n\n Args:\n name: The name of the LoRA.\n\n Returns:\n The scale of the LoRA layers with the given name.\n \"\"\"\n loras = self.get_loras_by_name(name)\n assert all([lora.scale == loras[0].scale for lora in loras]), \"lora scales are not all the same\"\n return loras[0].scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.remove_all","title":"remove_all","text":"remove_all() -> None\n
Remove all the LoRAs from the target.
Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def remove_all(self) -> None:\n \"\"\"Remove all the LoRAs from the target.\"\"\"\n for lora_adapter in self.lora_adapters:\n lora_adapter.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.remove_loras","title":"remove_loras","text":"remove_loras(*names: str) -> None\n
Remove multiple LoRAs from the target.
Parameters:
Name Type Description Defaultnames
str
The names of the LoRAs to remove.
()
Source code in src/refiners/foundationals/latent_diffusion/lora.py
def remove_loras(self, *names: str) -> None:\n \"\"\"Remove multiple LoRAs from the target.\n\n Args:\n names: The names of the LoRAs to remove.\n \"\"\"\n for lora_adapter in self.lora_adapters:\n for name in names:\n lora_adapter.remove_lora(name)\n\n if len(lora_adapter.loras) == 0:\n lora_adapter.eject()\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.set_scale","title":"set_scale","text":"set_scale(name: str, scale: float) -> None\n
Set the scale of the LoRA with the given name.
Parameters:
Name Type Description Defaultname
str
The name of the LoRA.
requiredscale
float
The new scale to set.
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def set_scale(self, name: str, scale: float, /) -> None:\n \"\"\"Set the scale of the LoRA with the given name.\n\n Args:\n name: The name of the LoRA.\n scale: The new scale to set.\n \"\"\"\n self.update_scales({name: scale})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.sort_keys","title":"sort_keys staticmethod
","text":"sort_keys(key: str) -> tuple[str, int]\n
Compute the score of a key, relatively to its suffix.
When used by sorted
, the keys will only be sorted \"at the suffix level\". The idea is that sometimes closely related keys in the state dict are not in the same order as the one we expect, for instance q -> k -> v
or in -> out
. This attempts to fix that issue, not cases where distant layers are called in a different order.
Parameters:
Name Type Description Defaultkey
str
The key to sort.
requiredReturns:
Type Descriptionstr
The padded prefix of the key.
int
A score depending on the key's suffix.
Source code insrc/refiners/foundationals/latent_diffusion/lora.py
@staticmethod\ndef sort_keys(key: str, /) -> tuple[str, int]:\n \"\"\"Compute the score of a key, relatively to its suffix.\n\n When used by [`sorted`][sorted], the keys will only be sorted \"at the suffix level\".\n The idea is that sometimes closely related keys in the state dict are not in the\n same order as the one we expect, for instance `q -> k -> v` or `in -> out`. This\n attempts to fix that issue, not cases where distant layers are called in a different\n order.\n\n Args:\n key: The key to sort.\n\n Returns:\n The padded prefix of the key.\n A score depending on the key's suffix.\n \"\"\"\n\n # this dict might not be exhaustive\n suffix_scores = {\"q\": 1, \"k\": 2, \"v\": 3, \"in\": 3, \"out\": 4, \"out0\": 4, \"out_0\": 4}\n patterns = [\"_{}\", \"_{}_lora\"]\n\n # apply patterns to the keys of suffix_scores\n key_char_order = {f.format(k): v for k, v in suffix_scores.items() for f in patterns}\n\n # get the suffix and score for `key` (default: no suffix, highest score = 5)\n (sfx, score) = next(((k, v) for k, v in key_char_order.items() if key.endswith(k)), (\"\", 5))\n\n padded_key_prefix = SDLoraManager._pad(key.removesuffix(sfx))\n return (padded_key_prefix, score)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.lora.SDLoraManager.update_scales","title":"update_scales","text":"update_scales(scales: dict[str, float]) -> None\n
Update the scales of multiple LoRAs.
Parameters:
Name Type Description Defaultscales
dict[str, float]
The scales to update. (keys are the names of the LoRAs, values are the new scales to set)
required Source code insrc/refiners/foundationals/latent_diffusion/lora.py
def update_scales(self, scales: dict[str, float], /) -> None:\n \"\"\"Update the scales of multiple LoRAs.\n\n Args:\n scales: The scales to update.\n (keys are the names of the LoRAs, values are the new scales to set)\n \"\"\"\n assert all([name in self.names for name in scales]), f\"Scales keys must be a subset of {self.names}\"\n for name, scale in scales.items():\n for lora in self.get_loras_by_name(name):\n lora.scale = scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter","title":"IPAdapter","text":"IPAdapter(\n target: T,\n clip_image_encoder: CLIPImageEncoderH,\n image_proj: Module,\n scale: float = 1.0,\n fine_grained: bool = False,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: Generic[T]
, Chain
, Adapter[T]
Image Prompt adapter for a Stable Diffusion U-Net model.
See [arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models for more details.
Parameters:
Name Type Description Defaulttarget
T
The target model to adapt.
requiredclip_image_encoder
CLIPImageEncoderH
The CLIP image encoder to use.
requiredimage_proj
Module
The image projection to use.
requiredscale
float
The scale to use for the image prompt.
1.0
fine_grained
bool
Whether to use fine-grained image prompt.
False
weights
dict[str, Tensor] | None
The weights of the IPAdapter.
None
Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def __init__(\n self,\n target: T,\n clip_image_encoder: CLIPImageEncoderH,\n image_proj: fl.Module,\n scale: float = 1.0,\n fine_grained: bool = False,\n weights: dict[str, Tensor] | None = None,\n) -> None:\n \"\"\"Initialize the adapter.\n\n Args:\n target: The target model to adapt.\n clip_image_encoder: The CLIP image encoder to use.\n image_proj: The image projection to use.\n scale: The scale to use for the image prompt.\n fine_grained: Whether to use fine-grained image prompt.\n weights: The weights of the IPAdapter.\n \"\"\"\n with self.setup_adapter(target):\n super().__init__(target)\n\n self.fine_grained = fine_grained\n self._clip_image_encoder = [clip_image_encoder]\n if fine_grained:\n self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]\n self._image_proj = [image_proj]\n\n self.sub_adapters = [\n CrossAttentionAdapter(target=cross_attn, scale=scale)\n for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))\n ]\n\n if weights is not None:\n image_proj_state_dict: dict[str, Tensor] = {\n k.removeprefix(\"image_proj.\"): v for k, v in weights.items() if k.startswith(\"image_proj.\")\n }\n self.image_proj.load_state_dict(image_proj_state_dict)\n\n for i, cross_attn in enumerate(self.sub_adapters):\n cross_attention_weights: list[Tensor] = []\n for k, v in weights.items():\n prefix = f\"ip_adapter.{i:03d}.\"\n if not k.startswith(prefix):\n continue\n cross_attention_weights.append(v)\n\n assert len(cross_attention_weights) == 2\n cross_attn.load_weights(*cross_attention_weights)\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.clip_image_encoder","title":"clip_image_encoder property
","text":"clip_image_encoder: CLIPImageEncoderH\n
The CLIP image encoder of the adapter.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.scale","title":"scaleproperty
writable
","text":"scale: float\n
The scale of the adapter.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.compute_clip_image_embedding","title":"compute_clip_image_embedding","text":"compute_clip_image_embedding(\n image_prompt: Image | list[Image] | Tensor,\n weights: list[float] | None = None,\n concat_batches: bool = True,\n) -> Tensor\n
Compute CLIP image embeddings from the provided image prompts.
Parameters:
Name Type Description Defaultimage_prompt
Image | list[Image] | Tensor
A single image or a list of images to compute embeddings for. This can be a PIL Image, a list of PIL Images, or a Tensor.
requiredweights
list[float] | None
An optional list of scaling factors for the conditional embeddings. If provided, it must have the same length as the number of images in image_prompt
. Each weight scales the corresponding image's conditional embedding, allowing you to adjust the influence of each image. Defaults to uniform weights of 1.0.
None
concat_batches
bool
Determines how embeddings are concatenated when multiple images are provided: - If True
, embeddings from multiple images are concatenated along the feature dimension to form a longer sequence of image tokens. This is useful when you want to treat multiple images as a single combined input. - If False
, embeddings are kept separate along the batch dimension, treating each image independently.
True
Returns:
Type DescriptionTensor
A Tensor containing the CLIP image embeddings.
Tensor
The structure of the returned Tensor depends on the concat_batches
parameter: - If concat_batches
is True
and multiple images are provided, the embeddings are concatenated along the feature dimension. - If concat_batches
is False
or a single image is provided, the embeddings are returned as a batch, with one embedding per image.
src/refiners/foundationals/latent_diffusion/image_prompt.py
def compute_clip_image_embedding(\n self,\n image_prompt: Image.Image | list[Image.Image] | Tensor,\n weights: list[float] | None = None,\n concat_batches: bool = True,\n) -> Tensor:\n \"\"\"Compute CLIP image embeddings from the provided image prompts.\n\n Args:\n image_prompt: A single image or a list of images to compute embeddings for.\n This can be a PIL Image, a list of PIL Images, or a Tensor.\n weights: An optional list of scaling factors for the conditional embeddings.\n If provided, it must have the same length as the number of images in `image_prompt`.\n Each weight scales the corresponding image's conditional embedding, allowing you to\n adjust the influence of each image. Defaults to uniform weights of 1.0.\n concat_batches: Determines how embeddings are concatenated when multiple images are provided:\n - If `True`, embeddings from multiple images are concatenated along the feature\n dimension to form a longer sequence of image tokens. This is useful when you want to\n treat multiple images as a single combined input.\n - If `False`, embeddings are kept separate along the batch dimension, treating each image\n independently.\n\n Returns:\n A Tensor containing the CLIP image embeddings.\n The structure of the returned Tensor depends on the `concat_batches` parameter:\n - If `concat_batches` is `True` and multiple images are provided, the embeddings are\n concatenated along the feature dimension.\n - If `concat_batches` is `False` or a single image is provided, the embeddings are returned\n as a batch, with one embedding per image.\n \"\"\"\n if isinstance(image_prompt, Image.Image):\n image_prompt = self.preprocess_image(image_prompt)\n elif isinstance(image_prompt, list):\n assert all(\n isinstance(image, Image.Image) for image in image_prompt\n ), \"All elements of `image_prompt` must be of PIL Images.\"\n image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])\n\n negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)\n\n batch_size = image_prompt.shape[0]\n if weights is not None:\n assert len(weights) == batch_size, f\"Got {len(weights)} weights for {batch_size} images\"\n if any(weight != 1.0 for weight in weights):\n conditional_embedding *= (\n torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)\n .unsqueeze(-1)\n .unsqueeze(-1)\n )\n\n if batch_size > 1 and concat_batches:\n # Create a longer image tokens sequence when a batch of images is given\n # See https://github.com/tencent-ailab/IP-Adapter/issues/99\n negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)\n conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)\n\n return torch.cat((negative_embedding, conditional_embedding))\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.preprocess_image","title":"preprocess_image","text":"preprocess_image(\n image: Image,\n size: tuple[int, int] = (224, 224),\n mean: list[float] | None = None,\n std: list[float] | None = None,\n) -> Tensor\n
Preprocess the image.
NoteThe default mean and std are parameters from https://github.com/openai/CLIP
Parameters:
Name Type Description Defaultimage
Image
The image to preprocess.
requiredsize
tuple[int, int]
The size to resize the image to.
(224, 224)
mean
list[float] | None
The mean to use for normalization.
None
std
list[float] | None
The standard deviation to use for normalization.
None
Source code in src/refiners/foundationals/latent_diffusion/image_prompt.py
def preprocess_image(\n self,\n image: Image.Image,\n size: tuple[int, int] = (224, 224),\n mean: list[float] | None = None,\n std: list[float] | None = None,\n) -> Tensor:\n \"\"\"Preprocess the image.\n\n Note:\n The default mean and std are parameters from\n https://github.com/openai/CLIP\n\n Args:\n image: The image to preprocess.\n size: The size to resize the image to.\n mean: The mean to use for normalization.\n std: The standard deviation to use for normalization.\n \"\"\"\n resized = image.resize(size) # type: ignore\n return normalize(\n image_to_tensor(resized, device=self.target.device, dtype=self.target.dtype),\n mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,\n std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.image_prompt.IPAdapter.set_clip_image_embedding","title":"set_clip_image_embedding","text":"set_clip_image_embedding(image_embedding: Tensor) -> None\n
Set the CLIP image embedding context.
NoteThis is required by ImageCrossAttention
.
Parameters:
Name Type Description Defaultimage_embedding
Tensor
The CLIP image embedding to set.
required Source code insrc/refiners/foundationals/latent_diffusion/image_prompt.py
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:\n \"\"\"Set the CLIP image embedding context.\n\n Note:\n This is required by `ImageCrossAttention`.\n\n Args:\n image_embedding: The CLIP image embedding to set.\n \"\"\"\n self.set_context(\"ip_adapter\", {\"clip_image_embedding\": image_embedding})\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.AdaIN","title":"AdaIN","text":"AdaIN(epsilon: float = 1e-08)\n
Bases: Module
Apply Adaptive Instance Normalization (AdaIN) to the target features.
See [arXiv:1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization for more details.
Receives:
Name Type Descriptionreference
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The reference features.
targets
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The target features.
Returns:
Name Type Descriptionreference
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The reference features (unchanged).
targets
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The target features, renormalized.
Parameters:
Name Type Description Defaultepsilon
float
A small value to avoid division by zero.
1e-08
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(self, epsilon: float = 1e-8) -> None:\n \"\"\"Initialize the AdaIN module.\n\n Args:\n epsilon: A small value to avoid division by zero.\n \"\"\"\n super().__init__()\n self.epsilon = epsilon\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.ExtractReferenceFeatures","title":"ExtractReferenceFeatures","text":"ExtractReferenceFeatures(*args: Any, **kwargs: Any)\n
Bases: Module
Extract the reference features from the input features.
NoteThis layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).
The reference features are the first features of the conditional and unconditional input features. They are extracted, and repeated to match the batch size of the input features.
Receives:
Name Type Descriptionfeatures
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The input features.
Returns:
Name Type Descriptionreference
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The reference features.
Source code insrc/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:\n super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.ScaleReferenceFeatures","title":"ScaleReferenceFeatures","text":"ScaleReferenceFeatures(scale: float = 1.0)\n
Bases: Module
Scale the reference features.
NoteThis layer expects the input features to be a concatenation of conditional and unconditional features, as done when using Classifier-free guidance (CFG).
This layer scales the reference features which will later be used (in the attention dot product) with the target features.
Receives:
Name Type Descriptionfeatures
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The input reference features.
Returns:
Name Type Descriptionfeatures
Float[Tensor, 'cfg_batch_size sequence_length embedding_dim']
The rescaled reference features.
Parameters:
Name Type Description Defaultscale
float
The scaling factor.
1.0
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n self,\n scale: float = 1.0,\n) -> None:\n \"\"\"Initialize the ScaleReferenceFeatures module.\n\n Args:\n scale: The scaling factor.\n \"\"\"\n super().__init__()\n self.scale = scale\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.SharedSelfAttentionAdapter","title":"SharedSelfAttentionAdapter","text":"SharedSelfAttentionAdapter(\n target: SelfAttention, scale: float = 1.0\n)\n
Bases: Chain
, Adapter[SelfAttention]
Upgrades a SelfAttention
layer into a SharedSelfAttention
layer.
This adapter inserts 3 StyleAligned
modules right after the original Q, K, V Linear
-s (wrapped inside a fl.Distribute
).
src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n self,\n target: fl.SelfAttention,\n scale: float = 1.0,\n) -> None:\n with self.setup_adapter(target):\n super().__init__(target)\n\n self._style_aligned_layers = [\n StyleAligned( # Query\n adain=True,\n concatenate=False,\n scale=scale,\n ),\n StyleAligned( # Key\n adain=True,\n concatenate=True,\n scale=scale,\n ),\n StyleAligned( # Value\n adain=False,\n concatenate=True,\n scale=scale,\n ),\n ]\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAligned","title":"StyleAligned","text":"StyleAligned(\n adain: bool, concatenate: bool, scale: float = 1.0\n)\n
Bases: Chain
StyleAligned module.
This layer encapsulates the logic of the StyleAligned method, as described in [arXiv:2312.02133] Style Aligned Image Generation via Shared Attention.
See also https://blog.finegrain.ai/posts/implementing-style-aligned/.
Receives:
Name Type Descriptionfeatures
Float[Tensor, 'cfg_batch_size sequence_length_in embedding_dim']
The input features.
Returns:
Name Type Descriptionshared_features
Float[Tensor, 'cfg_batch_size sequence_length_out embedding_dim']
The transformed features.
Parameters:
Name Type Description Defaultadain
bool
Whether to apply Adaptive Instance Normalization to the target features.
requiredscale
float
The scaling factor for the reference features.
1.0
concatenate
bool
Whether to concatenate the reference and target features.
required Source code insrc/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n self,\n adain: bool,\n concatenate: bool,\n scale: float = 1.0,\n) -> None:\n \"\"\"Initialize the StyleAligned module.\n\n Args:\n adain: Whether to apply Adaptive Instance Normalization to the target features.\n scale: The scaling factor for the reference features.\n concatenate: Whether to concatenate the reference and target features.\n \"\"\"\n super().__init__(\n # (features): (cfg_batch_size sequence_length embedding_dim)\n fl.Parallel(\n fl.Identity(),\n ExtractReferenceFeatures(),\n ),\n # (targets, reference)\n AdaIN(),\n # (targets_renormalized, reference)\n fl.Distribute(\n fl.Identity(),\n ScaleReferenceFeatures(scale=scale),\n ),\n # (targets_renormalized, reference_scaled)\n fl.Concatenate(\n fl.GetArg(index=0), # targets\n fl.GetArg(index=1), # reference\n dim=-2, # sequence_length\n ),\n # (features_with_shared_reference)\n )\n\n if not adain:\n adain_module = self.ensure_find(AdaIN)\n self.remove(adain_module)\n\n if not concatenate:\n concatenate_module = self.ensure_find(fl.Concatenate)\n self.replace(\n old_module=concatenate_module,\n new_module=fl.GetArg(index=0), # targets\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAligned.scale","title":"scale property
writable
","text":"scale: float\n
The scaling factor for the reference features.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAlignedAdapter","title":"StyleAlignedAdapter","text":"StyleAlignedAdapter(target: T, scale: float = 1.0)\n
Bases: Generic[T]
, Chain
, Adapter[T]
Upgrade each SelfAttention
layer of a UNet into a SharedSelfAttention
layer.
Parameters:
Name Type Description Defaulttarget
T
The target module.
requiredscale
float
The scaling factor for the reference features.
1.0
Source code in src/refiners/foundationals/latent_diffusion/style_aligned.py
def __init__(\n self,\n target: T,\n scale: float = 1.0,\n) -> None:\n \"\"\"Initialize the StyleAlignedAdapter.\n\n Args:\n target: The target module.\n scale: The scaling factor for the reference features.\n \"\"\"\n with self.setup_adapter(target):\n super().__init__(target)\n\n # create a SharedSelfAttentionAdapter for each SelfAttention module\n self.shared_self_attention_adapters = tuple(\n SharedSelfAttentionAdapter(\n target=self_attention,\n scale=scale,\n )\n for self_attention in self.target.layers(fl.SelfAttention)\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAlignedAdapter.scale","title":"scale property
writable
","text":"scale: float\n
The scaling factor for the reference features.
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.multi_diffusion.DiffusionTarget","title":"DiffusionTargetdataclass
","text":"DiffusionTarget(\n *,\n tile: Tile,\n solver: Solver,\n init_latents: Tensor | None = None,\n opacity_mask: Tensor | None = None,\n weight: int = 1,\n start_step: int = 0,\n end_step: int = MAX_STEPS\n)\n
Represents a target for the tiled diffusion process.
This class encapsulates the parameters and properties needed to define a specific area (target) within a larger diffusion process, allowing for fine-grained control over different regions of the generated image.
Attributes:
Name Type Descriptiontile
Tile
The tile defining the area of the target within the latent image.
solver
Solver
The solver to use for this target's diffusion process. This is useful because some solvers have an internal state that needs to be updated during the diffusion process. Using the same solver instance for multiple targets would interfere with this internal state.
init_latents
Tensor | None
The initial latents for this target. If None, the target will be initialized with noise.
opacity_mask
Tensor | None
Mask controlling the target's visibility in the final image. If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent which means the target has no influence.
weight
int
The importance of this target in the final image. Higher values increase the target's influence.
start_step
int
The diffusion step at which this target begins to influence the process.
end_step
int
The diffusion step at which this target stops influencing the process.
size
Size
The size of the target area.
offset
tuple[int, int]
The top-left offset of the target area within the latent image.
The combination of opacity_mask
and weight
determines the target's overall contribution to the final generated image. The solver
is responsible for the actual diffusion calculations for this target.
Bases: ABC
, Generic[T]
MultiDiffusion class for performing multi-target diffusion using tiled diffusion.
For more details, refer to the paper: MultiDiffusion
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.multi_diffusion.MultiDiffusion.generate_latent_tiles","title":"generate_latent_tilesstaticmethod
","text":"generate_latent_tiles(\n size: Size, tile_size: Size, min_overlap: int = 8\n) -> list[Tile]\n
Generate tiles for a latent image with the given size and tile size.
If one dimension of the tile_size
is larger than the corresponding dimension of the image size, a single tile is used to cover the entire image - and therefore tile_size
is ignored. This algorithm ensures that the tile size is respected as much as possible, while still covering the entire image and respecting the minimum overlap.
src/refiners/foundationals/latent_diffusion/multi_diffusion.py
@staticmethod\ndef generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:\n \"\"\"\n Generate tiles for a latent image with the given size and tile size.\n\n If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is\n used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size\n is respected as much as possible, while still covering the entire image and respecting the minimum overlap.\n \"\"\"\n assert (\n 0 <= min_overlap < min(tile_size.height, tile_size.width)\n ), \"Overlap must be non-negative and less than the tile size\"\n\n if tile_size.width > size.width or tile_size.height > size.height:\n return [Tile(top=0, left=0, bottom=size.height, right=size.width)]\n\n tiles: list[Tile] = []\n\n def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:\n if tile_length >= length:\n return 1, 0\n num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1\n overlap = (num_tiles * tile_length - length) // (num_tiles - 1)\n return num_tiles, overlap\n\n num_tiles_x, overlap_x = _compute_tiles_and_overlap(\n length=size.width, tile_length=tile_size.width, min_overlap=min_overlap\n )\n num_tiles_y, overlap_y = _compute_tiles_and_overlap(\n length=size.height, tile_length=tile_size.height, min_overlap=min_overlap\n )\n\n for i in range(num_tiles_y):\n for j in range(num_tiles_x):\n x = j * (tile_size.width - overlap_x)\n y = i * (tile_size.height - overlap_y)\n\n # Adjust x and y coordinates to ensure full-sized tiles\n if x + tile_size.width > size.width:\n x = size.width - tile_size.width\n if y + tile_size.height > size.height:\n y = size.height - tile_size.height\n\n tile_right = x + tile_size.width\n tile_bottom = y + tile_size.height\n tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))\n\n return tiles\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.ella_adapter.ELLA","title":"ELLA","text":"ELLA(\n time_channel: int,\n timestep_embedding_dim: int,\n width: int,\n num_layers: int,\n num_heads: int,\n num_latents: int,\n input_dim: int | None = None,\n out_dim: int | None = None,\n device: device | str | None = None,\n dtype: dtype | None = None,\n)\n
Bases: Passthrough
ELLA latents encoder.
See [arXiv:2403.05135] ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment for more details.
Source code insrc/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(\n self,\n time_channel: int,\n timestep_embedding_dim: int,\n width: int,\n num_layers: int,\n num_heads: int,\n num_latents: int,\n input_dim: int | None = None,\n out_dim: int | None = None,\n device: Device | str | None = None,\n dtype: DType | None = None,\n) -> None:\n super().__init__(\n TimestepEncoder(timestep_embedding_dim, time_channel, device=device, dtype=dtype),\n fl.UseContext(\"adapted_cross_attention_block\", \"llm_text_embedding\"),\n PerceiverResampler(\n timestep_embedding_dim,\n width,\n num_layers,\n num_heads,\n num_latents,\n out_dim,\n input_dim,\n device=device,\n dtype=dtype,\n ),\n fl.SetContext(\"ella\", \"latents\"),\n )\n
"},{"location":"reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.ella_adapter.ELLAAdapter","title":"ELLAAdapter","text":"ELLAAdapter(\n target: T,\n latents_encoder: ELLA,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: Generic[T]
, Chain
, Adapter[T]
Adapter for ELLA
.
src/refiners/foundationals/latent_diffusion/ella_adapter.py
def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None:\n if weights is not None:\n latents_encoder.load_state_dict(weights)\n\n self._latents_encoder = [latents_encoder]\n with self.setup_adapter(target):\n super().__init__(target)\n self.sub_adapters = [\n ELLACrossAttentionAdapter(use_context)\n for cross_attn in target.layers(CrossAttentionBlock)\n for use_context in cross_attn.layers(fl.UseContext)\n ]\n
"},{"location":"reference/foundationals/segment_anything/","title":"
Segment Anything","text":""},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.HQSAMAdapter","title":"HQSAMAdapter","text":"HQSAMAdapter(\n target: SegmentAnything,\n hq_mask_only: bool = False,\n weights: dict[str, Tensor] | None = None,\n)\n
Bases: Chain
, Adapter[SegmentAnything]
Adapter for SAM introducing HQ features.
See [arXiv:2306.01567] Segment Anything in High Quality for details.
Examplefrom refiners.fluxion.utils import load_from_safetensors\n\n# Tips: run scripts/prepare_test_weights.py to download the weights\ntensor_path = \"./tests/weights/refiners-sam-hq-vit-h.safetensors\"\nweights = load_from_safetensors(tensor_path)\n\nhq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)\nhq_sam_adapter.inject() # then use SAM as usual\n
Parameters:
Name Type Description Defaulttarget
SegmentAnything
The SegmentAnything model to adapt.
requiredhq_mask_only
bool
Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
False
weights
dict[str, Tensor] | None
The weights of the HQSAMAdapter.
None
Source code in src/refiners/foundationals/segment_anything/hq_sam.py
def __init__(\n self,\n target: SegmentAnything,\n hq_mask_only: bool = False,\n weights: dict[str, torch.Tensor] | None = None,\n) -> None:\n \"\"\"Initialize the adapter.\n\n Args:\n target: The SegmentAnything model to adapt.\n hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).\n weights: The weights of the HQSAMAdapter.\n \"\"\"\n self.vit_embedding_dim = target.image_encoder.embedding_dim\n self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2\n\n with self.setup_adapter(target):\n super().__init__(target)\n\n if target.mask_decoder.multimask_output:\n raise NotImplementedError(\"Multi-mask mode is not supported in HQSAMAdapter.\")\n\n mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)\n\n self._mask_prediction_adapter = [\n MaskPredictionAdapter(\n mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype\n )\n ]\n self._register_adapter_module(\"Chain.HQSAMMaskPrediction\", self.mask_prediction_adapter.hq_sam_mask_prediction)\n\n self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]\n self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]\n\n mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)\n self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]\n self._register_adapter_module(\"MaskDecoderTokensExtender.hq_token\", self.mask_decoder_tokens_extender.hq_token)\n\n if weights is not None:\n self.load_weights(weights)\n\n self.to(device=target.device, dtype=target.dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything","title":"SegmentAnything","text":"SegmentAnything(\n image_encoder: SAMViT,\n point_encoder: PointEncoder,\n mask_encoder: MaskEncoder,\n mask_decoder: MaskDecoder,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: Chain
SegmentAnything model.
See [arXiv:2304.02643] Segment Anything
E.g. see SegmentAnythingH
for usage.
Attributes:
Name Type Descriptionmask_threshold
float
0.0
Parameters:
Name Type Description Defaultimage_encoder
SAMViT
The image encoder to use.
requiredpoint_encoder
PointEncoder
The point encoder to use.
requiredmask_encoder
MaskEncoder
The mask encoder to use.
requiredmask_decoder
MaskDecoder
The mask decoder to use.
required Source code insrc/refiners/foundationals/segment_anything/model.py
def __init__(\n self,\n image_encoder: SAMViT,\n point_encoder: PointEncoder,\n mask_encoder: MaskEncoder,\n mask_decoder: MaskDecoder,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n \"\"\"Initialize SegmentAnything model.\n\n Args:\n image_encoder: The image encoder to use.\n point_encoder: The point encoder to use.\n mask_encoder: The mask encoder to use.\n mask_decoder: The mask decoder to use.\n \"\"\"\n super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)\n\n self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.image_encoder","title":"image_encoder property
","text":"image_encoder: SAMViT\n
The image encoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.image_encoder_resolution","title":"image_encoder_resolutionproperty
","text":"image_encoder_resolution: int\n
The resolution of the image encoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.mask_decoder","title":"mask_decoderproperty
","text":"mask_decoder: MaskDecoder\n
The mask decoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.mask_encoder","title":"mask_encoderproperty
","text":"mask_encoder: MaskEncoder\n
The mask encoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.point_encoder","title":"point_encoderproperty
","text":"point_encoder: PointEncoder\n
The point encoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.compute_image_embedding","title":"compute_image_embedding","text":"compute_image_embedding(image: Image) -> ImageEmbedding\n
Compute the embedding of an image.
Parameters:
Name Type Description Defaultimage
Image
The image to compute the embedding of.
requiredReturns:
Type DescriptionImageEmbedding
The computed image embedding.
Source code insrc/refiners/foundationals/segment_anything/model.py
@no_grad()\ndef compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:\n \"\"\"Compute the embedding of an image.\n\n Args:\n image: The image to compute the embedding of.\n\n Returns:\n The computed image embedding.\n \"\"\"\n original_size = (image.height, image.width)\n return ImageEmbedding(\n features=self.image_encoder(self.preprocess_image(image)),\n original_image_size=original_size,\n )\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.normalize","title":"normalize","text":"normalize(\n coordinates: Tensor, original_size: tuple[int, int]\n) -> Tensor\n
See normalize_coordinates
Args: coordinates: a tensor of coordinates. original_size: (h, w) the original size of the image. Returns: The [0,1] normalized coordinates tensor.
src/refiners/foundationals/segment_anything/model.py
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:\n \"\"\"\n See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]\n Args:\n coordinates: a tensor of coordinates.\n original_size: (h, w) the original size of the image.\n Returns:\n The [0,1] normalized coordinates tensor.\n \"\"\"\n return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.postprocess_masks","title":"postprocess_masks","text":"postprocess_masks(\n low_res_masks: Tensor, original_size: tuple[int, int]\n) -> Tensor\n
See postprocess_masks
Args: low_res_masks: a mask tensor of size (N, 1, 256, 256) original_size: (h, w) the original size of the image. Returns: The mask of shape (N, 1, H, W)
src/refiners/foundationals/segment_anything/model.py
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:\n \"\"\"\n See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]\n Args:\n low_res_masks: a mask tensor of size (N, 1, 256, 256)\n original_size: (h, w) the original size of the image.\n Returns:\n The mask of shape (N, 1, H, W)\n \"\"\"\n return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.predict","title":"predict","text":"predict(\n input: Image | ImageEmbedding,\n foreground_points: (\n Sequence[tuple[float, float]] | None\n ) = None,\n background_points: (\n Sequence[tuple[float, float]] | None\n ) = None,\n box_points: (\n Sequence[Sequence[tuple[float, float]]] | None\n ) = None,\n low_res_mask: (\n Float[Tensor, \"1 1 256 256\"] | None\n ) = None,\n binarize: bool = True,\n) -> tuple[Tensor, Tensor, Tensor]\n
Predict the masks of the input image.
Parameters:
Name Type Description Defaultinput
Image | ImageEmbedding
The input image or its embedding.
requiredforeground_points
Sequence[tuple[float, float]] | None
The points of the foreground.
None
background_points
Sequence[tuple[float, float]] | None
The points of the background.
None
box_points
Sequence[Sequence[tuple[float, float]]] | None
The points of the box.
None
low_res_mask
Float[Tensor, '1 1 256 256'] | None
The low resolution mask.
None
binarize
bool
Whether to binarize the masks.
True
Returns:
Type DescriptionTensor
The predicted masks.
Tensor
The IOU prediction.
Tensor
The low resolution masks.
Source code insrc/refiners/foundationals/segment_anything/model.py
@no_grad()\ndef predict(\n self,\n input: Image.Image | ImageEmbedding,\n foreground_points: Sequence[tuple[float, float]] | None = None,\n background_points: Sequence[tuple[float, float]] | None = None,\n box_points: Sequence[Sequence[tuple[float, float]]] | None = None,\n low_res_mask: Float[Tensor, \"1 1 256 256\"] | None = None,\n binarize: bool = True,\n) -> tuple[Tensor, Tensor, Tensor]:\n \"\"\"Predict the masks of the input image.\n\n Args:\n input: The input image or its embedding.\n foreground_points: The points of the foreground.\n background_points: The points of the background.\n box_points: The points of the box.\n low_res_mask: The low resolution mask.\n binarize: Whether to binarize the masks.\n\n Returns:\n The predicted masks.\n The IOU prediction.\n The low resolution masks.\n \"\"\"\n if isinstance(input, ImageEmbedding):\n original_size = input.original_image_size\n image_embedding = input.features\n else:\n original_size = (input.height, input.width)\n image_embedding = self.image_encoder(self.preprocess_image(input))\n\n coordinates, type_mask = self.point_encoder.points_to_tensor(\n foreground_points=foreground_points,\n background_points=background_points,\n box_points=box_points,\n )\n self.point_encoder.set_type_mask(type_mask=type_mask)\n\n if low_res_mask is not None:\n mask_embedding = self.mask_encoder(low_res_mask)\n else:\n mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(\n image_embedding_size=self.image_encoder.image_embedding_size\n )\n\n point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))\n dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(\n image_embedding_size=self.image_encoder.image_embedding_size\n )\n\n self.mask_decoder.set_image_embedding(image_embedding=image_embedding)\n self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)\n self.mask_decoder.set_point_embedding(point_embedding=point_embedding)\n self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)\n\n low_res_masks, iou_predictions = self.mask_decoder()\n\n high_res_masks = self.postprocess_masks(low_res_masks, original_size)\n\n if binarize:\n high_res_masks = high_res_masks > self.mask_threshold\n\n return high_res_masks, iou_predictions, low_res_masks\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnything.preprocess_image","title":"preprocess_image","text":"preprocess_image(image: Image) -> Tensor\n
See preprocess_image
Args: image: The image to preprocess. Returns: The preprocessed tensor.
src/refiners/foundationals/segment_anything/model.py
def preprocess_image(self, image: Image.Image) -> Tensor:\n \"\"\"\n See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]\n Args:\n image: The image to preprocess.\n Returns:\n The preprocessed tensor.\n \"\"\"\n return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnythingH","title":"SegmentAnythingH","text":"SegmentAnythingH(\n image_encoder: SAMViTH | None = None,\n point_encoder: PointEncoder | None = None,\n mask_encoder: MaskEncoder | None = None,\n mask_decoder: MaskDecoder | None = None,\n multimask_output: bool | None = None,\n device: device | str = \"cpu\",\n dtype: dtype = float32,\n)\n
Bases: SegmentAnything
SegmentAnything huge model.
Parameters:
Name Type Description Defaultimage_encoder
SAMViTH | None
The image encoder to use.
None
point_encoder
PointEncoder | None
The point encoder to use.
None
mask_encoder
MaskEncoder | None
The mask encoder to use.
None
mask_decoder
MaskDecoder | None
The mask decoder to use.
None
multimask_output
bool | None
Whether to use multimask output.
None
device
device | str
The PyTorch device to use.
'cpu'
dtype
dtype
The PyTorch data type to use.
float32
Example device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n# multimask_output=True is recommended for ambiguous prompts such as a single point.\n# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask\nsam_h = SegmentAnythingH(multimask_output=False, device=device)\n\n# Tips: run scripts/prepare_test_weights.py to download the weights\ntensors_path = \"./tests/weights/segment-anything-h.safetensors\"\nsam_h.load_from_safetensors(tensors_path=tensors_path)\n\nfrom PIL import Image\nimage = Image.open(\"image.png\")\n\nmasks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])\n\nassert masks.shape == (1, 1, image.height, image.width)\nassert masks.dtype == torch.bool\n\n# convert it to [0,255] uint8 ndarray of shape (H, W)\nmask = masks[0, 0].cpu().numpy().astype(\"uint8\") * 255\n\nImage.fromarray(mask).save(\"mask_image.png\")\n
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(\n self,\n image_encoder: SAMViTH | None = None,\n point_encoder: PointEncoder | None = None,\n mask_encoder: MaskEncoder | None = None,\n mask_decoder: MaskDecoder | None = None,\n multimask_output: bool | None = None,\n device: Device | str = \"cpu\",\n dtype: DType = torch.float32,\n) -> None:\n \"\"\"Initialize SegmentAnything huge model.\n\n Args:\n image_encoder: The image encoder to use.\n point_encoder: The point encoder to use.\n mask_encoder: The mask encoder to use.\n mask_decoder: The mask decoder to use.\n multimask_output: Whether to use multimask output.\n device: The PyTorch device to use.\n dtype: The PyTorch data type to use.\n\n Example:\n ```py\n device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n # multimask_output=True is recommended for ambiguous prompts such as a single point.\n # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask\n sam_h = SegmentAnythingH(multimask_output=False, device=device)\n\n # Tips: run scripts/prepare_test_weights.py to download the weights\n tensors_path = \"./tests/weights/segment-anything-h.safetensors\"\n sam_h.load_from_safetensors(tensors_path=tensors_path)\n\n from PIL import Image\n image = Image.open(\"image.png\")\n\n masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])\n\n assert masks.shape == (1, 1, image.height, image.width)\n assert masks.dtype == torch.bool\n\n # convert it to [0,255] uint8 ndarray of shape (H, W)\n mask = masks[0, 0].cpu().numpy().astype(\"uint8\") * 255\n\n Image.fromarray(mask).save(\"mask_image.png\")\n ```\n \"\"\"\n image_encoder = image_encoder or SAMViTH()\n point_encoder = point_encoder or PointEncoder()\n mask_encoder = mask_encoder or MaskEncoder()\n\n if mask_decoder:\n assert (\n multimask_output is None or mask_decoder.multimask_output == multimask_output\n ), f\"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})\"\n else:\n mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()\n\n super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)\n\n self.to(device=device, dtype=dtype)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.SegmentAnythingH.image_encoder","title":"image_encoder property
","text":"image_encoder: SAMViTH\n
The image encoder.
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.compute_scaled_size","title":"compute_scaled_size","text":"compute_scaled_size(\n size: tuple[int, int], image_encoder_resolution: int\n) -> tuple[int, int]\n
Compute the scaled size as expected by the image encoder. This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.
Parameters:
Name Type Description Defaultsize
tuple[int, int]
The size (h, w) of the input image.
requiredimage_encoder_resolution
int
Image encoder resolution.
requiredReturns:
Type Descriptionint
The target height.
int
The target width.
Source code insrc/refiners/foundationals/segment_anything/utils.py
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:\n \"\"\"Compute the scaled size as expected by the image encoder.\n This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.\n\n Args:\n size: The size (h, w) of the input image.\n image_encoder_resolution: Image encoder resolution.\n\n Returns:\n The target height.\n The target width.\n \"\"\"\n oldh, oldw = size\n scale = image_encoder_resolution * 1.0 / max(oldh, oldw)\n newh, neww = oldh * scale, oldw * scale\n neww = int(neww + 0.5)\n newh = int(newh + 0.5)\n return (newh, neww)\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.image_to_scaled_tensor","title":"image_to_scaled_tensor","text":"image_to_scaled_tensor(\n image: Image,\n scaled_size: tuple[int, int],\n device: device | None = None,\n dtype: dtype | None = None,\n) -> Tensor\n
Resize the image to scaled_size
and convert it to a tensor.
Parameters:
Name Type Description Defaultimage
Image
The image.
requiredscaled_size
tuple[int, int]
The target size (h, w).
requireddevice
device | None
Tensor device.
None
dtype
dtype | None
Tensor dtype.
None
Returns: a Tensor of shape (1, c, h, w)
Source code insrc/refiners/foundationals/segment_anything/utils.py
def image_to_scaled_tensor(\n image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None\n) -> Tensor:\n \"\"\"Resize the image to `scaled_size` and convert it to a tensor.\n\n Args:\n image: The image.\n scaled_size: The target size (h, w).\n device: Tensor device.\n dtype: Tensor dtype.\n Returns:\n a Tensor of shape (1, c, h, w)\n \"\"\"\n h, w = scaled_size\n resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore\n\n return image_to_tensor(resized, device=device, dtype=dtype) * 255.0\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.normalize_coordinates","title":"normalize_coordinates","text":"normalize_coordinates(\n coordinates: Tensor,\n original_size: tuple[int, int],\n image_encoder_resolution: int,\n) -> Tensor\n
Normalize the coordinates in the [0,1] range
Parameters:
Name Type Description Defaultcoordinates
Tensor
The coordinates to normalize.
requiredoriginal_size
tuple[int, int]
The original image size.
requiredimage_encoder_resolution
int
Image encoder resolution.
requiredReturns:
Type DescriptionTensor
The normalized coordinates.
Source code insrc/refiners/foundationals/segment_anything/utils.py
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n \"\"\"Normalize the coordinates in the [0,1] range\n\n Args:\n coordinates: The coordinates to normalize.\n original_size: The original image size.\n image_encoder_resolution: Image encoder resolution.\n\n Returns:\n The normalized coordinates.\n \"\"\"\n scaled_size = compute_scaled_size(original_size, image_encoder_resolution)\n coordinates[:, :, 0] = (\n (coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5\n ) / image_encoder_resolution\n coordinates[:, :, 1] = (\n (coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5\n ) / image_encoder_resolution\n return coordinates\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.pad_image_tensor","title":"pad_image_tensor","text":"pad_image_tensor(\n image_tensor: Tensor,\n scaled_size: tuple[int, int],\n image_encoder_resolution: int,\n) -> Tensor\n
Pad an image with zeros to make it square.
Parameters:
Name Type Description Defaultimage_tensor
Tensor
The image tensor to pad.
requiredscaled_size
tuple[int, int]
The scaled size (h, w).
requiredimage_encoder_resolution
int
Image encoder resolution.
requiredReturns:
Type DescriptionTensor
The padded image.
Source code insrc/refiners/foundationals/segment_anything/utils.py
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n \"\"\"Pad an image with zeros to make it square.\n\n Args:\n image_tensor: The image tensor to pad.\n scaled_size: The scaled size (h, w).\n image_encoder_resolution: Image encoder resolution.\n\n Returns:\n The padded image.\n \"\"\"\n assert len(image_tensor.shape) == 4\n assert image_tensor.shape[2] <= image_encoder_resolution\n assert image_tensor.shape[3] <= image_encoder_resolution\n\n h, w = scaled_size\n padh = image_encoder_resolution - h\n padw = image_encoder_resolution - w\n return pad(image_tensor, (0, padw, 0, padh))\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.postprocess_masks","title":"postprocess_masks","text":"postprocess_masks(\n low_res_masks: Tensor,\n original_size: tuple[int, int],\n image_encoder_resolution: int,\n) -> Tensor\n
Postprocess the masks to fit the original image size and remove zero-padding (if any).
Parameters:
Name Type Description Defaultlow_res_masks
Tensor
The masks to postprocess.
requiredoriginal_size
tuple[int, int]
The original size (h, w).
requiredimage_encoder_resolution
int
Image encoder resolution.
requiredReturns:
Type DescriptionTensor
The postprocessed masks.
Source code insrc/refiners/foundationals/segment_anything/utils.py
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:\n \"\"\"Postprocess the masks to fit the original image size and remove zero-padding (if any).\n\n Args:\n low_res_masks: The masks to postprocess.\n original_size: The original size (h, w).\n image_encoder_resolution: Image encoder resolution.\n\n Returns:\n The postprocessed masks.\n \"\"\"\n scaled_size = compute_scaled_size(original_size, image_encoder_resolution)\n masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode=\"bilinear\")\n masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time\n masks = interpolate(masks, size=Size(original_size), mode=\"bilinear\")\n return masks\n
"},{"location":"reference/foundationals/segment_anything/#refiners.foundationals.segment_anything.utils.preprocess_image","title":"preprocess_image","text":"preprocess_image(\n image: Image,\n image_encoder_resolution: int,\n device: device | None = None,\n dtype: dtype | None = None,\n) -> Tensor\n
Preprocess an image without distorting its aspect ratio.
Parameters:
Name Type Description Defaultimage
Image
The image to preprocess before calling the image encoder.
requiredimage_encoder_resolution
int
Image encoder resolution.
requireddevice
device | None
Tensor device (None by default).
None
dtype
dtype | None
Tensor dtype (None by default).
None
Returns:
Type DescriptionTensor
The preprocessed image.
Source code insrc/refiners/foundationals/segment_anything/utils.py
def preprocess_image(\n image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None\n) -> Tensor:\n \"\"\"Preprocess an image without distorting its aspect ratio.\n\n Args:\n image: The image to preprocess before calling the image encoder.\n image_encoder_resolution: Image encoder resolution.\n device: Tensor device (None by default).\n dtype: Tensor dtype (None by default).\n\n Returns:\n The preprocessed image.\n \"\"\"\n\n scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)\n\n image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype)\n\n return pad_image_tensor(\n normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),\n scaled_size,\n image_encoder_resolution,\n )\n
"},{"location":"reference/foundationals/swin/","title":"
Swin Transformers","text":""},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.swin_transformer.SwinTransformer","title":"SwinTransformer","text":"SwinTransformer(\n patch_size: tuple[int, int] = (4, 4),\n in_chans: int = 3,\n embedding_dim: int = 96,\n depths: list[int] | None = None,\n num_heads: list[int] | None = None,\n window_size: int = 7,\n mlp_ratio: float = 4.0,\n device: device | None = None,\n)\n
Bases: Chain
Swin Transformer (arXiv:2103.14030)
Currently specific to MVANet, only supports square inputs.
Source code insrc/refiners/foundationals/swin/swin_transformer.py
def __init__(\n self,\n patch_size: tuple[int, int] = (4, 4),\n in_chans: int = 3,\n embedding_dim: int = 96,\n depths: list[int] | None = None,\n num_heads: list[int] | None = None,\n window_size: int = 7, # image size is 32 * this\n mlp_ratio: float = 4.0,\n device: Device | None = None,\n) -> None:\n if depths is None:\n depths = [2, 2, 6, 2]\n\n if num_heads is None:\n num_heads = [3, 6, 12, 24]\n\n self.num_layers = len(depths)\n assert len(num_heads) == self.num_layers\n\n super().__init__(\n PatchEmbedding(\n patch_size=patch_size,\n in_chans=in_chans,\n embedding_dim=embedding_dim,\n device=device,\n ),\n fl.Passthrough(\n fl.Transpose(1, 2),\n SquareUnflatten(2),\n fl.SetContext(\"swin\", \"outputs\", callback=lambda t, x: t.append(x)),\n ),\n *(\n fl.Chain(\n BasicLayer(\n dim=int(embedding_dim * 2**i),\n depth=depths[i],\n num_heads=num_heads[i],\n window_size=window_size,\n mlp_ratio=mlp_ratio,\n device=device,\n ),\n fl.Passthrough(\n fl.LayerNorm(int(embedding_dim * 2**i), device=device),\n fl.Transpose(1, 2),\n SquareUnflatten(2),\n fl.SetContext(\"swin\", \"outputs\", callback=lambda t, x: t.insert(0, x)),\n ),\n PatchMerging(dim=int(embedding_dim * 2**i), device=device)\n if i < self.num_layers - 1\n else fl.UseContext(\"swin\", \"outputs\").compose(lambda t: tuple(t)),\n )\n for i in range(self.num_layers)\n ),\n )\n
"},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.swin_transformer.WindowAttention","title":"WindowAttention","text":"WindowAttention(\n dim: int,\n window_size: int,\n num_heads: int,\n shift: bool = False,\n device: device | None = None,\n)\n
Bases: Chain
Window-based Multi-head Self-Attention (W-MSA), optionally shifted (SW-MSA).
It has a trainable relative position bias (RelativePositionBias).
The input projection is stored as a single Linear for q, k and v.
Source code insrc/refiners/foundationals/swin/swin_transformer.py
def __init__(\n self,\n dim: int,\n window_size: int,\n num_heads: int,\n shift: bool = False,\n device: Device | None = None,\n) -> None:\n super().__init__(\n fl.Linear(dim, dim * 3, bias=True, device=device),\n WindowSDPA(window_size, num_heads, shift, device=device),\n fl.Linear(dim, dim, device=device),\n )\n
"},{"location":"reference/foundationals/swin/#refiners.foundationals.swin.mvanet.MVANet","title":"MVANet","text":"MVANet(\n embedding_dim: int = 128,\n n_logits: int = 1,\n depths: list[int] | None = None,\n num_heads: list[int] | None = None,\n window_size: int = 12,\n device: device | None = None,\n)\n
Bases: Chain
Multi-view Aggregation Network for Dichotomous Image Segmentation
See [arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation for more details.
Parameters:
Name Type Description Defaultembedding_dim
int
embedding dimension
128
n_logits
int
the number of output logits (default to 1) 1 logit is used for alpha matting/foreground-background segmentation/sod segmentation
1
depths
list[int]
see SwinTransformer
None
num_heads
list[int]
see SwinTransformer
None
window_size
int
default to 12, see SwinTransformer
12
device
device | None
the device to use
None
Source code in src/refiners/foundationals/swin/mvanet/mvanet.py
def __init__(\n self,\n embedding_dim: int = 128,\n n_logits: int = 1,\n depths: list[int] | None = None,\n num_heads: list[int] | None = None,\n window_size: int = 12,\n device: Device | None = None,\n) -> None:\n if depths is None:\n depths = [2, 2, 18, 2]\n if num_heads is None:\n num_heads = [4, 8, 16, 32]\n\n super().__init__(\n ComputeShallow(embedding_dim=embedding_dim, device=device),\n SplitMultiView(),\n fl.Flatten(0, 1),\n SwinTransformer(\n embedding_dim=embedding_dim,\n depths=depths,\n num_heads=num_heads,\n window_size=window_size,\n device=device,\n ),\n fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),\n Pyramid(embedding_dim=embedding_dim, device=device),\n RearrangeMultiView(embedding_dim=embedding_dim, device=device),\n ShallowUpscaler(embedding_dim, device=device),\n fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),\n )\n
"}]}