initial commit

This commit is contained in:
Cédric Deltheil 2023-08-04 15:28:41 +02:00
commit 48f674c433
109 changed files with 12003 additions and 0 deletions

31
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,31 @@
name: CI
on: push
jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
- name: Set up python
id: setup-python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: poetry install
run: poetry install --no-interaction --extras=training
- name: lint
run: poetry run ruff check .
- name: typecheck
run: poetry run pyright

28
.gitignore vendored Normal file
View file

@ -0,0 +1,28 @@
# compilation and distribution
__pycache__/
*.py[cod]
dist/
# virtual environments
venv/
# unit tests
.pytest_cache/
# tests' model weights
tests/weights/
# ruff
.ruff_cache
# vscode
.vscode
# Weights & Biases (offline trainings)
wandb/
# macos
.DS_Store
# model weights
*.safetensors

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Lagon Technologies
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

304
README.md Normal file
View file

@ -0,0 +1,304 @@
<div align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="assets/logo_dark.png">
<source media="(prefers-color-scheme: light)" srcset="assets/logo_light.png">
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
</picture>
**The simplest way to train and run adapters on top of foundational models**
______________________________________________________________________
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/refiners)](https://pypi.org/project/refiners/)
[![PyPI Status](https://badge.fury.io/py/refiners.svg)](https://badge.fury.io/py/refiners)
[![license](https://img.shields.io/badge/license-MIT-blue)](/LICENSE)
</div>
- [Motivation](#motivation)
- [Design](#design)
- [Downsides](#downsides)
- [Overview](#overview)
- [Key Concepts](#key-concepts)
- [The Chain class](#the-chain-class)
- [The Context API](#the-context-api)
- [The Adapter API](#the-adapter-api)
- [Getting Started](#getting-started)
- [Install](#install)
- [Hello World](#hello-world)
- [Training](#training)
- [Credits](#credits)
- [Citation](#citation)
## Motivation
At [Finegrain](https://finegrain.ai), we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.
That's why we're building Refiners.
It's a framework to easily bridge the last mile quality gap of foundational models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.
We decided to build Refiners in the open.
It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.
## Design
We are huge fans of PyTorch (we actually were core committers to [Torch](http://torch.ch/) in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.
Instead, we needed:
- A model structure that's human readable so that you know what models do and how they work right here, right now
- A mechanism to easily inject parameters in some target layers, or between them
- A way to easily pass data (like a conditioning input) between layers even when deeply nested
- Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on [Civitai](http://civitai.com/) and the likes)
Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.
## Downsides
As they say, there is no free lunch. Given Refiners comes with a new model structure, it can only work with models implemented that way. For now, we support Stable Diffusion 1.5, but more is in the making (SDXL, SAM, ...) - stay tuned.
## Overview
The Refiners library is made of:
1. An abstraction layer (called Fluxion) on top of [PyTorch](https://pytorch.org/) to easily build models
2. A zoo of compatible foundational models
3. Adapter APIs to easily patch supported foundational models
4. Training utils to train concrete adapters
5. Conversion scripts to easily use existing community adapters
## Key Concepts
### The Chain class
The `Chain` class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a **declarative way**.
E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:
```python
import torch
import refiners.fluxion.layers as fl
class ViT(fl.Chain):
# The Vision Transformer model structure is entirely defined in the constructor. It is
# ready-to-use right after i.e. no need to implement any forward function or add extra logic
def __init__(
self,
embedding_dim: int = 512,
patch_size: int = 16,
image_size: int = 384,
num_layers: int = 12,
num_heads: int = 8,
):
num_patches = (image_size // patch_size)
super().__init__(
fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
fl.Reshape(num_patches**2, embedding_dim),
# The Residual layer implements the so-called skip-connection, i.e. x + F(x).
# Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
# weights are stored in the Parameter layer (note: there is no extra classification
# token in this toy example)
fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
# These are the transformer encoders:
*(
fl.Chain(
fl.LayerNorm(embedding_dim),
fl.Residual(
# The Parallel layer is used to pass multiple inputs to a downstream
# layer, here multiheaded self-attention
fl.Parallel(
fl.Identity(),
fl.Identity(),
fl.Identity()
),
fl.Attention(
embedding_dim=embedding_dim,
num_heads=num_heads,
key_embedding_dim=embedding_dim,
value_embedding_dim=embedding_dim,
),
),
fl.LayerNorm(embedding_dim),
fl.Residual(
fl.Linear(embedding_dim, embedding_dim * 4),
fl.GeLU(),
fl.Linear(embedding_dim * 4, embedding_dim),
),
fl.Chain(
fl.Linear(embedding_dim, embedding_dim * 4),
fl.GeLU(),
fl.Linear(embedding_dim * 4, embedding_dim),
),
)
for _ in range(num_layers)
),
fl.Reshape(embedding_dim, num_patches, num_patches),
)
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
x = torch.randn(2, 3, 224, 224)
y = vit(x)
```
### The Context API
The `Chain` class has a context provider that allows you to **pass data to layers even when deeply nested**.
E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:
```diff
@@ -21,8 +21,8 @@
fl.Residual(
fl.Parallel(
fl.Identity(),
- fl.Identity(),
- fl.Identity()
+ fl.UseContext(context="cross_attention", key="my_embed"),
+ fl.UseContext(context="cross_attention", key="my_embed"),
), # used to pass multiple inputs to a layer
fl.Attention(
embedding_dim=embedding_dim,
@@ -49,5 +49,6 @@
)
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
x = torch.randn(2, 3, 224, 224)
y = vit(x)
```
### The Adapter API
The `Adapter` API lets you **easily patch models** by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.
E.g. to inject LoRA layers in all attention's linear layers:
```python
from refiners.adapters.lora import LoraAdapter
for layer in vit.layers(fl.Attention):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
adapter.inject(parent)
# ... and load existing weights if the LoRAs are pretrained ...
```
## Getting Started
### Install
```bash
# inference only
pip install refiners
```
Or:
```bash
# inference + training
pip install 'refiners[training]'
```
### Hello World
Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 foundational model patched with a Pokemon LoRA:
Step 1: prepare the model weights in refiners' format:
```bash
python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors
```
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://huggingface.co/pcuenq/pokemon-lora)
```bash
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
python scripts/convert-lora-weights.py \
--from pytorch_lora_weights.bin \
--output-file pokemon_lora.safetensors
```
Step 3: run inference using the GPU:
```python
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.fluxion.utils import load_from_safetensors, manual_seed
import torch
sd15 = StableDiffusion_1(device="cuda")
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
# This uses the LoraAdapter internally and takes care to inject it where it should
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
lora_weights.patch(sd15, scale=1.0)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
sd15.set_num_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=sd15.device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image.save("pokemon_cat.png")
```
You should get:
![pokemon cat output](assets/pokemon_cat.png)
## Training
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
```bash
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
```
## Credits
We took inspiration from these great projects:
- [tinygrad](https://github.com/tinygrad/tinygrad) - For something between PyTorch and [karpathy/micrograd](https://github.com/karpathy/micrograd)
- [Composer](https://github.com/mosaicml/composer) - A PyTorch Library for Efficient Neural Network Training
- [Keras](https://github.com/keras-team/keras) - Deep Learning for humans
## Citation
```bibtex
@misc{the-finegrain-team-2023-refiners,
author = {Benjamin Trom and Pierre Chapuis and Cédric Deltheil},
title = {Refiners: The simplest way to train and run adapters on top of foundational models},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/finegrain-ai/refiners}}
}
```

BIN
assets/dropy.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

BIN
assets/logo_dark.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
assets/logo_light.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
assets/pokemon_cat.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

55
configs/finetune-ldm.toml Normal file
View file

@ -0,0 +1,55 @@
script = "finetune-ldm.py" # not used for now
[wandb]
offline = "offline"
entity = "acme"
project = "test-ldm-training"
[models]
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
[latent_diffusion]
unconditional_sampling_probability = 0.2
offset_noise = 0.1
[training]
duration = "1:epoch"
seed = 0
gpu_index = 0
num_epochs = 1
batch_size = 1
gradient_accumulation = "1:step"
clip_grad_norm = 2.0
clip_grad_value = 1.0
evaluation_interval = "1:epoch"
evaluation_seed = 0
[optimizer]
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
learning_rate = 1e-5
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
[dropout]
dropout_probability = 0.2
[dataset]
hf_repo = "acme/images"
revision = "main"
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "1:epoch"
[test_diffusion]
prompts = [
"A cute cat",
]

View file

@ -0,0 +1,70 @@
script = "finetune-ldm-lora.py" # not used for now
[wandb]
mode = "offline" # "online", "offline", "disabled"
entity = "acme"
project = "test-lora-training"
[models]
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
[latent_diffusion]
unconditional_sampling_probability = 0.05
offset_noise = 0.1
[lora]
rank = 16
trigger_phrase = "a spsh photo,"
use_only_trigger_probability = 1.0
unet_targets = ["CrossAttentionBlock2d"]
text_encoder_targets = ["TransformerLayer"]
lda_targets = []
[training]
duration = "1000:epoch"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
# clip_grad_value = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
[optimizer]
optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
learning_rate = 1
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
warmup = "500:step"
[dropout]
dropout_probability = 0.2
use_gyro_dropout = false
[dataset]
hf_repo = "acme/images"
revision = "main"
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "1:step"
[test_diffusion]
num_inference_steps = 30
use_short_prompts = false
prompts = [
"a cute cat",
"a cute dog",
"a cute bird",
"a cute horse",
]

3360
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

67
pyproject.toml Normal file
View file

@ -0,0 +1,67 @@
[tool.poetry]
name = "refiners"
version = "0.1.0"
description = "The simplest way to train and run adapters on top of foundational models"
authors = [
"The Finegrain Team <bonjour@lagon.tech>",
]
license = "MIT"
readme = "README.md"
packages = [{include = "refiners", from = "src"}]
[tool.poetry.dependencies]
python = ">=3.10,<3.12"
jaxtyping = "^0.2.14"
torch = "^2.0.0"
safetensors = "^0.3.0"
numpy = "^1.24.2"
pillow = "^9.5.0"
datasets = {version = "^2.14.0", optional = true}
tomli = {version = "^2.0.1", optional = true}
wandb = {version = "^0.15.7", optional = true}
loguru = {version = "^0.7.0", optional = true}
bitsandbytes = {version = "^0.41.0", optional = true}
prodigyopt = {version = "^1.0", optional = true}
pydantic = {git = "https://github.com/pydantic/pydantic.git", rev = "v2.0b3", optional = true}
scipy = {version = "^1.11.1", optional = true}
[tool.poetry.extras]
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy"]
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
pytest = "^7.2.2"
isort = "^5.12.0"
ipykernel = "^6.22.0"
pyright = "^1.1.318"
ruff = "^0.0.281"
[tool.poetry.group.test.dependencies]
diffusers = "^0.18.0"
transformers = "^4.27.4"
piq = "^0.7.1"
invisible-watermark = "^0.2.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 120
preview = true
[tool.ruff]
ignore = [
"F722", # forward-annotation-syntax-error, because of Jaxtyping
"E731", # do-not-assign-lambda
"E501", # line-too-long, because Black (https://beta.ruff.rs/docs/faq/#is-ruff-compatible-with-black)
]
line-length = 120
[tool.pyright]
include = ["src/refiners", "tests", "scripts/training"]
exclude = ["**/__pycache__"]
reportMissingTypeStubs = "warning"

View file

@ -0,0 +1,50 @@
import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline
from transformers.models.clip.modeling_clip import CLIPTextModel
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
@torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderL()
x = dst_model.tokenizer("Nice cat", sequence_length=77)
mapping = create_state_dict_mapping(src_model, dst_model, [x])
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="CLIPTextEncoderL.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder
tensors = convert(src_model)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,203 @@
import torch
from diffusers import ControlNetModel
from safetensors.torch import save_file
from refiners.fluxion.utils import (
forward_order_of_execution,
verify_shape_match,
convert_state_dict,
)
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion import UNet
@torch.no_grad()
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
controlnet = Controlnet(name="mycn")
condition = torch.randn(1, 3, 512, 512)
controlnet.set_controlnet_condition(condition)
unet = UNet(4, clip_embedding_dim=768)
unet.insert(0, controlnet)
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding)
scheduler = DPMSolver(num_inference_steps=10)
timestep = scheduler.timesteps[0].unsqueeze(0)
unet.set_timestep(timestep.unsqueeze(0))
x = torch.randn(1, 4, 64, 64)
# We need the hack below because our implementation is not strictly equivalent
# to diffusers in order, since we compute the residuals inline instead of
# in a separate step.
source_order = forward_order_of_execution(controlnet_src, (x, timestep, clip_text_embedding, condition))
target_order = forward_order_of_execution(controlnet, (x,))
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
expected_source_order = [
"down_blocks.0.attentions.0.proj_in",
"down_blocks.0.attentions.0.proj_out",
"down_blocks.0.attentions.1.proj_in",
"down_blocks.0.attentions.1.proj_out",
"controlnet_down_blocks.0",
"controlnet_down_blocks.1",
"controlnet_down_blocks.2",
"controlnet_down_blocks.3",
]
expected_target_order = [
"DownBlocks.Chain_1.Passthrough.Conv2d",
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_2.Passthrough.Conv2d",
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_3.Passthrough.Conv2d",
"DownBlocks.Chain_4.Passthrough.Conv2d",
]
fixed_source_order = [
"controlnet_down_blocks.0",
"down_blocks.0.attentions.0.proj_in",
"down_blocks.0.attentions.0.proj_out",
"controlnet_down_blocks.1",
"down_blocks.0.attentions.1.proj_in",
"down_blocks.0.attentions.1.proj_out",
"controlnet_down_blocks.2",
"controlnet_down_blocks.3",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640])))
expected_source_order = [
"down_blocks.1.attentions.0.proj_in",
"down_blocks.1.attentions.0.proj_out",
"down_blocks.1.attentions.1.proj_in",
"down_blocks.1.attentions.1.proj_out",
"controlnet_down_blocks.4",
"controlnet_down_blocks.5",
"controlnet_down_blocks.6",
]
expected_target_order = [
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_5.Passthrough.Conv2d",
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_6.Passthrough.Conv2d",
"DownBlocks.Chain_7.Passthrough.Conv2d",
]
fixed_source_order = [
"down_blocks.1.attentions.0.proj_in",
"down_blocks.1.attentions.0.proj_out",
"controlnet_down_blocks.4",
"down_blocks.1.attentions.1.proj_in",
"down_blocks.1.attentions.1.proj_out",
"controlnet_down_blocks.5",
"controlnet_down_blocks.6",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
expected_source_order = [
"down_blocks.2.attentions.0.proj_in",
"down_blocks.2.attentions.0.proj_out",
"down_blocks.2.attentions.1.proj_in",
"down_blocks.2.attentions.1.proj_out",
"mid_block.attentions.0.proj_in",
"mid_block.attentions.0.proj_out",
"controlnet_down_blocks.7",
"controlnet_down_blocks.8",
"controlnet_down_blocks.9",
"controlnet_down_blocks.10",
"controlnet_down_blocks.11",
"controlnet_mid_block",
]
expected_target_order = [
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_8.Passthrough.Conv2d",
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"DownBlocks.Chain_9.Passthrough.Conv2d",
"DownBlocks.Chain_10.Passthrough.Conv2d",
"DownBlocks.Chain_11.Passthrough.Conv2d",
"DownBlocks.Chain_12.Passthrough.Conv2d",
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
"MiddleBlock.Passthrough.Conv2d",
]
fixed_source_order = [
"down_blocks.2.attentions.0.proj_in",
"down_blocks.2.attentions.0.proj_out",
"controlnet_down_blocks.7",
"down_blocks.2.attentions.1.proj_in",
"down_blocks.2.attentions.1.proj_out",
"controlnet_down_blocks.8",
"controlnet_down_blocks.9",
"controlnet_down_blocks.10",
"controlnet_down_blocks.11",
"mid_block.attentions.0.proj_in",
"mid_block.attentions.0.proj_out",
"controlnet_mid_block",
]
assert source_order[broken_k] == expected_source_order
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
assert verify_shape_match(source_order, target_order)
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
state_dict = convert_state_dict(controlnet_src.state_dict(), controlnet.state_dict(), state_dict_mapping=mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=True,
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="output.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
controlnet_src = ControlNetModel.from_pretrained(args.source)
tensors = convert(controlnet_src)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,115 @@
# Note: this conversion script currently only support simple LoRAs which adapt
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
import torch
from torch.nn.init import zeros_
from torch.nn import Parameter as TorchParameter
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
from refiners.adapters.lora import Lora
from refiners.fluxion.utils import create_state_dict_mapping
from diffusers import DiffusionPipeline
def get_weight(linear: fl.Linear) -> torch.Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
weights: list[torch.Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear))
assert len(linears) == 2
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
@torch.no_grad()
def process(source: str, base_model: str, output_file: str) -> None:
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
diffusers_sd = DiffusionPipeline.from_pretrained(base_model) # type: ignore
diffusers_model = diffusers_sd.unet
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
target = LoraTarget.CrossAttention
metadata = {"unet_targets": "CrossAttentionBlock2d"}
rank = diffusers_state_dict[
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
].shape[0]
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0])
clip_text_embeddings = torch.randn(1, 77, 768)
refiners_model.set_timestep(timestep)
refiners_model.set_clip_text_embedding(clip_text_embeddings)
refiners_args = (x,)
diffusers_args = (x, timestep, clip_text_embeddings)
diffusers_to_refiners = create_state_dict_mapping(refiners_model, diffusers_model, refiners_args, diffusers_args)
assert diffusers_to_refiners
apply_loras_to_target(refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
for layer in refiners_model.layers(layer_type=Lora):
zeros_(layer.Linear_1.weight)
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
for target_k in targets:
k_p, k_s = target_k.split(".processor.")
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
assert len(r) == 1
orig_k = r[0]
orig_path = orig_k.split(".")
p = refiners_model
for seg in orig_path[:-1]:
p = p[seg]
last_seg = (
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
)
p_down = TorchParameter(diffusers_state_dict[f"{target_k}_lora.down.weight"])
p_up = TorchParameter(diffusers_state_dict[f"{target_k}_lora.up.weight"])
p[last_seg].Lora.load_weights(p_down, p_up)
state_dict = build_loras_safetensors(refiners_model, key_prefix="unet.")
assert len(state_dict) == 320
save_to_safetensors(output_file, tensors=state_dict, metadata=metadata)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=True,
help="Source file path (.bin)",
)
parser.add_argument(
"--base-model",
type=str,
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Base model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="output.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
process(source=args.source, base_model=args.base_model, output_file=args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,134 @@
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.fluxion.layers.module import Module
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import create_state_dict_mapping
import torch
from diffusers import DiffusionPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from transformers.models.clip.modeling_clip import CLIPTextModel
@torch.no_grad()
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings)
dst_args = (x,)
return create_state_dict_mapping(src_model, dst_model, src_args, dst_args) # type: ignore
@torch.no_grad()
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
x = dst_model.tokenizer("Nice cat", sequence_length=77)
return create_state_dict_mapping(src_model, dst_model, [x]) # type: ignore
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-file",
type=str,
required=True,
help="Path to the input file with refiner's LoRA weights (safetensors format)",
)
parser.add_argument(
"-o",
"--output-file",
type=str,
required=True,
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
)
parser.add_argument(
"--sd15",
type=str,
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
)
args = parser.parse_args()
metadata = load_metadata_from_safetensors(args.input_file)
assert metadata is not None
tensors = load_from_safetensors(args.input_file)
diffusers_sd = DiffusionPipeline.from_pretrained(args.sd15) # type: ignore
state_dict: dict[str, torch.Tensor] = {}
for meta_key, meta_value in metadata.items():
match meta_key:
case "unet_targets":
src_model = diffusers_sd.unet # type: ignore
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
create_mapping = create_unet_mapping
key_prefix = "unet."
lora_prefix = "lora_unet_"
case "text_encoder_targets":
src_model = diffusers_sd.text_encoder # type: ignore
dst_model = CLIPTextEncoderL()
create_mapping = create_text_encoder_mapping
key_prefix = "text_encoder."
lora_prefix = "lora_te_"
case "lda_targets":
raise ValueError("SD-WebUI does not support LoRA for the auto-encoder")
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
submodule_to_key: dict[Module, str] = {}
for name, submodule in dst_model.named_modules():
submodule_to_key[submodule] = name
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
#
# lora_unet_down_blocks_0_attentions_0_proj_in.alpha
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight
# ...
#
# Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See
# `convert_diffusers_name_to_compvis` for more details.
#
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
refiners_to_diffusers = create_mapping(src_model, dst_model) # type: ignore
assert refiners_to_diffusers is not None
# Compute the corresponding diffusers' keys where LoRA layers must be applied
lora_injection_points: list[str] = [
refiners_to_diffusers[submodule_to_key[linear]]
for target in [LoraTarget(t) for t in meta_value.split(",")]
for layer in dst_model.layers(layer_type=target.get_class())
for linear in layer.layers(fl.Linear)
]
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
assert len(lora_injection_points) == len(lora_weights) // 2
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
for i, diffusers_key in enumerate(lora_injection_points):
lora_key = lora_prefix + diffusers_key.replace(".", "_")
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
# by default (see `lora_calc_updown` in SD-WebUI for more details)
state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i]
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
assert state_dict
save_to_safetensors(args.output_file, state_dict)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,50 @@
import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline
from diffusers.models.autoencoder_kl import AutoencoderKL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@torch.no_grad()
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
dst_model = LatentDiffusionAutoencoder()
x = torch.randn(1, 3, 512, 512)
mapping = create_state_dict_mapping(src_model, dst_model, [x])
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="lda.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).vae
tensors = convert(src_model)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,59 @@
import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import StableDiffusionInpaintPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from refiners.foundationals.latent_diffusion.unet import UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
x = torch.randn(1, 9, 32, 32)
timestep = torch.tensor([0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings)
dst_args = (x,)
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-inpainting",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_1_5_inpainting_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = StableDiffusionInpaintPipeline.from_pretrained(args.source).unet
tensors = convert(src_model)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,59 @@
import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from refiners.foundationals.latent_diffusion.unet import UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0])
clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings)
dst_args = (x,)
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="runwayml/stable-diffusion-v1-5",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_1_5_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).unet
tensors = convert(src_model)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,57 @@
import torch
from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
import refiners.fluxion.layers as fl
@torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderG()
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
x = dst_model.tokenizer("Nice cat", sequence_length=77)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
if mapping is None:
raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="stabilityai/stable-diffusion-xl-base-0.9",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="CLIPTextEncoderG.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
tensors = convert(src_model=src_model)
save_file(tensors=tensors, filename=args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,68 @@
import torch
from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = SDXLUNet(in_channels=4)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0])
clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
src_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_model.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
dst_model.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
dst_args = (x,)
mapping = create_state_dict_mapping(
source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args # type: ignore
)
if mapping is None:
raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v for k, v in state_dict.items()}
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="stabilityai/stable-diffusion-xl-base-0.9",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="stable_diffusion_xl_unet.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,148 @@
import random
from typing import Any
from pydantic import BaseModel
from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
from refiners.training_utils.callback import Callback
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
LatentDiffusionTrainer,
LatentDiffusionConfig,
)
class LoraConfig(BaseModel):
rank: int = 32
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
unet_targets: list[LoraTarget]
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
for layer in module.layers(layer_type=target.get_class()):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(
target=linear,
rank=self.rank,
device=module.device,
dtype=module.dtype,
)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
self,
trainer: "LoraLatentDiffusionTrainer",
) -> None:
super().__init__(trainer=trainer)
self.trigger_phrase = trainer.config.lora.trigger_phrase
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
logger.info(f"Trigger phrase: {self.trigger_phrase}")
def process_caption(self, caption: str) -> str:
caption = super().process_caption(caption=caption)
if self.trigger_phrase:
caption = (
f"{self.trigger_phrase} {caption}"
if random.random() < self.use_only_trigger_probability
else self.trigger_phrase
)
return caption
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig
def model_post_init(self, __context: Any) -> None:
"""Pydantic v2 does post init differently, so we need to override this method too."""
logger.info("Freezing models to train only the loras.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
def __init__(
self,
config: LoraLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadLoras(), SaveLoras()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TriggerPhraseDataset(trainer=self)
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for target in lora_config.unet_targets:
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
for target in lora_config.text_encoder_targets:
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
for target in lora_config.lda_targets:
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
def get_weight(linear: fl.Linear) -> Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
weights: list[Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear))
assert len(linears) == 2
# See `load_lora_weights` in refiners.adapters.lora
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
if lora_config.unet_targets:
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
if lora_config.text_encoder_targets:
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
if lora_config.lda_targets:
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
tensors=tensors,
metadata=metadata,
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = LoraLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LoraLatentDiffusionTrainer(config=config)
trainer.train()

View file

@ -0,0 +1,11 @@
from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = FinetuneLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LatentDiffusionTrainer(config=config)
trainer.train()

0
src/refiners/__init__.py Normal file
View file

View file

View file

@ -0,0 +1,66 @@
import contextlib
import refiners.fluxion.layers as fl
from typing import Any, Generic, TypeVar, Iterator
T = TypeVar("T", bound=fl.Module)
TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]")
class Adapter(Generic[T]):
# we store _target into a one element list to avoid pytorch thinking it is a submodule
_target: "list[T]"
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
assert issubclass(cls, fl.Chain), f"Adapter {cls.__name__} must be a Chain"
@property
def target(self) -> T:
return self._target[0]
@contextlib.contextmanager
def setup_adapter(self, target: T) -> Iterator[None]:
assert isinstance(self, fl.Chain)
assert (not hasattr(self, "_modules")) or (
len(self) == 0
), "Call the Chain constructor in the setup_adapter context."
self._target = [target]
if not isinstance(self.target, fl.ContextModule):
yield
return
_old_can_refresh_parent = target._can_refresh_parent
target._can_refresh_parent = False
yield
target._can_refresh_parent = _old_can_refresh_parent
def inject(self, parent: fl.Chain | None = None) -> None:
assert isinstance(self, fl.Chain)
if parent is None:
if isinstance(self.target, fl.ContextModule):
parent = self.target.parent
else:
raise ValueError(f"parent of {self.target} is mandatory")
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
if self.target not in iter(parent):
raise ValueError(f"{self.target} is not in {parent}")
parent.replace(
old_module=self.target,
new_module=self,
old_module_parent=self.find_parent(self.target),
)
def eject(self) -> None:
assert isinstance(self, fl.Chain)
self.ensure_parent.replace(old_module=self, new_module=self.target)
def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain):
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
self._target = [source.target]

View file

@ -0,0 +1,88 @@
import refiners.fluxion.layers as fl
from refiners.adapters.adapter import Adapter
from torch.nn.init import zeros_, normal_
from torch import Tensor, device as Device, dtype as DType
class Lora(fl.Chain):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
)
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def set_scale(self, scale: float) -> None:
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = down_weight
self.Linear_2.weight = up_weight
class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__(
self,
target: fl.Linear,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=device,
dtype=dtype,
),
)
self.Lora.set_scale(scale=scale)
def add_lora(self, lora: Lora) -> None:
self.append(module=lora)
def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None:
self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight)
def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None:
assert len(weights) % 2 == 0, "Number of weights must be even"
assert (
len(list(model.layers(layer_type=Lora))) == len(weights) // 2
), "Number of Lora layers must match number of weights"
for i, lora in enumerate(iterable=model.layers(layer_type=Lora)):
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])

View file

@ -0,0 +1,70 @@
import math
from torch import Tensor, arange, float32, exp, sin, cat, cos, device as Device, dtype as DType
from jaxtyping import Float, Int
from refiners.adapters.adapter import Adapter
import refiners.fluxion.layers as fl
def compute_sinusoidal_embedding(
x: Int[Tensor, "*batch 1"],
embedding_dim: int,
) -> Float[Tensor, "*batch 1 embedding_dim"]:
half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
exponent /= half_dim
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
return embedding
class RangeEncoder(fl.Chain):
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
def __init__(
self,
sinuosidal_embedding_dim: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(self.compute_sinuosoidal_embedding),
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim).to(self.dtype)
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
structural_attrs = ["channels", "embedding_dim", "context_key"]
def __init__(
self,
target: fl.Conv2d,
channels: int,
embedding_dim: int,
context_key: str,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.channels = channels
self.embedding_dim = embedding_dim
self.context_key = context_key
with self.setup_adapter(target):
super().__init__(
target,
fl.Chain(
fl.UseContext("range_adapter", context_key),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
fl.View(-1, channels, 1, 1),
),
)

View file

@ -0,0 +1,3 @@
from refiners.fluxion.utils import save_to_safetensors, load_from_safetensors, norm, manual_seed, pad
__all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"]

View file

@ -0,0 +1,52 @@
from typing import Any
from torch import Tensor
Context = dict[str, Any]
Contexts = dict[str, Context]
class ContextProvider:
def __init__(self) -> None:
self.contexts: Contexts = {}
def set_context(self, key: str, value: Context) -> None:
self.contexts[key] = value
def get_context(self, key: str) -> Any:
return self.contexts.get(key)
def update_contexts(self, new_contexts: Contexts) -> None:
for key, value in new_contexts.items():
if key not in self.contexts:
self.contexts[key] = value
else:
self.contexts[key].update(value)
@staticmethod
def create(contexts: Contexts) -> "ContextProvider":
provider = ContextProvider()
provider.update_contexts(contexts)
return provider
def __add__(self, other: "ContextProvider") -> "ContextProvider":
self.contexts.update(other.contexts)
return self
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
other.contexts.update(self.contexts)
return other
def __bool__(self) -> bool:
return bool(self.contexts)
def _get_repr_for_value(self, value: Any) -> str:
if isinstance(value, Tensor):
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
return repr(value)
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
def __repr__(self) -> str:
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
return f"{self.__class__.__name__}(contexts={contexts_repr})"

View file

@ -0,0 +1,82 @@
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
from refiners.fluxion.layers.basics import (
Identity,
View,
Flatten,
Unflatten,
Transpose,
Permute,
Reshape,
Squeeze,
Unsqueeze,
Slicing,
Parameter,
Buffer,
)
from refiners.fluxion.layers.chain import (
Lambda,
Sum,
Residual,
Return,
Chain,
UseContext,
SetContext,
Parallel,
Passthrough,
Breakpoint,
Concatenate,
)
from refiners.fluxion.layers.conv import Conv2d
from refiners.fluxion.layers.linear import Linear, MultiLinear
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding
__all__ = [
"Embedding",
"LayerNorm",
"GroupNorm",
"LayerNorm2d",
"GeLU",
"GLU",
"SiLU",
"ReLU",
"ApproximateGeLU",
"Attention",
"SelfAttention",
"SelfAttention2d",
"Identity",
"View",
"Flatten",
"Unflatten",
"Transpose",
"Permute",
"Squeeze",
"Unsqueeze",
"Reshape",
"Slicing",
"Parameter",
"Buffer",
"Lambda",
"Return",
"Sum",
"Residual",
"Chain",
"UseContext",
"SetContext",
"Parallel",
"Passthrough",
"Breakpoint",
"Concatenate",
"Conv2d",
"Linear",
"MultiLinear",
"Downsample",
"Upsample",
"Module",
"WeightedModule",
"ContextModule",
"Interpolate",
]

View file

@ -0,0 +1,66 @@
from refiners.fluxion.layers.module import Module
from torch.nn.functional import silu
from torch import Tensor, sigmoid
from torch.nn.functional import gelu # type: ignore
class Activation(Module):
def __init__(self) -> None:
super().__init__()
class SiLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return silu(x) # type: ignore
class ReLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x.relu()
class GeLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return gelu(x) # type: ignore
class ApproximateGeLU(Activation):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x * sigmoid(1.702 * x)
class GLU(Activation):
"""
Gated Linear Unit activation layer.
See https://arxiv.org/abs/2002.05202v1 for details.
"""
def __init__(self, activation: Activation) -> None:
super().__init__()
self.activation = activation
def __repr__(self):
return f"{self.__class__.__name__}(activation={self.activation})"
def forward(self, x: Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0, "Non-batch input dimension must be divisible by 2"
output, gate = x.chunk(2, dim=-1)
return output * self.activation(gate)

View file

@ -0,0 +1,189 @@
from jaxtyping import Float
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.layers.linear import Linear
from refiners.fluxion.layers.module import Module
from refiners.fluxion.layers.chain import Chain, Distribute, Parallel, Lambda
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.context import Contexts
def scaled_dot_product_attention(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
class ScaledDotProductAttention(Module):
def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None:
super().__init__()
self.num_heads = num_heads
self.is_causal = is_causal
def forward(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
return self.merge_multi_head(
scaled_dot_product_attention(
query=self.split_to_multi_head(query),
key=self.split_to_multi_head(key),
value=self.split_to_multi_head(value),
is_causal=(
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
),
)
)
def split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
assert (
len(x.shape) == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def merge_multi_head(
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"]
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
class Attention(Chain):
structural_attrs = [
"embedding_dim",
"num_heads",
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"use_bias",
"is_causal",
]
def __init__(
self,
embedding_dim: int,
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
Distribute(
Linear(
in_features=self.embedding_dim,
out_features=self.embedding_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.key_embedding_dim,
out_features=self.embedding_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.value_embedding_dim,
out_features=self.embedding_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
Linear(
in_features=self.embedding_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class SelfAttention(Attention):
def __init__(
self,
embedding_dim: int,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
embedding_dim=embedding_dim,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
device=device,
dtype=dtype,
)
self.insert(0, Parallel(Identity(), Identity(), Identity()))
class SelfAttention2d(SelfAttention):
structural_attrs = ["channels"]
def __init__(
self,
channels: int,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
self.channels = channels
super().__init__(
embedding_dim=channels,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
device=device,
dtype=dtype,
)
self.insert(0, Lambda(self.tensor_2d_to_sequence))
self.append(Lambda(self.sequence_to_tensor_2d))
def init_context(self) -> Contexts:
return {"reshape": {"height": None, "width": None}}
def tensor_2d_to_sequence(
self, x: Float[Tensor, "batch channels height width"]
) -> Float[Tensor, "batch height*width channels"]:
height, width = x.shape[-2:]
self.set_context(context="reshape", value={"height": height, "width": width})
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
def sequence_to_tensor_2d(
self, x: Float[Tensor, "batch sequence_length channels"]
) -> Float[Tensor, "batch channels height width"]:
height, width = self.use_context("reshape").values()
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)

View file

@ -0,0 +1,183 @@
from refiners.fluxion.layers.module import Module, WeightedModule
from torch import randn, Tensor, Size, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
class Identity(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x
class View(Module):
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim)
def __repr__(self):
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
class Unflatten(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Reshape(Module):
"""
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
dimension is preserved.
"""
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Transpose(Module):
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1)
def __repr__(self):
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
class Permute(Module):
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims
def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims)
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in self.dims])
return f"{self.__class__.__name__}({dims_repr})"
class Slicing(Module):
def __init__(self, dim: int, start: int, length: int) -> None:
super().__init__()
self.dim = dim
self.start = start
self.length = length
def forward(self, x: Tensor) -> Tensor:
return x.narrow(self.dim, self.start, self.length)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
class Squeeze(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Unsqueeze(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Parameter(WeightedModule):
"""
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
@property
def device(self) -> Device:
return self.parameter.device
@property
def dtype(self) -> DType:
return self.parameter.dtype
def forward(self, _: Tensor) -> Tensor:
return self.parameter
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
class Buffer(WeightedModule):
"""
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
Buffers are not trainable.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property
def device(self) -> Device:
return self.buffer.device
@property
def dtype(self) -> DType:
return self.buffer.dtype
def forward(self, _: Tensor) -> Tensor:
return self.buffer
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"

View file

@ -0,0 +1,466 @@
import inspect
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
from torch import Tensor, cat, device as Device, dtype as DType
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
from refiners.fluxion.context import Contexts, ContextProvider
T = TypeVar("T", bound=Module)
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
class Lambda(Module):
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
def __init__(self, func: Callable[..., Any]) -> None:
super().__init__()
self.func = func
def forward(self, *args: Any) -> Any:
return self.func(*args)
def __repr__(self):
func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
def generate_unique_names(
modules: tuple[Module, ...],
) -> dict[str, Module]:
class_counts: dict[str, int] = {}
unique_names: list[tuple[str, Module]] = []
for module in modules:
class_name = module.__class__.__name__
class_counts[class_name] = class_counts.get(class_name, 0) + 1
name_counter: dict[str, int] = {}
for module in modules:
class_name = module.__class__.__name__
name_counter[class_name] = name_counter.get(class_name, 0) + 1
unique_name = f"{class_name}_{name_counter[class_name]}" if class_counts[class_name] > 1 else class_name
unique_names.append((unique_name, module))
return dict(unique_names)
class UseContext(ContextModule):
structural_attrs = ["context", "key", "func"]
def __init__(self, context: str, key: str) -> None:
super().__init__()
self.context = context
self.key = key
self.func: Callable[[Any], Any] = lambda x: x
def __call__(self, *args: Any) -> Any:
context = self.use_context(self.context)
assert context, f"context {self.context} is unset"
value = context.get(self.key)
assert value is not None, f"context entry {self.context}.{self.key} is unset"
return self.func(value)
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
self.func = func
return self
class SetContext(ContextModule):
"""A Module that sets a context value when executed.
The context need to pre exist in the context provider.
#TODO Is there a way to create the context if it doesn't exist?
"""
structural_attrs = ["context", "key", "callback"]
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
super().__init__()
self.context = context
self.key = key
self.callback = callback
def __call__(self, x: Tensor) -> Tensor:
if context := self.use_context(self.context):
if not self.callback:
context.update({self.key: x})
else:
self.callback(context[self.key], x)
return x
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
class ReturnException(Exception):
"""Exception raised when a Return module is encountered."""
def __init__(self, value: Tensor):
self.value = value
class Return(Module):
"""A Module that stops the execution of a Chain when encountered."""
def forward(self, x: Tensor):
raise ReturnException(x)
def structural_copy(m: T) -> T:
return m.structural_copy() if isinstance(m, ContextModule) else m
class Chain(ContextModule):
_modules: dict[str, Module]
_provider: ContextProvider
def __init__(self, *args: Module | Iterable[Module]) -> None:
super().__init__()
self._provider = ContextProvider()
modules = cast(
tuple[Module],
(
tuple(args[0])
if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
else tuple(args)
),
)
for module in modules:
# Violating this would mean a ContextModule ends up in two chains,
# with a single one correctly set as its parent.
assert (
(not isinstance(module, ContextModule))
or (not module._can_refresh_parent)
or (module.parent is None)
or (module.parent == self)
), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"
self._regenerate_keys(modules)
self._reset_context()
for module in self:
if isinstance(module, ContextModule) and module._can_refresh_parent and module.parent != self:
module._set_parent(self)
@property
def provider(self) -> ContextProvider:
return self._provider
def init_context(self) -> Contexts:
return {}
def _register_provider(self, context: Contexts | None = None) -> None:
if context:
self._provider.update_contexts(context)
for module in self:
if isinstance(module, Chain):
module._register_provider(context=self._provider.contexts)
def _reset_context(self) -> None:
self._register_provider(self.init_context())
def set_context(self, context: str, value: Any) -> None:
self._provider.set_context(context, value)
self._register_provider()
def debug_repr(self, layer_name: str = "") -> str:
lines: list[str] = []
tab = " "
tab_length = 0
for i, parent in enumerate(self.get_parents()[::-1]):
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
tab_length += 1
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
for name, _ in self._modules.items():
error_arrow = "⚠️" if name == layer_name else ""
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
return "\n".join(lines)
def call_layer(self, layer: Module, layer_name: str, *args: Any):
try:
return layer(*args)
except Exception as e:
pretty_print = self.debug_repr(layer_name)
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
def forward(self, *args: Any) -> Any:
result: tuple[Any] | Any = None
intermediate_args: tuple[Any, ...] = args
for name, layer in self._modules.items():
result = self.call_layer(layer, name, *intermediate_args)
intermediate_args = (result,) if not isinstance(result, tuple) else result
self._reset_context()
return result
def _regenerate_keys(self, modules: Iterable[Module]) -> None:
self._modules = generate_unique_names(tuple(modules)) # type: ignore
def __add__(self, other: "Chain | Module | list[Module]") -> "Chain":
if isinstance(other, Module):
other = Chain(other)
if isinstance(other, list):
other = Chain(*other)
return Chain(*self, *other)
def __getitem__(self, key: int | str | slice) -> Module:
if isinstance(key, slice):
return Chain(*list(self)[key])
elif isinstance(key, str):
return self._modules[key]
else:
return list(self)[key]
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
layer_name = self.__class__.__name__ if layer_name is None else layer_name
pretty_print = f"{layer_name}:\n"
tab = " " * (num_tab + 4)
module_strings: list[str] = []
for i, (name, module) in enumerate(self._modules.items()):
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
module_str = (
module
if not isinstance(module, Chain)
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
)
module_strings.append(f"{tab}{ident} {module_str}")
pretty_print += "\n".join(module_strings)
return pretty_print
def __repr__(self) -> str:
return self._pretty_print()
def __str__(self) -> str:
return f"<{self.__class__.__name__} at {hex(id(self))}>"
def __len__(self) -> int:
return len(self._modules)
@property
def device(self) -> Device | None:
wm = self.find(WeightedModule)
return None if wm is None else wm.device
@property
def dtype(self) -> DType | None:
wm = self.find(WeightedModule)
return None if wm is None else wm.dtype
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
if predicate is None:
predicate = lambda _m, _p: True
for module in self:
keep_going = True
try:
p = predicate(module, self)
except StopIteration:
p = False
keep_going = False
if p:
yield (module, self)
if keep_going and isinstance(module, Chain):
yield from module.walk(predicate)
@overload
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
...
@overload
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]:
...
def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
if isinstance(predicate, type):
return self._walk(lambda m, _: isinstance(m, predicate))
else:
return self._walk(predicate)
def layers(self, layer_type: type[T]) -> Iterator[T]:
for module, _ in self.walk(layer_type):
yield module
def find(self, layer_type: type[T]) -> T | None:
return next(self.layers(layer_type=layer_type), None)
def find_parent(self, module: Module) -> "Chain | None":
if module in self: # avoid DFS-crawling the whole tree
return self
for _, parent in self.walk(lambda m, _: m == module):
return parent
return None
def insert(self, index: int, module: Module) -> None: # type: ignore
if index < 0:
index = max(0, len(self._modules) + index + 1)
modules = list(self)
modules.insert(index, module)
self._regenerate_keys(modules)
if isinstance(module, ContextModule):
module._set_parent(self)
self._register_provider()
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i + 1, new_module)
return
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def append(self, module: Module) -> None: # type: ignore
modules = list(self)
modules.append(module)
self._regenerate_keys(modules)
if isinstance(module, ContextModule):
module._set_parent(self)
self._register_provider()
def pop(self, index: int = -1) -> Module | tuple[Module]: # type: ignore
modules = list(self)
if index < 0:
index = len(modules) + index
if index < 0 or index >= len(modules):
raise IndexError("Index out of range.")
removed_module = modules.pop(index)
if isinstance(removed_module, ContextModule):
removed_module._set_parent(None)
self._regenerate_keys(modules)
return removed_module
def remove(self, module: Module) -> None:
"""Remove a module from the chain."""
modules = list(self)
try:
modules.remove(module)
except ValueError:
raise ValueError(f"{module} is not in {self}")
self._regenerate_keys(modules)
if isinstance(module, ContextModule):
module._set_parent(None)
def replace(
self,
old_module: Module,
new_module: Module,
old_module_parent: "Chain | None" = None,
) -> None:
"""Replace a module in the chain with a new module."""
modules = list(self)
try:
modules[modules.index(old_module)] = new_module
except ValueError:
raise ValueError(f"{old_module} is not in {self}")
self._regenerate_keys(modules)
if isinstance(new_module, ContextModule):
new_module._set_parent(self)
if isinstance(old_module, ContextModule):
old_module._set_parent(old_module_parent)
def structural_copy(self: TChain) -> TChain:
"""Copy the structure of the Chain tree.
This method returns a recursive copy of the Chain tree where all inner nodes
(instances of Chain and its subclasses) are duplicated and all leaves
(regular Modules) are not.
Such copies can be adapted without disrupting the base model, but do not
require extra GPU memory since the weights are in the leaves and hence not copied.
This assumes all subclasses define the class variable `structural_attrs` which
contains a list of basic attributes set in the constructor. In complicated cases
it may be required to overwrite that method.
"""
if hasattr(self, "_pre_structural_copy"):
self._pre_structural_copy()
modules = [structural_copy(m) for m in self]
# Instantiate the right subclass, but do not initialize.
clone = object.__new__(self.__class__)
# Copy all basic attributes of the class declared in `structural_attrs`.
for k in self.__class__.structural_attrs:
setattr(clone, k, getattr(self, k))
# Call constructor of Chain, which among other things refreshes the context tree.
Chain.__init__(clone, *modules)
for module in modules:
if isinstance(module, ContextModule):
module._set_parent(clone)
if hasattr(clone, "_post_structural_copy"):
clone._post_structural_copy(self)
return clone
class Parallel(Chain):
def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
class Distribute(Chain):
def forward(self, *args: Any) -> tuple[Tensor, ...]:
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
class Passthrough(Chain):
def forward(self, *inputs: Any) -> Any:
super().forward(*inputs)
return inputs
class Sum(Chain):
def forward(self, *inputs: Any) -> Any:
output = None
for layer in self:
layer_output: Any = layer(*inputs)
if isinstance(layer_output, tuple):
layer_output = sum(layer_output) # type: ignore
output = layer_output if output is None else output + layer_output
return output
class Residual(Sum):
def __init__(self, *modules: Module) -> None:
super().__init__(Identity(), Chain(*modules))
class Breakpoint(Module):
def __init__(self, vscode: bool = True):
super().__init__()
self.vscode = vscode
def forward(self, *args: Any):
if self.vscode:
import debugpy # type: ignore
debugpy.breakpoint() # type: ignore
else:
breakpoint()
return args[0] if len(args) == 1 else args
class Concatenate(Chain):
structural_attrs = ["dim"]
def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules)
self.dim = dim
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat([output for output in outputs if output is not None], dim=self.dim)

View file

@ -0,0 +1,73 @@
from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d
from torch import device as Device, dtype as DType
from refiners.fluxion.layers.module import WeightedModule
class Conv2d(_Conv2d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] | str = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
use_bias: bool = True,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
use_bias,
padding_mode,
device,
dtype,
)
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = (padding,) if isinstance(padding, int) else padding
self.dilation = (dilation,) if isinstance(dilation, int) else dilation
self.groups = groups
self.use_bias = use_bias
self.padding_mode = padding_mode
class Conv1d(_Conv1d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] | str = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
use_bias: bool = True,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
use_bias,
padding_mode,
device,
dtype,
)
self.in_channels = in_channels
self.out_channels = out_channels
self.use_bias = use_bias

View file

@ -0,0 +1,21 @@
from refiners.fluxion.layers.module import WeightedModule
from torch.nn import Embedding as _Embedding
from torch import Tensor, device as Device, dtype as DType
from jaxtyping import Float, Int
class Embedding(_Embedding, WeightedModule): # type: ignore
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
_Embedding.__init__( # type: ignore
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
)
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
return super().forward(x)

View file

@ -0,0 +1,50 @@
from torch import device as Device, dtype as DType
from torch.nn import Linear as _Linear
from torch import Tensor
from refiners.fluxion.layers.module import Module, WeightedModule
from refiners.fluxion.layers.activations import ReLU
from refiners.fluxion.layers.chain import Chain
from jaxtyping import Float
class Linear(_Linear, WeightedModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
super().__init__( # type: ignore
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, x: Float[Tensor, "batch in_features"]) -> Float[Tensor, "batch out_features"]: # type: ignore
return super().forward(x)
class MultiLinear(Chain):
def __init__(
self,
input_dim: int,
output_dim: int,
inner_dim: int,
num_layers: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
layers: list[Module] = []
for i in range(num_layers - 1):
layers.append(Linear(input_dim if i == 0 else inner_dim, inner_dim, device=device, dtype=dtype))
layers.append(ReLU())
layers.append(Linear(inner_dim, output_dim, device=device, dtype=dtype))
super().__init__(layers)

View file

@ -0,0 +1,100 @@
from pathlib import Path
from typing import Any, Generator, TypeVar
from torch import device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.context import Context, ContextProvider
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
from refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound="Module")
TContextModule = TypeVar("TContextModule", bound="ContextModule")
class Module(TorchModule):
_parameters: dict[str, Any]
_buffers: dict[str, Any]
__getattr__: Callable[["Module", str], Any] # type: ignore
__setattr__: Callable[["Module", str, Any], None] # type: ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs) # type: ignore
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
state_dict = load_from_safetensors(tensors_path)
self.load_state_dict(state_dict, strict=strict)
return self
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
return super().named_modules(*args) # type: ignore
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
return super().to(device=device, dtype=dtype) # type: ignore
class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule
_parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain
# Contains simple attributes set on the instance by `__init__` in subclasses
# and copied by `structural_copy`. Note that is not the case of `device` since
# Chain's __init__ takes care of it.
structural_attrs: list[str] = []
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs)
self._parent = []
@property
def parent(self) -> "Chain | None":
return self._parent[0] if self._parent else None
@property
def ensure_parent(self) -> "Chain":
assert self._parent, "module is not bound to a Chain"
return self._parent[0]
def _set_parent(self, parent: "Chain | None") -> None:
if parent is None:
self._parent = []
return
# Always insert the module in the Chain first to avoid inconsistencies.
assert self in iter(parent), f"{self} not in {parent}"
self._parent = [parent]
@property
def provider(self) -> ContextProvider:
return self.ensure_parent.provider
def get_parents(self) -> "list[Chain]":
return self._parent + self._parent[0].get_parents() if self._parent else []
def use_context(self, context_name: str) -> Context:
"""Retrieve the context object from the module's context provider."""
context = self.provider.get_context(context_name)
assert context is not None, f"Context {context_name} not found."
return context
def structural_copy(self: TContextModule) -> TContextModule:
clone = object.__new__(self.__class__)
for k in self.__class__.structural_attrs:
setattr(clone, k, getattr(self, k))
ContextModule.__init__(clone)
return clone
class WeightedModule(Module):
@property
def device(self) -> Device:
return self.weight.device
@property
def dtype(self) -> DType:
return self.weight.dtype

View file

@ -0,0 +1,75 @@
from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType
from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm
from jaxtyping import Float
from refiners.fluxion.layers.module import WeightedModule
class LayerNorm(_LayerNorm, WeightedModule):
def __init__(
self,
normalized_shape: int | list[int],
eps: float = 0.00001,
elementwise_affine: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
)
class GroupNorm(_GroupNorm, WeightedModule):
def __init__(
self,
channels: int,
num_groups: int,
eps: float = 1e-5,
affine: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
num_groups=num_groups,
num_channels=channels,
eps=eps,
affine=affine,
device=device,
dtype=dtype,
)
self.channels = channels
self.num_groups = num_groups
self.eps = eps
self.affine = affine
class LayerNorm2d(WeightedModule):
"""
2D Layer Normalization module.
Parameters:
channels (int): Number of channels in the input tensor.
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
"""
def __init__(
self,
channels: int,
eps: float = 1e-6,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = Parameter(ones(channels, device=device, dtype=dtype))
self.bias = Parameter(zeros(channels, device=device, dtype=dtype))
self.eps = eps
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out

View file

@ -0,0 +1,100 @@
from refiners.fluxion.layers.chain import Chain, UseContext, SetContext
from refiners.fluxion.layers.conv import Conv2d
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.chain import Parallel, Lambda
from refiners.fluxion.layers.module import Module
from refiners.fluxion.utils import interpolate
from torch.nn.functional import pad
from torch import Tensor, Size, device as Device, dtype as DType
class Downsample(Chain):
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
def __init__(
self,
channels: int,
scale_factor: int,
padding: int = 0,
register_shape: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Downsamples the input by the given scale factor.
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
sampling is not set or if the context does not contain a list.
"""
self.channels = channels
self.in_channels = channels
self.out_channels = channels
self.scale_factor = scale_factor
self.padding = padding
super().__init__(
Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=scale_factor,
padding=padding,
device=device,
dtype=dtype,
),
)
if padding == 0:
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
shapes.append(x.shape[2:])
class Interpolate(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, shape: Size) -> Tensor:
return interpolate(x, shape)
class Upsample(Chain):
structural_attrs = ["channels", "upsample_factor"]
def __init__(
self,
channels: int,
upsample_factor: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Upsamples the input by the given scale factor.
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
"""
self.channels = channels
self.upsample_factor = upsample_factor
super().__init__(
Parallel(
Identity(),
(
Lambda(self._get_static_shape)
if upsample_factor is not None
else UseContext(context="sampling", key="shapes").compose(lambda x: x.pop())
),
),
Interpolate(),
Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
)
def _get_static_shape(self, x: Tensor) -> Size:
assert self.upsample_factor is not None
return Size([size * self.upsample_factor for size in x.shape[2:]])

View file

@ -0,0 +1,262 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
from PIL import Image
from numpy import array, float32
from pathlib import Path
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType
from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING:
from refiners.fluxion.layers.module import Module
T = TypeVar("T")
E = TypeVar("E")
def norm(x: Tensor) -> Tensor:
return _norm(x) # type: ignore
def manual_seed(seed: int) -> None:
_manual_seed(seed)
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor:
return _pad(input=x, pad=pad, value=value) # type: ignore
def interpolate(x: Tensor, factor: float | Size, mode: str = "nearest") -> Tensor:
return (
_interpolate(x, scale_factor=factor, mode=mode)
if isinstance(factor, float | int)
else _interpolate(x, size=factor, mode=mode)
) # type: ignore
def bidirectional_mapping(mapping: Dict[str, str]) -> Dict[str, str]:
return {**mapping, **{value: key for key, value in mapping.items()}}
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(0)
def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
def safe_open(
path: Path | str,
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
device: Device | str = "cpu",
) -> dict[str, Tensor]:
framework_mapping = {
"pytorch": "pt",
"tensorflow": "tf",
"flax": "flax",
"numpy": "numpy",
}
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
with safe_open(path=path, framework="pytorch") as tensors: # type: ignore
return tensors.metadata() # type: ignore
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore
BASIC_LAYERS: list[str] = [
"Conv1d",
"Conv2d",
"Conv3d",
"Linear",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"LayerNorm",
"GroupNorm",
"Embedding",
"MaxPool2d",
"AvgPool2d",
"AdaptiveAvgPool2d",
]
ModelTypeShape = tuple[str, tuple[Size, ...]]
def is_basic_layer(module: "Module") -> bool:
return module.__class__.__name__ in BASIC_LAYERS
def get_module_signature(module: "Module") -> ModelTypeShape:
param_shapes = [p.shape for p in module.parameters()]
return (module.__class__.__name__, tuple(param_shapes))
def forward_order_of_execution(
module: "Module",
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> dict[ModelTypeShape, list[str]]:
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: "Module", *_: Any):
layer_signature = get_module_signature(layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook)
hooks.append(hook)
with no_grad():
module(*example_args)
for hook in hooks:
hook.remove()
return dict(execution_order)
def print_side_by_side(
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
):
print(f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
source_key = source_keys[i] if i < len(source_keys) else "---"
target_key = target_keys[i] if i < len(target_keys) else "---"
print(f"\t{source_key}\t{target_key}")
def verify_shape_match(
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
for model_type_shape in model_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_missmatched = True
print_side_by_side(model_type_shape, source_keys, target_keys)
return not shape_missmatched
def create_state_dict_mapping(
source_model: "Module",
target_model: "Module",
source_args: tuple[Any, ...],
target_args: tuple[Any, ...] | None = None,
source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None,
) -> dict[str, str] | None:
if target_args is None:
target_args = source_args
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper)
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper)
if not verify_shape_match(source_order, target_order):
return None
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
def convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(".", 1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
return converted_state_dict
def forward_store_outputs(
module: "Module",
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> list[tuple[str, Tensor]]:
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor):
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook)
hooks.append(hook)
with no_grad():
module(*example_args)
for hook in hooks:
hook.remove()
return execution_order
def compare_models(
source_model: "Module",
target_model: "Module",
source_args: tuple[Any, ...],
target_args: tuple[Any, ...] | None = None,
source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None,
threshold: float = 1e-5,
) -> bool:
if target_args is None:
target_args = source_args
source_order = forward_store_outputs(source_model, source_args, source_key_skipper)
target_order = forward_store_outputs(target_model, target_args, target_key_skipper)
prev_source_key, prev_target_key = None, None
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
diff = norm(source_output - target_output).item()
if diff > threshold:
print(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
f" {target_key}, difference in norm: {diff}"
)
return False
prev_source_key, prev_target_key = source_key, target_key
return True

View file

View file

@ -0,0 +1,250 @@
from torch import Tensor, arange, device as Device, dtype as DType
from refiners.fluxion.layers import (
ApproximateGeLU,
GeLU,
Linear,
LayerNorm,
Embedding,
Chain,
Sum,
SelfAttention,
Lambda,
Residual,
)
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class PositionalTokenEncoder(Sum):
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
def __init__(
self,
vocabulary_size: int,
embedding_dim: int,
positional_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim
super().__init__(
Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
Chain(
Lambda(self.get_position_ids),
Embedding(
num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
@property
def position_ids(self) -> Tensor:
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
class FeedForward(Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
GeLU(),
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
class TransformerLayer(Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
num_attention_heads: int = 1,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
super().__init__(
Residual(
LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
dtype=dtype,
),
SelfAttention(
embedding_dim=embedding_dim,
num_heads=num_attention_heads,
is_causal=True,
device=device,
dtype=dtype,
),
),
Residual(
LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
dtype=dtype,
),
FeedForward(
embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
device=device,
dtype=dtype,
),
),
)
class CLIPTextEncoder(Chain):
structural_attrs = [
"embedding_dim",
"positional_embedding_dim",
"vocabulary_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
"layer_norm_eps",
"tokenizer",
]
def __init__(
self,
embedding_dim: int = 768,
positional_embedding_dim: int = 77,
vocabulary_size: int = 49408,
num_layers: int = 12,
num_attention_heads: int = 12,
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.embedding_dim = embedding_dim
self.positional_embedding_dim = positional_embedding_dim
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.tokenizer = CLIPTokenizer()
super().__init__(
PositionalTokenEncoder(
vocabulary_size=vocabulary_size,
embedding_dim=embedding_dim,
positional_embedding_dim=positional_embedding_dim,
device=device,
dtype=dtype,
),
*(
TransformerLayer(
embedding_dim=embedding_dim,
num_attention_heads=num_attention_heads,
feedforward_dim=feedforward_dim,
layer_norm_eps=layer_norm_eps,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
)
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device)
return self(tokens)
@property
def unconditional_text_embedding(self) -> Tensor:
return self.encode("")
class CLIPTextEncoderL(CLIPTextEncoder):
"""
CLIPTextEncoderL is the CLIP text encoder with the following parameters:
embedding_dim=768
num_layers=12
num_attention_heads=12
feedforward_dim=3072
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=768,
num_layers=12,
num_attention_heads=12,
feedforward_dim=3072,
device=device,
dtype=dtype,
)
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
class CLIPTextEncoderH(CLIPTextEncoder):
"""
CLIPTextEncoderH is the CLIP text encoder with the following parameters:
embedding_dim=1024
num_layers=23
num_attention_heads=16
feedforward_dim=4096
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1024,
num_layers=23,
num_attention_heads=16,
feedforward_dim=4096,
device=device,
dtype=dtype,
)
class CLIPTextEncoderG(CLIPTextEncoder):
"""
CLIPTextEncoderG is the CLIP text encoder with the following parameters:
embedding_dim=1280
num_layers=32
num_attention_heads=16
feedforward_dim=5120
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1280,
num_layers=32,
num_attention_heads=20,
feedforward_dim=5120,
device=device,
dtype=dtype,
)

View file

@ -0,0 +1,108 @@
import gzip
from pathlib import Path
from functools import lru_cache
from itertools import islice
import re
from torch import Tensor, tensor
from refiners.fluxion import pad
class CLIPTokenizer:
def __init__(
self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
):
self.vocabulary_path = vocabulary_path
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [
tuple(merge.split())
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1]
]
vocabulary = (
list(self.byte_to_unicode_mapping.values())
+ [v + "</w>" for v in self.byte_to_unicode_mapping.values()]
+ ["".join(merge) for merge in merge_tuples]
+ ["", ""]
)
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)}
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)}
self.byte_pair_encoding_cache = {"": ""}
# Note: this regular expression does not support Unicode. It was changed so
# to get rid of the dependence on the `regex` module. Unicode support could
# potentially be added back by leveraging the `\w` character class.
self.token_pattern = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
re.IGNORECASE,
)
self.start_of_text_token_id: int = 49406
self.end_of_text_token_id: int = 49407
def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0)
assert (
tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
@lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
initial_byte_values = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
extra_unicode_values = (byte for byte in range(2**8) if byte not in initial_byte_values)
byte_values = initial_byte_values + list(extra_unicode_values)
unicode_values = [chr(value) for value in byte_values]
return dict(zip(byte_values, unicode_values))
def byte_pair_encoding(self, token: str) -> str:
if token in self.byte_pair_encoding_cache:
return self.byte_pair_encoding_cache[token]
def recursive_bpe(word: tuple[str, ...]) -> tuple[str, ...]:
if len(word) < 2:
return word
pairs = {(i, (word[i], word[i + 1])) for i in range(len(word) - 1)}
min_pair = min(
pairs,
key=lambda pair: self.byte_pair_encoding_ranks.get(pair[1], float("inf")),
)
if min_pair[1] not in self.byte_pair_encoding_ranks:
return word
new_word: list[str] = []
i = 0
while i < len(word):
if i == min_pair[0]:
new_word.append(min_pair[1][0] + min_pair[1][1])
i += 2
else:
new_word.append(word[i])
i += 1
return recursive_bpe(tuple(new_word))
word = tuple(token[:-1]) + (token[-1] + "</w>",)
result = " ".join(recursive_bpe(word))
self.byte_pair_encoding_cache[token] = result
return result
def encode(self, text: str, max_length: int | None = None) -> Tensor:
text = re.sub(r"\s+", " ", text.lower())
tokens = re.findall(self.token_pattern, text)
upper_bound = None
if max_length:
assert max_length >= 2
upper_bound = max_length - 2
encoded_tokens = islice(
(
self.token_to_id_mapping[subtoken]
for token in tokens
for subtoken in self.byte_pair_encoding(
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
).split(" ")
),
0,
upper_bound,
)
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])

View file

@ -0,0 +1,201 @@
from typing import TypeVar
from torch import cat, float32, randn, tensor, device as Device, dtype as DType, Size, Tensor
from PIL import Image
import numpy as np
from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.fluxion.layers.module import Module
from refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder,
)
from refiners.foundationals.clip.text_encoder import (
CLIPTextEncoder,
CLIPTextEncoderL,
)
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
from refiners.foundationals.latent_diffusion.unet import UNet
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
__all__ = [
"LatentDiffusionModel",
"UNet",
"DPMSolver",
"Scheduler",
"CLIPTextEncoder",
"LatentDiffusionAutoencoder",
]
class LatentDiffusionModel(Module):
def __init__(
self,
unet: UNet,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: CLIPTextEncoder,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = float32,
):
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device)
self.dtype = dtype
self.unet = unet.to(self.device, dtype=self.dtype)
self.lda = lda.to(self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(self.device, dtype=self.dtype)
self.scheduler = scheduler.to(self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int):
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
).to(device=device, dtype=dtype)
def init_latents(
self,
size: tuple[int, int],
init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None,
) -> Tensor:
if noise is None:
height, width = size
noise = randn(1, 4, height // 8, width // 8, device=self.device)
assert list(noise.shape[2:]) == [
size[0] // 8,
size[1] // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(init_image.resize(size))
return self.scheduler.add_noise(encoded_image, noise, self.steps[first_step])
@property
def steps(self) -> list[int]:
return self.scheduler.steps
@property
def timestep_embeddings(self) -> Tensor:
return self.timestep_encoder(self.scheduler.timesteps)
@property
def unconditional_clip_text_embeddings(self) -> Tensor:
return self.clip_text_encoder.unconditional_text_embedding
def compute_text_embedding(self, text: str) -> Tensor:
return self.clip_text_encoder.encode(text)
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
negative_clip_text_embedding: Tensor | None = None,
condition_scale: float = 7.5,
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(0)
self.unet.set_timestep(timestep)
negative_clip_text_embedding = (
self.clip_text_encoder.unconditional_text_embedding
if negative_clip_text_embedding is None
else negative_clip_text_embedding
)
clip_text_embeddings = cat((negative_clip_text_embedding, clip_text_embedding))
self.unet.set_clip_text_embedding(clip_text_embeddings)
latents = cat((x, x)) # for classifier-free guidance
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
device=self.device,
dtype=self.dtype,
)
class StableDiffusion_1(LatentDiffusionModel):
def __init__(
self,
unet: UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
):
unet = unet or UNet(in_channels=4, clip_embedding_dim=768)
lda = lda or LatentDiffusionAutoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
super().__init__(
unet,
lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__(
self,
unet: UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(unet, lda, clip_text_encoder, scheduler, device, dtype)
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
negative_clip_text_embedding: Tensor | None = None,
condition_scale: float = 7.5,
):
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = cat((x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(x, step, clip_text_embedding, negative_clip_text_embedding, condition_scale)
def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert("RGB")
mask = mask.convert("L")
mask_tensor = tensor(np.array(mask).astype(np.float32) / 255.0).to(self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype)
self.mask_latents = interpolate(mask_tensor, Size(latents_size))
init_image_tensor = image_to_tensor(target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(masked_init_image)
return self.mask_latents, self.target_image_latents

View file

@ -0,0 +1,230 @@
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import (
Chain,
Conv2d,
GroupNorm,
Identity,
SiLU,
Downsample,
Upsample,
Sum,
SelfAttention2d,
Slicing,
)
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from torch import Tensor, device as Device, dtype as DType
from PIL import Image
class Resnet(Sum):
structural_attrs = ["in_channels", "out_channels"]
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
self.out_channels = out_channels
shortcut = (
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else Identity()
)
super().__init__(
shortcut,
Chain(
GroupNorm(channels=in_channels, num_groups=num_groups, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
GroupNorm(channels=out_channels, num_groups=num_groups, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
)
class Encoder(Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
resnet_sizes: list[int] = [128, 256, 512, 512, 512]
input_channels: int = 3
latent_dim: int = 8
resnet_layers: list[Chain] = [
Chain(
[
Resnet(
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
out_channels=resnet_sizes[i],
device=device,
dtype=dtype,
),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
for i in range(len(resnet_sizes))
]
for _, layer in zip(range(3), resnet_layers):
channels: int = layer[-1].out_channels # type: ignore
layer.append(Downsample(channels=channels, scale_factor=2, device=device, dtype=dtype))
attention_layer = Sum(
Identity(),
Chain(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype),
),
)
resnet_layers[-1].insert_after_type(Resnet, attention_layer)
super().__init__(
Conv2d(
in_channels=input_channels,
out_channels=resnet_sizes[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
Chain(*resnet_layers),
Chain(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=resnet_sizes[-1],
out_channels=latent_dim,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
Chain(
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
Slicing(dim=1, start=0, length=4),
),
)
def init_context(self) -> Contexts:
return {"sampling": {"shapes": []}}
class Decoder(Chain):
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
self.latent_dim: int = 4
self.output_channels: int = 3
resnet_sizes = self.resnet_sizes[::-1]
resnet_layers: list[Chain] = [
(
Chain(
[
Resnet(
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
out_channels=resnet_sizes[i],
device=device,
dtype=dtype,
),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
if i > 0
else Chain(
[
Resnet(in_channels=resnet_sizes[0], out_channels=resnet_sizes[i], device=device, dtype=dtype),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
)
for i in range(len(resnet_sizes))
]
attention_layer = Sum(
Identity(),
Chain(
GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype),
),
)
resnet_layers[0].insert(1, attention_layer)
for _, layer in zip(range(3), resnet_layers[1:]):
channels: int = layer[-1].out_channels
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
super().__init__(
Conv2d(
in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=1, device=device, dtype=dtype
),
Conv2d(
in_channels=self.latent_dim,
out_channels=resnet_sizes[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
Chain(*resnet_layers),
Chain(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=resnet_sizes[-1],
out_channels=self.output_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
)
class LatentDiffusionAutoencoder(Chain):
structural_attrs = ["encoder_scale"]
def __init__(
self,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.encoder_scale: float = 0.18215
super().__init__(
Encoder(device=device, dtype=dtype),
Decoder(device=device, dtype=dtype),
)
def encode(self, x: Tensor) -> Tensor:
encoder = self[0]
x = self.encoder_scale * encoder(x)
return x
def decode(self, x: Tensor) -> Tensor:
decoder = self[1]
x = decoder(x / self.encoder_scale)
return x
def encode_image(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)
def decode_latents(self, x: Tensor) -> Image.Image:
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_image(x)

View file

@ -0,0 +1,150 @@
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
from refiners.foundationals.latent_diffusion.unet import DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder
from refiners.adapters.range_adapter import RangeAdapter2d
from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType
class ConditionEncoder(Chain):
"""Encode an image to be used as a condition for Controlnet.
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
"""
structural_attrs = ["out_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.out_channels = (16, 32, 96, 256)
super().__init__(
Chain(
Conv2d(
in_channels=3,
out_channels=self.out_channels[0],
kernel_size=3,
stride=1,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
),
*(
Chain(
Conv2d(
in_channels=self.out_channels[i],
out_channels=self.out_channels[i],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
Conv2d(
in_channels=self.out_channels[i],
out_channels=self.out_channels[i + 1],
kernel_size=3,
stride=2,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
)
for i in range(len(self.out_channels) - 1)
),
Conv2d(
in_channels=self.out_channels[-1],
out_channels=320,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
)
class Controlnet(Passthrough):
structural_attrs = ["name", "scale"]
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
stored in the context.
It has to use the same context as the UNet: `unet` and `sampling`.
"""
self.name = name
self.scale: float = 1.0
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
)
# We run the condition encoder at each step. Caching the result
# is not worth it as subsequent runs take virtually no time (FG-374).
self.DownBlocks[0].append(
Sum(
Identity(),
Chain(
UseContext("controlnet", f"condition_{name}"),
ConditionEncoder(device=device, dtype=dtype),
),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key=f"timestep_embedding_{self.name}",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
assert hasattr(block[0], "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
f" {block[0]} does not."
)
out_channels: int = block[0].out_channels
block.append(
Passthrough(
Conv2d(
in_channels=out_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype
),
Lambda(self._store_nth_residual(n)),
)
)
self.MiddleBlock.append(
Passthrough(
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
Lambda(self._store_nth_residual(12)),
)
)
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"sampling": {"shapes": []},
"controlnet": {f"condition_{self.name}": None},
"range_adapter": {f"timestep_embedding_{self.name}": None},
}
def _store_nth_residual(self, n: int):
def _store_residual(x: Tensor):
residuals = self.use_context("unet")["residuals"]
residuals[n] = residuals[n] + x * self.scale
return x
return _store_residual
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition})
def set_scale(self, scale: float) -> None:
self.scale = scale

View file

@ -0,0 +1,203 @@
from torch import Tensor, Size, device as Device, dtype as DType
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import (
Identity,
Flatten,
Unflatten,
Transpose,
Chain,
Parallel,
LayerNorm,
Attention,
Sum,
UseContext,
Linear,
GLU,
GeLU,
GroupNorm,
Conv2d,
SelfAttention,
SetContext,
)
class CrossAttentionBlock(Chain):
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
def __init__(
self,
embedding_dim: int,
context_embedding_dim: int,
context_key: str,
num_heads: int = 1,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.context_embedding_dim = context_embedding_dim
self.context = "cross_attention_block"
self.context_key = context_key
self.num_heads = num_heads
self.use_bias = use_bias
super().__init__(
Sum(
Identity(),
Chain(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
SelfAttention(
embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype
),
),
),
Sum(
Identity(),
Chain(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
Parallel(
Identity(),
UseContext(context=self.context, key=context_key),
UseContext(context=self.context, key=context_key),
),
Attention(
embedding_dim=embedding_dim,
num_heads=num_heads,
key_embedding_dim=context_embedding_dim,
value_embedding_dim=context_embedding_dim,
use_bias=use_bias,
device=device,
dtype=dtype,
),
),
),
Sum(
Identity(),
Chain(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype),
GLU(GeLU()),
Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
),
),
)
class StatefulFlatten(Chain):
structural_attrs = ["start_dim", "end_dim"]
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
self.start_dim = start_dim
self.end_dim = end_dim
super().__init__(
SetContext(context=context, key=key, callback=self.push),
Flatten(start_dim=start_dim, end_dim=end_dim),
)
def push(self, sizes: list[Size], x: Tensor) -> None:
sizes.append(
x.shape[slice(self.start_dim, self.end_dim + 1 if self.end_dim >= 0 else x.ndim + self.end_dim + 1)]
)
class CrossAttentionBlock2d(Sum):
structural_attrs = [
"channels",
"in_channels",
"out_channels",
"context_embedding_dim",
"num_attention_heads",
"num_attention_layers",
"num_groups",
"context_key",
"use_linear_projection",
"projection_type",
]
def __init__(
self,
channels: int,
context_embedding_dim: int,
context_key: str,
num_attention_heads: int = 1,
num_attention_layers: int = 1,
num_groups: int = 32,
use_bias: bool = True,
use_linear_projection: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert channels % num_attention_heads == 0, "in_channels must be divisible by num_attention_heads"
self.channels = channels
self.in_channels = channels
self.out_channels = channels
self.context_embedding_dim = context_embedding_dim
self.num_attention_heads = num_attention_heads
self.num_attention_layers = num_attention_layers
self.num_groups = num_groups
self.context_key = context_key
self.use_linear_projection = use_linear_projection
self.projection_type = "Linear" if use_linear_projection else "Conv2d"
in_block = (
Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
)
if use_linear_projection
else Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),
)
)
out_block = (
Chain(
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
Transpose(1, 2),
Parallel(
Identity(),
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
),
Unflatten(dim=2),
)
if use_linear_projection
else Chain(
Transpose(1, 2),
Parallel(
Identity(),
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
),
Unflatten(dim=2),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
)
)
super().__init__(
Identity(),
Chain(
in_block,
Chain(
CrossAttentionBlock(
embedding_dim=channels,
context_embedding_dim=context_embedding_dim,
context_key=context_key,
num_heads=num_attention_heads,
use_bias=use_bias,
device=device,
dtype=dtype,
)
for _ in range(num_attention_layers)
),
out_block,
),
)
def init_context(self) -> Contexts:
return {"flatten": {"sizes": []}}

View file

@ -0,0 +1,101 @@
from enum import Enum
from pathlib import Path
from torch import Tensor, device as Device
from torch.nn import Parameter as TorchParameter
from refiners.adapters.lora import LoraAdapter, load_lora_weights
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
class LoraTarget(str, Enum):
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def get_lora_rank(weights: list[Tensor]) -> int:
ranks: set[int] = {w.shape[1] for w in weights[0::2]}
assert len(ranks) == 1
return ranks.pop()
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
for layer in module.layers(layer_type=target.get_class()):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(
target=linear,
rank=rank,
scale=scale,
device=module.device,
dtype=module.dtype,
)
adapter.inject(parent)
class LoraWeights:
"""A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model."""
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
def __init__(self, checkpoint_path: Path | str, device: Device | str):
self.metadata = load_metadata_from_safetensors(checkpoint_path)
self.tensors = load_from_safetensors(checkpoint_path, device=device)
def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None:
assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata"
for meta_key, meta_value in self.metadata.items():
match meta_key:
case "unet_targets":
# TODO: support this transparently
if any([isinstance(module, Controlnet) for module in sd.unet]):
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
model = sd.unet
key_prefix = "unet."
case "text_encoder_targets":
model = sd.clip_text_encoder
key_prefix = "text_encoder."
case "lda_targets":
model = sd.lda
key_prefix = "lda."
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
# TODO(FG-487): support loading multiple LoRA-s
if any(model.layers(LoraAdapter)):
raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers")
lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]]
assert len(lora_weights) % 2 == 0
rank = get_lora_rank(lora_weights)
for target in meta_value.split(","):
apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale)
assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2)
load_lora_weights(model, [TorchParameter(w) for w in lora_weights])

View file

@ -0,0 +1,11 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]

View file

@ -0,0 +1,41 @@
from torch import Tensor, device as Device, arange, sqrt
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
class DDIM(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
) -> None:
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
self.timesteps = self._generate_timesteps()
def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps,
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
self.previous_scale_factor = previous_scale_factor
return denoised_x

View file

@ -0,0 +1,75 @@
from torch import Tensor, device as Device, randn, arange, Generator, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
class DDPM(Scheduler):
"""
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model,
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
"""
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
) -> None:
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Generate the next step in the diffusion process.
This method adjusts the input data using added noise and an estimate of the denoised data, based on the current
step in the diffusion process. This adjusted data forms the next step in the diffusion process.
1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original
data and noise to the new step.
2. An estimate of the denoised data (`estimated_denoised_data`) is generated.
3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and
`current_data_coeff`) that balance their contribution to the denoised data for the next step.
4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated
denoised data and current data, adjusted by their respective coefficients.
5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on
the cumulative scaling factor and the current factor.
The output is the new data step for the next stage in the diffusion process.
"""
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < len(self.timesteps) - 1
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
),
)
current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], (
(self.scale_factors.cumprod(0))[previous_timestep]
if step < len(self.timesteps) - 1
else tensor(1, device=self.device)
)
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = (
x - (1 - current_cumulative_factor) ** 0.5 * noise
) / current_cumulative_factor**0.5
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
1 - current_cumulative_factor
)
current_data_coeff = (
current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor)
)
denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x
if step < len(self.timesteps) - 1:
variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor)
denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn(
x.shape, device=x.device, dtype=x.dtype, generator=generator
)
return denoised_x

View file

@ -0,0 +1,111 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp
from collections import deque
class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now.
"""
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
):
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
device=device,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
)
previous_ratio, current_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
exp_factor = exp(-(previous_ratio - current_ratio))
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
exp_neg_factor = exp(-(previous_ratio - current_ratio))
x_t = (
(previous_std / current_std) * x
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta
)
return x_t
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs).
"""
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x

View file

@ -0,0 +1,95 @@
from abc import abstractmethod
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar
T = TypeVar("T", bound="Scheduler")
class Scheduler:
"""
A base class for creating a diffusion model scheduler.
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates,
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
"""
timesteps: Tensor
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.device: Device = Device(device)
self.dtype: DType = dtype
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.scale_factors = (
1.0
- linspace(
start=initial_diffusion_rate**0.5,
end=final_diffusion_rate**0.5,
steps=num_train_timesteps,
dtype=dtype,
)
** 2
)
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
This method should be overridden by subclasses to implement the specific diffusion process.
"""
...
@abstractmethod
def _generate_timesteps(self) -> Tensor:
"""
Generates a tensor of timesteps.
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
"""
...
@property
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def add_noise(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1)
noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1)
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
return self

View file

@ -0,0 +1,291 @@
from typing import cast
from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.unet import ResidualAccumulator, ResidualBlock, ResidualConcatenator
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
class TextTimeEmbedding(fl.Chain):
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
self.time_ids_embedding_dim = 256
self.text_time_embedding_dim = 2816
super().__init__(
fl.Concatenate(
fl.UseContext(context="diffusion", key="pooled_text_embedding"),
fl.Chain(
fl.UseContext(context="diffusion", key="time_ids"),
fl.Unsqueeze(dim=-1),
fl.Lambda(func=self.compute_sinuosoidal_embedding),
fl.Reshape(-1),
),
dim=1,
),
fl.Linear(
in_features=self.text_time_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
fl.SiLU(),
fl.Linear(
in_features=self.timestep_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
)
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim).to(dtype=self.dtype)
class TimestepEncoder(fl.Passthrough):
structural_attrs = ["timestep_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
super().__init__(
fl.Sum(
fl.Chain(
fl.UseContext(context="diffusion", key="timestep"),
RangeEncoder(
sinuosidal_embedding_dim=320,
embedding_dim=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
),
TextTimeEmbedding(device=device, dtype=dtype),
),
fl.SetContext(context="range_adapter", key="timestep_embedding"),
)
class SDXLCrossAttention(CrossAttentionBlock2d):
structural_attrs = ["channels", "num_attention_layers", "num_attention_heads"]
def __init__(
self,
channels: int,
num_attention_layers: int = 1,
num_attention_heads: int = 10,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=2048,
context_key="clip_text_embedding",
num_attention_layers=num_attention_layers,
num_attention_heads=num_attention_heads,
use_bias=False,
use_linear_projection=True,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
in_block = fl.Chain(
fl.Conv2d(in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype)
)
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
]
super().__init__(
in_block,
*first_blocks,
*second_blocks,
*third_blocks,
)
class UpBlocks(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
]
super().__init__(
*first_blocks,
*second_blocks,
*third_blocks,
)
class MiddleBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class OutputBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
fl.GroupNorm(channels=320, num_groups=32),
fl.SiLU(),
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
)
class SDXLUNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
UpBlocks(device=device, dtype=dtype),
OutputBlock(device=device, dtype=dtype),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
block.append(module=ResidualAccumulator(n=n))
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
block.insert(index=0, module=ResidualConcatenator(n=-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 10},
"diffusion": {"timestep": None, "time_ids": None, "pooled_text_embedding": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context(context="diffusion", value={"timestep": timestep})
def set_time_ids(self, time_ids: Tensor) -> None:
self.set_context(context="diffusion", value={"time_ids": time_ids})
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})

View file

@ -0,0 +1,130 @@
from refiners.fluxion.layers import (
Passthrough,
Lambda,
Chain,
Concatenate,
UseContext,
SelfAttention,
SetContext,
Identity,
Parallel,
)
from refiners.adapters.adapter import Adapter
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from torch import Tensor
class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
def __init__(self, target: SelfAttention, context: str) -> None:
self.context = context
with self.setup_adapter(target):
super().__init__(SetContext(self.context, "norm"), target)
class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
def __init__(
self,
target: SelfAttention,
context: str,
sai: "SelfAttentionInjection",
) -> None:
self.context = context
self._sai = [sai] # only to support setting `style_cfg` dynamically
sa_guided = target.structural_copy()
assert isinstance(sa_guided[0], Parallel)
sa_guided.replace(
sa_guided[0],
Parallel(
Identity(),
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
),
)
with self.setup_adapter(target):
super().__init__(
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
Lambda(self.compute_averaged_unconditioned_x),
)
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
style_cfg = self._sai[0].style_cfg
x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x
return x
class SelfAttentionInjection(Passthrough):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, unet: UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.style_cfg = style_cfg
self._adapters: list[ReferenceOnlyControlAdapter] = []
self._unet = [unet]
guide_unet = unet.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)):
unet.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self))
super().__init__(
Lambda(self.copy_diffusion_context),
UseContext("self_attention_injection", "guide"),
guide_unet,
Lambda(self.restore_diffusion_context),
)
@property
def unet(self):
return self._unet[0]
def inject(self) -> None:
assert self not in self._unet[0], f"{self} is already injected"
for adapter in self._adapters:
adapter.inject()
self.unet.insert(0, self)
def eject(self) -> None:
assert self.unet[0] == self, f"{self} is not the first element of target UNet"
for adapter in self._adapters:
adapter.eject()
self.unet.pop(0)
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
def copy_diffusion_context(self, x: Tensor) -> Tensor:
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
self.set_context(
"self_attention_residuals_buffer",
{"buffer": self.use_context("unet")["residuals"]},
)
self.set_context(
"unet",
{"residuals": [0.0] * 13},
)
return x
def restore_diffusion_context(self, x: Tensor) -> Tensor:
self.set_context(
"unet",
{
"residuals": self.use_context("self_attention_residuals_buffer")["buffer"],
},
)
return x
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")

View file

@ -0,0 +1,307 @@
from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
class TimestepEncoder(fl.Passthrough):
def __init__(
self,
context_key: str = "timestep_embedding",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext("diffusion", "timestep"),
RangeEncoder(320, 1280, device=device, dtype=dtype),
fl.SetContext("range_adapter", context_key),
)
class ResidualBlock(fl.Sum):
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class CLIPLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=768,
context_key="clip_text_embedding",
num_attention_heads=8,
use_bias=False,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(
self,
in_channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
super().__init__(
fl.Chain(
fl.Conv2d(
in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype
)
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=1280, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
)
class UpBlocks(fl.Chain):
def __init__(
self,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
)
class MiddleBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class ResidualAccumulator(fl.Passthrough):
structural_attrs = ["n"]
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
structural_attrs = ["n"]
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)
class UNet(fl.Chain):
structural_attrs = ["in_channels", "clip_embedding_dim"]
def __init__(
self,
in_channels: int,
clip_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
self.clip_embedding_dim = clip_embedding_dim
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
fl.Sum(
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
MiddleBlock(device=device, dtype=dtype),
),
UpBlocks(),
fl.Chain(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=320,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
device=device,
dtype=dtype,
),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
block.append(ResidualAccumulator(n))
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
block.insert(0, ResidualConcatenator(-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"diffusion": {"timestep": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context("diffusion", {"timestep": timestep})

0
src/refiners/py.typed Normal file
View file

View file

@ -0,0 +1,17 @@
from importlib import import_module
from importlib.metadata import requires
import sys
refiners_requires = requires("refiners")
assert refiners_requires is not None
for dep in filter(lambda r: r.endswith('extra == "training"'), refiners_requires):
try:
import_module(dep.split(" ")[0])
except ImportError:
print(
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"
" refiners[training]`",
file=sys.stderr,
)
sys.exit(1)

View file

@ -0,0 +1,186 @@
from typing import TYPE_CHECKING, Generic, Iterable, Any, TypeVar
from torch import tensor
from torch.nn import Parameter
from loguru import logger
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
__all__ = [
"Callback",
"GradientNormClipping",
"GradientValueClipping",
"ClockCallback",
"GradientNormLogging",
"MonitorLoss",
]
def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
"""
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
for gradient in gradients:
gradient.mul_(other=clip_coefficient) # type: ignore
def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None:
"""
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
"""
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to clip."
for gradient in gradients:
gradient.clamp_(min=-clip_value, max=clip_value)
T = TypeVar("T")
class Callback(Generic[T]):
def on_train_begin(self, trainer: T) -> None:
...
def on_train_end(self, trainer: T) -> None:
...
def on_epoch_begin(self, trainer: T) -> None:
...
def on_epoch_end(self, trainer: T) -> None:
...
def on_batch_begin(self, trainer: T) -> None:
...
def on_batch_end(self, trainer: T) -> None:
...
def on_backward_begin(self, trainer: T) -> None:
...
def on_backward_end(self, trainer: T) -> None:
...
def on_optimizer_step_begin(self, trainer: T) -> None:
...
def on_optimizer_step_end(self, trainer: T) -> None:
...
def on_compute_loss_begin(self, trainer: T) -> None:
...
def on_compute_loss_end(self, trainer: T) -> None:
...
def on_evaluate_begin(self, trainer: T) -> None:
...
def on_evaluate_end(self, trainer: T) -> None:
...
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
...
def on_lr_scheduler_step_end(self, trainer: T) -> None:
...
def on_checkpoint_save(self, trainer: T) -> None:
...
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.reset()
logger.info(f"""Starting training for a total of:
{trainer.clock.num_steps} steps.
{trainer.clock.num_epochs} epochs.
{trainer.clock.num_iterations} iterations.
""")
trainer.clock.start_timer()
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.stop_timer()
logger.info(f"""Training took:
{trainer.clock.time_elapsed} seconds.
{trainer.clock.iteration} iterations.
{trainer.clock.epoch} epochs.
{trainer.clock.step} steps.
""")
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Epoch {trainer.clock.epoch} started.")
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Step {trainer.clock.step} started.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info(f"Iteration {trainer.clock.iteration} ended.")
trainer.clock.iteration += 1
trainer.clock.num_minibatches_processed = 0
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info("Evaluation started.")
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
logger.info("Evaluation ended.")
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
trainer.log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = trainer.config.training.clip_grad_norm
if clip_norm is not None:
clip_gradient_norm(
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
)
class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_value = trainer.config.training.clip_grad_value
if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})

View file

@ -0,0 +1,242 @@
from logging import warn
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
from torch.optim import AdamW, SGD, Optimizer, Adam
from torch.nn import Parameter
from enum import Enum
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from pydantic import BaseModel, validator
import tomli
import refiners.fluxion.layers as fl
from prodigyopt import Prodigy # type: ignore
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
__all__ = [
"parse_number_unit_field",
"TimeUnit",
"TimeValue",
"TrainingConfig",
"OptimizerConfig",
"Optimizers",
]
class TimeUnit(Enum):
STEP = "step"
EPOCH = "epoch"
ITERATION = "iteration"
DEFAULT = "step"
class TimeValue(TypedDict):
number: int
unit: TimeUnit
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
match value:
case str(value_str):
number, unit = value_str.split(sep=":")
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
case int(number):
return {"number": number, "unit": TimeUnit.DEFAULT}
case {"number": int(number), "unit": str(unit)}:
return {"number": number, "unit": TimeUnit(value=unit.lower())}
case _:
raise ValueError(f"Unsupported value format: {value}")
class TrainingConfig(BaseModel):
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
seed: int = 0
gpu_index: int = 0
batch_size: int = 1
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
clip_grad_norm: float | None = None
clip_grad_value: float | None = None
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
evaluation_seed: int = 0
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
class Optimizers(str, Enum):
SGD = "SGD"
Adam = "Adam"
AdamW = "AdamW"
AdamW8bit = "AdamW8bit"
Lion8bit = "Lion8bit"
Prodigy = "Prodigy"
class SchedulerType(str, Enum):
STEP_LR = "StepLR"
EXPONENTIAL_LR = "ExponentialLR"
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
COSINE_ANNEALING_LR = "CosineAnnealingLR"
CONSTANT_LR = "ConstantLR" # not to be confused with PyTorch's ConstantLR
LAMBDA_LR = "LambdaLR"
ONE_CYCLE_LR = "OneCycleLR"
MULTIPLICATIVE_LR = "MultiplicativeLR"
COSINE_ANNEALING_WARM_RESTARTS = "CosineAnnealingWarmRestarts"
CYCLIC_LR = "CyclicLR"
MULTI_STEP_LR = "MultiStepLR"
DEFAULT = "ConstantLR"
class SchedulerConfig(BaseModel):
scheduler_type: SchedulerType = SchedulerType.DEFAULT
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
gamma: float = 0.1
lr_lambda: Callable[[int], float] | None = None
mode: Literal["min", "max"] = "min"
factor: float = 0.1
patience: int = 10
threshold: float = 1e-4
cooldown: int = 0
milestones: list[int] = []
base_lr: float = 1e-7
min_lr: float | list[float] = 0
max_lr: float | list[float] = 0
eta_min: float = 0
@validator("update_interval", "warmup", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
class OptimizerConfig(BaseModel):
optimizer: Optimizers
learning_rate: float = 1e-4
betas: tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
weight_decay: float = 0.0
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer:
match self.optimizer:
case Optimizers.SGD:
return SGD(
params=model_parameters,
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
case Optimizers.Adam:
return Adam(
params=model_parameters,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
case Optimizers.AdamW:
return AdamW(
params=model_parameters,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
case Optimizers.AdamW8bit:
return AdamW8bit(
params=model_parameters,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
case Optimizers.Lion8bit:
return Lion8bit(
params=model_parameters,
lr=self.learning_rate,
betas=self.betas,
weight_decay=self.weight_decay, # type: ignore
)
case Optimizers.Prodigy:
if self.learning_rate != 1.0:
warn("Prodigy learning rate is not 1.0, this might cause instability.")
return Prodigy(
lr=self.learning_rate,
params=model_parameters,
betas=self.betas,
weight_decay=self.weight_decay, # type: ignore
safeguard_warmup=True,
)
class ModelConfig(BaseModel):
checkpoint: Path | None = None
train: bool = True
learning_rate: float | None = None # TODO: Implement this
class GyroDropoutConfig(BaseModel):
total_subnetworks: int = 512
concurent_subnetworks: int = 64
iters_per_epoch: int = 512
num_features_threshold: float = 5e5
class DropoutConfig(BaseModel):
dropout_probability: float = 0.0
gyro_dropout: GyroDropoutConfig | None = None
def apply_dropout(self, model: fl.Chain) -> None:
if self.dropout_probability > 0.0:
if self.gyro_dropout is not None:
apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump())
else:
apply_dropout(module=model, probability=self.dropout_probability)
class WandbConfig(BaseModel):
mode: Literal["online", "offline", "disabled"] = "online"
project: str
entity: str = "finegrain"
name: str | None = None
tags: list[str] = []
group: str | None = None
job_type: str | None = None
notes: str | None = None
class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main"
split: str = "train"
use_verification: bool = False
class CheckpointingConfig(BaseModel):
save_folder: Path | None = None
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
@validator("save_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
T = TypeVar("T", bound="BaseConfig")
class BaseConfig(BaseModel):
script: Path # TODO not used for now, but will be used by the cli
models: dict[str, ModelConfig]
wandb: WandbConfig
training: TrainingConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
dropout: DropoutConfig
dataset: HuggingfaceDatasetConfig
checkpointing: CheckpointingConfig
@classmethod
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
with open(file=toml_path, mode="rb") as f:
config_dict = tomli.load(f)
return cls(**config_dict)

View file

@ -0,0 +1,202 @@
from typing import TYPE_CHECKING, Any, TypeVar
from torch import Tensor, randint, cat, rand
from torch.nn import Dropout as TorchDropout
import refiners.fluxion.layers as fl
from refiners.training_utils.callback import Callback
from refiners.adapters.adapter import Adapter
if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
__all__ = ["Dropout", "GyroDropout", "DropoutCallback"]
class Dropout(TorchDropout, fl.Module):
def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
super().__init__(p=probability, inplace=inplace)
class GyroDropout(fl.Module):
"""
GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training.
It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training.
This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional
dropout.
Parameters:
-----------
total_subnetworks:
The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks
that are precomputed and stored.
concurrent_subnetworks:
The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of
masks from the precomputed set is used to dropout different portions of the input.
dropout_probability: float, optional (default=0.5)
The probability that an element will be zeroed by the dropout.
iters_per_epoch:
Number of iterations per epoch, used to determine how often the masks should be updated.
num_features_threshold:
If the number of features in the input is greater than this threshold, dropout is skipped. This is because
gyro dropout mask size vram usage is proportional to the number of features in the input.
"""
def __init__(
self,
total_subnetworks: int,
concurrent_subnetworks: int,
dropout_probability: float = 0.5,
iters_per_epoch: int = 1,
num_features_threshold: float = 5e5,
) -> None:
super().__init__()
assert (
iters_per_epoch >= total_subnetworks
), "The number of iterations per epoch must be greater than the number of masks"
self.dropout_probability = dropout_probability
self.iters_per_epoch = iters_per_epoch
self.total_subnetworks = total_subnetworks
self.concurrent_subnetworks = concurrent_subnetworks
self.scale = 1 / (1 - self.dropout_probability)
self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks
self.preselected_masks: Tensor | None = None
self.dropout_mask = None
self.training_step = 0
self.num_features_threshold = num_features_threshold
self.skip_high_num_features = False
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
if self.skip_high_num_features:
return self.basic_dropout(x)
if self.training_step == 0:
num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1]
if num_features > self.num_features_threshold:
self.skip_high_num_features = True
self.basic_dropout = Dropout(probability=self.dropout_probability)
return self.basic_dropout(x)
self.init_masks(x=x)
if self.training_step % self.mask_update_interval == 0:
self.update_dropout_mask(x=x)
self.training_step += 1
return x * self.dropout_mask * self.scale
def init_masks(self, x: Tensor) -> None:
if x.dim() == 2:
self.preselected_masks = (
rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability
)
if x.dim() == 3:
self.preselected_masks = (
rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability
)
assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions"
self.preselected_masks = self.preselected_masks.float()
def update_dropout_mask(self, x: Tensor) -> None:
assert self.preselected_masks is not None
indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device)
selected_masks = self.preselected_masks[indices]
repeat_factor = x.shape[0] // self.concurrent_subnetworks
remaining = x.shape[0] % self.concurrent_subnetworks
repeated_masks = [selected_masks] * repeat_factor
if remaining > 0:
repeated_masks.append(selected_masks[:remaining])
final_masks = cat(tensors=repeated_masks, dim=0)
if x.dim() == 2:
self.dropout_mask = final_masks
if x.dim() == 3:
self.dropout_mask = final_masks.expand_as(x)
class DropoutAdapter(fl.Chain, Adapter[fl.Linear]):
def __init__(self, target: fl.Linear, probability: float = 0.5):
with self.setup_adapter(target):
super().__init__(target, Dropout(probability=probability))
class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]):
def __init__(
self,
target: fl.Linear,
probability: float = 0.5,
total_subnetworks: int = 512,
concurrent_subnetworks: int = 64,
iters_per_epoch: int = 512,
num_features_threshold: float = 5e5,
) -> None:
self.probability = probability
self.total_subnetworks = total_subnetworks
self.concurrent_subnetworks = concurrent_subnetworks
self.iters_per_epoch = iters_per_epoch
with self.setup_adapter(target):
super().__init__(
target,
GyroDropout(
total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks,
dropout_probability=probability,
iters_per_epoch=iters_per_epoch,
num_features_threshold=num_features_threshold,
),
)
def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
for linear, parent in module.walk(fl.Linear):
if not linear.weight.requires_grad:
continue
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
adapter = DropoutAdapter(target=linear, probability=probability)
adapter.inject(parent)
def apply_gyro_dropout(
module: fl.Chain,
probability: float = 0.5,
total_subnetworks: int = 32,
concurrent_subnetworks: int = 16,
iters_per_epoch: int = 32,
) -> None:
for linear, parent in module.walk(fl.Linear):
if not linear.weight.requires_grad:
continue
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
adapter = GyroDropoutAdapter(
target=linear,
probability=probability,
total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks,
iters_per_epoch=iters_per_epoch,
)
adapter.inject(parent)
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
class DropoutCallback(Callback["Trainer[ConfigType, Any]"]):
def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None:
dropout_config = trainer.config.dropout
chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)]
for model in chain_models:
dropout_config.apply_dropout(model=model)

View file

@ -0,0 +1,23 @@
from datasets import load_dataset as _load_dataset, VerificationMode # type: ignore
from typing import Any, Generic, Protocol, TypeVar, cast
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
T = TypeVar("T", covariant=True)
class HuggingfaceDataset(Generic[T], Protocol):
def __getitem__(self, index: int) -> T:
...
def __len__(self) -> int:
...
def load_hf_dataset(
path: str, revision: str = "main", split: str = "train", use_verification: bool = False
) -> HuggingfaceDataset[Any]:
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
return cast(HuggingfaceDataset[Any], dataset)

View file

@ -0,0 +1,238 @@
from dataclasses import dataclass
from typing import Any, TypeVar, TypedDict, cast
from pydantic import BaseModel
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
from loguru import logger
from torch.utils.data import Dataset
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from torchvision.transforms import RandomCrop # type: ignore
import refiners.fluxion.layers as fl
from PIL import Image
from functools import cached_property
from refiners.training_utils.config import BaseConfig
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver
from torch.nn.functional import mse_loss
import random
from refiners.training_utils.wandb import WandbLoggable
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
class LatentDiffusionConfig(BaseModel):
unconditional_sampling_probability: float = 0.2
offset_noise: float = 0.1
min_timestep: int = 0
max_timestep: int = 999
class TestDiffusionConfig(BaseModel):
seed: int = 0
num_inference_steps: int = 30
use_short_prompts: bool = False
prompts: list[str] = []
num_images_per_prompt: int = 1
class FinetuneLatentDiffusionConfig(BaseConfig):
latent_diffusion: LatentDiffusionConfig
test_diffusion: TestDiffusionConfig
@dataclass
class TextEmbeddingLatentsBatch:
text_embeddings: Tensor
latents: Tensor
class CaptionImage(TypedDict):
caption: str
image: Image.Image
ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig)
class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
self.trainer = trainer
self.config = trainer.config
self.device = self.trainer.device
self.lda = self.trainer.lda
self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset()
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
dataset_config = self.config.dataset
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
return cast(
HuggingfaceDataset[CaptionImage],
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
)
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
return resize_image(image=image, min_size=min_size, max_size=max_size)
def process_caption(self, caption: str) -> str:
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
item = self.dataset[index]
caption, image = item["caption"], item["image"]
resized_image = self.resize_image(image=image)
processed_image: Image.Image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
text_embeddings = cat(tensors=[item.text_embeddings for item in batch])
latents = cat(tensors=[item.latents for item in batch])
return TextEmbeddingLatentsBatch(text_embeddings=text_embeddings, latents=latents)
def __len__(self) -> int:
return len(self.dataset)
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
return UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
@cached_property
def text_encoder(self) -> CLIPTextEncoderL:
assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry."
return CLIPTextEncoderL(device=self.device).to(device=self.device)
@cached_property
def lda(self) -> LatentDiffusionAutoencoder:
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
return LatentDiffusionAutoencoder(device=self.device).to(device=self.device)
def load_models(self) -> dict[str, fl.Module]:
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TextEmbeddingLatentsDataset(trainer=self)
@cached_property
def ddpm_scheduler(self) -> DDPM:
return DDPM(
num_inference_steps=1000,
device=self.device,
).to(device=self.device)
def sample_timestep(self) -> Tensor:
random_step = random.randint(
a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep
)
self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
def sample_noise(self, size: tuple[int, int, int, int], dtype: DType | None = None) -> Tensor:
return sample_noise(
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
)
def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
prediction = self.unet(noisy_latents)
loss = mse_loss(input=prediction, target=noise)
return loss
def compute_evaluation(self) -> None:
sd = StableDiffusion_1(
unet=self.unet,
lda=self.lda,
clip_text_encoder=self.text_encoder,
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
device=self.device,
)
prompts = self.config.test_diffusion.prompts
num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt
if self.config.test_diffusion.use_short_prompts:
prompts = [prompt.split(sep=",")[0] for prompt in prompts]
images: dict[str, WandbLoggable] = {}
for prompt in prompts:
canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt))
for i in range(num_images_per_prompt):
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
x = randn(1, 4, 64, 64, device=self.device)
clip_text_embedding = sd.compute_text_embedding(text=prompt).to(device=self.device)
negative_clip_text_embedding = sd.compute_text_embedding(text="").to(device=self.device)
for step in sd.steps:
x = sd(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)
def sample_noise(
size: tuple[int, int, int, int],
offset_noise: float = 0.1,
device: Device | str = "cpu",
dtype: DType | None = None,
generator: Generator | None = None,
) -> Tensor:
"""Sample noise from a normal distribution.
If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate
images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise.
"""
device = Device(device)
noise = randn(*size, generator=generator, device=device, dtype=dtype)
return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype)
def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
image_min_size = min(image.size)
if image_min_size > max_size:
if image_min_size == image.size[0]:
image = image.resize(size=(max_size, int(max_size * image.size[1] / image.size[0])))
else:
image = image.resize(size=(int(max_size * image.size[0] / image.size[1]), max_size))
if image_min_size < min_size:
if image_min_size == image.size[0]:
image = image.resize(size=(min_size, int(min_size * image.size[1] / image.size[0])))
else:
image = image.resize(size=(int(min_size * image.size[0] / image.size[1]), min_size))
return image
class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
def on_train_begin(self, trainer: LatentDiffusionTrainer[Any]) -> None:
self.timestep_bins: dict[int, list[float]] = {i: [] for i in range(10)}
def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
loss_value = trainer.loss.detach().cpu().item()
current_step = trainer.current_step
bin_index = min(current_step // 100, 9)
self.timestep_bins[bin_index].append(loss_value)
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
log_data = {}
for bin_index, losses in self.timestep_bins.items():
if losses:
avg_loss = sum(losses) / len(losses)
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
self.timestep_bins[bin_index] = []
trainer.log(data=log_data)

View file

@ -0,0 +1,546 @@
from functools import cached_property, wraps
from pathlib import Path
import random
import time
import numpy as np
from torch import device as Device, Tensor, get_rng_state, no_grad, set_rng_state, cuda, stack
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.autograd import backward
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
from loguru import logger
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import manual_seed
from refiners.training_utils.wandb import WandbLogger, WandbLoggable
from refiners.training_utils.config import BaseConfig, TimeUnit, TimeValue, SchedulerType
from refiners.training_utils.dropout import DropoutCallback
from refiners.training_utils.callback import (
Callback,
ClockCallback,
GradientNormClipping,
GradientValueClipping,
GradientNormLogging,
MonitorLoss,
)
from torch.optim.lr_scheduler import (
StepLR,
ExponentialLR,
ReduceLROnPlateau,
CosineAnnealingLR,
LambdaLR,
OneCycleLR,
LRScheduler,
MultiplicativeLR,
CosineAnnealingWarmRestarts,
CyclicLR,
MultiStepLR,
)
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
def count_learnable_parameters(parameters: Iterable[Parameter]) -> int:
return sum(p.numel() for p in parameters if p.requires_grad)
def human_readable_number(number: int) -> str:
float_number = float(number)
for unit in ["", "K", "M", "G", "T", "P"]:
if abs(float_number) < 1000:
return f"{float_number:.1f}{unit}"
float_number /= 1000
return f"{float_number:.1f}E"
def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
manual_seed(seed=seed)
cuda.manual_seed_all(seed=seed)
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
"""
Decorator for setting a random seed within the scope of a function.
This decorator sets the random seed for Python's built-in `random` module,
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
function is executed, it restores the state of the random number generators
to what it was before the function was called. This is useful for ensuring
reproducibility for specific parts of the code without affecting randomness
elsewhere.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
random_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = get_rng_state()
cuda_torch_state = cuda.get_rng_state()
actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed)
result = func(*args, **kwargs)
random.setstate(random_state)
np.random.set_state(numpy_state)
set_rng_state(torch_state)
cuda.set_rng_state(cuda_torch_state)
return result
return inner_wrapper
return decorator
class WarmupScheduler(LRScheduler):
_step_count: int # defined by LRScheduler
def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_steps: int = 0) -> None:
self.warmup_steps = warmup_steps
self.scheduler = scheduler
super().__init__(optimizer=optimizer)
def get_lr(self) -> list[float] | float: # type: ignore
if self._step_count < self.warmup_steps:
return [base_lr * self._step_count / self.warmup_steps for base_lr in self.base_lrs]
return self.scheduler.get_lr()
def step(self, epoch: int | None = None) -> None:
if self._step_count < self.warmup_steps:
super().step()
else:
self.scheduler.step(epoch=epoch)
self._step_count += 1
class TrainingClock:
def __init__(
self,
dataset_length: int,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: TimeValue,
evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue,
checkpointing_save_interval: TimeValue,
) -> None:
self.dataset_length = dataset_length
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval
self.checkpointing_save_interval = checkpointing_save_interval
self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
self.loss: Tensor | None = None
@cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]:
return {
TimeUnit.STEP: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch,
TimeUnit.ITERATION: self.gradient_accumulation["number"] * {
TimeUnit.STEP: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch,
}.get(self.gradient_accumulation["unit"], 1),
}
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
return number * self.unit_to_steps[unit]
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
number, unit = time_value["number"], time_value["unit"]
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@cached_property
def num_epochs(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
@cached_property
def num_iterations(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
@cached_property
def num_steps(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
@cached_property
def num_step_per_iteration(self) -> int:
return self.convert_time_unit_to_steps(
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
)
@cached_property
def num_step_per_evaluation(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
def reset(self) -> None:
self.start_time = None
self.end_time = None
self.step = 0
self.epoch = 0
self.iteration = 0
self.num_batches_processed = 0
self.num_minibatches_processed = 0
def start_timer(self) -> None:
self.start_time = time.time()
def stop_timer(self) -> None:
self.end_time = time.time()
@property
def time_elapsed(self) -> int:
assert self.start_time is not None, "Timer has not been started yet."
return int(time.time() - self.start_time)
@cached_property
def evalution_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
)
@cached_property
def lr_scheduler_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
)
@cached_property
def checkpointing_save_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
)
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration
@property
def is_lr_scheduler_step(self) -> bool:
return self.step % self.lr_scheduler_interval_steps == 0
@property
def done(self) -> bool:
return self.step >= self.num_steps
@property
def is_evaluation_step(self) -> bool:
return self.step % self.evalution_interval_steps == 0
@property
def is_checkpointing_step(self) -> bool:
return self.step % self.checkpointing_save_interval_steps == 0
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
"""
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
"""
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to compute the norm."
total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
return total_norm # type: ignore
Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
class Trainer(Generic[ConfigType, Batch]):
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
self.config = config
self.clock = TrainingClock(
dataset_length=self.dataset_length,
batch_size=config.training.batch_size,
training_duration=config.training.duration,
evaluation_interval=config.training.evaluation_interval,
gradient_accumulation=config.training.gradient_accumulation,
lr_scheduler_interval=config.scheduler.update_interval,
checkpointing_save_interval=config.checkpointing.save_interval,
)
self.callbacks = callbacks or []
self.callbacks += self.default_callbacks()
self.load_wandb()
self.load_models()
self.prepare_models()
self.prepare_checkpointing()
def default_callbacks(self) -> list[Callback[Any]]:
return [
ClockCallback(),
MonitorLoss(),
GradientNormLogging(),
GradientValueClipping(),
GradientNormClipping(),
DropoutCallback(),
]
@cached_property
def device(self) -> Device:
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}")
logger.info(f"Using device: {selected_device}")
return selected_device
@property
def parameters(self) -> list[Parameter]:
"""Returns a list of all parameters in all models"""
return [param for model in self.models.values() for param in model.parameters()]
@property
def learnable_parameters(self) -> list[Parameter]:
"""Returns a list of learnable parameters in all models"""
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
@property
def learnable_parameter_count(self) -> int:
"""Returns the number of learnable parameters in all models"""
return count_learnable_parameters(parameters=self.learnable_parameters)
@property
def gradients(self) -> list[Tensor]:
"""Returns a list of detached gradients for all learnable parameters in all models"""
return [
param.grad.detach()
for model in self.models.values()
for param in model.parameters()
if param.grad is not None
]
@property
def total_gradient_norm(self) -> float:
"""Returns the total gradient norm for all learnable parameters in all models"""
return compute_grad_norm(parameters=self.parameters)
@cached_property
def optimizer(self) -> Optimizer:
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
return optimizer
@cached_property
def lr_scheduler(self) -> LRScheduler:
config = self.config.scheduler
step_size = self.clock.convert_time_unit_to_steps(
number=config.update_interval["number"], unit=config.update_interval["unit"]
)
match config.scheduler_type:
case SchedulerType.CONSTANT_LR:
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
case SchedulerType.STEP_LR:
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=step_size, gamma=config.gamma)
case SchedulerType.EXPONENTIAL_LR:
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
case SchedulerType.COSINE_ANNEALING_LR:
lr_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=step_size, eta_min=config.eta_min)
case SchedulerType.REDUCE_LR_ON_PLATEAU:
lr_scheduler = cast(
LRScheduler,
ReduceLROnPlateau(
optimizer=self.optimizer,
mode=config.mode,
factor=config.factor,
patience=config.patience,
threshold=config.threshold,
cooldown=config.cooldown,
min_lr=config.min_lr,
),
)
case SchedulerType.LAMBDA_LR:
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.ONE_CYCLE_LR:
lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=config.max_lr, total_steps=step_size)
case SchedulerType.MULTIPLICATIVE_LR:
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=step_size)
case SchedulerType.CYCLIC_LR:
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
case SchedulerType.MULTI_STEP_LR:
lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma)
case _:
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
warmup_steps = self.clock.convert_time_unit_to_steps(number=config.warmup["number"], unit=config.warmup["unit"])
if warmup_steps > 0:
lr_scheduler = WarmupScheduler(
optimizer=self.optimizer,
scheduler=lr_scheduler,
warmup_steps=warmup_steps,
)
return lr_scheduler
@cached_property
def models(self) -> dict[str, fl.Module]:
return self.load_models()
def set_models_to_train_mode(self) -> None:
for model in self.models.values():
model.train()
def set_models_to_eval_mode(self) -> None:
for model in self.models.values():
model.eval()
def log(self, data: dict[str, WandbLoggable]) -> None:
self.wandb.log(data=data, step=self.clock.step)
def load_wandb(self) -> None:
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
self.wandb = WandbLogger(init_config=init_config)
def prepare_model(self, model_name: str) -> None:
model = self.models[model_name]
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
model.load_from_safetensors(tensors_path=checkpoint)
else:
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
model.requires_grad_(requires_grad=self.config.models[model_name].train)
model.to(self.device)
model.zero_grad()
def prepare_models(self) -> None:
assert self.models, "No models found."
for model_name in self.models:
self.prepare_model(model_name=model_name)
def prepare_checkpointing(self) -> None:
if self.config.checkpointing.save_folder is not None:
assert self.config.checkpointing.save_folder.is_dir()
self.checkpoints_save_folder = (
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
)
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
else:
self.checkpoints_save_folder = None
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
def load_models(self) -> dict[str, fl.Module]:
raise NotImplementedError("The `load_models` method must be implemented in the subclass.")
def load_dataset(self) -> Dataset[Batch]:
raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.")
@cached_property
def dataset(self) -> Dataset[Batch]:
return self.load_dataset()
@cached_property
def dataset_length(self) -> int:
assert hasattr(self.dataset, "__len__"), "The dataset must implement the `__len__` method."
return len(self.dataset) # type: ignore
@cached_property
def dataloader(self) -> DataLoader[Batch]:
collate_fn = getattr(self.dataset, "collate_fn", None)
return DataLoader(
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
)
@property
def checkpointing_enabled(self) -> bool:
return self.checkpoints_save_folder is not None
@property
def ensure_checkpoints_save_folder(self) -> Path:
assert self.checkpoints_save_folder is not None
return self.checkpoints_save_folder
def compute_loss(self, batch: Batch) -> Tensor:
raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.")
def compute_evaluation(self) -> None:
pass
def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.clock.num_step_per_iteration
backward(tensors=scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin")
self.optimizer.step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
if self.clock.is_lr_scheduler_step:
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
self.lr_scheduler.step()
self._call_callbacks(event_name="on_lr_scheduler_step_end")
if self.clock.is_evaluation_step:
self.evaluate()
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
self._call_callbacks(event_name="on_checkpoint_save")
def step(self, batch: Batch) -> None:
"""Perform a single training step."""
self._call_callbacks(event_name="on_compute_loss_begin")
loss = self.compute_loss(batch=batch)
self.loss = loss
self._call_callbacks(event_name="on_compute_loss_end")
self.backward()
def epoch(self) -> None:
"""Perform a single epoch."""
for batch in self.dataloader:
self._call_callbacks(event_name="on_batch_begin")
self.step(batch=batch)
self._call_callbacks(event_name="on_batch_end")
@staticmethod
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:
return instance.config.training.seed
@scoped_seed(seed=get_training_seed)
def train(self) -> None:
"""Train the model."""
self.set_models_to_train_mode()
self._call_callbacks(event_name="on_train_begin")
assert self.learnable_parameters, "There are no learnable parameters in the models."
self.evaluate()
while not self.clock.done:
self._call_callbacks(event_name="on_epoch_begin")
self.epoch()
self._call_callbacks(event_name="on_epoch_end")
self._call_callbacks(event_name="on_train_end")
@staticmethod
def get_evaluation_seed(instance: "Trainer[BaseConfig, Any]") -> int:
return instance.config.training.evaluation_seed
@no_grad()
@scoped_seed(seed=get_evaluation_seed)
def evaluate(self) -> None:
"""Evaluate the model."""
self.set_models_to_eval_mode()
self._call_callbacks(event_name="on_evaluate_begin")
self.compute_evaluation()
self._call_callbacks(event_name="on_evaluate_end")
self.set_models_to_train_mode()
def _call_callbacks(self, event_name: str) -> None:
for callback in self.callbacks:
getattr(callback, event_name)(self)

View file

@ -0,0 +1,61 @@
from typing import Any
import wandb
from PIL import Image
__all__ = [
"WandbLogger",
"WandbLoggable",
]
number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
def convert_to_wandb(value: WandbLoggable) -> Any:
match value:
case Image.Image():
return convert_to_wandb_image(value=value)
case list():
return convert_to_wandb_histogram(value=value)
case dict():
return convert_to_wandb_table(value=value)
case _:
return value
def convert_to_wandb_image(value: Image.Image) -> wandb.Image:
return wandb.Image(data_or_path=value)
def convert_to_wandb_histogram(value: list[number]) -> wandb.Histogram:
return wandb.Histogram(sequence=value)
def convert_to_wandb_table(value: dict[str, list[number]]) -> wandb.Table:
assert all(
isinstance(v, list) and len(v) == len(next(iter(value.values()))) for v in value.values()
), "Expected a dictionary of lists of the same size"
columns = list(value.keys())
data_rows = list(zip(*value.values()))
return wandb.Table(columns=columns, data=data_rows)
class WandbLogger:
def __init__(self, init_config: dict[str, Any] = {}) -> None:
self.wandb_run = wandb.init(**init_config) # type: ignore
def log(self, data: dict[str, WandbLoggable], step: int) -> None:
converted_data = {key: convert_to_wandb(value=value) for key, value in data.items()}
self.wandb_run.log(converted_data, step=step) # type: ignore
def update_summary(self, key: str, value: Any) -> None:
self.wandb_run.summary[key] = value # type: ignore
@property
def project_name(self) -> str:
return self.wandb_run.project_name() # type: ignore
@property
def run_name(self) -> str:
return self.wandb_run.name or "" # type: ignore

0
tests/__init__.py Normal file
View file

View file

@ -0,0 +1,82 @@
import pytest
from refiners.adapters.adapter import Adapter
from refiners.fluxion.layers import Chain, Linear
class DummyLinearAdapter(Chain, Adapter[Linear]):
def __init__(self, target: Linear):
with self.setup_adapter(target):
super().__init__(target)
class DummyChainAdapter(Chain, Adapter[Chain]):
def __init__(self, target: Chain):
with self.setup_adapter(target):
super().__init__(target)
@pytest.fixture
def chain() -> Chain:
return Chain(Chain(Linear(2, 2)))
def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter.inject(parent)
assert adapter.parent == parent
assert adapter in iter(parent)
assert adaptee not in iter(parent)
adapter.eject()
assert adapter.parent is None
assert adapter not in iter(parent)
assert adaptee in iter(parent)
def test_chain_adapter_insertion(chain: Chain):
parent = chain
adaptee = parent.Chain
adapter = DummyChainAdapter(adaptee)
assert adaptee.parent == parent
adapter.inject()
assert adapter.parent == parent
assert adaptee.parent == adapter
assert adapter in iter(parent)
assert adaptee not in iter(parent)
adapter.eject()
assert adapter.parent is None
assert adaptee.parent == parent
assert adapter not in iter(parent)
assert adaptee in iter(parent)
def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter.inject(parent)
clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter
assert cloned_adapter.parent == clone.Chain
assert cloned_adapter.target == adaptee
def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain)
adapter.inject()
with pytest.raises(RuntimeError):
chain.structural_copy()
adapter.eject()
chain.structural_copy()

View file

@ -0,0 +1,29 @@
from refiners.adapters.lora import Lora, LoraAdapter
from torch import randn, allclose
import refiners.fluxion.layers as fl
def test_lora() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
x = randn(1, 1)
y = chain(x)
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
lora_adapter.inject(chain.Chain)
assert isinstance(lora_adapter[1], Lora)
assert allclose(input=chain(x), other=y)
assert lora_adapter.parent == chain.Chain
lora_adapter.eject()
assert isinstance(chain.Chain[0], fl.Linear)
assert len(chain) == 2
lora_adapter.inject(chain.Chain)
assert isinstance(chain.Chain[0], LoraAdapter)

View file

@ -0,0 +1,25 @@
import torch
from refiners.adapters.adapter import Adapter
from refiners.adapters.range_adapter import RangeEncoder
from refiners.fluxion.layers import Chain, Linear
class DummyLinearAdapter(Chain, Adapter[Linear]):
def __init__(self, target: Linear):
with self.setup_adapter(target):
super().__init__(target)
def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-433
dtype = torch.float64
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1
adapter = DummyLinearAdapter(adaptee)
adapter.inject(chain.RangeEncoder)
assert adapter.parent == chain.RangeEncoder
x = torch.tensor([42], device=test_device)
y = chain(x)
assert y.dtype == dtype

25
tests/conftest.py Normal file
View file

@ -0,0 +1,25 @@
import os
import torch
from pathlib import Path
from pytest import fixture
PARENT_PATH = Path(__file__).parent
@fixture(scope="session")
def test_device() -> torch.device:
test_device = os.getenv("REFINERS_TEST_DEVICE")
if not test_device:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(test_device)
@fixture(scope="session")
def test_weights_path() -> Path:
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "weights"
@fixture(scope="session")
def test_e2e_path() -> Path:
return PARENT_PATH / "e2e"

709
tests/e2e/test_diffusion.py Normal file
View file

@ -0,0 +1,709 @@
import torch
import pytest
from typing import Iterator
from warnings import warn
from PIL import Image
from pathlib import Path
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
from tests.utils import ensure_similar_images
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref"
@pytest.fixture(scope="module")
def cutecat_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture(scope="module")
def kitchen_dog(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "kitchen_dog.png").convert("RGB")
@pytest.fixture(scope="module")
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
@pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_fn = {
"depth": "lllyasviel_control_v11f1p_sd15_depth",
"canny": "lllyasviel_control_v11p_sd15_canny",
"lineart": "lllyasviel_control_v11p_sd15_lineart",
"normals": "lllyasviel_control_v11p_sd15_normalbae",
"sam": "mfidabel_controlnet-segment-anything",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny"
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors"
return expected_image, weights_path
@pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
@pytest.fixture
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-mask.png").convert("RGB")
@pytest.fixture
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "inpainting-target.png").convert("RGB")
@pytest.fixture
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
@pytest.fixture
def expected_image_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_refonly.png").convert("RGB")
@pytest.fixture
def condition_image_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
@pytest.fixture(scope="module")
def text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def lda_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "unet.safetensors"
if not unet_weights_std.is_file():
warn(f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def unet_weights_inpainting(test_weights_path: Path) -> Path:
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
if not unet_weights_inpainting.is_file():
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_inpainting
@pytest.fixture
def sd15_std(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sd15 = StableDiffusion_1(device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
return sd15
@pytest.fixture
def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
return sd15
@pytest.fixture
def sd15_inpainting(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
unet = UNet(in_channels=9, clip_embedding_dim=768)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
return sd15
@pytest.fixture
def sd15_inpainting_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
) -> StableDiffusion_1_Inpainting:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
unet = UNet(in_channels=9, clip_embedding_dim=768)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
return sd15
@pytest.fixture
def sd15_ddim(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
return sd15
@torch.no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_std_random_init)
@torch.no_grad()
def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_float16
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
assert clip_text_embedding.dtype == torch.float16
assert negative_clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
expected_image_std_init_image: Image.Image,
):
sd15 = sd15_std
n_steps = 35
first_step = 5
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
manual_seed(2)
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
with torch.no_grad():
for step in sd15.steps[first_step:]:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_std_init_image)
@torch.no_grad()
def test_diffusion_inpainting(
sd15_inpainting: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
kitchen_dog_mask: Image.Image,
expected_image_std_inpainting: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@torch.no_grad()
def test_diffusion_inpainting_float16(
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
kitchen_dog_mask: Image.Image,
expected_image_std_inpainting: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting_float16
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
assert clip_text_embedding.dtype == torch.float16
assert negative_clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
# PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@torch.no_grad()
def test_diffusion_controlnet(
sd15_std: StableDiffusion_1,
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_controlnet_structural_copy(
sd15_std: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15_base = sd15_std
sd15 = sd15_base.structural_copy()
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_controlnet_float16(
sd15_std_float16: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std_float16
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
lora_weights = LoraWeights(lora_weights_path, device=test_device)
lora_weights.patch(sd15, scale=1.0)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1,
condition_image_refonly: Image.Image,
expected_image_refonly: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim
prompt = "Chicken"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()
guide = sd15.lda.encode_image(condition_image_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
sai.set_controlnet_condition(noised_guide)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad()
def test_diffusion_inpainting_refonly(
sd15_inpainting: StableDiffusion_1_Inpainting,
scene_image_inpainting_refonly: Image.Image,
target_image_inpainting_refonly: Image.Image,
mask_image_inpainting_refonly: Image.Image,
expected_image_inpainting_refonly: Image.Image,
test_device: torch.device,
):
sd15 = sd15_inpainting
n_steps = 30
prompt = "" # unconditional
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
refonly_guide = torch.cat((refonly_guide, refonly_guide))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
refonly_noise = torch.randn_like(refonly_guide)
refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models")
refonly_noised_guide = torch.cat(
[refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1
)
sai.set_controlnet_condition(refonly_noised_guide)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)

View file

@ -0,0 +1,82 @@
# Note about this data
## Expected outputs
`expected_*.png` files are the output of the same diffusion run with a different codebase, usually diffusers with the same settings as us (`DPMSolverMultistepScheduler`, VAE [patched to remove randomness](#vae-without-randomness), same seed...).
For instance here is how we generate `expected_std_random_init.png`:
```py
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float32,
).to("cuda)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
torch.manual_seed(2)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=30,
guidance_scale=7.5,
)
output.images[0].save("std_random_init_expected.png")
```
Special cases:
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- `expected_inpainting_refonly.png` has been generated with refiners itself (and inspected so that it looks reasonable).
## Other images
- `cutecat_init.png` is generated with the same Diffusers script and prompt but with seed 1234.
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
- `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated using open source software and models, namely:
- Canny: opencv-python
- Depth: https://github.com/isl-org/ZoeDepth
- Lineart: https://github.com/lllyasviel/ControlNet-v1-1-nightly/tree/main/annotator/lineart
- Normals: https://github.com/baegwangbin/surface_normal_uncertainty/tree/fe2b9f1
- SAM: https://huggingface.co/spaces/mfidabel/controlnet-segment-anything
- `cyberpunk_guide.png` [comes from Lexica](https://lexica.art/prompt/5ba40855-0d0c-4322-8722-51115985f573).
- `inpainting-mask.png`, `inpainting-scene.png` and `inpainting-target.png` have been generated as follows:
- `inpainting-mask.png`: negated version of a mask computed with [SAM](https://github.com/facebookresearch/segment-anything) automatic mask generation using the `vit_h` checkpoint
- `inpainting-scene.png`: cropped-to-square-and-resized version of https://unsplash.com/photos/RCz6eSVPGYU by @jannerboy62
- `inpainting-target.png`: computed with `convert <(convert -size 512x512 xc:white png:-) kitchen_dog.png <(convert inpainting-mask.png -negate png:-) -compose Over -composite inpainting-target.png`
## VAE without randomness
```diff
--- a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -524,13 +524,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
- if isinstance(generator, list):
- init_latents = [
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
- ]
- init_latents = torch.cat(init_latents, dim=0)
- else:
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
+ init_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mean for i in range(batch_size)]
+ init_latents = torch.cat(init_latents, dim=0)
init_latents = self.vae.config.scaling_factor * init_latents
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 387 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 563 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 416 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 409 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 468 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 447 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 461 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 700 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 316 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 491 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 476 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 225 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 379 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Some files were not shown because too many files have changed in this diff Show more