initial commit
31
.github/workflows/ci.yml
vendored
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
name: CI
|
||||||
|
on: push
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
lint_and_typecheck:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up python
|
||||||
|
id: setup-python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
with:
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
installer-parallel: true
|
||||||
|
|
||||||
|
- name: poetry install
|
||||||
|
run: poetry install --no-interaction --extras=training
|
||||||
|
|
||||||
|
- name: lint
|
||||||
|
run: poetry run ruff check .
|
||||||
|
|
||||||
|
- name: typecheck
|
||||||
|
run: poetry run pyright
|
28
.gitignore
vendored
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# compilation and distribution
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# virtual environments
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# unit tests
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# tests' model weights
|
||||||
|
tests/weights/
|
||||||
|
|
||||||
|
# ruff
|
||||||
|
.ruff_cache
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
# Weights & Biases (offline trainings)
|
||||||
|
wandb/
|
||||||
|
|
||||||
|
# macos
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# model weights
|
||||||
|
*.safetensors
|
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Lagon Technologies
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
304
README.md
Normal file
|
@ -0,0 +1,304 @@
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="assets/logo_dark.png">
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="assets/logo_light.png">
|
||||||
|
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
|
||||||
|
</picture>
|
||||||
|
|
||||||
|
**The simplest way to train and run adapters on top of foundational models**
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/refiners)](https://pypi.org/project/refiners/)
|
||||||
|
[![PyPI Status](https://badge.fury.io/py/refiners.svg)](https://badge.fury.io/py/refiners)
|
||||||
|
[![license](https://img.shields.io/badge/license-MIT-blue)](/LICENSE)
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
- [Motivation](#motivation)
|
||||||
|
- [Design](#design)
|
||||||
|
- [Downsides](#downsides)
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Key Concepts](#key-concepts)
|
||||||
|
- [The Chain class](#the-chain-class)
|
||||||
|
- [The Context API](#the-context-api)
|
||||||
|
- [The Adapter API](#the-adapter-api)
|
||||||
|
- [Getting Started](#getting-started)
|
||||||
|
- [Install](#install)
|
||||||
|
- [Hello World](#hello-world)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Credits](#credits)
|
||||||
|
- [Citation](#citation)
|
||||||
|
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
At [Finegrain](https://finegrain.ai), we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.
|
||||||
|
|
||||||
|
That's why we're building Refiners.
|
||||||
|
|
||||||
|
It's a framework to easily bridge the last mile quality gap of foundational models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.
|
||||||
|
|
||||||
|
We decided to build Refiners in the open.
|
||||||
|
|
||||||
|
It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.
|
||||||
|
|
||||||
|
## Design
|
||||||
|
|
||||||
|
We are huge fans of PyTorch (we actually were core committers to [Torch](http://torch.ch/) in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.
|
||||||
|
|
||||||
|
Instead, we needed:
|
||||||
|
|
||||||
|
- A model structure that's human readable so that you know what models do and how they work right here, right now
|
||||||
|
- A mechanism to easily inject parameters in some target layers, or between them
|
||||||
|
- A way to easily pass data (like a conditioning input) between layers even when deeply nested
|
||||||
|
- Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on [Civitai](http://civitai.com/) and the likes)
|
||||||
|
|
||||||
|
Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.
|
||||||
|
|
||||||
|
## Downsides
|
||||||
|
|
||||||
|
As they say, there is no free lunch. Given Refiners comes with a new model structure, it can only work with models implemented that way. For now, we support Stable Diffusion 1.5, but more is in the making (SDXL, SAM, ...) - stay tuned.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Refiners library is made of:
|
||||||
|
|
||||||
|
1. An abstraction layer (called Fluxion) on top of [PyTorch](https://pytorch.org/) to easily build models
|
||||||
|
2. A zoo of compatible foundational models
|
||||||
|
3. Adapter APIs to easily patch supported foundational models
|
||||||
|
4. Training utils to train concrete adapters
|
||||||
|
5. Conversion scripts to easily use existing community adapters
|
||||||
|
|
||||||
|
## Key Concepts
|
||||||
|
|
||||||
|
### The Chain class
|
||||||
|
|
||||||
|
The `Chain` class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a **declarative way**.
|
||||||
|
|
||||||
|
E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
class ViT(fl.Chain):
|
||||||
|
# The Vision Transformer model structure is entirely defined in the constructor. It is
|
||||||
|
# ready-to-use right after i.e. no need to implement any forward function or add extra logic
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 512,
|
||||||
|
patch_size: int = 16,
|
||||||
|
image_size: int = 384,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_heads: int = 8,
|
||||||
|
):
|
||||||
|
num_patches = (image_size // patch_size)
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
|
||||||
|
fl.Reshape(num_patches**2, embedding_dim),
|
||||||
|
# The Residual layer implements the so-called skip-connection, i.e. x + F(x).
|
||||||
|
# Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
|
||||||
|
# weights are stored in the Parameter layer (note: there is no extra classification
|
||||||
|
# token in this toy example)
|
||||||
|
fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
|
||||||
|
# These are the transformer encoders:
|
||||||
|
*(
|
||||||
|
fl.Chain(
|
||||||
|
fl.LayerNorm(embedding_dim),
|
||||||
|
fl.Residual(
|
||||||
|
# The Parallel layer is used to pass multiple inputs to a downstream
|
||||||
|
# layer, here multiheaded self-attention
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.Identity(),
|
||||||
|
fl.Identity()
|
||||||
|
),
|
||||||
|
fl.Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
key_embedding_dim=embedding_dim,
|
||||||
|
value_embedding_dim=embedding_dim,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.LayerNorm(embedding_dim),
|
||||||
|
fl.Residual(
|
||||||
|
fl.Linear(embedding_dim, embedding_dim * 4),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(embedding_dim * 4, embedding_dim),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Linear(embedding_dim, embedding_dim * 4),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(embedding_dim * 4, embedding_dim),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
fl.Reshape(embedding_dim, num_patches, num_patches),
|
||||||
|
)
|
||||||
|
|
||||||
|
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
|
||||||
|
x = torch.randn(2, 3, 224, 224)
|
||||||
|
y = vit(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
### The Context API
|
||||||
|
|
||||||
|
The `Chain` class has a context provider that allows you to **pass data to layers even when deeply nested**.
|
||||||
|
|
||||||
|
E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:
|
||||||
|
|
||||||
|
|
||||||
|
```diff
|
||||||
|
@@ -21,8 +21,8 @@
|
||||||
|
fl.Residual(
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Identity(),
|
||||||
|
- fl.Identity(),
|
||||||
|
- fl.Identity()
|
||||||
|
+ fl.UseContext(context="cross_attention", key="my_embed"),
|
||||||
|
+ fl.UseContext(context="cross_attention", key="my_embed"),
|
||||||
|
), # used to pass multiple inputs to a layer
|
||||||
|
fl.Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
@@ -49,5 +49,6 @@
|
||||||
|
)
|
||||||
|
|
||||||
|
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
|
||||||
|
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
|
||||||
|
x = torch.randn(2, 3, 224, 224)
|
||||||
|
y = vit(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
### The Adapter API
|
||||||
|
|
||||||
|
The `Adapter` API lets you **easily patch models** by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.
|
||||||
|
|
||||||
|
E.g. to inject LoRA layers in all attention's linear layers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from refiners.adapters.lora import LoraAdapter
|
||||||
|
|
||||||
|
for layer in vit.layers(fl.Attention):
|
||||||
|
for linear, parent in layer.walk(fl.Linear):
|
||||||
|
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
|
||||||
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
# ... and load existing weights if the LoRAs are pretrained ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# inference only
|
||||||
|
pip install refiners
|
||||||
|
```
|
||||||
|
|
||||||
|
Or:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# inference + training
|
||||||
|
pip install 'refiners[training]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hello World
|
||||||
|
|
||||||
|
Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 foundational model patched with a Pokemon LoRA:
|
||||||
|
|
||||||
|
Step 1: prepare the model weights in refiners' format:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
|
||||||
|
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
|
||||||
|
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
|
||||||
|
|
||||||
|
Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://huggingface.co/pcuenq/pokemon-lora)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
|
||||||
|
python scripts/convert-lora-weights.py \
|
||||||
|
--from pytorch_lora_weights.bin \
|
||||||
|
--output-file pokemon_lora.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 3: run inference using the GPU:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
sd15 = StableDiffusion_1(device="cuda")
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
||||||
|
|
||||||
|
# This uses the LoraAdapter internally and takes care to inject it where it should
|
||||||
|
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
|
||||||
|
lora_weights.patch(sd15, scale=1.0)
|
||||||
|
|
||||||
|
prompt = "a cute cat"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(30)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=sd15.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
predicted_image.save("pokemon_cat.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
You should get:
|
||||||
|
|
||||||
|
![pokemon cat output](assets/pokemon_cat.png)
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
|
||||||
|
|
||||||
|
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Credits
|
||||||
|
|
||||||
|
We took inspiration from these great projects:
|
||||||
|
|
||||||
|
- [tinygrad](https://github.com/tinygrad/tinygrad) - For something between PyTorch and [karpathy/micrograd](https://github.com/karpathy/micrograd)
|
||||||
|
- [Composer](https://github.com/mosaicml/composer) - A PyTorch Library for Efficient Neural Network Training
|
||||||
|
- [Keras](https://github.com/keras-team/keras) - Deep Learning for humans
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{the-finegrain-team-2023-refiners,
|
||||||
|
author = {Benjamin Trom and Pierre Chapuis and Cédric Deltheil},
|
||||||
|
title = {Refiners: The simplest way to train and run adapters on top of foundational models},
|
||||||
|
year = {2023},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/finegrain-ai/refiners}}
|
||||||
|
}
|
||||||
|
```
|
BIN
assets/dropy.png
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
assets/logo_dark.png
Normal file
After Width: | Height: | Size: 27 KiB |
BIN
assets/logo_light.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
assets/pokemon_cat.png
Normal file
After Width: | Height: | Size: 336 KiB |
55
configs/finetune-ldm.toml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
script = "finetune-ldm.py" # not used for now
|
||||||
|
|
||||||
|
[wandb]
|
||||||
|
offline = "offline"
|
||||||
|
entity = "acme"
|
||||||
|
project = "test-ldm-training"
|
||||||
|
|
||||||
|
[models]
|
||||||
|
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
|
||||||
|
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
|
||||||
|
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
|
||||||
|
|
||||||
|
[latent_diffusion]
|
||||||
|
unconditional_sampling_probability = 0.2
|
||||||
|
offset_noise = 0.1
|
||||||
|
|
||||||
|
[training]
|
||||||
|
duration = "1:epoch"
|
||||||
|
seed = 0
|
||||||
|
gpu_index = 0
|
||||||
|
num_epochs = 1
|
||||||
|
batch_size = 1
|
||||||
|
gradient_accumulation = "1:step"
|
||||||
|
clip_grad_norm = 2.0
|
||||||
|
clip_grad_value = 1.0
|
||||||
|
evaluation_interval = "1:epoch"
|
||||||
|
evaluation_seed = 0
|
||||||
|
|
||||||
|
|
||||||
|
[optimizer]
|
||||||
|
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
|
||||||
|
learning_rate = 1e-5
|
||||||
|
betas = [0.9, 0.999]
|
||||||
|
eps = 1e-8
|
||||||
|
weight_decay = 1e-2
|
||||||
|
|
||||||
|
|
||||||
|
[scheduler]
|
||||||
|
|
||||||
|
|
||||||
|
[dropout]
|
||||||
|
dropout_probability = 0.2
|
||||||
|
|
||||||
|
[dataset]
|
||||||
|
hf_repo = "acme/images"
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
[checkpointing]
|
||||||
|
# save_folder = "/path/to/ckpts"
|
||||||
|
save_interval = "1:epoch"
|
||||||
|
|
||||||
|
[test_diffusion]
|
||||||
|
prompts = [
|
||||||
|
"A cute cat",
|
||||||
|
]
|
70
configs/finetune-lora.toml
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
script = "finetune-ldm-lora.py" # not used for now
|
||||||
|
|
||||||
|
[wandb]
|
||||||
|
mode = "offline" # "online", "offline", "disabled"
|
||||||
|
entity = "acme"
|
||||||
|
project = "test-lora-training"
|
||||||
|
|
||||||
|
[models]
|
||||||
|
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
||||||
|
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
||||||
|
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
||||||
|
|
||||||
|
[latent_diffusion]
|
||||||
|
unconditional_sampling_probability = 0.05
|
||||||
|
offset_noise = 0.1
|
||||||
|
|
||||||
|
[lora]
|
||||||
|
rank = 16
|
||||||
|
trigger_phrase = "a spsh photo,"
|
||||||
|
use_only_trigger_probability = 1.0
|
||||||
|
unet_targets = ["CrossAttentionBlock2d"]
|
||||||
|
text_encoder_targets = ["TransformerLayer"]
|
||||||
|
lda_targets = []
|
||||||
|
|
||||||
|
[training]
|
||||||
|
duration = "1000:epoch"
|
||||||
|
seed = 0
|
||||||
|
gpu_index = 0
|
||||||
|
batch_size = 4
|
||||||
|
gradient_accumulation = "4:step"
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
# clip_grad_value = 1.0
|
||||||
|
evaluation_interval = "5:epoch"
|
||||||
|
evaluation_seed = 1
|
||||||
|
|
||||||
|
|
||||||
|
[optimizer]
|
||||||
|
optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
||||||
|
learning_rate = 1
|
||||||
|
betas = [0.9, 0.999]
|
||||||
|
eps = 1e-8
|
||||||
|
weight_decay = 1e-2
|
||||||
|
|
||||||
|
[scheduler]
|
||||||
|
scheduler_type = "ConstantLR"
|
||||||
|
update_interval = "1:step"
|
||||||
|
warmup = "500:step"
|
||||||
|
|
||||||
|
|
||||||
|
[dropout]
|
||||||
|
dropout_probability = 0.2
|
||||||
|
use_gyro_dropout = false
|
||||||
|
|
||||||
|
[dataset]
|
||||||
|
hf_repo = "acme/images"
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
[checkpointing]
|
||||||
|
# save_folder = "/path/to/ckpts"
|
||||||
|
save_interval = "1:step"
|
||||||
|
|
||||||
|
[test_diffusion]
|
||||||
|
num_inference_steps = 30
|
||||||
|
use_short_prompts = false
|
||||||
|
prompts = [
|
||||||
|
"a cute cat",
|
||||||
|
"a cute dog",
|
||||||
|
"a cute bird",
|
||||||
|
"a cute horse",
|
||||||
|
]
|
3360
poetry.lock
generated
Normal file
67
pyproject.toml
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "refiners"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "The simplest way to train and run adapters on top of foundational models"
|
||||||
|
authors = [
|
||||||
|
"The Finegrain Team <bonjour@lagon.tech>",
|
||||||
|
]
|
||||||
|
license = "MIT"
|
||||||
|
readme = "README.md"
|
||||||
|
packages = [{include = "refiners", from = "src"}]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.10,<3.12"
|
||||||
|
jaxtyping = "^0.2.14"
|
||||||
|
torch = "^2.0.0"
|
||||||
|
safetensors = "^0.3.0"
|
||||||
|
numpy = "^1.24.2"
|
||||||
|
pillow = "^9.5.0"
|
||||||
|
datasets = {version = "^2.14.0", optional = true}
|
||||||
|
tomli = {version = "^2.0.1", optional = true}
|
||||||
|
wandb = {version = "^0.15.7", optional = true}
|
||||||
|
loguru = {version = "^0.7.0", optional = true}
|
||||||
|
bitsandbytes = {version = "^0.41.0", optional = true}
|
||||||
|
prodigyopt = {version = "^1.0", optional = true}
|
||||||
|
pydantic = {git = "https://github.com/pydantic/pydantic.git", rev = "v2.0b3", optional = true}
|
||||||
|
scipy = {version = "^1.11.1", optional = true}
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy"]
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
black = "^23.1.0"
|
||||||
|
pytest = "^7.2.2"
|
||||||
|
isort = "^5.12.0"
|
||||||
|
ipykernel = "^6.22.0"
|
||||||
|
pyright = "^1.1.318"
|
||||||
|
ruff = "^0.0.281"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.test.dependencies]
|
||||||
|
diffusers = "^0.18.0"
|
||||||
|
transformers = "^4.27.4"
|
||||||
|
piq = "^0.7.1"
|
||||||
|
invisible-watermark = "^0.2.0"
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
preview = true
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
ignore = [
|
||||||
|
"F722", # forward-annotation-syntax-error, because of Jaxtyping
|
||||||
|
"E731", # do-not-assign-lambda
|
||||||
|
"E501", # line-too-long, because Black (https://beta.ruff.rs/docs/faq/#is-ruff-compatible-with-black)
|
||||||
|
]
|
||||||
|
line-length = 120
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
include = ["src/refiners", "tests", "scripts/training"]
|
||||||
|
exclude = ["**/__pycache__"]
|
||||||
|
reportMissingTypeStubs = "warning"
|
50
scripts/convert-clip-weights.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPTextModel
|
||||||
|
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = CLIPTextEncoderL()
|
||||||
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||||
|
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||||||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="CLIPTextEncoderL.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder
|
||||||
|
tensors = convert(src_model)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
203
scripts/convert-controlnet-weights.py
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
import torch
|
||||||
|
from diffusers import ControlNetModel
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
forward_order_of_execution,
|
||||||
|
verify_shape_match,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion import UNet
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
|
controlnet = Controlnet(name="mycn")
|
||||||
|
|
||||||
|
condition = torch.randn(1, 3, 512, 512)
|
||||||
|
controlnet.set_controlnet_condition(condition)
|
||||||
|
|
||||||
|
unet = UNet(4, clip_embedding_dim=768)
|
||||||
|
unet.insert(0, controlnet)
|
||||||
|
clip_text_embedding = torch.rand(1, 77, 768)
|
||||||
|
unet.set_clip_text_embedding(clip_text_embedding)
|
||||||
|
|
||||||
|
scheduler = DPMSolver(num_inference_steps=10)
|
||||||
|
timestep = scheduler.timesteps[0].unsqueeze(0)
|
||||||
|
unet.set_timestep(timestep.unsqueeze(0))
|
||||||
|
|
||||||
|
x = torch.randn(1, 4, 64, 64)
|
||||||
|
|
||||||
|
# We need the hack below because our implementation is not strictly equivalent
|
||||||
|
# to diffusers in order, since we compute the residuals inline instead of
|
||||||
|
# in a separate step.
|
||||||
|
|
||||||
|
source_order = forward_order_of_execution(controlnet_src, (x, timestep, clip_text_embedding, condition))
|
||||||
|
target_order = forward_order_of_execution(controlnet, (x,))
|
||||||
|
|
||||||
|
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||||
|
|
||||||
|
expected_source_order = [
|
||||||
|
"down_blocks.0.attentions.0.proj_in",
|
||||||
|
"down_blocks.0.attentions.0.proj_out",
|
||||||
|
"down_blocks.0.attentions.1.proj_in",
|
||||||
|
"down_blocks.0.attentions.1.proj_out",
|
||||||
|
"controlnet_down_blocks.0",
|
||||||
|
"controlnet_down_blocks.1",
|
||||||
|
"controlnet_down_blocks.2",
|
||||||
|
"controlnet_down_blocks.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_target_order = [
|
||||||
|
"DownBlocks.Chain_1.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_2.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_3.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_4.Passthrough.Conv2d",
|
||||||
|
]
|
||||||
|
|
||||||
|
fixed_source_order = [
|
||||||
|
"controlnet_down_blocks.0",
|
||||||
|
"down_blocks.0.attentions.0.proj_in",
|
||||||
|
"down_blocks.0.attentions.0.proj_out",
|
||||||
|
"controlnet_down_blocks.1",
|
||||||
|
"down_blocks.0.attentions.1.proj_in",
|
||||||
|
"down_blocks.0.attentions.1.proj_out",
|
||||||
|
"controlnet_down_blocks.2",
|
||||||
|
"controlnet_down_blocks.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert source_order[broken_k] == expected_source_order
|
||||||
|
assert target_order[broken_k] == expected_target_order
|
||||||
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
|
broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
||||||
|
|
||||||
|
expected_source_order = [
|
||||||
|
"down_blocks.1.attentions.0.proj_in",
|
||||||
|
"down_blocks.1.attentions.0.proj_out",
|
||||||
|
"down_blocks.1.attentions.1.proj_in",
|
||||||
|
"down_blocks.1.attentions.1.proj_out",
|
||||||
|
"controlnet_down_blocks.4",
|
||||||
|
"controlnet_down_blocks.5",
|
||||||
|
"controlnet_down_blocks.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_target_order = [
|
||||||
|
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_5.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_6.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_7.Passthrough.Conv2d",
|
||||||
|
]
|
||||||
|
|
||||||
|
fixed_source_order = [
|
||||||
|
"down_blocks.1.attentions.0.proj_in",
|
||||||
|
"down_blocks.1.attentions.0.proj_out",
|
||||||
|
"controlnet_down_blocks.4",
|
||||||
|
"down_blocks.1.attentions.1.proj_in",
|
||||||
|
"down_blocks.1.attentions.1.proj_out",
|
||||||
|
"controlnet_down_blocks.5",
|
||||||
|
"controlnet_down_blocks.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert source_order[broken_k] == expected_source_order
|
||||||
|
assert target_order[broken_k] == expected_target_order
|
||||||
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
|
broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
||||||
|
|
||||||
|
expected_source_order = [
|
||||||
|
"down_blocks.2.attentions.0.proj_in",
|
||||||
|
"down_blocks.2.attentions.0.proj_out",
|
||||||
|
"down_blocks.2.attentions.1.proj_in",
|
||||||
|
"down_blocks.2.attentions.1.proj_out",
|
||||||
|
"mid_block.attentions.0.proj_in",
|
||||||
|
"mid_block.attentions.0.proj_out",
|
||||||
|
"controlnet_down_blocks.7",
|
||||||
|
"controlnet_down_blocks.8",
|
||||||
|
"controlnet_down_blocks.9",
|
||||||
|
"controlnet_down_blocks.10",
|
||||||
|
"controlnet_down_blocks.11",
|
||||||
|
"controlnet_mid_block",
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_target_order = [
|
||||||
|
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_8.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"DownBlocks.Chain_9.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_10.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_11.Passthrough.Conv2d",
|
||||||
|
"DownBlocks.Chain_12.Passthrough.Conv2d",
|
||||||
|
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
||||||
|
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
||||||
|
"MiddleBlock.Passthrough.Conv2d",
|
||||||
|
]
|
||||||
|
|
||||||
|
fixed_source_order = [
|
||||||
|
"down_blocks.2.attentions.0.proj_in",
|
||||||
|
"down_blocks.2.attentions.0.proj_out",
|
||||||
|
"controlnet_down_blocks.7",
|
||||||
|
"down_blocks.2.attentions.1.proj_in",
|
||||||
|
"down_blocks.2.attentions.1.proj_out",
|
||||||
|
"controlnet_down_blocks.8",
|
||||||
|
"controlnet_down_blocks.9",
|
||||||
|
"controlnet_down_blocks.10",
|
||||||
|
"controlnet_down_blocks.11",
|
||||||
|
"mid_block.attentions.0.proj_in",
|
||||||
|
"mid_block.attentions.0.proj_out",
|
||||||
|
"controlnet_mid_block",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert source_order[broken_k] == expected_source_order
|
||||||
|
assert target_order[broken_k] == expected_target_order
|
||||||
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
|
assert verify_shape_match(source_order, target_order)
|
||||||
|
|
||||||
|
mapping: dict[str, str] = {}
|
||||||
|
for model_type_shape in source_order:
|
||||||
|
source_keys = source_order[model_type_shape]
|
||||||
|
target_keys = target_order[model_type_shape]
|
||||||
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
|
state_dict = convert_state_dict(controlnet_src.state_dict(), controlnet.state_dict(), state_dict_mapping=mapping)
|
||||||
|
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=True,
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="output.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
controlnet_src = ControlNetModel.from_pretrained(args.source)
|
||||||
|
tensors = convert(controlnet_src)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
115
scripts/convert-lora-weights.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Note: this conversion script currently only support simple LoRAs which adapt
|
||||||
|
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.init import zeros_
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
||||||
|
from refiners.adapters.lora import Lora
|
||||||
|
from refiners.fluxion.utils import create_state_dict_mapping
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||||
|
assert linear.bias is None
|
||||||
|
return linear.state_dict()["weight"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
||||||
|
weights: list[torch.Tensor] = []
|
||||||
|
for lora in module.layers(layer_type=Lora):
|
||||||
|
linears = list(lora.layers(fl.Linear))
|
||||||
|
assert len(linears) == 2
|
||||||
|
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
|
||||||
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
|
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
|
||||||
|
diffusers_sd = DiffusionPipeline.from_pretrained(base_model) # type: ignore
|
||||||
|
diffusers_model = diffusers_sd.unet
|
||||||
|
|
||||||
|
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
target = LoraTarget.CrossAttention
|
||||||
|
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||||
|
rank = diffusers_state_dict[
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
|
||||||
|
].shape[0]
|
||||||
|
|
||||||
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
timestep = torch.tensor([0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
|
refiners_model.set_timestep(timestep)
|
||||||
|
refiners_model.set_clip_text_embedding(clip_text_embeddings)
|
||||||
|
refiners_args = (x,)
|
||||||
|
|
||||||
|
diffusers_args = (x, timestep, clip_text_embeddings)
|
||||||
|
|
||||||
|
diffusers_to_refiners = create_state_dict_mapping(refiners_model, diffusers_model, refiners_args, diffusers_args)
|
||||||
|
assert diffusers_to_refiners
|
||||||
|
|
||||||
|
apply_loras_to_target(refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
||||||
|
for layer in refiners_model.layers(layer_type=Lora):
|
||||||
|
zeros_(layer.Linear_1.weight)
|
||||||
|
|
||||||
|
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
||||||
|
for target_k in targets:
|
||||||
|
k_p, k_s = target_k.split(".processor.")
|
||||||
|
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
||||||
|
assert len(r) == 1
|
||||||
|
orig_k = r[0]
|
||||||
|
orig_path = orig_k.split(".")
|
||||||
|
p = refiners_model
|
||||||
|
for seg in orig_path[:-1]:
|
||||||
|
p = p[seg]
|
||||||
|
last_seg = (
|
||||||
|
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
||||||
|
)
|
||||||
|
p_down = TorchParameter(diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
||||||
|
p_up = TorchParameter(diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
||||||
|
p[last_seg].Lora.load_weights(p_down, p_up)
|
||||||
|
|
||||||
|
state_dict = build_loras_safetensors(refiners_model, key_prefix="unet.")
|
||||||
|
assert len(state_dict) == 320
|
||||||
|
save_to_safetensors(output_file, tensors=state_dict, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=True,
|
||||||
|
help="Source file path (.bin)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-model",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Base model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="output.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
process(source=args.source, base_model=args.base_model, output_file=args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
134
scripts/convert-loras-to-sdwebui.py
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.utils import create_state_dict_mapping
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPTextModel
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
|
||||||
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
timestep = torch.tensor([0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
|
dst_model.set_timestep(timestep)
|
||||||
|
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
||||||
|
dst_args = (x,)
|
||||||
|
|
||||||
|
return create_state_dict_mapping(src_model, dst_model, src_args, dst_args) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
||||||
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||||
|
|
||||||
|
return create_state_dict_mapping(src_model, dst_model, [x]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--input-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the input file with refiner's LoRA weights (safetensors format)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output file with sd-webui's LoRA weights (safetensors format)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd15",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Path (preferred) or repository ID of Stable Diffusion 1.5 model (Hugging Face diffusers format)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
metadata = load_metadata_from_safetensors(args.input_file)
|
||||||
|
assert metadata is not None
|
||||||
|
tensors = load_from_safetensors(args.input_file)
|
||||||
|
|
||||||
|
diffusers_sd = DiffusionPipeline.from_pretrained(args.sd15) # type: ignore
|
||||||
|
|
||||||
|
state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
for meta_key, meta_value in metadata.items():
|
||||||
|
match meta_key:
|
||||||
|
case "unet_targets":
|
||||||
|
src_model = diffusers_sd.unet # type: ignore
|
||||||
|
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
create_mapping = create_unet_mapping
|
||||||
|
key_prefix = "unet."
|
||||||
|
lora_prefix = "lora_unet_"
|
||||||
|
case "text_encoder_targets":
|
||||||
|
src_model = diffusers_sd.text_encoder # type: ignore
|
||||||
|
dst_model = CLIPTextEncoderL()
|
||||||
|
create_mapping = create_text_encoder_mapping
|
||||||
|
key_prefix = "text_encoder."
|
||||||
|
lora_prefix = "lora_te_"
|
||||||
|
case "lda_targets":
|
||||||
|
raise ValueError("SD-WebUI does not support LoRA for the auto-encoder")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
||||||
|
|
||||||
|
submodule_to_key: dict[Module, str] = {}
|
||||||
|
for name, submodule in dst_model.named_modules():
|
||||||
|
submodule_to_key[submodule] = name
|
||||||
|
|
||||||
|
# SD-WebUI expects LoRA state dicts with keys derived from the diffusers format, e.g.:
|
||||||
|
#
|
||||||
|
# lora_unet_down_blocks_0_attentions_0_proj_in.alpha
|
||||||
|
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
|
||||||
|
# lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
# Internally SD-WebUI has some logic[1] to convert such keys into the CompVis format. See
|
||||||
|
# `convert_diffusers_name_to_compvis` for more details.
|
||||||
|
#
|
||||||
|
# [1]: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/394ffa7/extensions-builtin/Lora/lora.py#L158-L225
|
||||||
|
|
||||||
|
refiners_to_diffusers = create_mapping(src_model, dst_model) # type: ignore
|
||||||
|
assert refiners_to_diffusers is not None
|
||||||
|
|
||||||
|
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
||||||
|
lora_injection_points: list[str] = [
|
||||||
|
refiners_to_diffusers[submodule_to_key[linear]]
|
||||||
|
for target in [LoraTarget(t) for t in meta_value.split(",")]
|
||||||
|
for layer in dst_model.layers(layer_type=target.get_class())
|
||||||
|
for linear in layer.layers(fl.Linear)
|
||||||
|
]
|
||||||
|
|
||||||
|
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
|
||||||
|
assert len(lora_injection_points) == len(lora_weights) // 2
|
||||||
|
|
||||||
|
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
|
||||||
|
for i, diffusers_key in enumerate(lora_injection_points):
|
||||||
|
lora_key = lora_prefix + diffusers_key.replace(".", "_")
|
||||||
|
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
|
||||||
|
# by default (see `lora_calc_updown` in SD-WebUI for more details)
|
||||||
|
state_dict[lora_key + ".lora_up.weight"] = lora_weights[2 * i]
|
||||||
|
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
|
||||||
|
|
||||||
|
assert state_dict
|
||||||
|
save_to_safetensors(args.output_file, state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
50
scripts/convert-sd-lda-weights.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers.models.autoencoder_kl import AutoencoderKL
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = LatentDiffusionAutoencoder()
|
||||||
|
x = torch.randn(1, 3, 512, 512)
|
||||||
|
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||||||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="lda.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = DiffusionPipeline.from_pretrained(args.source).vae
|
||||||
|
tensors = convert(src_model)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
59
scripts/convert-sd-unet-inpainting-weights.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
|
|
||||||
|
x = torch.randn(1, 9, 32, 32)
|
||||||
|
timestep = torch.tensor([0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
|
dst_model.set_timestep(timestep)
|
||||||
|
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
||||||
|
dst_args = (x,)
|
||||||
|
|
||||||
|
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
|
||||||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-inpainting",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="stable_diffusion_1_5_inpainting_unet.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = StableDiffusionInpaintPipeline.from_pretrained(args.source).unet
|
||||||
|
tensors = convert(src_model)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
59
scripts/convert-sd-unet-weights.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
|
||||||
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
timestep = torch.tensor([0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
|
dst_model.set_timestep(timestep)
|
||||||
|
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
||||||
|
dst_args = (x,)
|
||||||
|
|
||||||
|
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
|
||||||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="stable_diffusion_1_5_unet.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = DiffusionPipeline.from_pretrained(args.source).unet
|
||||||
|
tensors = convert(src_model)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
57
scripts/convert-sdxl-text-encoder-2.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file # type: ignore
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = CLIPTextEncoderG()
|
||||||
|
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
|
||||||
|
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||||
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||||
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||||
|
if mapping is None:
|
||||||
|
raise RuntimeError("Could not create state dict mapping")
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="stabilityai/stable-diffusion-xl-base-0.9",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="CLIPTextEncoderG.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
|
||||||
|
tensors = convert(src_model=src_model)
|
||||||
|
save_file(tensors=tensors, filename=args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
68
scripts/convert-sdxl-unet-weights.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file # type: ignore
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
|
dst_model = SDXLUNet(in_channels=4)
|
||||||
|
|
||||||
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
timestep = torch.tensor([0])
|
||||||
|
clip_text_embeddings = torch.randn(1, 77, 2048)
|
||||||
|
|
||||||
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||||
|
src_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
|
||||||
|
dst_model.set_timestep(timestep=timestep)
|
||||||
|
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
|
dst_model.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||||
|
dst_model.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||||
|
dst_args = (x,)
|
||||||
|
|
||||||
|
mapping = create_state_dict_mapping(
|
||||||
|
source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args # type: ignore
|
||||||
|
)
|
||||||
|
if mapping is None:
|
||||||
|
raise RuntimeError("Could not create state dict mapping")
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
|
return {k: v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="stabilityai/stable-diffusion-xl-base-0.9",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="stable_diffusion_xl_unet.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
||||||
|
tensors = convert(src_model)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
148
scripts/training/finetune-ldm-lora.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
import random
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from loguru import logger
|
||||||
|
from refiners.adapters.lora import LoraAdapter, Lora
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.training_utils.latent_diffusion import (
|
||||||
|
FinetuneLatentDiffusionConfig,
|
||||||
|
TextEmbeddingLatentsBatch,
|
||||||
|
TextEmbeddingLatentsDataset,
|
||||||
|
LatentDiffusionTrainer,
|
||||||
|
LatentDiffusionConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig(BaseModel):
|
||||||
|
rank: int = 32
|
||||||
|
trigger_phrase: str = ""
|
||||||
|
use_only_trigger_probability: float = 0.0
|
||||||
|
unet_targets: list[LoraTarget]
|
||||||
|
text_encoder_targets: list[LoraTarget]
|
||||||
|
lda_targets: list[LoraTarget]
|
||||||
|
|
||||||
|
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
||||||
|
for layer in module.layers(layer_type=target.get_class()):
|
||||||
|
for linear, parent in layer.walk(fl.Linear):
|
||||||
|
adapter = LoraAdapter(
|
||||||
|
target=linear,
|
||||||
|
rank=self.rank,
|
||||||
|
device=module.device,
|
||||||
|
dtype=module.dtype,
|
||||||
|
)
|
||||||
|
adapter.inject(parent)
|
||||||
|
for linear in adapter.Lora.layers(fl.Linear):
|
||||||
|
linear.requires_grad_(requires_grad=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
trainer: "LoraLatentDiffusionTrainer",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(trainer=trainer)
|
||||||
|
self.trigger_phrase = trainer.config.lora.trigger_phrase
|
||||||
|
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
|
||||||
|
logger.info(f"Trigger phrase: {self.trigger_phrase}")
|
||||||
|
|
||||||
|
def process_caption(self, caption: str) -> str:
|
||||||
|
caption = super().process_caption(caption=caption)
|
||||||
|
if self.trigger_phrase:
|
||||||
|
caption = (
|
||||||
|
f"{self.trigger_phrase} {caption}"
|
||||||
|
if random.random() < self.use_only_trigger_probability
|
||||||
|
else self.trigger_phrase
|
||||||
|
)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||||
|
latent_diffusion: LatentDiffusionConfig
|
||||||
|
lora: LoraConfig
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
"""Pydantic v2 does post init differently, so we need to override this method too."""
|
||||||
|
logger.info("Freezing models to train only the loras.")
|
||||||
|
self.models["unet"].train = False
|
||||||
|
self.models["text_encoder"].train = False
|
||||||
|
self.models["lda"].train = False
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LoraLatentDiffusionConfig,
|
||||||
|
callbacks: "list[Callback[Any]] | None" = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config=config, callbacks=callbacks)
|
||||||
|
self.callbacks.extend((LoadLoras(), SaveLoras()))
|
||||||
|
|
||||||
|
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||||
|
return TriggerPhraseDataset(trainer=self)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||||
|
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||||
|
lora_config = trainer.config.lora
|
||||||
|
for target in lora_config.unet_targets:
|
||||||
|
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
|
||||||
|
for target in lora_config.text_encoder_targets:
|
||||||
|
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
|
||||||
|
for target in lora_config.lda_targets:
|
||||||
|
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||||
|
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||||
|
lora_config = trainer.config.lora
|
||||||
|
|
||||||
|
def get_weight(linear: fl.Linear) -> Tensor:
|
||||||
|
assert linear.bias is None
|
||||||
|
return linear.state_dict()["weight"]
|
||||||
|
|
||||||
|
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
|
||||||
|
weights: list[Tensor] = []
|
||||||
|
for lora in module.layers(layer_type=Lora):
|
||||||
|
linears = list(lora.layers(fl.Linear))
|
||||||
|
assert len(linears) == 2
|
||||||
|
# See `load_lora_weights` in refiners.adapters.lora
|
||||||
|
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
|
||||||
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
|
||||||
|
|
||||||
|
tensors: dict[str, Tensor] = {}
|
||||||
|
metadata: dict[str, str] = {}
|
||||||
|
|
||||||
|
if lora_config.unet_targets:
|
||||||
|
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
|
||||||
|
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
|
||||||
|
|
||||||
|
if lora_config.text_encoder_targets:
|
||||||
|
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
|
||||||
|
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
|
||||||
|
|
||||||
|
if lora_config.lda_targets:
|
||||||
|
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
|
||||||
|
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
|
||||||
|
|
||||||
|
save_to_safetensors(
|
||||||
|
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
||||||
|
tensors=tensors,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
config = LoraLatentDiffusionConfig.load_from_toml(
|
||||||
|
toml_path=config_path,
|
||||||
|
)
|
||||||
|
trainer = LoraLatentDiffusionTrainer(config=config)
|
||||||
|
trainer.train()
|
11
scripts/training/finetune-ldm.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
config = FinetuneLatentDiffusionConfig.load_from_toml(
|
||||||
|
toml_path=config_path,
|
||||||
|
)
|
||||||
|
trainer = LatentDiffusionTrainer(config=config)
|
||||||
|
trainer.train()
|
0
src/refiners/__init__.py
Normal file
0
src/refiners/adapters/__init__.py
Normal file
66
src/refiners/adapters/adapter.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import contextlib
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from typing import Any, Generic, TypeVar, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=fl.Module)
|
||||||
|
TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]")
|
||||||
|
|
||||||
|
|
||||||
|
class Adapter(Generic[T]):
|
||||||
|
# we store _target into a one element list to avoid pytorch thinking it is a submodule
|
||||||
|
_target: "list[T]"
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
assert issubclass(cls, fl.Chain), f"Adapter {cls.__name__} must be a Chain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target(self) -> T:
|
||||||
|
return self._target[0]
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def setup_adapter(self, target: T) -> Iterator[None]:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
assert (not hasattr(self, "_modules")) or (
|
||||||
|
len(self) == 0
|
||||||
|
), "Call the Chain constructor in the setup_adapter context."
|
||||||
|
self._target = [target]
|
||||||
|
|
||||||
|
if not isinstance(self.target, fl.ContextModule):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
_old_can_refresh_parent = target._can_refresh_parent
|
||||||
|
target._can_refresh_parent = False
|
||||||
|
yield
|
||||||
|
target._can_refresh_parent = _old_can_refresh_parent
|
||||||
|
|
||||||
|
def inject(self, parent: fl.Chain | None = None) -> None:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
|
if parent is None:
|
||||||
|
if isinstance(self.target, fl.ContextModule):
|
||||||
|
parent = self.target.parent
|
||||||
|
else:
|
||||||
|
raise ValueError(f"parent of {self.target} is mandatory")
|
||||||
|
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
|
||||||
|
if self.target not in iter(parent):
|
||||||
|
raise ValueError(f"{self.target} is not in {parent}")
|
||||||
|
|
||||||
|
parent.replace(
|
||||||
|
old_module=self.target,
|
||||||
|
new_module=self,
|
||||||
|
old_module_parent=self.find_parent(self.target),
|
||||||
|
)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
self.ensure_parent.replace(old_module=self, new_module=self.target)
|
||||||
|
|
||||||
|
def _pre_structural_copy(self) -> None:
|
||||||
|
if isinstance(self.target, fl.Chain):
|
||||||
|
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
|
||||||
|
|
||||||
|
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||||
|
self._target = [source.target]
|
88
src/refiners/adapters/lora.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
from torch.nn.init import zeros_, normal_
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
|
||||||
|
class Lora(fl.Chain):
|
||||||
|
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
rank: int = 16,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.rank = rank
|
||||||
|
self.scale: float = 1.0
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
|
||||||
|
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
|
||||||
|
fl.Lambda(func=self.scale_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
|
||||||
|
zeros_(tensor=self.Linear_2.weight)
|
||||||
|
|
||||||
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
|
self.Linear_1.weight = down_weight
|
||||||
|
self.Linear_2.weight = up_weight
|
||||||
|
|
||||||
|
|
||||||
|
class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||||
|
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Linear,
|
||||||
|
rank: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = target.in_features
|
||||||
|
self.out_features = target.out_features
|
||||||
|
self.rank = rank
|
||||||
|
self.scale = scale
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
Lora(
|
||||||
|
in_features=target.in_features,
|
||||||
|
out_features=target.out_features,
|
||||||
|
rank=rank,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.Lora.set_scale(scale=scale)
|
||||||
|
|
||||||
|
def add_lora(self, lora: Lora) -> None:
|
||||||
|
self.append(module=lora)
|
||||||
|
|
||||||
|
def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None:
|
||||||
|
self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None:
|
||||||
|
assert len(weights) % 2 == 0, "Number of weights must be even"
|
||||||
|
assert (
|
||||||
|
len(list(model.layers(layer_type=Lora))) == len(weights) // 2
|
||||||
|
), "Number of Lora layers must match number of weights"
|
||||||
|
for i, lora in enumerate(iterable=model.layers(layer_type=Lora)):
|
||||||
|
assert (
|
||||||
|
lora.rank == weights[i * 2].shape[1]
|
||||||
|
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||||
|
lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
70
src/refiners/adapters/range_adapter.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import math
|
||||||
|
from torch import Tensor, arange, float32, exp, sin, cat, cos, device as Device, dtype as DType
|
||||||
|
from jaxtyping import Float, Int
|
||||||
|
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sinusoidal_embedding(
|
||||||
|
x: Int[Tensor, "*batch 1"],
|
||||||
|
embedding_dim: int,
|
||||||
|
) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
# Note: it is important that this computation is done in float32.
|
||||||
|
# The result can be cast to lower precision later if necessary.
|
||||||
|
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
|
||||||
|
exponent /= half_dim
|
||||||
|
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
|
||||||
|
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class RangeEncoder(fl.Chain):
|
||||||
|
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sinuosidal_embedding_dim: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(self.compute_sinuosoidal_embedding),
|
||||||
|
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
||||||
|
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim).to(self.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
|
structural_attrs = ["channels", "embedding_dim", "context_key"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Conv2d,
|
||||||
|
channels: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channels = channels
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.context_key = context_key
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext("range_adapter", context_key),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
||||||
|
fl.View(-1, channels, 1, 1),
|
||||||
|
),
|
||||||
|
)
|
3
src/refiners/fluxion/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors, load_from_safetensors, norm, manual_seed, pad
|
||||||
|
|
||||||
|
__all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"]
|
52
src/refiners/fluxion/context.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from typing import Any
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
Context = dict[str, Any]
|
||||||
|
Contexts = dict[str, Context]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextProvider:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.contexts: Contexts = {}
|
||||||
|
|
||||||
|
def set_context(self, key: str, value: Context) -> None:
|
||||||
|
self.contexts[key] = value
|
||||||
|
|
||||||
|
def get_context(self, key: str) -> Any:
|
||||||
|
return self.contexts.get(key)
|
||||||
|
|
||||||
|
def update_contexts(self, new_contexts: Contexts) -> None:
|
||||||
|
for key, value in new_contexts.items():
|
||||||
|
if key not in self.contexts:
|
||||||
|
self.contexts[key] = value
|
||||||
|
else:
|
||||||
|
self.contexts[key].update(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(contexts: Contexts) -> "ContextProvider":
|
||||||
|
provider = ContextProvider()
|
||||||
|
provider.update_contexts(contexts)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
def __add__(self, other: "ContextProvider") -> "ContextProvider":
|
||||||
|
self.contexts.update(other.contexts)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
|
||||||
|
other.contexts.update(self.contexts)
|
||||||
|
return other
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.contexts)
|
||||||
|
|
||||||
|
def _get_repr_for_value(self, value: Any) -> str:
|
||||||
|
if isinstance(value, Tensor):
|
||||||
|
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
|
||||||
|
return repr(value)
|
||||||
|
|
||||||
|
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
|
||||||
|
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
|
||||||
|
return f"{self.__class__.__name__}(contexts={contexts_repr})"
|
82
src/refiners/fluxion/layers/__init__.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU
|
||||||
|
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d
|
||||||
|
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
||||||
|
from refiners.fluxion.layers.basics import (
|
||||||
|
Identity,
|
||||||
|
View,
|
||||||
|
Flatten,
|
||||||
|
Unflatten,
|
||||||
|
Transpose,
|
||||||
|
Permute,
|
||||||
|
Reshape,
|
||||||
|
Squeeze,
|
||||||
|
Unsqueeze,
|
||||||
|
Slicing,
|
||||||
|
Parameter,
|
||||||
|
Buffer,
|
||||||
|
)
|
||||||
|
from refiners.fluxion.layers.chain import (
|
||||||
|
Lambda,
|
||||||
|
Sum,
|
||||||
|
Residual,
|
||||||
|
Return,
|
||||||
|
Chain,
|
||||||
|
UseContext,
|
||||||
|
SetContext,
|
||||||
|
Parallel,
|
||||||
|
Passthrough,
|
||||||
|
Breakpoint,
|
||||||
|
Concatenate,
|
||||||
|
)
|
||||||
|
from refiners.fluxion.layers.conv import Conv2d
|
||||||
|
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||||
|
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
||||||
|
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||||
|
from refiners.fluxion.layers.embedding import Embedding
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Embedding",
|
||||||
|
"LayerNorm",
|
||||||
|
"GroupNorm",
|
||||||
|
"LayerNorm2d",
|
||||||
|
"GeLU",
|
||||||
|
"GLU",
|
||||||
|
"SiLU",
|
||||||
|
"ReLU",
|
||||||
|
"ApproximateGeLU",
|
||||||
|
"Attention",
|
||||||
|
"SelfAttention",
|
||||||
|
"SelfAttention2d",
|
||||||
|
"Identity",
|
||||||
|
"View",
|
||||||
|
"Flatten",
|
||||||
|
"Unflatten",
|
||||||
|
"Transpose",
|
||||||
|
"Permute",
|
||||||
|
"Squeeze",
|
||||||
|
"Unsqueeze",
|
||||||
|
"Reshape",
|
||||||
|
"Slicing",
|
||||||
|
"Parameter",
|
||||||
|
"Buffer",
|
||||||
|
"Lambda",
|
||||||
|
"Return",
|
||||||
|
"Sum",
|
||||||
|
"Residual",
|
||||||
|
"Chain",
|
||||||
|
"UseContext",
|
||||||
|
"SetContext",
|
||||||
|
"Parallel",
|
||||||
|
"Passthrough",
|
||||||
|
"Breakpoint",
|
||||||
|
"Concatenate",
|
||||||
|
"Conv2d",
|
||||||
|
"Linear",
|
||||||
|
"MultiLinear",
|
||||||
|
"Downsample",
|
||||||
|
"Upsample",
|
||||||
|
"Module",
|
||||||
|
"WeightedModule",
|
||||||
|
"ContextModule",
|
||||||
|
"Interpolate",
|
||||||
|
]
|
66
src/refiners/fluxion/layers/activations.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
from torch import Tensor, sigmoid
|
||||||
|
from torch.nn.functional import gelu # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Activation(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
class SiLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return silu(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.relu()
|
||||||
|
|
||||||
|
|
||||||
|
class GeLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return gelu(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ApproximateGeLU(Activation):
|
||||||
|
"""
|
||||||
|
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||||
|
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x * sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(Activation):
|
||||||
|
"""
|
||||||
|
Gated Linear Unit activation layer.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/2002.05202v1 for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, activation: Activation) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(activation={self.activation})"
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
assert x.shape[-1] % 2 == 0, "Non-batch input dimension must be divisible by 2"
|
||||||
|
output, gate = x.chunk(2, dim=-1)
|
||||||
|
return output * self.activation(gate)
|
189
src/refiners/fluxion/layers/attentions.py
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.fluxion.layers.linear import Linear
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
from refiners.fluxion.layers.chain import Chain, Distribute, Parallel, Lambda
|
||||||
|
from refiners.fluxion.layers.basics import Identity
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||||
|
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
|
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttention(Module):
|
||||||
|
def __init__(self, num_heads: int = 1, is_causal: bool | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
return self.merge_multi_head(
|
||||||
|
scaled_dot_product_attention(
|
||||||
|
query=self.split_to_multi_head(query),
|
||||||
|
key=self.split_to_multi_head(key),
|
||||||
|
value=self.split_to_multi_head(value),
|
||||||
|
is_causal=(
|
||||||
|
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_to_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
||||||
|
assert (
|
||||||
|
len(x.shape) == 3
|
||||||
|
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % self.num_heads == 0
|
||||||
|
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
||||||
|
|
||||||
|
def merge_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
|
||||||
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(Chain):
|
||||||
|
structural_attrs = [
|
||||||
|
"embedding_dim",
|
||||||
|
"num_heads",
|
||||||
|
"heads_dim",
|
||||||
|
"key_embedding_dim",
|
||||||
|
"value_embedding_dim",
|
||||||
|
"use_bias",
|
||||||
|
"is_causal",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int = 1,
|
||||||
|
key_embedding_dim: int | None = None,
|
||||||
|
value_embedding_dim: int | None = None,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
embedding_dim % num_heads == 0
|
||||||
|
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.heads_dim = embedding_dim // num_heads
|
||||||
|
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
||||||
|
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.is_causal = is_causal
|
||||||
|
super().__init__(
|
||||||
|
Distribute(
|
||||||
|
Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Linear(
|
||||||
|
in_features=self.value_embedding_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
|
||||||
|
Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(Attention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
is_causal=is_causal,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.insert(0, Parallel(Identity(), Identity(), Identity()))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention2d(SelfAttention):
|
||||||
|
structural_attrs = ["channels"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
|
||||||
|
self.channels = channels
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
is_causal=is_causal,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.insert(0, Lambda(self.tensor_2d_to_sequence))
|
||||||
|
self.append(Lambda(self.sequence_to_tensor_2d))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"reshape": {"height": None, "width": None}}
|
||||||
|
|
||||||
|
def tensor_2d_to_sequence(
|
||||||
|
self, x: Float[Tensor, "batch channels height width"]
|
||||||
|
) -> Float[Tensor, "batch height*width channels"]:
|
||||||
|
height, width = x.shape[-2:]
|
||||||
|
self.set_context(context="reshape", value={"height": height, "width": width})
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
def sequence_to_tensor_2d(
|
||||||
|
self, x: Float[Tensor, "batch sequence_length channels"]
|
||||||
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
height, width = self.use_context("reshape").values()
|
||||||
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)
|
183
src/refiners/fluxion/layers/basics.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
from torch import randn, Tensor, Size, device as Device, dtype as DType
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class View(Module):
|
||||||
|
def __init__(self, *shape: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.view(*self.shape)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
shape_repr = ", ".join([repr(s) for s in self.shape])
|
||||||
|
return f"{self.__class__.__name__}({shape_repr})"
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(Module):
|
||||||
|
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.start_dim = start_dim
|
||||||
|
self.end_dim = end_dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.flatten(self.start_dim, self.end_dim)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Unflatten(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
||||||
|
return x.unflatten(self.dim, sizes) # type: ignore
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Reshape(Module):
|
||||||
|
"""
|
||||||
|
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
|
||||||
|
dimension is preserved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *shape: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.reshape(x.shape[0], *self.shape)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
shape_repr = ", ".join([repr(s) for s in self.shape])
|
||||||
|
return f"{self.__class__.__name__}({shape_repr})"
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(Module):
|
||||||
|
def __init__(self, dim0: int, dim1: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim0 = dim0
|
||||||
|
self.dim1 = dim1
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.transpose(self.dim0, self.dim1)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Permute(Module):
|
||||||
|
def __init__(self, *dims: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.permute(*self.dims)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
dims_repr = ", ".join([repr(d) for d in self.dims])
|
||||||
|
return f"{self.__class__.__name__}({dims_repr})"
|
||||||
|
|
||||||
|
|
||||||
|
class Slicing(Module):
|
||||||
|
def __init__(self, dim: int, start: int, length: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.start = start
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.narrow(self.dim, self.start, self.length)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Squeeze(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.squeeze(self.dim)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Unsqueeze(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.unsqueeze(self.dim)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Parameter(WeightedModule):
|
||||||
|
"""
|
||||||
|
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.parameter.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.parameter.dtype
|
||||||
|
|
||||||
|
def forward(self, _: Tensor) -> Tensor:
|
||||||
|
return self.parameter
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
|
||||||
|
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Buffer(WeightedModule):
|
||||||
|
"""
|
||||||
|
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
|
||||||
|
|
||||||
|
Buffers are not trainable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.buffer.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.buffer.dtype
|
||||||
|
|
||||||
|
def forward(self, _: Tensor) -> Tensor:
|
||||||
|
return self.buffer
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
|
||||||
|
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
466
src/refiners/fluxion/layers/chain.py
Normal file
|
@ -0,0 +1,466 @@
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||||
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
from refiners.fluxion.layers.basics import Identity
|
||||||
|
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
||||||
|
from refiners.fluxion.context import Contexts, ContextProvider
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Module)
|
||||||
|
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(Module):
|
||||||
|
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[..., Any]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Any:
|
||||||
|
return self.func(*args)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
func_name = getattr(self.func, "__name__", "partial_function")
|
||||||
|
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unique_names(
|
||||||
|
modules: tuple[Module, ...],
|
||||||
|
) -> dict[str, Module]:
|
||||||
|
class_counts: dict[str, int] = {}
|
||||||
|
unique_names: list[tuple[str, Module]] = []
|
||||||
|
for module in modules:
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
||||||
|
name_counter: dict[str, int] = {}
|
||||||
|
for module in modules:
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
name_counter[class_name] = name_counter.get(class_name, 0) + 1
|
||||||
|
unique_name = f"{class_name}_{name_counter[class_name]}" if class_counts[class_name] > 1 else class_name
|
||||||
|
unique_names.append((unique_name, module))
|
||||||
|
return dict(unique_names)
|
||||||
|
|
||||||
|
|
||||||
|
class UseContext(ContextModule):
|
||||||
|
structural_attrs = ["context", "key", "func"]
|
||||||
|
|
||||||
|
def __init__(self, context: str, key: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.func: Callable[[Any], Any] = lambda x: x
|
||||||
|
|
||||||
|
def __call__(self, *args: Any) -> Any:
|
||||||
|
context = self.use_context(self.context)
|
||||||
|
assert context, f"context {self.context} is unset"
|
||||||
|
value = context.get(self.key)
|
||||||
|
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
||||||
|
return self.func(value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
||||||
|
self.func = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class SetContext(ContextModule):
|
||||||
|
"""A Module that sets a context value when executed.
|
||||||
|
|
||||||
|
The context need to pre exist in the context provider.
|
||||||
|
#TODO Is there a way to create the context if it doesn't exist?
|
||||||
|
"""
|
||||||
|
|
||||||
|
structural_attrs = ["context", "key", "callback"]
|
||||||
|
|
||||||
|
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
|
if context := self.use_context(self.context):
|
||||||
|
if not self.callback:
|
||||||
|
context.update({self.key: x})
|
||||||
|
else:
|
||||||
|
self.callback(context[self.key], x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
"""Exception raised when a Return module is encountered."""
|
||||||
|
|
||||||
|
def __init__(self, value: Tensor):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class Return(Module):
|
||||||
|
"""A Module that stops the execution of a Chain when encountered."""
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
raise ReturnException(x)
|
||||||
|
|
||||||
|
|
||||||
|
def structural_copy(m: T) -> T:
|
||||||
|
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||||
|
|
||||||
|
|
||||||
|
class Chain(ContextModule):
|
||||||
|
_modules: dict[str, Module]
|
||||||
|
_provider: ContextProvider
|
||||||
|
|
||||||
|
def __init__(self, *args: Module | Iterable[Module]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._provider = ContextProvider()
|
||||||
|
modules = cast(
|
||||||
|
tuple[Module],
|
||||||
|
(
|
||||||
|
tuple(args[0])
|
||||||
|
if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
|
||||||
|
else tuple(args)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
# Violating this would mean a ContextModule ends up in two chains,
|
||||||
|
# with a single one correctly set as its parent.
|
||||||
|
assert (
|
||||||
|
(not isinstance(module, ContextModule))
|
||||||
|
or (not module._can_refresh_parent)
|
||||||
|
or (module.parent is None)
|
||||||
|
or (module.parent == self)
|
||||||
|
), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"
|
||||||
|
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
self._reset_context()
|
||||||
|
|
||||||
|
for module in self:
|
||||||
|
if isinstance(module, ContextModule) and module._can_refresh_parent and module.parent != self:
|
||||||
|
module._set_parent(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> ContextProvider:
|
||||||
|
return self._provider
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _register_provider(self, context: Contexts | None = None) -> None:
|
||||||
|
if context:
|
||||||
|
self._provider.update_contexts(context)
|
||||||
|
|
||||||
|
for module in self:
|
||||||
|
if isinstance(module, Chain):
|
||||||
|
module._register_provider(context=self._provider.contexts)
|
||||||
|
|
||||||
|
def _reset_context(self) -> None:
|
||||||
|
self._register_provider(self.init_context())
|
||||||
|
|
||||||
|
def set_context(self, context: str, value: Any) -> None:
|
||||||
|
self._provider.set_context(context, value)
|
||||||
|
self._register_provider()
|
||||||
|
|
||||||
|
def debug_repr(self, layer_name: str = "") -> str:
|
||||||
|
lines: list[str] = []
|
||||||
|
tab = " "
|
||||||
|
tab_length = 0
|
||||||
|
for i, parent in enumerate(self.get_parents()[::-1]):
|
||||||
|
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
|
||||||
|
tab_length += 1
|
||||||
|
|
||||||
|
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
|
||||||
|
|
||||||
|
for name, _ in self._modules.items():
|
||||||
|
error_arrow = "⚠️" if name == layer_name else ""
|
||||||
|
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def call_layer(self, layer: Module, layer_name: str, *args: Any):
|
||||||
|
try:
|
||||||
|
return layer(*args)
|
||||||
|
except Exception as e:
|
||||||
|
pretty_print = self.debug_repr(layer_name)
|
||||||
|
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Any:
|
||||||
|
result: tuple[Any] | Any = None
|
||||||
|
intermediate_args: tuple[Any, ...] = args
|
||||||
|
for name, layer in self._modules.items():
|
||||||
|
result = self.call_layer(layer, name, *intermediate_args)
|
||||||
|
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
||||||
|
|
||||||
|
self._reset_context()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _regenerate_keys(self, modules: Iterable[Module]) -> None:
|
||||||
|
self._modules = generate_unique_names(tuple(modules)) # type: ignore
|
||||||
|
|
||||||
|
def __add__(self, other: "Chain | Module | list[Module]") -> "Chain":
|
||||||
|
if isinstance(other, Module):
|
||||||
|
other = Chain(other)
|
||||||
|
if isinstance(other, list):
|
||||||
|
other = Chain(*other)
|
||||||
|
return Chain(*self, *other)
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
|
if isinstance(key, slice):
|
||||||
|
return Chain(*list(self)[key])
|
||||||
|
elif isinstance(key, str):
|
||||||
|
return self._modules[key]
|
||||||
|
else:
|
||||||
|
return list(self)[key]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Module]:
|
||||||
|
return iter(self._modules.values())
|
||||||
|
|
||||||
|
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
|
||||||
|
layer_name = self.__class__.__name__ if layer_name is None else layer_name
|
||||||
|
pretty_print = f"{layer_name}:\n"
|
||||||
|
tab = " " * (num_tab + 4)
|
||||||
|
module_strings: list[str] = []
|
||||||
|
for i, (name, module) in enumerate(self._modules.items()):
|
||||||
|
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
|
||||||
|
module_str = (
|
||||||
|
module
|
||||||
|
if not isinstance(module, Chain)
|
||||||
|
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
|
||||||
|
)
|
||||||
|
module_strings.append(f"{tab}{ident} {module_str}")
|
||||||
|
pretty_print += "\n".join(module_strings)
|
||||||
|
return pretty_print
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self._pretty_print()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} at {hex(id(self))}>"
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._modules)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device | None:
|
||||||
|
wm = self.find(WeightedModule)
|
||||||
|
return None if wm is None else wm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType | None:
|
||||||
|
wm = self.find(WeightedModule)
|
||||||
|
return None if wm is None else wm.dtype
|
||||||
|
|
||||||
|
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
if predicate is None:
|
||||||
|
predicate = lambda _m, _p: True
|
||||||
|
for module in self:
|
||||||
|
keep_going = True
|
||||||
|
try:
|
||||||
|
p = predicate(module, self)
|
||||||
|
except StopIteration:
|
||||||
|
p = False
|
||||||
|
keep_going = False
|
||||||
|
if p:
|
||||||
|
yield (module, self)
|
||||||
|
if keep_going and isinstance(module, Chain):
|
||||||
|
yield from module.walk(predicate)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def walk(
|
||||||
|
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None
|
||||||
|
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||||
|
if isinstance(predicate, type):
|
||||||
|
return self._walk(lambda m, _: isinstance(m, predicate))
|
||||||
|
else:
|
||||||
|
return self._walk(predicate)
|
||||||
|
|
||||||
|
def layers(self, layer_type: type[T]) -> Iterator[T]:
|
||||||
|
for module, _ in self.walk(layer_type):
|
||||||
|
yield module
|
||||||
|
|
||||||
|
def find(self, layer_type: type[T]) -> T | None:
|
||||||
|
return next(self.layers(layer_type=layer_type), None)
|
||||||
|
|
||||||
|
def find_parent(self, module: Module) -> "Chain | None":
|
||||||
|
if module in self: # avoid DFS-crawling the whole tree
|
||||||
|
return self
|
||||||
|
for _, parent in self.walk(lambda m, _: m == module):
|
||||||
|
return parent
|
||||||
|
return None
|
||||||
|
|
||||||
|
def insert(self, index: int, module: Module) -> None: # type: ignore
|
||||||
|
if index < 0:
|
||||||
|
index = max(0, len(self._modules) + index + 1)
|
||||||
|
modules = list(self)
|
||||||
|
modules.insert(index, module)
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(self)
|
||||||
|
self._register_provider()
|
||||||
|
|
||||||
|
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||||
|
for i, module in enumerate(self):
|
||||||
|
if isinstance(module, module_type):
|
||||||
|
self.insert(i + 1, new_module)
|
||||||
|
return
|
||||||
|
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||||
|
|
||||||
|
def append(self, module: Module) -> None: # type: ignore
|
||||||
|
modules = list(self)
|
||||||
|
modules.append(module)
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(self)
|
||||||
|
self._register_provider()
|
||||||
|
|
||||||
|
def pop(self, index: int = -1) -> Module | tuple[Module]: # type: ignore
|
||||||
|
modules = list(self)
|
||||||
|
if index < 0:
|
||||||
|
index = len(modules) + index
|
||||||
|
if index < 0 or index >= len(modules):
|
||||||
|
raise IndexError("Index out of range.")
|
||||||
|
removed_module = modules.pop(index)
|
||||||
|
if isinstance(removed_module, ContextModule):
|
||||||
|
removed_module._set_parent(None)
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
return removed_module
|
||||||
|
|
||||||
|
def remove(self, module: Module) -> None:
|
||||||
|
"""Remove a module from the chain."""
|
||||||
|
modules = list(self)
|
||||||
|
try:
|
||||||
|
modules.remove(module)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{module} is not in {self}")
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(None)
|
||||||
|
|
||||||
|
def replace(
|
||||||
|
self,
|
||||||
|
old_module: Module,
|
||||||
|
new_module: Module,
|
||||||
|
old_module_parent: "Chain | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""Replace a module in the chain with a new module."""
|
||||||
|
modules = list(self)
|
||||||
|
try:
|
||||||
|
modules[modules.index(old_module)] = new_module
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{old_module} is not in {self}")
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(new_module, ContextModule):
|
||||||
|
new_module._set_parent(self)
|
||||||
|
if isinstance(old_module, ContextModule):
|
||||||
|
old_module._set_parent(old_module_parent)
|
||||||
|
|
||||||
|
def structural_copy(self: TChain) -> TChain:
|
||||||
|
"""Copy the structure of the Chain tree.
|
||||||
|
|
||||||
|
This method returns a recursive copy of the Chain tree where all inner nodes
|
||||||
|
(instances of Chain and its subclasses) are duplicated and all leaves
|
||||||
|
(regular Modules) are not.
|
||||||
|
|
||||||
|
Such copies can be adapted without disrupting the base model, but do not
|
||||||
|
require extra GPU memory since the weights are in the leaves and hence not copied.
|
||||||
|
|
||||||
|
This assumes all subclasses define the class variable `structural_attrs` which
|
||||||
|
contains a list of basic attributes set in the constructor. In complicated cases
|
||||||
|
it may be required to overwrite that method.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_pre_structural_copy"):
|
||||||
|
self._pre_structural_copy()
|
||||||
|
|
||||||
|
modules = [structural_copy(m) for m in self]
|
||||||
|
|
||||||
|
# Instantiate the right subclass, but do not initialize.
|
||||||
|
clone = object.__new__(self.__class__)
|
||||||
|
|
||||||
|
# Copy all basic attributes of the class declared in `structural_attrs`.
|
||||||
|
for k in self.__class__.structural_attrs:
|
||||||
|
setattr(clone, k, getattr(self, k))
|
||||||
|
|
||||||
|
# Call constructor of Chain, which among other things refreshes the context tree.
|
||||||
|
Chain.__init__(clone, *modules)
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(clone)
|
||||||
|
|
||||||
|
if hasattr(clone, "_post_structural_copy"):
|
||||||
|
clone._post_structural_copy(self)
|
||||||
|
|
||||||
|
return clone
|
||||||
|
|
||||||
|
|
||||||
|
class Parallel(Chain):
|
||||||
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
|
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||||
|
|
||||||
|
|
||||||
|
class Distribute(Chain):
|
||||||
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
|
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
|
||||||
|
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||||
|
|
||||||
|
|
||||||
|
class Passthrough(Chain):
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
super().forward(*inputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class Sum(Chain):
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
output = None
|
||||||
|
for layer in self:
|
||||||
|
layer_output: Any = layer(*inputs)
|
||||||
|
if isinstance(layer_output, tuple):
|
||||||
|
layer_output = sum(layer_output) # type: ignore
|
||||||
|
output = layer_output if output is None else output + layer_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Residual(Sum):
|
||||||
|
def __init__(self, *modules: Module) -> None:
|
||||||
|
super().__init__(Identity(), Chain(*modules))
|
||||||
|
|
||||||
|
|
||||||
|
class Breakpoint(Module):
|
||||||
|
def __init__(self, vscode: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.vscode = vscode
|
||||||
|
|
||||||
|
def forward(self, *args: Any):
|
||||||
|
if self.vscode:
|
||||||
|
import debugpy # type: ignore
|
||||||
|
|
||||||
|
debugpy.breakpoint() # type: ignore
|
||||||
|
else:
|
||||||
|
breakpoint()
|
||||||
|
return args[0] if len(args) == 1 else args
|
||||||
|
|
||||||
|
|
||||||
|
class Concatenate(Chain):
|
||||||
|
structural_attrs = ["dim"]
|
||||||
|
|
||||||
|
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||||
|
super().__init__(*modules)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Tensor:
|
||||||
|
outputs = [module(*args) for module in self]
|
||||||
|
return cat([output for output in outputs if output is not None], dim=self.dim)
|
73
src/refiners/fluxion/layers/conv.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d(_Conv2d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int, ...],
|
||||||
|
stride: int | tuple[int, ...] = 1,
|
||||||
|
padding: int | tuple[int, ...] | str = 0,
|
||||||
|
dilation: int | tuple[int, ...] = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
use_bias,
|
||||||
|
padding_mode,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.padding = (padding,) if isinstance(padding, int) else padding
|
||||||
|
self.dilation = (dilation,) if isinstance(dilation, int) else dilation
|
||||||
|
self.groups = groups
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(_Conv1d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int, ...],
|
||||||
|
stride: int | tuple[int, ...] = 1,
|
||||||
|
padding: int | tuple[int, ...] | str = 0,
|
||||||
|
dilation: int | tuple[int, ...] = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
use_bias,
|
||||||
|
padding_mode,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_bias = use_bias
|
21
src/refiners/fluxion/layers/embedding.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
from torch.nn import Embedding as _Embedding
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from jaxtyping import Float, Int
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(_Embedding, WeightedModule): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
_Embedding.__init__( # type: ignore
|
||||||
|
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
|
||||||
|
return super().forward(x)
|
50
src/refiners/fluxion/layers/linear.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
from torch.nn import Linear as _Linear
|
||||||
|
from torch import Tensor
|
||||||
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
from refiners.fluxion.layers.activations import ReLU
|
||||||
|
from refiners.fluxion.layers.chain import Chain
|
||||||
|
|
||||||
|
from jaxtyping import Float
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(_Linear, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Float[Tensor, "batch in_features"]) -> Float[Tensor, "batch out_features"]: # type: ignore
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLinear(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
inner_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
layers: list[Module] = []
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
layers.append(Linear(input_dim if i == 0 else inner_dim, inner_dim, device=device, dtype=dtype))
|
||||||
|
layers.append(ReLU())
|
||||||
|
layers.append(Linear(inner_dim, output_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
super().__init__(layers)
|
100
src/refiners/fluxion/layers/module.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator, TypeVar
|
||||||
|
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
from torch.nn.modules.module import Module as TorchModule
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
from refiners.fluxion.context import Context, ContextProvider
|
||||||
|
|
||||||
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.fluxion.layers.chain import Chain
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Module")
|
||||||
|
TContextModule = TypeVar("TContextModule", bound="ContextModule")
|
||||||
|
|
||||||
|
|
||||||
|
class Module(TorchModule):
|
||||||
|
_parameters: dict[str, Any]
|
||||||
|
_buffers: dict[str, Any]
|
||||||
|
|
||||||
|
__getattr__: Callable[["Module", str], Any] # type: ignore
|
||||||
|
__setattr__: Callable[["Module", str, Any], None] # type: ignore
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, *kwargs) # type: ignore
|
||||||
|
|
||||||
|
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
||||||
|
state_dict = load_from_safetensors(tensors_path)
|
||||||
|
self.load_state_dict(state_dict, strict=strict)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
|
||||||
|
return super().named_modules(*args) # type: ignore
|
||||||
|
|
||||||
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
|
return super().to(device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ContextModule(Module):
|
||||||
|
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||||
|
_parent: "list[Chain]"
|
||||||
|
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||||
|
|
||||||
|
# Contains simple attributes set on the instance by `__init__` in subclasses
|
||||||
|
# and copied by `structural_copy`. Note that is not the case of `device` since
|
||||||
|
# Chain's __init__ takes care of it.
|
||||||
|
structural_attrs: list[str] = []
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, *kwargs)
|
||||||
|
self._parent = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent(self) -> "Chain | None":
|
||||||
|
return self._parent[0] if self._parent else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ensure_parent(self) -> "Chain":
|
||||||
|
assert self._parent, "module is not bound to a Chain"
|
||||||
|
return self._parent[0]
|
||||||
|
|
||||||
|
def _set_parent(self, parent: "Chain | None") -> None:
|
||||||
|
if parent is None:
|
||||||
|
self._parent = []
|
||||||
|
return
|
||||||
|
# Always insert the module in the Chain first to avoid inconsistencies.
|
||||||
|
assert self in iter(parent), f"{self} not in {parent}"
|
||||||
|
self._parent = [parent]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> ContextProvider:
|
||||||
|
return self.ensure_parent.provider
|
||||||
|
|
||||||
|
def get_parents(self) -> "list[Chain]":
|
||||||
|
return self._parent + self._parent[0].get_parents() if self._parent else []
|
||||||
|
|
||||||
|
def use_context(self, context_name: str) -> Context:
|
||||||
|
"""Retrieve the context object from the module's context provider."""
|
||||||
|
context = self.provider.get_context(context_name)
|
||||||
|
assert context is not None, f"Context {context_name} not found."
|
||||||
|
return context
|
||||||
|
|
||||||
|
def structural_copy(self: TContextModule) -> TContextModule:
|
||||||
|
clone = object.__new__(self.__class__)
|
||||||
|
for k in self.__class__.structural_attrs:
|
||||||
|
setattr(clone, k, getattr(self, k))
|
||||||
|
ContextModule.__init__(clone)
|
||||||
|
return clone
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedModule(Module):
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.weight.dtype
|
75
src/refiners/fluxion/layers/norm.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType
|
||||||
|
from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm
|
||||||
|
from jaxtyping import Float
|
||||||
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(_LayerNorm, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape: int | list[int],
|
||||||
|
eps: float = 0.00001,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
normalized_shape=normalized_shape,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=elementwise_affine,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm(_GroupNorm, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_groups: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
affine: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=channels,
|
||||||
|
eps=eps,
|
||||||
|
affine=affine,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.channels = channels
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.eps = eps
|
||||||
|
self.affine = affine
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2d(WeightedModule):
|
||||||
|
"""
|
||||||
|
2D Layer Normalization module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
channels (int): Number of channels in the input tensor.
|
||||||
|
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = Parameter(ones(channels, device=device, dtype=dtype))
|
||||||
|
self.bias = Parameter(zeros(channels, device=device, dtype=dtype))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
x_mean = x.mean(1, keepdim=True)
|
||||||
|
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||||
|
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||||
|
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
return x_out
|
100
src/refiners/fluxion/layers/sampling.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from refiners.fluxion.layers.chain import Chain, UseContext, SetContext
|
||||||
|
from refiners.fluxion.layers.conv import Conv2d
|
||||||
|
from refiners.fluxion.layers.basics import Identity
|
||||||
|
from refiners.fluxion.layers.chain import Parallel, Lambda
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
from refiners.fluxion.utils import interpolate
|
||||||
|
from torch.nn.functional import pad
|
||||||
|
from torch import Tensor, Size, device as Device, dtype as DType
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(Chain):
|
||||||
|
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
scale_factor: int,
|
||||||
|
padding: int = 0,
|
||||||
|
register_shape: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
"""Downsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
|
||||||
|
sampling is not set or if the context does not contain a list.
|
||||||
|
"""
|
||||||
|
self.channels = channels
|
||||||
|
self.in_channels = channels
|
||||||
|
self.out_channels = channels
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.padding = padding
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=scale_factor,
|
||||||
|
padding=padding,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if padding == 0:
|
||||||
|
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
|
||||||
|
if register_shape:
|
||||||
|
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
||||||
|
|
||||||
|
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
||||||
|
shapes.append(x.shape[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, shape: Size) -> Tensor:
|
||||||
|
return interpolate(x, shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(Chain):
|
||||||
|
structural_attrs = ["channels", "upsample_factor"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
upsample_factor: int | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
"""Upsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
|
||||||
|
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
|
||||||
|
"""
|
||||||
|
self.channels = channels
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
super().__init__(
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
(
|
||||||
|
Lambda(self._get_static_shape)
|
||||||
|
if upsample_factor is not None
|
||||||
|
else UseContext(context="sampling", key="shapes").compose(lambda x: x.pop())
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Interpolate(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_static_shape(self, x: Tensor) -> Size:
|
||||||
|
assert self.upsample_factor is not None
|
||||||
|
return Size([size * self.upsample_factor for size in x.shape[2:]])
|
262
src/refiners/fluxion/utils.py
Normal file
|
@ -0,0 +1,262 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
|
||||||
|
from PIL import Image
|
||||||
|
from numpy import array, float32
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
|
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||||
|
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||||
|
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
E = TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
|
def norm(x: Tensor) -> Tensor:
|
||||||
|
return _norm(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int) -> None:
|
||||||
|
_manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor:
|
||||||
|
return _pad(input=x, pad=pad, value=value) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(x: Tensor, factor: float | Size, mode: str = "nearest") -> Tensor:
|
||||||
|
return (
|
||||||
|
_interpolate(x, scale_factor=factor, mode=mode)
|
||||||
|
if isinstance(factor, float | int)
|
||||||
|
else _interpolate(x, size=factor, mode=mode)
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def bidirectional_mapping(mapping: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
return {**mapping, **{value: key for key, value in mapping.items()}}
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
|
return tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
|
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def safe_open(
|
||||||
|
path: Path | str,
|
||||||
|
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
framework_mapping = {
|
||||||
|
"pytorch": "pt",
|
||||||
|
"tensorflow": "tf",
|
||||||
|
"flax": "flax",
|
||||||
|
"numpy": "numpy",
|
||||||
|
}
|
||||||
|
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
|
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
||||||
|
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
|
||||||
|
with safe_open(path=path, framework="pytorch") as tensors: # type: ignore
|
||||||
|
return tensors.metadata() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
BASIC_LAYERS: list[str] = [
|
||||||
|
"Conv1d",
|
||||||
|
"Conv2d",
|
||||||
|
"Conv3d",
|
||||||
|
"Linear",
|
||||||
|
"BatchNorm1d",
|
||||||
|
"BatchNorm2d",
|
||||||
|
"BatchNorm3d",
|
||||||
|
"LayerNorm",
|
||||||
|
"GroupNorm",
|
||||||
|
"Embedding",
|
||||||
|
"MaxPool2d",
|
||||||
|
"AvgPool2d",
|
||||||
|
"AdaptiveAvgPool2d",
|
||||||
|
]
|
||||||
|
|
||||||
|
ModelTypeShape = tuple[str, tuple[Size, ...]]
|
||||||
|
|
||||||
|
|
||||||
|
def is_basic_layer(module: "Module") -> bool:
|
||||||
|
return module.__class__.__name__ in BASIC_LAYERS
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_signature(module: "Module") -> ModelTypeShape:
|
||||||
|
param_shapes = [p.shape for p in module.parameters()]
|
||||||
|
return (module.__class__.__name__, tuple(param_shapes))
|
||||||
|
|
||||||
|
|
||||||
|
def forward_order_of_execution(
|
||||||
|
module: "Module",
|
||||||
|
example_args: tuple[Any, ...],
|
||||||
|
key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
) -> dict[ModelTypeShape, list[str]]:
|
||||||
|
key_skipper = key_skipper or (lambda _: False)
|
||||||
|
|
||||||
|
submodule_to_key: dict["Module", str] = {}
|
||||||
|
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: "Module", *_: Any):
|
||||||
|
layer_signature = get_module_signature(layer)
|
||||||
|
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
for name, submodule in module.named_modules():
|
||||||
|
if is_basic_layer(submodule) and not key_skipper(name):
|
||||||
|
submodule_to_key[submodule] = name
|
||||||
|
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
module(*example_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return dict(execution_order)
|
||||||
|
|
||||||
|
|
||||||
|
def print_side_by_side(
|
||||||
|
shape: ModelTypeShape,
|
||||||
|
source_keys: list[str],
|
||||||
|
target_keys: list[str],
|
||||||
|
):
|
||||||
|
print(f"{shape}")
|
||||||
|
max_len = max(len(source_keys), len(target_keys))
|
||||||
|
for i in range(max_len):
|
||||||
|
source_key = source_keys[i] if i < len(source_keys) else "---"
|
||||||
|
target_key = target_keys[i] if i < len(target_keys) else "---"
|
||||||
|
print(f"\t{source_key}\t{target_key}")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_shape_match(
|
||||||
|
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||||
|
) -> bool:
|
||||||
|
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||||
|
shape_missmatched = False
|
||||||
|
|
||||||
|
for model_type_shape in model_type_shapes:
|
||||||
|
source_keys = source_order.get(model_type_shape, [])
|
||||||
|
target_keys = target_order.get(model_type_shape, [])
|
||||||
|
|
||||||
|
if len(source_keys) != len(target_keys):
|
||||||
|
shape_missmatched = True
|
||||||
|
print_side_by_side(model_type_shape, source_keys, target_keys)
|
||||||
|
|
||||||
|
return not shape_missmatched
|
||||||
|
|
||||||
|
|
||||||
|
def create_state_dict_mapping(
|
||||||
|
source_model: "Module",
|
||||||
|
target_model: "Module",
|
||||||
|
source_args: tuple[Any, ...],
|
||||||
|
target_args: tuple[Any, ...] | None = None,
|
||||||
|
source_key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
target_key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
) -> dict[str, str] | None:
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper)
|
||||||
|
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper)
|
||||||
|
|
||||||
|
if not verify_shape_match(source_order, target_order):
|
||||||
|
return None
|
||||||
|
|
||||||
|
mapping: dict[str, str] = {}
|
||||||
|
for model_type_shape in source_order:
|
||||||
|
source_keys = source_order[model_type_shape]
|
||||||
|
target_keys = target_order[model_type_shape]
|
||||||
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(
|
||||||
|
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
converted_state_dict: dict[str, Tensor] = {}
|
||||||
|
for target_key in target_state_dict:
|
||||||
|
target_prefix, suffix = target_key.rsplit(".", 1)
|
||||||
|
source_prefix = state_dict_mapping[target_prefix]
|
||||||
|
source_key = ".".join([source_prefix, suffix])
|
||||||
|
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def forward_store_outputs(
|
||||||
|
module: "Module",
|
||||||
|
example_args: tuple[Any, ...],
|
||||||
|
key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
) -> list[tuple[str, Tensor]]:
|
||||||
|
key_skipper = key_skipper or (lambda _: False)
|
||||||
|
submodule_to_key: dict["Module", str] = {}
|
||||||
|
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor):
|
||||||
|
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
for name, submodule in module.named_modules():
|
||||||
|
if is_basic_layer(submodule) and not key_skipper(name):
|
||||||
|
submodule_to_key[submodule] = name
|
||||||
|
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
module(*example_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return execution_order
|
||||||
|
|
||||||
|
|
||||||
|
def compare_models(
|
||||||
|
source_model: "Module",
|
||||||
|
target_model: "Module",
|
||||||
|
source_args: tuple[Any, ...],
|
||||||
|
target_args: tuple[Any, ...] | None = None,
|
||||||
|
source_key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
target_key_skipper: Callable[[str], bool] | None = None,
|
||||||
|
threshold: float = 1e-5,
|
||||||
|
) -> bool:
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_order = forward_store_outputs(source_model, source_args, source_key_skipper)
|
||||||
|
target_order = forward_store_outputs(target_model, target_args, target_key_skipper)
|
||||||
|
|
||||||
|
prev_source_key, prev_target_key = None, None
|
||||||
|
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
|
||||||
|
diff = norm(source_output - target_output).item()
|
||||||
|
if diff > threshold:
|
||||||
|
print(
|
||||||
|
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
||||||
|
f" {target_key}, difference in norm: {diff}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
prev_source_key, prev_target_key = source_key, target_key
|
||||||
|
|
||||||
|
return True
|
0
src/refiners/foundationals/__init__.py
Normal file
0
src/refiners/foundationals/clip/__init__.py
Normal file
BIN
src/refiners/foundationals/clip/bpe_simple_vocab_16e6.txt.gz
Normal file
0
src/refiners/foundationals/clip/image_encoder.py
Normal file
250
src/refiners/foundationals/clip/text_encoder.py
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.fluxion.layers import (
|
||||||
|
ApproximateGeLU,
|
||||||
|
GeLU,
|
||||||
|
Linear,
|
||||||
|
LayerNorm,
|
||||||
|
Embedding,
|
||||||
|
Chain,
|
||||||
|
Sum,
|
||||||
|
SelfAttention,
|
||||||
|
Lambda,
|
||||||
|
Residual,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalTokenEncoder(Sum):
|
||||||
|
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabulary_size: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
positional_embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.vocabulary_size = vocabulary_size
|
||||||
|
self.positional_embedding_dim = positional_embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
Embedding(
|
||||||
|
num_embeddings=vocabulary_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Chain(
|
||||||
|
Lambda(self.get_position_ids),
|
||||||
|
Embedding(
|
||||||
|
num_embeddings=positional_embedding_dim,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position_ids(self) -> Tensor:
|
||||||
|
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(Chain):
|
||||||
|
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
GeLU(),
|
||||||
|
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(Chain):
|
||||||
|
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
num_attention_heads: int = 1,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
super().__init__(
|
||||||
|
Residual(
|
||||||
|
LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SelfAttention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
is_causal=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Residual(
|
||||||
|
LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
FeedForward(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoder(Chain):
|
||||||
|
structural_attrs = [
|
||||||
|
"embedding_dim",
|
||||||
|
"positional_embedding_dim",
|
||||||
|
"vocabulary_size",
|
||||||
|
"num_layers",
|
||||||
|
"num_attention_heads",
|
||||||
|
"feedforward_dim",
|
||||||
|
"layer_norm_eps",
|
||||||
|
"tokenizer",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
positional_embedding_dim: int = 77,
|
||||||
|
vocabulary_size: int = 49408,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.positional_embedding_dim = positional_embedding_dim
|
||||||
|
self.vocabulary_size = vocabulary_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.tokenizer = CLIPTokenizer()
|
||||||
|
super().__init__(
|
||||||
|
PositionalTokenEncoder(
|
||||||
|
vocabulary_size=vocabulary_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
positional_embedding_dim=positional_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, text: str) -> Tensor:
|
||||||
|
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device)
|
||||||
|
return self(tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unconditional_text_embedding(self) -> Tensor:
|
||||||
|
return self.encode("")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderL is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=768
|
||||||
|
num_layers=12
|
||||||
|
num_attention_heads=12
|
||||||
|
feedforward_dim=3072
|
||||||
|
|
||||||
|
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
|
||||||
|
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
feedforward_dim=3072,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)):
|
||||||
|
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderH(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderH is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=1024
|
||||||
|
num_layers=23
|
||||||
|
num_attention_heads=16
|
||||||
|
feedforward_dim=4096
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1024,
|
||||||
|
num_layers=23,
|
||||||
|
num_attention_heads=16,
|
||||||
|
feedforward_dim=4096,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderG(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderG is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=1280
|
||||||
|
num_layers=32
|
||||||
|
num_attention_heads=16
|
||||||
|
feedforward_dim=5120
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1280,
|
||||||
|
num_layers=32,
|
||||||
|
num_attention_heads=20,
|
||||||
|
feedforward_dim=5120,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
108
src/refiners/foundationals/clip/tokenizer.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
import gzip
|
||||||
|
from pathlib import Path
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import islice
|
||||||
|
import re
|
||||||
|
from torch import Tensor, tensor
|
||||||
|
from refiners.fluxion import pad
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTokenizer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
):
|
||||||
|
self.vocabulary_path = vocabulary_path
|
||||||
|
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||||
|
merge_tuples = [
|
||||||
|
tuple(merge.split())
|
||||||
|
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
|
]
|
||||||
|
vocabulary = (
|
||||||
|
list(self.byte_to_unicode_mapping.values())
|
||||||
|
+ [v + "</w>" for v in self.byte_to_unicode_mapping.values()]
|
||||||
|
+ ["".join(merge) for merge in merge_tuples]
|
||||||
|
+ ["", ""]
|
||||||
|
)
|
||||||
|
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)}
|
||||||
|
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)}
|
||||||
|
self.byte_pair_encoding_cache = {"": ""}
|
||||||
|
# Note: this regular expression does not support Unicode. It was changed so
|
||||||
|
# to get rid of the dependence on the `regex` module. Unicode support could
|
||||||
|
# potentially be added back by leveraging the `\w` character class.
|
||||||
|
self.token_pattern = re.compile(
|
||||||
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
self.start_of_text_token_id: int = 49406
|
||||||
|
self.end_of_text_token_id: int = 49407
|
||||||
|
|
||||||
|
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
||||||
|
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0)
|
||||||
|
assert (
|
||||||
|
tokens.shape[1] <= sequence_length
|
||||||
|
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
||||||
|
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||||
|
initial_byte_values = (
|
||||||
|
list(range(ord("!"), ord("~") + 1))
|
||||||
|
+ list(range(ord("¡"), ord("¬") + 1))
|
||||||
|
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
)
|
||||||
|
extra_unicode_values = (byte for byte in range(2**8) if byte not in initial_byte_values)
|
||||||
|
byte_values = initial_byte_values + list(extra_unicode_values)
|
||||||
|
unicode_values = [chr(value) for value in byte_values]
|
||||||
|
return dict(zip(byte_values, unicode_values))
|
||||||
|
|
||||||
|
def byte_pair_encoding(self, token: str) -> str:
|
||||||
|
if token in self.byte_pair_encoding_cache:
|
||||||
|
return self.byte_pair_encoding_cache[token]
|
||||||
|
|
||||||
|
def recursive_bpe(word: tuple[str, ...]) -> tuple[str, ...]:
|
||||||
|
if len(word) < 2:
|
||||||
|
return word
|
||||||
|
pairs = {(i, (word[i], word[i + 1])) for i in range(len(word) - 1)}
|
||||||
|
min_pair = min(
|
||||||
|
pairs,
|
||||||
|
key=lambda pair: self.byte_pair_encoding_ranks.get(pair[1], float("inf")),
|
||||||
|
)
|
||||||
|
if min_pair[1] not in self.byte_pair_encoding_ranks:
|
||||||
|
return word
|
||||||
|
new_word: list[str] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
if i == min_pair[0]:
|
||||||
|
new_word.append(min_pair[1][0] + min_pair[1][1])
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
return recursive_bpe(tuple(new_word))
|
||||||
|
|
||||||
|
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||||
|
result = " ".join(recursive_bpe(word))
|
||||||
|
self.byte_pair_encoding_cache[token] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def encode(self, text: str, max_length: int | None = None) -> Tensor:
|
||||||
|
text = re.sub(r"\s+", " ", text.lower())
|
||||||
|
tokens = re.findall(self.token_pattern, text)
|
||||||
|
upper_bound = None
|
||||||
|
if max_length:
|
||||||
|
assert max_length >= 2
|
||||||
|
upper_bound = max_length - 2
|
||||||
|
encoded_tokens = islice(
|
||||||
|
(
|
||||||
|
self.token_to_id_mapping[subtoken]
|
||||||
|
for token in tokens
|
||||||
|
for subtoken in self.byte_pair_encoding(
|
||||||
|
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
||||||
|
).split(" ")
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
upper_bound,
|
||||||
|
)
|
||||||
|
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
201
src/refiners/foundationals/latent_diffusion/__init__.py
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
from typing import TypeVar
|
||||||
|
from torch import cat, float32, randn, tensor, device as Device, dtype as DType, Size, Tensor
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
|
LatentDiffusionAutoencoder,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.clip.text_encoder import (
|
||||||
|
CLIPTextEncoder,
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
|
||||||
|
|
||||||
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LatentDiffusionModel",
|
||||||
|
"UNet",
|
||||||
|
"DPMSolver",
|
||||||
|
"Scheduler",
|
||||||
|
"CLIPTextEncoder",
|
||||||
|
"LatentDiffusionAutoencoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionModel(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNet,
|
||||||
|
lda: LatentDiffusionAutoencoder,
|
||||||
|
clip_text_encoder: CLIPTextEncoder,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.device: Device = device if isinstance(device, Device) else Device(device)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.unet = unet.to(self.device, dtype=self.dtype)
|
||||||
|
self.lda = lda.to(self.device, dtype=self.dtype)
|
||||||
|
self.clip_text_encoder = clip_text_encoder.to(self.device, dtype=self.dtype)
|
||||||
|
self.scheduler = scheduler.to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
def set_num_inference_steps(self, num_inference_steps: int):
|
||||||
|
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||||
|
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||||
|
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||||
|
self.scheduler = self.scheduler.__class__(
|
||||||
|
num_inference_steps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def init_latents(
|
||||||
|
self,
|
||||||
|
size: tuple[int, int],
|
||||||
|
init_image: Image.Image | None = None,
|
||||||
|
first_step: int = 0,
|
||||||
|
noise: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if noise is None:
|
||||||
|
height, width = size
|
||||||
|
noise = randn(1, 4, height // 8, width // 8, device=self.device)
|
||||||
|
assert list(noise.shape[2:]) == [
|
||||||
|
size[0] // 8,
|
||||||
|
size[1] // 8,
|
||||||
|
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||||
|
if init_image is None:
|
||||||
|
return noise
|
||||||
|
encoded_image = self.lda.encode_image(init_image.resize(size))
|
||||||
|
return self.scheduler.add_noise(encoded_image, noise, self.steps[first_step])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return self.scheduler.steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestep_embeddings(self) -> Tensor:
|
||||||
|
return self.timestep_encoder(self.scheduler.timesteps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unconditional_clip_text_embeddings(self) -> Tensor:
|
||||||
|
return self.clip_text_encoder.unconditional_text_embedding
|
||||||
|
|
||||||
|
def compute_text_embedding(self, text: str) -> Tensor:
|
||||||
|
return self.clip_text_encoder.encode(text)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
negative_clip_text_embedding: Tensor | None = None,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
) -> Tensor:
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(0)
|
||||||
|
self.unet.set_timestep(timestep)
|
||||||
|
|
||||||
|
negative_clip_text_embedding = (
|
||||||
|
self.clip_text_encoder.unconditional_text_embedding
|
||||||
|
if negative_clip_text_embedding is None
|
||||||
|
else negative_clip_text_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_text_embeddings = cat((negative_clip_text_embedding, clip_text_embedding))
|
||||||
|
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embeddings)
|
||||||
|
latents = cat((x, x)) # for classifier-free guidance
|
||||||
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
|
# classifier-free guidance
|
||||||
|
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||||
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
return self.scheduler(x, noise=noise, step=step)
|
||||||
|
|
||||||
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
|
return self.__class__(
|
||||||
|
unet=self.unet.structural_copy(),
|
||||||
|
lda=self.lda.structural_copy(),
|
||||||
|
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNet | None = None,
|
||||||
|
lda: LatentDiffusionAutoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
):
|
||||||
|
unet = unet or UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
lda = lda or LatentDiffusionAutoencoder()
|
||||||
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
|
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
unet,
|
||||||
|
lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNet | None = None,
|
||||||
|
lda: LatentDiffusionAutoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
):
|
||||||
|
self.mask_latents: Tensor | None = None
|
||||||
|
self.target_image_latents: Tensor | None = None
|
||||||
|
super().__init__(unet, lda, clip_text_encoder, scheduler, device, dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
negative_clip_text_embedding: Tensor | None = None,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
):
|
||||||
|
assert self.mask_latents is not None
|
||||||
|
assert self.target_image_latents is not None
|
||||||
|
x = cat((x, self.mask_latents, self.target_image_latents), dim=1)
|
||||||
|
return super().forward(x, step, clip_text_embedding, negative_clip_text_embedding, condition_scale)
|
||||||
|
|
||||||
|
def set_inpainting_conditions(
|
||||||
|
self,
|
||||||
|
target_image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
latents_size: tuple[int, int] = (64, 64),
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
target_image = target_image.convert("RGB")
|
||||||
|
mask = mask.convert("L")
|
||||||
|
|
||||||
|
mask_tensor = tensor(np.array(mask).astype(np.float32) / 255.0).to(self.device)
|
||||||
|
mask_tensor = (mask_tensor > 0.5).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype)
|
||||||
|
self.mask_latents = interpolate(mask_tensor, Size(latents_size))
|
||||||
|
|
||||||
|
init_image_tensor = image_to_tensor(target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||||
|
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||||
|
self.target_image_latents = self.lda.encode(masked_init_image)
|
||||||
|
|
||||||
|
return self.mask_latents, self.target_image_latents
|
230
src/refiners/foundationals/latent_diffusion/auto_encoder.py
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.layers import (
|
||||||
|
Chain,
|
||||||
|
Conv2d,
|
||||||
|
GroupNorm,
|
||||||
|
Identity,
|
||||||
|
SiLU,
|
||||||
|
Downsample,
|
||||||
|
Upsample,
|
||||||
|
Sum,
|
||||||
|
SelfAttention2d,
|
||||||
|
Slicing,
|
||||||
|
)
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Resnet(Sum):
|
||||||
|
structural_attrs = ["in_channels", "out_channels"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
shortcut = (
|
||||||
|
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
shortcut,
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=in_channels, num_groups=num_groups, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
GroupNorm(channels=out_channels, num_groups=num_groups, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||||
|
input_channels: int = 3
|
||||||
|
latent_dim: int = 8
|
||||||
|
resnet_layers: list[Chain] = [
|
||||||
|
Chain(
|
||||||
|
[
|
||||||
|
Resnet(
|
||||||
|
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
|
||||||
|
out_channels=resnet_sizes[i],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for i in range(len(resnet_sizes))
|
||||||
|
]
|
||||||
|
for _, layer in zip(range(3), resnet_layers):
|
||||||
|
channels: int = layer[-1].out_channels # type: ignore
|
||||||
|
layer.append(Downsample(channels=channels, scale_factor=2, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
attention_layer = Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
resnet_layers[-1].insert_after_type(Resnet, attention_layer)
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=resnet_sizes[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Chain(*resnet_layers),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=resnet_sizes[-1],
|
||||||
|
out_channels=latent_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Chain(
|
||||||
|
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
Slicing(dim=1, start=0, length=4),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"sampling": {"shapes": []}}
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(Chain):
|
||||||
|
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||||
|
self.latent_dim: int = 4
|
||||||
|
self.output_channels: int = 3
|
||||||
|
resnet_sizes = self.resnet_sizes[::-1]
|
||||||
|
resnet_layers: list[Chain] = [
|
||||||
|
(
|
||||||
|
Chain(
|
||||||
|
[
|
||||||
|
Resnet(
|
||||||
|
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
|
||||||
|
out_channels=resnet_sizes[i],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if i > 0
|
||||||
|
else Chain(
|
||||||
|
[
|
||||||
|
Resnet(in_channels=resnet_sizes[0], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(len(resnet_sizes))
|
||||||
|
]
|
||||||
|
attention_layer = Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
resnet_layers[0].insert(1, attention_layer)
|
||||||
|
for _, layer in zip(range(3), resnet_layers[1:]):
|
||||||
|
channels: int = layer[-1].out_channels
|
||||||
|
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.latent_dim,
|
||||||
|
out_channels=resnet_sizes[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Chain(*resnet_layers),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=resnet_sizes[-1],
|
||||||
|
out_channels=self.output_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
|
structural_attrs = ["encoder_scale"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.encoder_scale: float = 0.18215
|
||||||
|
super().__init__(
|
||||||
|
Encoder(device=device, dtype=dtype),
|
||||||
|
Decoder(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
encoder = self[0]
|
||||||
|
x = self.encoder_scale * encoder(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x: Tensor) -> Tensor:
|
||||||
|
decoder = self[1]
|
||||||
|
x = decoder(x / self.encoder_scale)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode_image(self, image: Image.Image) -> Tensor:
|
||||||
|
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||||
|
x = 2 * x - 1
|
||||||
|
return self.encode(x)
|
||||||
|
|
||||||
|
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||||
|
x = self.decode(x)
|
||||||
|
x = (x + 1) / 2
|
||||||
|
return tensor_to_image(x)
|
150
src/refiners/foundationals/latent_diffusion/controlnet.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder
|
||||||
|
from refiners.adapters.range_adapter import RangeAdapter2d
|
||||||
|
from typing import cast, Iterable
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoder(Chain):
|
||||||
|
"""Encode an image to be used as a condition for Controlnet.
|
||||||
|
|
||||||
|
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
structural_attrs = ["out_channels"]
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.out_channels = (16, 32, 96, 256)
|
||||||
|
super().__init__(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=self.out_channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[i],
|
||||||
|
out_channels=self.out_channels[i],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[i],
|
||||||
|
out_channels=self.out_channels[i + 1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
)
|
||||||
|
for i in range(len(self.out_channels) - 1)
|
||||||
|
),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[-1],
|
||||||
|
out_channels=320,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Controlnet(Passthrough):
|
||||||
|
structural_attrs = ["name", "scale"]
|
||||||
|
|
||||||
|
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
||||||
|
|
||||||
|
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
|
||||||
|
stored in the context.
|
||||||
|
|
||||||
|
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.scale: float = 1.0
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
|
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
|
||||||
|
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
# We run the condition encoder at each step. Caching the result
|
||||||
|
# is not worth it as subsequent runs take virtually no time (FG-374).
|
||||||
|
self.DownBlocks[0].append(
|
||||||
|
Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
UseContext("controlnet", f"condition_{name}"),
|
||||||
|
ConditionEncoder(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
range_adapter = RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key=f"timestep_embedding_{self.name}",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
range_adapter.inject(chain)
|
||||||
|
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
|
||||||
|
assert hasattr(block[0], "out_channels"), (
|
||||||
|
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
|
||||||
|
f" {block[0]} does not."
|
||||||
|
)
|
||||||
|
out_channels: int = block[0].out_channels
|
||||||
|
block.append(
|
||||||
|
Passthrough(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=out_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
Lambda(self._store_nth_residual(n)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.MiddleBlock.append(
|
||||||
|
Passthrough(
|
||||||
|
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
Lambda(self._store_nth_residual(12)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"unet": {"residuals": [0.0] * 13},
|
||||||
|
"sampling": {"shapes": []},
|
||||||
|
"controlnet": {f"condition_{self.name}": None},
|
||||||
|
"range_adapter": {f"timestep_embedding_{self.name}": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _store_nth_residual(self, n: int):
|
||||||
|
def _store_residual(x: Tensor):
|
||||||
|
residuals = self.use_context("unet")["residuals"]
|
||||||
|
residuals[n] = residuals[n] + x * self.scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _store_residual
|
||||||
|
|
||||||
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
|
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
self.scale = scale
|
203
src/refiners/foundationals/latent_diffusion/cross_attention.py
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
from torch import Tensor, Size, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.layers import (
|
||||||
|
Identity,
|
||||||
|
Flatten,
|
||||||
|
Unflatten,
|
||||||
|
Transpose,
|
||||||
|
Chain,
|
||||||
|
Parallel,
|
||||||
|
LayerNorm,
|
||||||
|
Attention,
|
||||||
|
Sum,
|
||||||
|
UseContext,
|
||||||
|
Linear,
|
||||||
|
GLU,
|
||||||
|
GeLU,
|
||||||
|
GroupNorm,
|
||||||
|
Conv2d,
|
||||||
|
SelfAttention,
|
||||||
|
SetContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionBlock(Chain):
|
||||||
|
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
context_embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.context_embedding_dim = context_embedding_dim
|
||||||
|
self.context = "cross_attention_block"
|
||||||
|
self.context_key = context_key
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
SelfAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context=self.context, key=context_key),
|
||||||
|
UseContext(context=self.context, key=context_key),
|
||||||
|
),
|
||||||
|
Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
key_embedding_dim=context_embedding_dim,
|
||||||
|
value_embedding_dim=context_embedding_dim,
|
||||||
|
use_bias=use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Sum(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype),
|
||||||
|
GLU(GeLU()),
|
||||||
|
Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulFlatten(Chain):
|
||||||
|
structural_attrs = ["start_dim", "end_dim"]
|
||||||
|
|
||||||
|
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
|
self.start_dim = start_dim
|
||||||
|
self.end_dim = end_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
SetContext(context=context, key=key, callback=self.push),
|
||||||
|
Flatten(start_dim=start_dim, end_dim=end_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def push(self, sizes: list[Size], x: Tensor) -> None:
|
||||||
|
sizes.append(
|
||||||
|
x.shape[slice(self.start_dim, self.end_dim + 1 if self.end_dim >= 0 else x.ndim + self.end_dim + 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionBlock2d(Sum):
|
||||||
|
structural_attrs = [
|
||||||
|
"channels",
|
||||||
|
"in_channels",
|
||||||
|
"out_channels",
|
||||||
|
"context_embedding_dim",
|
||||||
|
"num_attention_heads",
|
||||||
|
"num_attention_layers",
|
||||||
|
"num_groups",
|
||||||
|
"context_key",
|
||||||
|
"use_linear_projection",
|
||||||
|
"projection_type",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
context_embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
num_attention_heads: int = 1,
|
||||||
|
num_attention_layers: int = 1,
|
||||||
|
num_groups: int = 32,
|
||||||
|
use_bias: bool = True,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert channels % num_attention_heads == 0, "in_channels must be divisible by num_attention_heads"
|
||||||
|
self.channels = channels
|
||||||
|
self.in_channels = channels
|
||||||
|
self.out_channels = channels
|
||||||
|
self.context_embedding_dim = context_embedding_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_attention_layers = num_attention_layers
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.context_key = context_key
|
||||||
|
self.use_linear_projection = use_linear_projection
|
||||||
|
self.projection_type = "Linear" if use_linear_projection else "Conv2d"
|
||||||
|
|
||||||
|
in_block = (
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
||||||
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
|
Transpose(1, 2),
|
||||||
|
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
if use_linear_projection
|
||||||
|
else Chain(
|
||||||
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, affine=True, device=device, dtype=dtype),
|
||||||
|
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out_block = (
|
||||||
|
Chain(
|
||||||
|
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||||
|
Transpose(1, 2),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
|
||||||
|
),
|
||||||
|
Unflatten(dim=2),
|
||||||
|
)
|
||||||
|
if use_linear_projection
|
||||||
|
else Chain(
|
||||||
|
Transpose(1, 2),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
|
||||||
|
),
|
||||||
|
Unflatten(dim=2),
|
||||||
|
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Identity(),
|
||||||
|
Chain(
|
||||||
|
in_block,
|
||||||
|
Chain(
|
||||||
|
CrossAttentionBlock(
|
||||||
|
embedding_dim=channels,
|
||||||
|
context_embedding_dim=context_embedding_dim,
|
||||||
|
context_key=context_key,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_attention_layers)
|
||||||
|
),
|
||||||
|
out_block,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"flatten": {"sizes": []}}
|
101
src/refiners/foundationals/latent_diffusion/lora.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
from refiners.adapters.lora import LoraAdapter, load_lora_weights
|
||||||
|
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||||
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
class LoraTarget(str, Enum):
|
||||||
|
Self = "self"
|
||||||
|
Attention = "Attention"
|
||||||
|
SelfAttention = "SelfAttention"
|
||||||
|
CrossAttention = "CrossAttentionBlock2d"
|
||||||
|
FeedForward = "FeedForward"
|
||||||
|
TransformerLayer = "TransformerLayer"
|
||||||
|
|
||||||
|
def get_class(self) -> type[fl.Chain]:
|
||||||
|
match self:
|
||||||
|
case LoraTarget.Self:
|
||||||
|
return fl.Chain
|
||||||
|
case LoraTarget.Attention:
|
||||||
|
return fl.Attention
|
||||||
|
case LoraTarget.SelfAttention:
|
||||||
|
return fl.SelfAttention
|
||||||
|
case LoraTarget.CrossAttention:
|
||||||
|
return CrossAttentionBlock2d
|
||||||
|
case LoraTarget.FeedForward:
|
||||||
|
return FeedForward
|
||||||
|
case LoraTarget.TransformerLayer:
|
||||||
|
return TransformerLayer
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_rank(weights: list[Tensor]) -> int:
|
||||||
|
ranks: set[int] = {w.shape[1] for w in weights[0::2]}
|
||||||
|
assert len(ranks) == 1
|
||||||
|
return ranks.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
|
||||||
|
for layer in module.layers(layer_type=target.get_class()):
|
||||||
|
for linear, parent in layer.walk(fl.Linear):
|
||||||
|
adapter = LoraAdapter(
|
||||||
|
target=linear,
|
||||||
|
rank=rank,
|
||||||
|
scale=scale,
|
||||||
|
device=module.device,
|
||||||
|
dtype=module.dtype,
|
||||||
|
)
|
||||||
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraWeights:
|
||||||
|
"""A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model."""
|
||||||
|
|
||||||
|
metadata: dict[str, str] | None
|
||||||
|
tensors: dict[str, Tensor]
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_path: Path | str, device: Device | str):
|
||||||
|
self.metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||||
|
self.tensors = load_from_safetensors(checkpoint_path, device=device)
|
||||||
|
|
||||||
|
def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None:
|
||||||
|
assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||||
|
|
||||||
|
for meta_key, meta_value in self.metadata.items():
|
||||||
|
match meta_key:
|
||||||
|
case "unet_targets":
|
||||||
|
# TODO: support this transparently
|
||||||
|
if any([isinstance(module, Controlnet) for module in sd.unet]):
|
||||||
|
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
|
||||||
|
model = sd.unet
|
||||||
|
key_prefix = "unet."
|
||||||
|
case "text_encoder_targets":
|
||||||
|
model = sd.clip_text_encoder
|
||||||
|
key_prefix = "text_encoder."
|
||||||
|
case "lda_targets":
|
||||||
|
model = sd.lda
|
||||||
|
key_prefix = "lda."
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
||||||
|
|
||||||
|
# TODO(FG-487): support loading multiple LoRA-s
|
||||||
|
if any(model.layers(LoraAdapter)):
|
||||||
|
raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers")
|
||||||
|
|
||||||
|
lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]]
|
||||||
|
assert len(lora_weights) % 2 == 0
|
||||||
|
|
||||||
|
rank = get_lora_rank(lora_weights)
|
||||||
|
for target in meta_value.split(","):
|
||||||
|
apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale)
|
||||||
|
|
||||||
|
assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2)
|
||||||
|
|
||||||
|
load_lora_weights(model, [TorchParameter(w) for w in lora_weights])
|
|
@ -0,0 +1,11 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Scheduler",
|
||||||
|
"DPMSolver",
|
||||||
|
"DDPM",
|
||||||
|
"DDIM",
|
||||||
|
]
|
|
@ -0,0 +1,41 @@
|
||||||
|
from torch import Tensor, device as Device, arange, sqrt
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DDIM(Scheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
||||||
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
||||||
|
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
||||||
|
"""
|
||||||
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
||||||
|
return timesteps.flip(0)
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep, previous_timestep = (
|
||||||
|
self.timesteps[step],
|
||||||
|
self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps,
|
||||||
|
)
|
||||||
|
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||||
|
self.cumulative_scale_factors[previous_timestep]
|
||||||
|
if previous_timestep > 0
|
||||||
|
else self.cumulative_scale_factors[0]
|
||||||
|
)
|
||||||
|
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||||
|
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||||
|
|
||||||
|
self.previous_scale_factor = previous_scale_factor
|
||||||
|
|
||||||
|
return denoised_x
|
|
@ -0,0 +1,75 @@
|
||||||
|
from torch import Tensor, device as Device, randn, arange, Generator, tensor
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DDPM(Scheduler):
|
||||||
|
"""
|
||||||
|
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model,
|
||||||
|
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio
|
||||||
|
return timesteps.flip(0)
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generate the next step in the diffusion process.
|
||||||
|
|
||||||
|
This method adjusts the input data using added noise and an estimate of the denoised data, based on the current
|
||||||
|
step in the diffusion process. This adjusted data forms the next step in the diffusion process.
|
||||||
|
|
||||||
|
1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original
|
||||||
|
data and noise to the new step.
|
||||||
|
2. An estimate of the denoised data (`estimated_denoised_data`) is generated.
|
||||||
|
3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and
|
||||||
|
`current_data_coeff`) that balance their contribution to the denoised data for the next step.
|
||||||
|
4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated
|
||||||
|
denoised data and current data, adjusted by their respective coefficients.
|
||||||
|
5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on
|
||||||
|
the cumulative scaling factor and the current factor.
|
||||||
|
|
||||||
|
The output is the new data step for the next stage in the diffusion process.
|
||||||
|
"""
|
||||||
|
timestep, previous_timestep = (
|
||||||
|
self.timesteps[step],
|
||||||
|
(
|
||||||
|
self.timesteps[step + 1]
|
||||||
|
if step < len(self.timesteps) - 1
|
||||||
|
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], (
|
||||||
|
(self.scale_factors.cumprod(0))[previous_timestep]
|
||||||
|
if step < len(self.timesteps) - 1
|
||||||
|
else tensor(1, device=self.device)
|
||||||
|
)
|
||||||
|
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
||||||
|
estimated_denoised_data = (
|
||||||
|
x - (1 - current_cumulative_factor) ** 0.5 * noise
|
||||||
|
) / current_cumulative_factor**0.5
|
||||||
|
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
|
||||||
|
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
|
||||||
|
1 - current_cumulative_factor
|
||||||
|
)
|
||||||
|
current_data_coeff = (
|
||||||
|
current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor)
|
||||||
|
)
|
||||||
|
denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x
|
||||||
|
if step < len(self.timesteps) - 1:
|
||||||
|
variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor)
|
||||||
|
denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn(
|
||||||
|
x.shape, device=x.device, dtype=x.dtype, generator=generator
|
||||||
|
)
|
||||||
|
return denoised_x
|
|
@ -0,0 +1,111 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
import numpy as np
|
||||||
|
from torch import Tensor, device as Device, tensor, exp
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class DPMSolver(Scheduler):
|
||||||
|
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||||
|
|
||||||
|
We only support noise prediction for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||||
|
self.initial_steps = 0
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
# We need to use numpy here because:
|
||||||
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
# ...and we want the same result as the original codebase.
|
||||||
|
return tensor(
|
||||||
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||||
|
device=self.device,
|
||||||
|
).flip(0)
|
||||||
|
|
||||||
|
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep, previous_timestep = (
|
||||||
|
self.timesteps[step],
|
||||||
|
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
|
||||||
|
)
|
||||||
|
previous_ratio, current_ratio = (
|
||||||
|
self.signal_to_noise_ratios[previous_timestep],
|
||||||
|
self.signal_to_noise_ratios[timestep],
|
||||||
|
)
|
||||||
|
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||||
|
previous_noise_std, current_noise_std = (
|
||||||
|
self.noise_std[previous_timestep],
|
||||||
|
self.noise_std[timestep],
|
||||||
|
)
|
||||||
|
exp_factor = exp(-(previous_ratio - current_ratio))
|
||||||
|
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
|
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
previous_timestep, current_timestep, next_timestep = (
|
||||||
|
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
|
||||||
|
self.timesteps[step],
|
||||||
|
self.timesteps[step - 1],
|
||||||
|
)
|
||||||
|
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
|
||||||
|
previous_ratio, current_ratio, next_ratio = (
|
||||||
|
self.signal_to_noise_ratios[previous_timestep],
|
||||||
|
self.signal_to_noise_ratios[current_timestep],
|
||||||
|
self.signal_to_noise_ratios[next_timestep],
|
||||||
|
)
|
||||||
|
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||||
|
previous_std, current_std = (
|
||||||
|
self.noise_std[previous_timestep],
|
||||||
|
self.noise_std[current_timestep],
|
||||||
|
)
|
||||||
|
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||||
|
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||||
|
)
|
||||||
|
exp_neg_factor = exp(-(previous_ratio - current_ratio))
|
||||||
|
x_t = (
|
||||||
|
(previous_std / current_std) * x
|
||||||
|
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation
|
||||||
|
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta
|
||||||
|
)
|
||||||
|
return x_t
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||||
|
|
||||||
|
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
|
||||||
|
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
||||||
|
(ODEs).
|
||||||
|
"""
|
||||||
|
current_timestep = self.timesteps[step]
|
||||||
|
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||||
|
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||||
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
|
denoised_x = (
|
||||||
|
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||||
|
if (self.initial_steps == 0)
|
||||||
|
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||||
|
)
|
||||||
|
if self.initial_steps < 2:
|
||||||
|
self.initial_steps += 1
|
||||||
|
return denoised_x
|
|
@ -0,0 +1,95 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
"""
|
||||||
|
A base class for creating a diffusion model scheduler.
|
||||||
|
|
||||||
|
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
|
||||||
|
which gradually transforms the original data distribution into a Gaussian one.
|
||||||
|
|
||||||
|
This process is described using several parameters such as initial and final diffusion rates,
|
||||||
|
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timesteps: Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
):
|
||||||
|
self.device: Device = Device(device)
|
||||||
|
self.dtype: DType = dtype
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
self.initial_diffusion_rate = initial_diffusion_rate
|
||||||
|
self.final_diffusion_rate = final_diffusion_rate
|
||||||
|
self.scale_factors = (
|
||||||
|
1.0
|
||||||
|
- linspace(
|
||||||
|
start=initial_diffusion_rate**0.5,
|
||||||
|
end=final_diffusion_rate**0.5,
|
||||||
|
steps=num_train_timesteps,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
** 2
|
||||||
|
)
|
||||||
|
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||||
|
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
|
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||||
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generates a tensor of timesteps.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return list(range(self.num_inference_steps))
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
) -> Tensor:
|
||||||
|
timestep = self.timesteps[step]
|
||||||
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||||
|
return noised_x
|
||||||
|
|
||||||
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
|
if device is not None:
|
||||||
|
self.device = Device(device)
|
||||||
|
self.timesteps = self.timesteps.to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtype = dtype
|
||||||
|
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
|
||||||
|
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
|
||||||
|
self.noise_std = self.noise_std.to(device, dtype=dtype)
|
||||||
|
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
|
||||||
|
return self
|
291
src/refiners/foundationals/latent_diffusion/sdxl_unet.py
Normal file
|
@ -0,0 +1,291 @@
|
||||||
|
from typing import cast
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import ResidualAccumulator, ResidualBlock, ResidualConcatenator
|
||||||
|
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TextTimeEmbedding(fl.Chain):
|
||||||
|
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.timestep_embedding_dim = 1280
|
||||||
|
self.time_ids_embedding_dim = 256
|
||||||
|
self.text_time_embedding_dim = 2816
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
fl.UseContext(context="diffusion", key="pooled_text_embedding"),
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="diffusion", key="time_ids"),
|
||||||
|
fl.Unsqueeze(dim=-1),
|
||||||
|
fl.Lambda(func=self.compute_sinuosoidal_embedding),
|
||||||
|
fl.Reshape(-1),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.text_time_embedding_dim,
|
||||||
|
out_features=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.timestep_embedding_dim,
|
||||||
|
out_features=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor:
|
||||||
|
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim).to(dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
structural_attrs = ["timestep_embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.timestep_embedding_dim = 1280
|
||||||
|
super().__init__(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="diffusion", key="timestep"),
|
||||||
|
RangeEncoder(
|
||||||
|
sinuosidal_embedding_dim=320,
|
||||||
|
embedding_dim=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TextTimeEmbedding(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.SetContext(context="range_adapter", key="timestep_embedding"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
structural_attrs = ["channels", "num_attention_layers", "num_attention_heads"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_attention_layers: int = 1,
|
||||||
|
num_attention_heads: int = 10,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
channels=channels,
|
||||||
|
context_embedding_dim=2048,
|
||||||
|
context_key="clip_text_embedding",
|
||||||
|
num_attention_layers=num_attention_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
use_bias=False,
|
||||||
|
use_linear_projection=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlocks(fl.Chain):
|
||||||
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
in_block = fl.Chain(
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
first_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
second_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
third_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_block,
|
||||||
|
*first_blocks,
|
||||||
|
*second_blocks,
|
||||||
|
*third_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlocks(fl.Chain):
|
||||||
|
structural_attrs = []
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
first_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
second_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.Upsample(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
third_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
*first_blocks,
|
||||||
|
*second_blocks,
|
||||||
|
*third_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleBlock(fl.Chain):
|
||||||
|
structural_attrs = []
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputBlock(fl.Chain):
|
||||||
|
structural_attrs = []
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.GroupNorm(channels=320, num_groups=32),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUNet(fl.Chain):
|
||||||
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
|
||||||
|
UpBlocks(device=device, dtype=dtype),
|
||||||
|
OutputBlock(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
range_adapter = RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key="timestep_embedding",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
range_adapter.inject(chain)
|
||||||
|
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
|
||||||
|
block.append(module=ResidualAccumulator(n=n))
|
||||||
|
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
|
||||||
|
block.insert(index=0, module=ResidualConcatenator(n=-n - 2))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"unet": {"residuals": [0.0] * 10},
|
||||||
|
"diffusion": {"timestep": None, "time_ids": None, "pooled_text_embedding": None},
|
||||||
|
"range_adapter": {"timestep_embedding": None},
|
||||||
|
"sampling": {"shapes": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"timestep": timestep})
|
||||||
|
|
||||||
|
def set_time_ids(self, time_ids: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"time_ids": time_ids})
|
||||||
|
|
||||||
|
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})
|
|
@ -0,0 +1,130 @@
|
||||||
|
from refiners.fluxion.layers import (
|
||||||
|
Passthrough,
|
||||||
|
Lambda,
|
||||||
|
Chain,
|
||||||
|
Concatenate,
|
||||||
|
UseContext,
|
||||||
|
SelfAttention,
|
||||||
|
SetContext,
|
||||||
|
Identity,
|
||||||
|
Parallel,
|
||||||
|
)
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
|
||||||
|
def __init__(self, target: SelfAttention, context: str) -> None:
|
||||||
|
self.context = context
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(SetContext(self.context, "norm"), target)
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SelfAttention,
|
||||||
|
context: str,
|
||||||
|
sai: "SelfAttentionInjection",
|
||||||
|
) -> None:
|
||||||
|
self.context = context
|
||||||
|
self._sai = [sai] # only to support setting `style_cfg` dynamically
|
||||||
|
|
||||||
|
sa_guided = target.structural_copy()
|
||||||
|
assert isinstance(sa_guided[0], Parallel)
|
||||||
|
sa_guided.replace(
|
||||||
|
sa_guided[0],
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
|
||||||
|
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
|
||||||
|
Lambda(self.compute_averaged_unconditioned_x),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
|
||||||
|
style_cfg = self._sai[0].style_cfg
|
||||||
|
x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionInjection(Passthrough):
|
||||||
|
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||||
|
|
||||||
|
def __init__(self, unet: UNet, style_cfg: float = 0.5) -> None:
|
||||||
|
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||||
|
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||||
|
self.style_cfg = style_cfg
|
||||||
|
self._adapters: list[ReferenceOnlyControlAdapter] = []
|
||||||
|
self._unet = [unet]
|
||||||
|
|
||||||
|
guide_unet = unet.structural_copy()
|
||||||
|
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
||||||
|
sa = attention_block.find(SelfAttention)
|
||||||
|
assert sa is not None and sa.parent is not None
|
||||||
|
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
||||||
|
|
||||||
|
for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)):
|
||||||
|
unet.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||||
|
|
||||||
|
sa = attention_block.find(SelfAttention)
|
||||||
|
assert sa is not None and sa.parent is not None
|
||||||
|
|
||||||
|
self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self))
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Lambda(self.copy_diffusion_context),
|
||||||
|
UseContext("self_attention_injection", "guide"),
|
||||||
|
guide_unet,
|
||||||
|
Lambda(self.restore_diffusion_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unet(self):
|
||||||
|
return self._unet[0]
|
||||||
|
|
||||||
|
def inject(self) -> None:
|
||||||
|
assert self not in self._unet[0], f"{self} is already injected"
|
||||||
|
for adapter in self._adapters:
|
||||||
|
adapter.inject()
|
||||||
|
self.unet.insert(0, self)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
assert self.unet[0] == self, f"{self} is not the first element of target UNet"
|
||||||
|
for adapter in self._adapters:
|
||||||
|
adapter.eject()
|
||||||
|
self.unet.pop(0)
|
||||||
|
|
||||||
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
|
self.set_context("self_attention_injection", {"guide": condition})
|
||||||
|
|
||||||
|
def copy_diffusion_context(self, x: Tensor) -> Tensor:
|
||||||
|
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
|
||||||
|
self.set_context(
|
||||||
|
"self_attention_residuals_buffer",
|
||||||
|
{"buffer": self.use_context("unet")["residuals"]},
|
||||||
|
)
|
||||||
|
self.set_context(
|
||||||
|
"unet",
|
||||||
|
{"residuals": [0.0] * 13},
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def restore_diffusion_context(self, x: Tensor) -> Tensor:
|
||||||
|
self.set_context(
|
||||||
|
"unet",
|
||||||
|
{
|
||||||
|
"residuals": self.use_context("self_attention_residuals_buffer")["buffer"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
|
||||||
|
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")
|
307
src/refiners/foundationals/latent_diffusion/unet.py
Normal file
|
@ -0,0 +1,307 @@
|
||||||
|
from typing import cast, Iterable
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_key: str = "timestep_embedding",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext("diffusion", "timestep"),
|
||||||
|
RangeEncoder(320, 1280, device=device, dtype=dtype),
|
||||||
|
fl.SetContext("range_adapter", context_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Sum):
|
||||||
|
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
|
||||||
|
raise ValueError("Number of input and output channels must be divisible by num_groups.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.eps = eps
|
||||||
|
shortcut = (
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else fl.Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
shortcut,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
channels=channels,
|
||||||
|
context_embedding_dim=768,
|
||||||
|
context_key="clip_text_embedding",
|
||||||
|
num_attention_heads=8,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlocks(fl.Chain):
|
||||||
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=1280, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleBlock(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAccumulator(fl.Passthrough):
|
||||||
|
structural_attrs = ["n"]
|
||||||
|
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
|
||||||
|
),
|
||||||
|
fl.SetContext(context="unet", key="residuals", callback=self.update),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
|
||||||
|
residuals[self.n] = x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConcatenator(fl.Chain):
|
||||||
|
structural_attrs = ["n"]
|
||||||
|
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(fl.Chain):
|
||||||
|
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
clip_embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.clip_embedding_dim = clip_embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||||
|
fl.Sum(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
UpBlocks(),
|
||||||
|
fl.Chain(
|
||||||
|
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=320,
|
||||||
|
out_channels=4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
range_adapter = RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key="timestep_embedding",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
range_adapter.inject(chain)
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
|
||||||
|
block.append(ResidualAccumulator(n))
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
|
||||||
|
block.insert(0, ResidualConcatenator(-n - 2))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"unet": {"residuals": [0.0] * 13},
|
||||||
|
"diffusion": {"timestep": None},
|
||||||
|
"range_adapter": {"timestep_embedding": None},
|
||||||
|
"sampling": {"shapes": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
self.set_context("diffusion", {"timestep": timestep})
|
0
src/refiners/py.typed
Normal file
17
src/refiners/training_utils/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from importlib import import_module
|
||||||
|
from importlib.metadata import requires
|
||||||
|
import sys
|
||||||
|
|
||||||
|
refiners_requires = requires("refiners")
|
||||||
|
assert refiners_requires is not None
|
||||||
|
|
||||||
|
for dep in filter(lambda r: r.endswith('extra == "training"'), refiners_requires):
|
||||||
|
try:
|
||||||
|
import_module(dep.split(" ")[0])
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"
|
||||||
|
" refiners[training]`",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
186
src/refiners/training_utils/callback.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
from typing import TYPE_CHECKING, Generic, Iterable, Any, TypeVar
|
||||||
|
from torch import tensor
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Callback",
|
||||||
|
"GradientNormClipping",
|
||||||
|
"GradientValueClipping",
|
||||||
|
"ClockCallback",
|
||||||
|
"GradientNormLogging",
|
||||||
|
"MonitorLoss",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_norm(parameters: Iterable[Parameter], total_norm: float, clip_norm: float = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`.
|
||||||
|
"""
|
||||||
|
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to clip."
|
||||||
|
clip_coefficient = tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1)
|
||||||
|
for gradient in gradients:
|
||||||
|
gradient.mul_(other=clip_coefficient) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient_value(parameters: Iterable[Parameter], clip_value: float) -> None:
|
||||||
|
"""
|
||||||
|
Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`.
|
||||||
|
"""
|
||||||
|
gradients = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to clip."
|
||||||
|
for gradient in gradients:
|
||||||
|
gradient.clamp_(min=-clip_value, max=clip_value)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Callback(Generic[T]):
|
||||||
|
def on_train_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_epoch_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_batch_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_backward_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_backward_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_optimizer_step_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_compute_loss_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_compute_loss_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_evaluate_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_evaluate_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_lr_scheduler_step_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_checkpoint_save(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.reset()
|
||||||
|
logger.info(f"""Starting training for a total of:
|
||||||
|
{trainer.clock.num_steps} steps.
|
||||||
|
{trainer.clock.num_epochs} epochs.
|
||||||
|
{trainer.clock.num_iterations} iterations.
|
||||||
|
""")
|
||||||
|
trainer.clock.start_timer()
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.stop_timer()
|
||||||
|
logger.info(f"""Training took:
|
||||||
|
{trainer.clock.time_elapsed} seconds.
|
||||||
|
{trainer.clock.iteration} iterations.
|
||||||
|
{trainer.clock.epoch} epochs.
|
||||||
|
{trainer.clock.step} steps.
|
||||||
|
""")
|
||||||
|
|
||||||
|
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
logger.info(f"Epoch {trainer.clock.epoch} started.")
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.epoch += 1
|
||||||
|
trainer.clock.num_batches_processed = 0
|
||||||
|
|
||||||
|
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
logger.info(f"Step {trainer.clock.step} started.")
|
||||||
|
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.clock.step += 1
|
||||||
|
trainer.clock.num_batches_processed += 1
|
||||||
|
trainer.clock.num_minibatches_processed += 1
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
logger.info(f"Iteration {trainer.clock.iteration} ended.")
|
||||||
|
trainer.clock.iteration += 1
|
||||||
|
trainer.clock.num_minibatches_processed = 0
|
||||||
|
|
||||||
|
def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
logger.info("Evaluation started.")
|
||||||
|
|
||||||
|
def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
logger.info("Evaluation ended.")
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
self.epoch_losses: list[float] = []
|
||||||
|
self.iteration_losses: list[float] = []
|
||||||
|
|
||||||
|
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
loss_value = trainer.loss.detach().cpu().item()
|
||||||
|
self.epoch_losses.append(loss_value)
|
||||||
|
self.iteration_losses.append(loss_value)
|
||||||
|
trainer.log(data={"step_loss": loss_value})
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||||
|
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
|
||||||
|
self.iteration_losses = []
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||||
|
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
|
||||||
|
self.epoch_losses = []
|
||||||
|
|
||||||
|
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||||
|
|
||||||
|
|
||||||
|
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
clip_norm = trainer.config.training.clip_grad_norm
|
||||||
|
if clip_norm is not None:
|
||||||
|
clip_gradient_norm(
|
||||||
|
parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
clip_value = trainer.config.training.clip_grad_value
|
||||||
|
if clip_value is not None:
|
||||||
|
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
|
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})
|
242
src/refiners/training_utils/config.py
Normal file
|
@ -0,0 +1,242 @@
|
||||||
|
from logging import warn
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||||
|
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||||
|
from torch.optim import AdamW, SGD, Optimizer, Adam
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from enum import Enum
|
||||||
|
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
import tomli
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from prodigyopt import Prodigy # type: ignore
|
||||||
|
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"parse_number_unit_field",
|
||||||
|
"TimeUnit",
|
||||||
|
"TimeValue",
|
||||||
|
"TrainingConfig",
|
||||||
|
"OptimizerConfig",
|
||||||
|
"Optimizers",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TimeUnit(Enum):
|
||||||
|
STEP = "step"
|
||||||
|
EPOCH = "epoch"
|
||||||
|
ITERATION = "iteration"
|
||||||
|
DEFAULT = "step"
|
||||||
|
|
||||||
|
|
||||||
|
class TimeValue(TypedDict):
|
||||||
|
number: int
|
||||||
|
unit: TimeUnit
|
||||||
|
|
||||||
|
|
||||||
|
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
||||||
|
match value:
|
||||||
|
case str(value_str):
|
||||||
|
number, unit = value_str.split(sep=":")
|
||||||
|
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
||||||
|
case int(number):
|
||||||
|
return {"number": number, "unit": TimeUnit.DEFAULT}
|
||||||
|
case {"number": int(number), "unit": str(unit)}:
|
||||||
|
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported value format: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
|
seed: int = 0
|
||||||
|
gpu_index: int = 0
|
||||||
|
batch_size: int = 1
|
||||||
|
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
||||||
|
clip_grad_norm: float | None = None
|
||||||
|
clip_grad_value: float | None = None
|
||||||
|
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
|
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
||||||
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
class Optimizers(str, Enum):
|
||||||
|
SGD = "SGD"
|
||||||
|
Adam = "Adam"
|
||||||
|
AdamW = "AdamW"
|
||||||
|
AdamW8bit = "AdamW8bit"
|
||||||
|
Lion8bit = "Lion8bit"
|
||||||
|
Prodigy = "Prodigy"
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerType(str, Enum):
|
||||||
|
STEP_LR = "StepLR"
|
||||||
|
EXPONENTIAL_LR = "ExponentialLR"
|
||||||
|
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
|
||||||
|
COSINE_ANNEALING_LR = "CosineAnnealingLR"
|
||||||
|
CONSTANT_LR = "ConstantLR" # not to be confused with PyTorch's ConstantLR
|
||||||
|
LAMBDA_LR = "LambdaLR"
|
||||||
|
ONE_CYCLE_LR = "OneCycleLR"
|
||||||
|
MULTIPLICATIVE_LR = "MultiplicativeLR"
|
||||||
|
COSINE_ANNEALING_WARM_RESTARTS = "CosineAnnealingWarmRestarts"
|
||||||
|
CYCLIC_LR = "CyclicLR"
|
||||||
|
MULTI_STEP_LR = "MultiStepLR"
|
||||||
|
DEFAULT = "ConstantLR"
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerConfig(BaseModel):
|
||||||
|
scheduler_type: SchedulerType = SchedulerType.DEFAULT
|
||||||
|
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
|
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
|
||||||
|
gamma: float = 0.1
|
||||||
|
lr_lambda: Callable[[int], float] | None = None
|
||||||
|
mode: Literal["min", "max"] = "min"
|
||||||
|
factor: float = 0.1
|
||||||
|
patience: int = 10
|
||||||
|
threshold: float = 1e-4
|
||||||
|
cooldown: int = 0
|
||||||
|
milestones: list[int] = []
|
||||||
|
base_lr: float = 1e-7
|
||||||
|
min_lr: float | list[float] = 0
|
||||||
|
max_lr: float | list[float] = 0
|
||||||
|
eta_min: float = 0
|
||||||
|
|
||||||
|
@validator("update_interval", "warmup", pre=True)
|
||||||
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerConfig(BaseModel):
|
||||||
|
optimizer: Optimizers
|
||||||
|
learning_rate: float = 1e-4
|
||||||
|
betas: tuple[float, float] = (0.9, 0.999)
|
||||||
|
eps: float = 1e-8
|
||||||
|
weight_decay: float = 0.0
|
||||||
|
|
||||||
|
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer:
|
||||||
|
match self.optimizer:
|
||||||
|
case Optimizers.SGD:
|
||||||
|
return SGD(
|
||||||
|
params=model_parameters,
|
||||||
|
lr=self.learning_rate,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
case Optimizers.Adam:
|
||||||
|
return Adam(
|
||||||
|
params=model_parameters,
|
||||||
|
lr=self.learning_rate,
|
||||||
|
betas=self.betas,
|
||||||
|
eps=self.eps,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
case Optimizers.AdamW:
|
||||||
|
return AdamW(
|
||||||
|
params=model_parameters,
|
||||||
|
lr=self.learning_rate,
|
||||||
|
betas=self.betas,
|
||||||
|
eps=self.eps,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
case Optimizers.AdamW8bit:
|
||||||
|
return AdamW8bit(
|
||||||
|
params=model_parameters,
|
||||||
|
lr=self.learning_rate,
|
||||||
|
betas=self.betas,
|
||||||
|
eps=self.eps,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
case Optimizers.Lion8bit:
|
||||||
|
return Lion8bit(
|
||||||
|
params=model_parameters,
|
||||||
|
lr=self.learning_rate,
|
||||||
|
betas=self.betas,
|
||||||
|
weight_decay=self.weight_decay, # type: ignore
|
||||||
|
)
|
||||||
|
case Optimizers.Prodigy:
|
||||||
|
if self.learning_rate != 1.0:
|
||||||
|
warn("Prodigy learning rate is not 1.0, this might cause instability.")
|
||||||
|
return Prodigy(
|
||||||
|
lr=self.learning_rate,
|
||||||
|
params=model_parameters,
|
||||||
|
betas=self.betas,
|
||||||
|
weight_decay=self.weight_decay, # type: ignore
|
||||||
|
safeguard_warmup=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
checkpoint: Path | None = None
|
||||||
|
train: bool = True
|
||||||
|
learning_rate: float | None = None # TODO: Implement this
|
||||||
|
|
||||||
|
|
||||||
|
class GyroDropoutConfig(BaseModel):
|
||||||
|
total_subnetworks: int = 512
|
||||||
|
concurent_subnetworks: int = 64
|
||||||
|
iters_per_epoch: int = 512
|
||||||
|
num_features_threshold: float = 5e5
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutConfig(BaseModel):
|
||||||
|
dropout_probability: float = 0.0
|
||||||
|
gyro_dropout: GyroDropoutConfig | None = None
|
||||||
|
|
||||||
|
def apply_dropout(self, model: fl.Chain) -> None:
|
||||||
|
if self.dropout_probability > 0.0:
|
||||||
|
if self.gyro_dropout is not None:
|
||||||
|
apply_gyro_dropout(module=model, probability=self.dropout_probability, **self.gyro_dropout.model_dump())
|
||||||
|
else:
|
||||||
|
apply_dropout(module=model, probability=self.dropout_probability)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
mode: Literal["online", "offline", "disabled"] = "online"
|
||||||
|
project: str
|
||||||
|
entity: str = "finegrain"
|
||||||
|
name: str | None = None
|
||||||
|
tags: list[str] = []
|
||||||
|
group: str | None = None
|
||||||
|
job_type: str | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDatasetConfig(BaseModel):
|
||||||
|
hf_repo: str = "finegrain/unsplash-dummy"
|
||||||
|
revision: str = "main"
|
||||||
|
split: str = "train"
|
||||||
|
use_verification: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointingConfig(BaseModel):
|
||||||
|
save_folder: Path | None = None
|
||||||
|
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
||||||
|
|
||||||
|
@validator("save_interval", pre=True)
|
||||||
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(BaseModel):
|
||||||
|
script: Path # TODO not used for now, but will be used by the cli
|
||||||
|
models: dict[str, ModelConfig]
|
||||||
|
wandb: WandbConfig
|
||||||
|
training: TrainingConfig
|
||||||
|
optimizer: OptimizerConfig
|
||||||
|
scheduler: SchedulerConfig
|
||||||
|
dropout: DropoutConfig
|
||||||
|
dataset: HuggingfaceDatasetConfig
|
||||||
|
checkpointing: CheckpointingConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
||||||
|
with open(file=toml_path, mode="rb") as f:
|
||||||
|
config_dict = tomli.load(f)
|
||||||
|
|
||||||
|
return cls(**config_dict)
|
202
src/refiners/training_utils/dropout.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
|
from torch import Tensor, randint, cat, rand
|
||||||
|
from torch.nn import Dropout as TorchDropout
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Dropout", "GyroDropout", "DropoutCallback"]
|
||||||
|
|
||||||
|
|
||||||
|
class Dropout(TorchDropout, fl.Module):
|
||||||
|
def __init__(self, probability: float = 0.5, inplace: bool = False) -> None:
|
||||||
|
super().__init__(p=probability, inplace=inplace)
|
||||||
|
|
||||||
|
|
||||||
|
class GyroDropout(fl.Module):
|
||||||
|
"""
|
||||||
|
GyroDropout is a variant of dropout that maximizes the ensemble effect during neural network training.
|
||||||
|
It pre-selects a fixed number of dropout masks and periodically selects a subset of them for training.
|
||||||
|
This leads to increased robustness and diversity among the subnetworks, improving accuracy compared to conventional
|
||||||
|
dropout.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
total_subnetworks:
|
||||||
|
The total number of pre-selected subnetworks ('Sigma'). These subnetworks are dropout masks
|
||||||
|
that are precomputed and stored.
|
||||||
|
|
||||||
|
concurrent_subnetworks:
|
||||||
|
The number of subnetworks to use concurrently in each forward pass ('Tau'). A random selection of
|
||||||
|
masks from the precomputed set is used to dropout different portions of the input.
|
||||||
|
|
||||||
|
dropout_probability: float, optional (default=0.5)
|
||||||
|
The probability that an element will be zeroed by the dropout.
|
||||||
|
|
||||||
|
iters_per_epoch:
|
||||||
|
Number of iterations per epoch, used to determine how often the masks should be updated.
|
||||||
|
|
||||||
|
num_features_threshold:
|
||||||
|
If the number of features in the input is greater than this threshold, dropout is skipped. This is because
|
||||||
|
gyro dropout mask size vram usage is proportional to the number of features in the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
total_subnetworks: int,
|
||||||
|
concurrent_subnetworks: int,
|
||||||
|
dropout_probability: float = 0.5,
|
||||||
|
iters_per_epoch: int = 1,
|
||||||
|
num_features_threshold: float = 5e5,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
iters_per_epoch >= total_subnetworks
|
||||||
|
), "The number of iterations per epoch must be greater than the number of masks"
|
||||||
|
self.dropout_probability = dropout_probability
|
||||||
|
self.iters_per_epoch = iters_per_epoch
|
||||||
|
self.total_subnetworks = total_subnetworks
|
||||||
|
self.concurrent_subnetworks = concurrent_subnetworks
|
||||||
|
self.scale = 1 / (1 - self.dropout_probability)
|
||||||
|
self.mask_update_interval = int(self.iters_per_epoch / self.total_subnetworks) * self.concurrent_subnetworks
|
||||||
|
self.preselected_masks: Tensor | None = None
|
||||||
|
self.dropout_mask = None
|
||||||
|
self.training_step = 0
|
||||||
|
self.num_features_threshold = num_features_threshold
|
||||||
|
self.skip_high_num_features = False
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if not self.training:
|
||||||
|
return x
|
||||||
|
if self.skip_high_num_features:
|
||||||
|
return self.basic_dropout(x)
|
||||||
|
if self.training_step == 0:
|
||||||
|
num_features = x.shape[1] * x.shape[2] if x.dim() == 3 else x.shape[1]
|
||||||
|
if num_features > self.num_features_threshold:
|
||||||
|
self.skip_high_num_features = True
|
||||||
|
self.basic_dropout = Dropout(probability=self.dropout_probability)
|
||||||
|
return self.basic_dropout(x)
|
||||||
|
self.init_masks(x=x)
|
||||||
|
|
||||||
|
if self.training_step % self.mask_update_interval == 0:
|
||||||
|
self.update_dropout_mask(x=x)
|
||||||
|
|
||||||
|
self.training_step += 1
|
||||||
|
|
||||||
|
return x * self.dropout_mask * self.scale
|
||||||
|
|
||||||
|
def init_masks(self, x: Tensor) -> None:
|
||||||
|
if x.dim() == 2:
|
||||||
|
self.preselected_masks = (
|
||||||
|
rand(self.total_subnetworks, x.shape[1], device=x.device) > self.dropout_probability
|
||||||
|
)
|
||||||
|
if x.dim() == 3:
|
||||||
|
self.preselected_masks = (
|
||||||
|
rand(self.total_subnetworks, x.shape[1], x.shape[2], device=x.device) > self.dropout_probability
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.preselected_masks is not None, "The input tensor must have 2 or 3 dimensions"
|
||||||
|
self.preselected_masks = self.preselected_masks.float()
|
||||||
|
|
||||||
|
def update_dropout_mask(self, x: Tensor) -> None:
|
||||||
|
assert self.preselected_masks is not None
|
||||||
|
indices = randint(low=0, high=self.total_subnetworks, size=(self.concurrent_subnetworks,), device=x.device)
|
||||||
|
selected_masks = self.preselected_masks[indices]
|
||||||
|
|
||||||
|
repeat_factor = x.shape[0] // self.concurrent_subnetworks
|
||||||
|
remaining = x.shape[0] % self.concurrent_subnetworks
|
||||||
|
repeated_masks = [selected_masks] * repeat_factor
|
||||||
|
if remaining > 0:
|
||||||
|
repeated_masks.append(selected_masks[:remaining])
|
||||||
|
final_masks = cat(tensors=repeated_masks, dim=0)
|
||||||
|
|
||||||
|
if x.dim() == 2:
|
||||||
|
self.dropout_mask = final_masks
|
||||||
|
if x.dim() == 3:
|
||||||
|
self.dropout_mask = final_masks.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
||||||
|
def __init__(self, target: fl.Linear, probability: float = 0.5):
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target, Dropout(probability=probability))
|
||||||
|
|
||||||
|
|
||||||
|
class GyroDropoutAdapter(fl.Chain, Adapter[fl.Linear]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Linear,
|
||||||
|
probability: float = 0.5,
|
||||||
|
total_subnetworks: int = 512,
|
||||||
|
concurrent_subnetworks: int = 64,
|
||||||
|
iters_per_epoch: int = 512,
|
||||||
|
num_features_threshold: float = 5e5,
|
||||||
|
) -> None:
|
||||||
|
self.probability = probability
|
||||||
|
self.total_subnetworks = total_subnetworks
|
||||||
|
self.concurrent_subnetworks = concurrent_subnetworks
|
||||||
|
self.iters_per_epoch = iters_per_epoch
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
GyroDropout(
|
||||||
|
total_subnetworks=total_subnetworks,
|
||||||
|
concurrent_subnetworks=concurrent_subnetworks,
|
||||||
|
dropout_probability=probability,
|
||||||
|
iters_per_epoch=iters_per_epoch,
|
||||||
|
num_features_threshold=num_features_threshold,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
|
||||||
|
for linear, parent in module.walk(fl.Linear):
|
||||||
|
if not linear.weight.requires_grad:
|
||||||
|
continue
|
||||||
|
assert not (
|
||||||
|
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||||
|
), f"{linear} already has a dropout layer"
|
||||||
|
adapter = DropoutAdapter(target=linear, probability=probability)
|
||||||
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gyro_dropout(
|
||||||
|
module: fl.Chain,
|
||||||
|
probability: float = 0.5,
|
||||||
|
total_subnetworks: int = 32,
|
||||||
|
concurrent_subnetworks: int = 16,
|
||||||
|
iters_per_epoch: int = 32,
|
||||||
|
) -> None:
|
||||||
|
for linear, parent in module.walk(fl.Linear):
|
||||||
|
if not linear.weight.requires_grad:
|
||||||
|
continue
|
||||||
|
assert not (
|
||||||
|
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||||
|
), f"{linear} already has a dropout layer"
|
||||||
|
adapter = GyroDropoutAdapter(
|
||||||
|
target=linear,
|
||||||
|
probability=probability,
|
||||||
|
total_subnetworks=total_subnetworks,
|
||||||
|
concurrent_subnetworks=concurrent_subnetworks,
|
||||||
|
iters_per_epoch=iters_per_epoch,
|
||||||
|
)
|
||||||
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
|
||||||
|
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutCallback(Callback["Trainer[ConfigType, Any]"]):
|
||||||
|
def on_train_begin(self, trainer: "Trainer[ConfigType, Any]") -> None:
|
||||||
|
dropout_config = trainer.config.dropout
|
||||||
|
chain_models = [model for model in trainer.models.values() if isinstance(model, fl.Chain)]
|
||||||
|
for model in chain_models:
|
||||||
|
dropout_config.apply_dropout(model=model)
|
23
src/refiners/training_utils/huggingface_datasets.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from datasets import load_dataset as _load_dataset, VerificationMode # type: ignore
|
||||||
|
from typing import Any, Generic, Protocol, TypeVar, cast
|
||||||
|
|
||||||
|
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDataset(Generic[T], Protocol):
|
||||||
|
def __getitem__(self, index: int) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_dataset(
|
||||||
|
path: str, revision: str = "main", split: str = "train", use_verification: bool = False
|
||||||
|
) -> HuggingfaceDataset[Any]:
|
||||||
|
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
|
||||||
|
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
|
||||||
|
return cast(HuggingfaceDataset[Any], dataset)
|
238
src/refiners/training_utils/latent_diffusion.py
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, TypeVar, TypedDict, cast
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||||
|
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
||||||
|
from loguru import logger
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from torchvision.transforms import RandomCrop # type: ignore
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from PIL import Image
|
||||||
|
from functools import cached_property
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver
|
||||||
|
from torch.nn.functional import mse_loss
|
||||||
|
import random
|
||||||
|
from refiners.training_utils.wandb import WandbLoggable
|
||||||
|
from refiners.training_utils.trainer import Trainer
|
||||||
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionConfig(BaseModel):
|
||||||
|
unconditional_sampling_probability: float = 0.2
|
||||||
|
offset_noise: float = 0.1
|
||||||
|
min_timestep: int = 0
|
||||||
|
max_timestep: int = 999
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiffusionConfig(BaseModel):
|
||||||
|
seed: int = 0
|
||||||
|
num_inference_steps: int = 30
|
||||||
|
use_short_prompts: bool = False
|
||||||
|
prompts: list[str] = []
|
||||||
|
num_images_per_prompt: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneLatentDiffusionConfig(BaseConfig):
|
||||||
|
latent_diffusion: LatentDiffusionConfig
|
||||||
|
test_diffusion: TestDiffusionConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextEmbeddingLatentsBatch:
|
||||||
|
text_embeddings: Tensor
|
||||||
|
latents: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionImage(TypedDict):
|
||||||
|
caption: str
|
||||||
|
image: Image.Image
|
||||||
|
|
||||||
|
|
||||||
|
ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
|
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
|
||||||
|
self.trainer = trainer
|
||||||
|
self.config = trainer.config
|
||||||
|
self.device = self.trainer.device
|
||||||
|
self.lda = self.trainer.lda
|
||||||
|
self.text_encoder = self.trainer.text_encoder
|
||||||
|
self.dataset = self.load_huggingface_dataset()
|
||||||
|
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms
|
||||||
|
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
|
||||||
|
|
||||||
|
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
|
||||||
|
dataset_config = self.config.dataset
|
||||||
|
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
|
||||||
|
return cast(
|
||||||
|
HuggingfaceDataset[CaptionImage],
|
||||||
|
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
|
||||||
|
)
|
||||||
|
|
||||||
|
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
|
||||||
|
return resize_image(image=image, min_size=min_size, max_size=max_size)
|
||||||
|
|
||||||
|
def process_caption(self, caption: str) -> str:
|
||||||
|
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
|
||||||
|
item = self.dataset[index]
|
||||||
|
caption, image = item["caption"], item["image"]
|
||||||
|
resized_image = self.resize_image(image=image)
|
||||||
|
processed_image: Image.Image = self.process_image(resized_image)
|
||||||
|
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||||
|
processed_caption = self.process_caption(caption=caption)
|
||||||
|
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device)
|
||||||
|
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
|
||||||
|
|
||||||
|
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
|
||||||
|
text_embeddings = cat(tensors=[item.text_embeddings for item in batch])
|
||||||
|
latents = cat(tensors=[item.latents for item in batch])
|
||||||
|
return TextEmbeddingLatentsBatch(text_embeddings=text_embeddings, latents=latents)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
|
@cached_property
|
||||||
|
def unet(self) -> UNet:
|
||||||
|
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
||||||
|
return UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def text_encoder(self) -> CLIPTextEncoderL:
|
||||||
|
assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry."
|
||||||
|
return CLIPTextEncoderL(device=self.device).to(device=self.device)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def lda(self) -> LatentDiffusionAutoencoder:
|
||||||
|
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
|
||||||
|
return LatentDiffusionAutoencoder(device=self.device).to(device=self.device)
|
||||||
|
|
||||||
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
|
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
|
||||||
|
|
||||||
|
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||||
|
return TextEmbeddingLatentsDataset(trainer=self)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def ddpm_scheduler(self) -> DDPM:
|
||||||
|
return DDPM(
|
||||||
|
num_inference_steps=1000,
|
||||||
|
device=self.device,
|
||||||
|
).to(device=self.device)
|
||||||
|
|
||||||
|
def sample_timestep(self) -> Tensor:
|
||||||
|
random_step = random.randint(
|
||||||
|
a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep
|
||||||
|
)
|
||||||
|
self.current_step = random_step
|
||||||
|
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
||||||
|
|
||||||
|
def sample_noise(self, size: tuple[int, int, int, int], dtype: DType | None = None) -> Tensor:
|
||||||
|
return sample_noise(
|
||||||
|
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
|
||||||
|
clip_text_embedding, latents = batch.text_embeddings, batch.latents
|
||||||
|
timestep = self.sample_timestep()
|
||||||
|
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
|
||||||
|
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)
|
||||||
|
self.unet.set_timestep(timestep=timestep)
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
prediction = self.unet(noisy_latents)
|
||||||
|
loss = mse_loss(input=prediction, target=noise)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def compute_evaluation(self) -> None:
|
||||||
|
sd = StableDiffusion_1(
|
||||||
|
unet=self.unet,
|
||||||
|
lda=self.lda,
|
||||||
|
clip_text_encoder=self.text_encoder,
|
||||||
|
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
prompts = self.config.test_diffusion.prompts
|
||||||
|
num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt
|
||||||
|
if self.config.test_diffusion.use_short_prompts:
|
||||||
|
prompts = [prompt.split(sep=",")[0] for prompt in prompts]
|
||||||
|
images: dict[str, WandbLoggable] = {}
|
||||||
|
for prompt in prompts:
|
||||||
|
canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt))
|
||||||
|
for i in range(num_images_per_prompt):
|
||||||
|
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
|
||||||
|
x = randn(1, 4, 64, 64, device=self.device)
|
||||||
|
clip_text_embedding = sd.compute_text_embedding(text=prompt).to(device=self.device)
|
||||||
|
negative_clip_text_embedding = sd.compute_text_embedding(text="").to(device=self.device)
|
||||||
|
for step in sd.steps:
|
||||||
|
x = sd(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
)
|
||||||
|
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
|
||||||
|
images[prompt] = canvas_image
|
||||||
|
self.log(data=images)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_noise(
|
||||||
|
size: tuple[int, int, int, int],
|
||||||
|
offset_noise: float = 0.1,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType | None = None,
|
||||||
|
generator: Generator | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Sample noise from a normal distribution.
|
||||||
|
|
||||||
|
If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate
|
||||||
|
images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise.
|
||||||
|
"""
|
||||||
|
device = Device(device)
|
||||||
|
noise = randn(*size, generator=generator, device=device, dtype=dtype)
|
||||||
|
return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
|
||||||
|
image_min_size = min(image.size)
|
||||||
|
if image_min_size > max_size:
|
||||||
|
if image_min_size == image.size[0]:
|
||||||
|
image = image.resize(size=(max_size, int(max_size * image.size[1] / image.size[0])))
|
||||||
|
else:
|
||||||
|
image = image.resize(size=(int(max_size * image.size[0] / image.size[1]), max_size))
|
||||||
|
if image_min_size < min_size:
|
||||||
|
if image_min_size == image.size[0]:
|
||||||
|
image = image.resize(size=(min_size, int(min_size * image.size[1] / image.size[0])))
|
||||||
|
else:
|
||||||
|
image = image.resize(size=(int(min_size * image.size[0] / image.size[1]), min_size))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
|
||||||
|
def on_train_begin(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
||||||
|
self.timestep_bins: dict[int, list[float]] = {i: [] for i in range(10)}
|
||||||
|
|
||||||
|
def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
||||||
|
loss_value = trainer.loss.detach().cpu().item()
|
||||||
|
current_step = trainer.current_step
|
||||||
|
bin_index = min(current_step // 100, 9)
|
||||||
|
self.timestep_bins[bin_index].append(loss_value)
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
||||||
|
log_data = {}
|
||||||
|
for bin_index, losses in self.timestep_bins.items():
|
||||||
|
if losses:
|
||||||
|
avg_loss = sum(losses) / len(losses)
|
||||||
|
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
|
||||||
|
self.timestep_bins[bin_index] = []
|
||||||
|
|
||||||
|
trainer.log(data=log_data)
|
546
src/refiners/training_utils/trainer.py
Normal file
|
@ -0,0 +1,546 @@
|
||||||
|
from functools import cached_property, wraps
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from torch import device as Device, Tensor, get_rng_state, no_grad, set_rng_state, cuda, stack
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch.autograd import backward
|
||||||
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||||
|
from loguru import logger
|
||||||
|
from refiners.fluxion import layers as fl
|
||||||
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
from refiners.training_utils.wandb import WandbLogger, WandbLoggable
|
||||||
|
from refiners.training_utils.config import BaseConfig, TimeUnit, TimeValue, SchedulerType
|
||||||
|
from refiners.training_utils.dropout import DropoutCallback
|
||||||
|
from refiners.training_utils.callback import (
|
||||||
|
Callback,
|
||||||
|
ClockCallback,
|
||||||
|
GradientNormClipping,
|
||||||
|
GradientValueClipping,
|
||||||
|
GradientNormLogging,
|
||||||
|
MonitorLoss,
|
||||||
|
)
|
||||||
|
from torch.optim.lr_scheduler import (
|
||||||
|
StepLR,
|
||||||
|
ExponentialLR,
|
||||||
|
ReduceLROnPlateau,
|
||||||
|
CosineAnnealingLR,
|
||||||
|
LambdaLR,
|
||||||
|
OneCycleLR,
|
||||||
|
LRScheduler,
|
||||||
|
MultiplicativeLR,
|
||||||
|
CosineAnnealingWarmRestarts,
|
||||||
|
CyclicLR,
|
||||||
|
MultiStepLR,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
||||||
|
|
||||||
|
|
||||||
|
def count_learnable_parameters(parameters: Iterable[Parameter]) -> int:
|
||||||
|
return sum(p.numel() for p in parameters if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def human_readable_number(number: int) -> str:
|
||||||
|
float_number = float(number)
|
||||||
|
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||||
|
if abs(float_number) < 1000:
|
||||||
|
return f"{float_number:.1f}{unit}"
|
||||||
|
float_number /= 1000
|
||||||
|
return f"{float_number:.1f}E"
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int | None = None) -> None:
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
logger.info(f"Using random seed: {seed}")
|
||||||
|
random.seed(a=seed)
|
||||||
|
np.random.seed(seed=seed)
|
||||||
|
manual_seed(seed=seed)
|
||||||
|
cuda.manual_seed_all(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
|
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
||||||
|
"""
|
||||||
|
Decorator for setting a random seed within the scope of a function.
|
||||||
|
|
||||||
|
This decorator sets the random seed for Python's built-in `random` module,
|
||||||
|
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
||||||
|
function is executed, it restores the state of the random number generators
|
||||||
|
to what it was before the function was called. This is useful for ensuring
|
||||||
|
reproducibility for specific parts of the code without affecting randomness
|
||||||
|
elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
@wraps(func)
|
||||||
|
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
random_state = random.getstate()
|
||||||
|
numpy_state = np.random.get_state()
|
||||||
|
torch_state = get_rng_state()
|
||||||
|
cuda_torch_state = cuda.get_rng_state()
|
||||||
|
actual_seed = seed(*args) if callable(seed) else seed
|
||||||
|
seed_everything(seed=actual_seed)
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
random.setstate(random_state)
|
||||||
|
np.random.set_state(numpy_state)
|
||||||
|
set_rng_state(torch_state)
|
||||||
|
cuda.set_rng_state(cuda_torch_state)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return inner_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupScheduler(LRScheduler):
|
||||||
|
_step_count: int # defined by LRScheduler
|
||||||
|
|
||||||
|
def __init__(self, optimizer: Optimizer, scheduler: LRScheduler, warmup_steps: int = 0) -> None:
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.scheduler = scheduler
|
||||||
|
super().__init__(optimizer=optimizer)
|
||||||
|
|
||||||
|
def get_lr(self) -> list[float] | float: # type: ignore
|
||||||
|
if self._step_count < self.warmup_steps:
|
||||||
|
return [base_lr * self._step_count / self.warmup_steps for base_lr in self.base_lrs]
|
||||||
|
return self.scheduler.get_lr()
|
||||||
|
|
||||||
|
def step(self, epoch: int | None = None) -> None:
|
||||||
|
if self._step_count < self.warmup_steps:
|
||||||
|
super().step()
|
||||||
|
else:
|
||||||
|
self.scheduler.step(epoch=epoch)
|
||||||
|
self._step_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingClock:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_length: int,
|
||||||
|
batch_size: int,
|
||||||
|
training_duration: TimeValue,
|
||||||
|
gradient_accumulation: TimeValue,
|
||||||
|
evaluation_interval: TimeValue,
|
||||||
|
lr_scheduler_interval: TimeValue,
|
||||||
|
checkpointing_save_interval: TimeValue,
|
||||||
|
) -> None:
|
||||||
|
self.dataset_length = dataset_length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.training_duration = training_duration
|
||||||
|
self.gradient_accumulation = gradient_accumulation
|
||||||
|
self.evaluation_interval = evaluation_interval
|
||||||
|
self.lr_scheduler_interval = lr_scheduler_interval
|
||||||
|
self.checkpointing_save_interval = checkpointing_save_interval
|
||||||
|
self.num_batches_per_epoch = dataset_length // batch_size
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.step = 0
|
||||||
|
self.epoch = 0
|
||||||
|
self.iteration = 0
|
||||||
|
self.num_batches_processed = 0
|
||||||
|
self.num_minibatches_processed = 0
|
||||||
|
self.loss: Tensor | None = None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||||
|
return {
|
||||||
|
TimeUnit.STEP: 1,
|
||||||
|
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||||
|
TimeUnit.ITERATION: self.gradient_accumulation["number"] * {
|
||||||
|
TimeUnit.STEP: 1,
|
||||||
|
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||||
|
}.get(self.gradient_accumulation["unit"], 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||||
|
return number * self.unit_to_steps[unit]
|
||||||
|
|
||||||
|
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||||
|
return steps // self.unit_to_steps[unit]
|
||||||
|
|
||||||
|
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||||
|
number, unit = time_value["number"], time_value["unit"]
|
||||||
|
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
||||||
|
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_epochs(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_iterations(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_steps(self) -> int:
|
||||||
|
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_step_per_iteration(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def num_step_per_evaluation(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.step = 0
|
||||||
|
self.epoch = 0
|
||||||
|
self.iteration = 0
|
||||||
|
self.num_batches_processed = 0
|
||||||
|
self.num_minibatches_processed = 0
|
||||||
|
|
||||||
|
def start_timer(self) -> None:
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def stop_timer(self) -> None:
|
||||||
|
self.end_time = time.time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_elapsed(self) -> int:
|
||||||
|
assert self.start_time is not None, "Timer has not been started yet."
|
||||||
|
return int(time.time() - self.start_time)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def evalution_interval_steps(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def lr_scheduler_interval_steps(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def checkpointing_save_interval_steps(self) -> int:
|
||||||
|
return self.convert_time_unit_to_steps(
|
||||||
|
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_optimizer_step(self) -> bool:
|
||||||
|
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_lr_scheduler_step(self) -> bool:
|
||||||
|
return self.step % self.lr_scheduler_interval_steps == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self) -> bool:
|
||||||
|
return self.step >= self.num_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_evaluation_step(self) -> bool:
|
||||||
|
return self.step % self.evalution_interval_steps == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_checkpointing_step(self) -> bool:
|
||||||
|
return self.step % self.checkpointing_save_interval_steps == 0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
||||||
|
"""
|
||||||
|
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
|
||||||
|
"""
|
||||||
|
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
|
||||||
|
assert gradients, "The model has no gradients to compute the norm."
|
||||||
|
total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
|
||||||
|
return total_norm # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
Batch = TypeVar("Batch")
|
||||||
|
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(Generic[ConfigType, Batch]):
|
||||||
|
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.clock = TrainingClock(
|
||||||
|
dataset_length=self.dataset_length,
|
||||||
|
batch_size=config.training.batch_size,
|
||||||
|
training_duration=config.training.duration,
|
||||||
|
evaluation_interval=config.training.evaluation_interval,
|
||||||
|
gradient_accumulation=config.training.gradient_accumulation,
|
||||||
|
lr_scheduler_interval=config.scheduler.update_interval,
|
||||||
|
checkpointing_save_interval=config.checkpointing.save_interval,
|
||||||
|
)
|
||||||
|
self.callbacks = callbacks or []
|
||||||
|
self.callbacks += self.default_callbacks()
|
||||||
|
self.load_wandb()
|
||||||
|
self.load_models()
|
||||||
|
self.prepare_models()
|
||||||
|
self.prepare_checkpointing()
|
||||||
|
|
||||||
|
def default_callbacks(self) -> list[Callback[Any]]:
|
||||||
|
return [
|
||||||
|
ClockCallback(),
|
||||||
|
MonitorLoss(),
|
||||||
|
GradientNormLogging(),
|
||||||
|
GradientValueClipping(),
|
||||||
|
GradientNormClipping(),
|
||||||
|
DropoutCallback(),
|
||||||
|
]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def device(self) -> Device:
|
||||||
|
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}")
|
||||||
|
logger.info(f"Using device: {selected_device}")
|
||||||
|
return selected_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> list[Parameter]:
|
||||||
|
"""Returns a list of all parameters in all models"""
|
||||||
|
return [param for model in self.models.values() for param in model.parameters()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def learnable_parameters(self) -> list[Parameter]:
|
||||||
|
"""Returns a list of learnable parameters in all models"""
|
||||||
|
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def learnable_parameter_count(self) -> int:
|
||||||
|
"""Returns the number of learnable parameters in all models"""
|
||||||
|
return count_learnable_parameters(parameters=self.learnable_parameters)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gradients(self) -> list[Tensor]:
|
||||||
|
"""Returns a list of detached gradients for all learnable parameters in all models"""
|
||||||
|
return [
|
||||||
|
param.grad.detach()
|
||||||
|
for model in self.models.values()
|
||||||
|
for param in model.parameters()
|
||||||
|
if param.grad is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_gradient_norm(self) -> float:
|
||||||
|
"""Returns the total gradient norm for all learnable parameters in all models"""
|
||||||
|
return compute_grad_norm(parameters=self.parameters)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def optimizer(self) -> Optimizer:
|
||||||
|
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
|
||||||
|
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
|
||||||
|
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def lr_scheduler(self) -> LRScheduler:
|
||||||
|
config = self.config.scheduler
|
||||||
|
step_size = self.clock.convert_time_unit_to_steps(
|
||||||
|
number=config.update_interval["number"], unit=config.update_interval["unit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
match config.scheduler_type:
|
||||||
|
case SchedulerType.CONSTANT_LR:
|
||||||
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lambda _: 1.0)
|
||||||
|
case SchedulerType.STEP_LR:
|
||||||
|
lr_scheduler = StepLR(optimizer=self.optimizer, step_size=step_size, gamma=config.gamma)
|
||||||
|
case SchedulerType.EXPONENTIAL_LR:
|
||||||
|
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
||||||
|
case SchedulerType.COSINE_ANNEALING_LR:
|
||||||
|
lr_scheduler = CosineAnnealingLR(optimizer=self.optimizer, T_max=step_size, eta_min=config.eta_min)
|
||||||
|
case SchedulerType.REDUCE_LR_ON_PLATEAU:
|
||||||
|
lr_scheduler = cast(
|
||||||
|
LRScheduler,
|
||||||
|
ReduceLROnPlateau(
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
mode=config.mode,
|
||||||
|
factor=config.factor,
|
||||||
|
patience=config.patience,
|
||||||
|
threshold=config.threshold,
|
||||||
|
cooldown=config.cooldown,
|
||||||
|
min_lr=config.min_lr,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case SchedulerType.LAMBDA_LR:
|
||||||
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use LambdaLR"
|
||||||
|
lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
|
case SchedulerType.ONE_CYCLE_LR:
|
||||||
|
lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=config.max_lr, total_steps=step_size)
|
||||||
|
case SchedulerType.MULTIPLICATIVE_LR:
|
||||||
|
assert config.lr_lambda is not None, "lr_lambda must be specified to use MultiplicativeLR"
|
||||||
|
lr_scheduler = MultiplicativeLR(optimizer=self.optimizer, lr_lambda=config.lr_lambda)
|
||||||
|
case SchedulerType.COSINE_ANNEALING_WARM_RESTARTS:
|
||||||
|
lr_scheduler = CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=step_size)
|
||||||
|
case SchedulerType.CYCLIC_LR:
|
||||||
|
lr_scheduler = CyclicLR(optimizer=self.optimizer, base_lr=config.base_lr, max_lr=config.max_lr)
|
||||||
|
case SchedulerType.MULTI_STEP_LR:
|
||||||
|
lr_scheduler = MultiStepLR(optimizer=self.optimizer, milestones=config.milestones, gamma=config.gamma)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")
|
||||||
|
|
||||||
|
warmup_steps = self.clock.convert_time_unit_to_steps(number=config.warmup["number"], unit=config.warmup["unit"])
|
||||||
|
if warmup_steps > 0:
|
||||||
|
lr_scheduler = WarmupScheduler(
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
scheduler=lr_scheduler,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def models(self) -> dict[str, fl.Module]:
|
||||||
|
return self.load_models()
|
||||||
|
|
||||||
|
def set_models_to_train_mode(self) -> None:
|
||||||
|
for model in self.models.values():
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
def set_models_to_eval_mode(self) -> None:
|
||||||
|
for model in self.models.values():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
def log(self, data: dict[str, WandbLoggable]) -> None:
|
||||||
|
self.wandb.log(data=data, step=self.clock.step)
|
||||||
|
|
||||||
|
def load_wandb(self) -> None:
|
||||||
|
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
|
||||||
|
self.wandb = WandbLogger(init_config=init_config)
|
||||||
|
|
||||||
|
def prepare_model(self, model_name: str) -> None:
|
||||||
|
model = self.models[model_name]
|
||||||
|
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
||||||
|
model.load_from_safetensors(tensors_path=checkpoint)
|
||||||
|
else:
|
||||||
|
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
||||||
|
model.requires_grad_(requires_grad=self.config.models[model_name].train)
|
||||||
|
model.to(self.device)
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
def prepare_models(self) -> None:
|
||||||
|
assert self.models, "No models found."
|
||||||
|
for model_name in self.models:
|
||||||
|
self.prepare_model(model_name=model_name)
|
||||||
|
|
||||||
|
def prepare_checkpointing(self) -> None:
|
||||||
|
if self.config.checkpointing.save_folder is not None:
|
||||||
|
assert self.config.checkpointing.save_folder.is_dir()
|
||||||
|
self.checkpoints_save_folder = (
|
||||||
|
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
|
||||||
|
)
|
||||||
|
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
|
||||||
|
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
|
||||||
|
else:
|
||||||
|
self.checkpoints_save_folder = None
|
||||||
|
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
||||||
|
|
||||||
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
|
raise NotImplementedError("The `load_models` method must be implemented in the subclass.")
|
||||||
|
|
||||||
|
def load_dataset(self) -> Dataset[Batch]:
|
||||||
|
raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dataset(self) -> Dataset[Batch]:
|
||||||
|
return self.load_dataset()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dataset_length(self) -> int:
|
||||||
|
assert hasattr(self.dataset, "__len__"), "The dataset must implement the `__len__` method."
|
||||||
|
return len(self.dataset) # type: ignore
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dataloader(self) -> DataLoader[Batch]:
|
||||||
|
collate_fn = getattr(self.dataset, "collate_fn", None)
|
||||||
|
return DataLoader(
|
||||||
|
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def checkpointing_enabled(self) -> bool:
|
||||||
|
return self.checkpoints_save_folder is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ensure_checkpoints_save_folder(self) -> Path:
|
||||||
|
assert self.checkpoints_save_folder is not None
|
||||||
|
return self.checkpoints_save_folder
|
||||||
|
|
||||||
|
def compute_loss(self, batch: Batch) -> Tensor:
|
||||||
|
raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.")
|
||||||
|
|
||||||
|
def compute_evaluation(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def backward(self) -> None:
|
||||||
|
"""Backward pass on the loss."""
|
||||||
|
self._call_callbacks(event_name="on_backward_begin")
|
||||||
|
scaled_loss = self.loss / self.clock.num_step_per_iteration
|
||||||
|
backward(tensors=scaled_loss)
|
||||||
|
self._call_callbacks(event_name="on_backward_end")
|
||||||
|
if self.clock.is_optimizer_step:
|
||||||
|
self._call_callbacks(event_name="on_optimizer_step_begin")
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self._call_callbacks(event_name="on_optimizer_step_end")
|
||||||
|
if self.clock.is_lr_scheduler_step:
|
||||||
|
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
||||||
|
if self.clock.is_evaluation_step:
|
||||||
|
self.evaluate()
|
||||||
|
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
|
||||||
|
self._call_callbacks(event_name="on_checkpoint_save")
|
||||||
|
|
||||||
|
def step(self, batch: Batch) -> None:
|
||||||
|
"""Perform a single training step."""
|
||||||
|
self._call_callbacks(event_name="on_compute_loss_begin")
|
||||||
|
loss = self.compute_loss(batch=batch)
|
||||||
|
self.loss = loss
|
||||||
|
self._call_callbacks(event_name="on_compute_loss_end")
|
||||||
|
self.backward()
|
||||||
|
|
||||||
|
def epoch(self) -> None:
|
||||||
|
"""Perform a single epoch."""
|
||||||
|
for batch in self.dataloader:
|
||||||
|
self._call_callbacks(event_name="on_batch_begin")
|
||||||
|
self.step(batch=batch)
|
||||||
|
self._call_callbacks(event_name="on_batch_end")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:
|
||||||
|
return instance.config.training.seed
|
||||||
|
|
||||||
|
@scoped_seed(seed=get_training_seed)
|
||||||
|
def train(self) -> None:
|
||||||
|
"""Train the model."""
|
||||||
|
self.set_models_to_train_mode()
|
||||||
|
self._call_callbacks(event_name="on_train_begin")
|
||||||
|
assert self.learnable_parameters, "There are no learnable parameters in the models."
|
||||||
|
self.evaluate()
|
||||||
|
while not self.clock.done:
|
||||||
|
self._call_callbacks(event_name="on_epoch_begin")
|
||||||
|
self.epoch()
|
||||||
|
self._call_callbacks(event_name="on_epoch_end")
|
||||||
|
self._call_callbacks(event_name="on_train_end")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_evaluation_seed(instance: "Trainer[BaseConfig, Any]") -> int:
|
||||||
|
return instance.config.training.evaluation_seed
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
@scoped_seed(seed=get_evaluation_seed)
|
||||||
|
def evaluate(self) -> None:
|
||||||
|
"""Evaluate the model."""
|
||||||
|
self.set_models_to_eval_mode()
|
||||||
|
self._call_callbacks(event_name="on_evaluate_begin")
|
||||||
|
self.compute_evaluation()
|
||||||
|
self._call_callbacks(event_name="on_evaluate_end")
|
||||||
|
self.set_models_to_train_mode()
|
||||||
|
|
||||||
|
def _call_callbacks(self, event_name: str) -> None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
getattr(callback, event_name)(self)
|
61
src/refiners/training_utils/wandb.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from typing import Any
|
||||||
|
import wandb
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WandbLogger",
|
||||||
|
"WandbLoggable",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
number = float | int
|
||||||
|
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_wandb(value: WandbLoggable) -> Any:
|
||||||
|
match value:
|
||||||
|
case Image.Image():
|
||||||
|
return convert_to_wandb_image(value=value)
|
||||||
|
case list():
|
||||||
|
return convert_to_wandb_histogram(value=value)
|
||||||
|
case dict():
|
||||||
|
return convert_to_wandb_table(value=value)
|
||||||
|
case _:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_wandb_image(value: Image.Image) -> wandb.Image:
|
||||||
|
return wandb.Image(data_or_path=value)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_wandb_histogram(value: list[number]) -> wandb.Histogram:
|
||||||
|
return wandb.Histogram(sequence=value)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_wandb_table(value: dict[str, list[number]]) -> wandb.Table:
|
||||||
|
assert all(
|
||||||
|
isinstance(v, list) and len(v) == len(next(iter(value.values()))) for v in value.values()
|
||||||
|
), "Expected a dictionary of lists of the same size"
|
||||||
|
columns = list(value.keys())
|
||||||
|
data_rows = list(zip(*value.values()))
|
||||||
|
return wandb.Table(columns=columns, data=data_rows)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger:
|
||||||
|
def __init__(self, init_config: dict[str, Any] = {}) -> None:
|
||||||
|
self.wandb_run = wandb.init(**init_config) # type: ignore
|
||||||
|
|
||||||
|
def log(self, data: dict[str, WandbLoggable], step: int) -> None:
|
||||||
|
converted_data = {key: convert_to_wandb(value=value) for key, value in data.items()}
|
||||||
|
self.wandb_run.log(converted_data, step=step) # type: ignore
|
||||||
|
|
||||||
|
def update_summary(self, key: str, value: Any) -> None:
|
||||||
|
self.wandb_run.summary[key] = value # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def project_name(self) -> str:
|
||||||
|
return self.wandb_run.project_name() # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_name(self) -> str:
|
||||||
|
return self.wandb_run.name or "" # type: ignore
|
0
tests/__init__.py
Normal file
82
tests/adapters/test_adapter.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import pytest
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.layers import Chain, Linear
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
||||||
|
def __init__(self, target: Linear):
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChainAdapter(Chain, Adapter[Chain]):
|
||||||
|
def __init__(self, target: Chain):
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chain() -> Chain:
|
||||||
|
return Chain(Chain(Linear(2, 2)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_module_adapter_insertion(chain: Chain):
|
||||||
|
parent = chain.Chain
|
||||||
|
adaptee = parent.Linear
|
||||||
|
|
||||||
|
adapter = DummyLinearAdapter(adaptee)
|
||||||
|
|
||||||
|
adapter.inject(parent)
|
||||||
|
assert adapter.parent == parent
|
||||||
|
assert adapter in iter(parent)
|
||||||
|
assert adaptee not in iter(parent)
|
||||||
|
|
||||||
|
adapter.eject()
|
||||||
|
assert adapter.parent is None
|
||||||
|
assert adapter not in iter(parent)
|
||||||
|
assert adaptee in iter(parent)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_adapter_insertion(chain: Chain):
|
||||||
|
parent = chain
|
||||||
|
adaptee = parent.Chain
|
||||||
|
|
||||||
|
adapter = DummyChainAdapter(adaptee)
|
||||||
|
assert adaptee.parent == parent
|
||||||
|
|
||||||
|
adapter.inject()
|
||||||
|
assert adapter.parent == parent
|
||||||
|
assert adaptee.parent == adapter
|
||||||
|
assert adapter in iter(parent)
|
||||||
|
assert adaptee not in iter(parent)
|
||||||
|
|
||||||
|
adapter.eject()
|
||||||
|
assert adapter.parent is None
|
||||||
|
assert adaptee.parent == parent
|
||||||
|
assert adapter not in iter(parent)
|
||||||
|
assert adaptee in iter(parent)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
||||||
|
parent = chain.Chain
|
||||||
|
adaptee = parent.Linear
|
||||||
|
|
||||||
|
adapter = DummyLinearAdapter(adaptee)
|
||||||
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
clone = chain.structural_copy()
|
||||||
|
cloned_adapter = clone.Chain.DummyLinearAdapter
|
||||||
|
assert cloned_adapter.parent == clone.Chain
|
||||||
|
assert cloned_adapter.target == adaptee
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_adapter_structural_copy(chain: Chain):
|
||||||
|
# Chain adapters cannot be copied by default.
|
||||||
|
adapter = DummyChainAdapter(chain.Chain)
|
||||||
|
adapter.inject()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
chain.structural_copy()
|
||||||
|
|
||||||
|
adapter.eject()
|
||||||
|
chain.structural_copy()
|
29
tests/adapters/test_lora.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
from refiners.adapters.lora import Lora, LoraAdapter
|
||||||
|
from torch import randn, allclose
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
),
|
||||||
|
fl.Linear(in_features=1, out_features=2),
|
||||||
|
)
|
||||||
|
x = randn(1, 1)
|
||||||
|
y = chain(x)
|
||||||
|
|
||||||
|
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
|
||||||
|
lora_adapter.inject(chain.Chain)
|
||||||
|
|
||||||
|
assert isinstance(lora_adapter[1], Lora)
|
||||||
|
assert allclose(input=chain(x), other=y)
|
||||||
|
assert lora_adapter.parent == chain.Chain
|
||||||
|
|
||||||
|
lora_adapter.eject()
|
||||||
|
assert isinstance(chain.Chain[0], fl.Linear)
|
||||||
|
assert len(chain) == 2
|
||||||
|
|
||||||
|
lora_adapter.inject(chain.Chain)
|
||||||
|
assert isinstance(chain.Chain[0], LoraAdapter)
|
25
tests/adapters/test_range_adapter.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import torch
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
from refiners.adapters.range_adapter import RangeEncoder
|
||||||
|
from refiners.fluxion.layers import Chain, Linear
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
||||||
|
def __init__(self, target: Linear):
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
|
||||||
|
def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-433
|
||||||
|
dtype = torch.float64
|
||||||
|
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
||||||
|
|
||||||
|
adaptee = chain.RangeEncoder.Linear_1
|
||||||
|
adapter = DummyLinearAdapter(adaptee)
|
||||||
|
adapter.inject(chain.RangeEncoder)
|
||||||
|
|
||||||
|
assert adapter.parent == chain.RangeEncoder
|
||||||
|
|
||||||
|
x = torch.tensor([42], device=test_device)
|
||||||
|
y = chain(x)
|
||||||
|
assert y.dtype == dtype
|
25
tests/conftest.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from pytest import fixture
|
||||||
|
|
||||||
|
PARENT_PATH = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_device() -> torch.device:
|
||||||
|
test_device = os.getenv("REFINERS_TEST_DEVICE")
|
||||||
|
if not test_device:
|
||||||
|
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
return torch.device(test_device)
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_weights_path() -> Path:
|
||||||
|
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
|
||||||
|
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_e2e_path() -> Path:
|
||||||
|
return PARENT_PATH / "e2e"
|
709
tests/e2e/test_diffusion.py
Normal file
|
@ -0,0 +1,709 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
from warnings import warn
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
|
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||||
|
|
||||||
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
return test_e2e_path / "test_diffusion_ref"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cutecat_init(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "cutecat_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def kitchen_dog(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "kitchen_dog.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
|
def controlnet_data(
|
||||||
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||||
|
cn_name: str = request.param
|
||||||
|
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
|
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
|
weights_fn = {
|
||||||
|
"depth": "lllyasviel_control_v11f1p_sd15_depth",
|
||||||
|
"canny": "lllyasviel_control_v11p_sd15_canny",
|
||||||
|
"lineart": "lllyasviel_control_v11p_sd15_lineart",
|
||||||
|
"normals": "lllyasviel_control_v11p_sd15_normalbae",
|
||||||
|
"sam": "mfidabel_controlnet-segment-anything",
|
||||||
|
}
|
||||||
|
|
||||||
|
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
|
||||||
|
yield (cn_name, condition_image, expected_image, weights_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
|
cn_name = "canny"
|
||||||
|
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
|
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||||
|
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11p_sd15_canny.safetensors"
|
||||||
|
return cn_name, condition_image, expected_image, weights_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
|
||||||
|
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||||
|
weights_path = test_weights_path / "loras" / "pcuenq_pokemon_lora.safetensors"
|
||||||
|
return expected_image, weights_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "inpainting-mask.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "inpainting-target.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_refonly.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def condition_image_refonly(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def text_encoder_weights(test_weights_path: Path) -> Path:
|
||||||
|
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||||
|
if not text_encoder_weights.is_file():
|
||||||
|
warn(f"could not find weights at {text_encoder_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return text_encoder_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def lda_weights(test_weights_path: Path) -> Path:
|
||||||
|
lda_weights = test_weights_path / "lda.safetensors"
|
||||||
|
if not lda_weights.is_file():
|
||||||
|
warn(f"could not find weights at {lda_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return lda_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def unet_weights_std(test_weights_path: Path) -> Path:
|
||||||
|
unet_weights_std = test_weights_path / "unet.safetensors"
|
||||||
|
if not unet_weights_std.is_file():
|
||||||
|
warn(f"could not find weights at {unet_weights_std}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return unet_weights_std
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def unet_weights_inpainting(test_weights_path: Path) -> Path:
|
||||||
|
unet_weights_inpainting = test_weights_path / "inpainting" / "unet.safetensors"
|
||||||
|
if not unet_weights_inpainting.is_file():
|
||||||
|
warn(f"could not find weights at {unet_weights_inpainting}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return unet_weights_inpainting
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_std(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
sd15 = StableDiffusion_1(device=test_device)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_std_float16(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_inpainting(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1_Inpainting:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_inpainting_float16(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1_Inpainting:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_ddim(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||||
|
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||||
|
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_std_random_init(
|
||||||
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_std_random_init_float16(
|
||||||
|
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_float16
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
|
assert negative_clip_text_embedding.dtype == torch.float16
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_std_init_image(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
cutecat_init: Image.Image,
|
||||||
|
expected_image_std_init_image: Image.Image,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 35
|
||||||
|
first_step = 5
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps[first_step:]:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_inpainting(
|
||||||
|
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||||
|
kitchen_dog: Image.Image,
|
||||||
|
kitchen_dog_mask: Image.Image,
|
||||||
|
expected_image_std_inpainting: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_inpainting
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_inpainting_float16(
|
||||||
|
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
|
||||||
|
kitchen_dog: Image.Image,
|
||||||
|
kitchen_dog_mask: Image.Image,
|
||||||
|
expected_image_std_inpainting: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_inpainting_float16
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
|
assert negative_clip_text_embedding.dtype == torch.float16
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
# PSNR and SSIM values are large because float16 is even worse than float32.
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_controlnet(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
|
||||||
|
|
||||||
|
if not cn_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {cn_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
|
controlnet = Controlnet(name=cn_name, device=test_device)
|
||||||
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
|
controlnet.set_scale(0.5)
|
||||||
|
sd15.unet.insert(0, controlnet)
|
||||||
|
|
||||||
|
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_controlnet_structural_copy(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15_base = sd15_std
|
||||||
|
sd15 = sd15_base.structural_copy()
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
|
||||||
|
|
||||||
|
if not cn_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {cn_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
|
controlnet = Controlnet(name=cn_name, device=test_device)
|
||||||
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
|
controlnet.set_scale(0.5)
|
||||||
|
sd15.unet.insert(0, controlnet)
|
||||||
|
|
||||||
|
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_controlnet_float16(
|
||||||
|
sd15_std_float16: StableDiffusion_1,
|
||||||
|
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_float16
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
|
||||||
|
|
||||||
|
if not cn_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {cn_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
|
controlnet = Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
||||||
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
|
controlnet.set_scale(0.5)
|
||||||
|
sd15.unet.insert(0, controlnet)
|
||||||
|
|
||||||
|
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_lora(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
expected_image, lora_weights_path = lora_data_pokemon
|
||||||
|
|
||||||
|
if not lora_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
lora_weights = LoraWeights(lora_weights_path, device=test_device)
|
||||||
|
lora_weights.patch(sd15, scale=1.0)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_refonly(
|
||||||
|
sd15_ddim: StableDiffusion_1,
|
||||||
|
condition_image_refonly: Image.Image,
|
||||||
|
expected_image_refonly: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_ddim
|
||||||
|
prompt = "Chicken"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
|
||||||
|
sai = SelfAttentionInjection(sd15.unet)
|
||||||
|
sai.inject()
|
||||||
|
|
||||||
|
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||||
|
guide = torch.cat((guide, guide))
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||||
|
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||||
|
sai.set_controlnet_condition(noised_guide)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_inpainting_refonly(
|
||||||
|
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||||
|
scene_image_inpainting_refonly: Image.Image,
|
||||||
|
target_image_inpainting_refonly: Image.Image,
|
||||||
|
mask_image_inpainting_refonly: Image.Image,
|
||||||
|
expected_image_inpainting_refonly: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_inpainting
|
||||||
|
n_steps = 30
|
||||||
|
prompt = "" # unconditional
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||||
|
|
||||||
|
sai = SelfAttentionInjection(sd15.unet)
|
||||||
|
sai.inject()
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||||
|
|
||||||
|
refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
|
||||||
|
refonly_guide = torch.cat((refonly_guide, refonly_guide))
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
refonly_noise = torch.randn_like(refonly_guide)
|
||||||
|
refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step)
|
||||||
|
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
||||||
|
# inpaint variation models")
|
||||||
|
refonly_noised_guide = torch.cat(
|
||||||
|
[refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
sai.set_controlnet_condition(refonly_noised_guide)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
82
tests/e2e/test_diffusion_ref/README.md
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Note about this data
|
||||||
|
|
||||||
|
## Expected outputs
|
||||||
|
|
||||||
|
`expected_*.png` files are the output of the same diffusion run with a different codebase, usually diffusers with the same settings as us (`DPMSolverMultistepScheduler`, VAE [patched to remove randomness](#vae-without-randomness), same seed...).
|
||||||
|
|
||||||
|
For instance here is how we generate `expected_std_random_init.png`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DPMSolverMultistepScheduler
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
).to("cuda)
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
torch.manual_seed(2)
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=30,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.images[0].save("std_random_init_expected.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
Special cases:
|
||||||
|
|
||||||
|
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
|
||||||
|
- `expected_inpainting_refonly.png` has been generated with refiners itself (and inspected so that it looks reasonable).
|
||||||
|
|
||||||
|
## Other images
|
||||||
|
|
||||||
|
- `cutecat_init.png` is generated with the same Diffusers script and prompt but with seed 1234.
|
||||||
|
|
||||||
|
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
|
||||||
|
|
||||||
|
- `kitchen_mask.png` is made manually.
|
||||||
|
|
||||||
|
- Controlnet guides have been manually generated using open source software and models, namely:
|
||||||
|
- Canny: opencv-python
|
||||||
|
- Depth: https://github.com/isl-org/ZoeDepth
|
||||||
|
- Lineart: https://github.com/lllyasviel/ControlNet-v1-1-nightly/tree/main/annotator/lineart
|
||||||
|
- Normals: https://github.com/baegwangbin/surface_normal_uncertainty/tree/fe2b9f1
|
||||||
|
- SAM: https://huggingface.co/spaces/mfidabel/controlnet-segment-anything
|
||||||
|
|
||||||
|
- `cyberpunk_guide.png` [comes from Lexica](https://lexica.art/prompt/5ba40855-0d0c-4322-8722-51115985f573).
|
||||||
|
|
||||||
|
- `inpainting-mask.png`, `inpainting-scene.png` and `inpainting-target.png` have been generated as follows:
|
||||||
|
- `inpainting-mask.png`: negated version of a mask computed with [SAM](https://github.com/facebookresearch/segment-anything) automatic mask generation using the `vit_h` checkpoint
|
||||||
|
- `inpainting-scene.png`: cropped-to-square-and-resized version of https://unsplash.com/photos/RCz6eSVPGYU by @jannerboy62
|
||||||
|
- `inpainting-target.png`: computed with `convert <(convert -size 512x512 xc:white png:-) kitchen_dog.png <(convert inpainting-mask.png -negate png:-) -compose Over -composite inpainting-target.png`
|
||||||
|
|
||||||
|
## VAE without randomness
|
||||||
|
|
||||||
|
```diff
|
||||||
|
--- a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
|
||||||
|
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
|
||||||
|
@@ -524,13 +524,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
- if isinstance(generator, list):
|
||||||
|
- init_latents = [
|
||||||
|
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||||
|
- ]
|
||||||
|
- init_latents = torch.cat(init_latents, dim=0)
|
||||||
|
- else:
|
||||||
|
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||||
|
+ init_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mean for i in range(batch_size)]
|
||||||
|
+ init_latents = torch.cat(init_latents, dim=0)
|
||||||
|
|
||||||
|
init_latents = self.vae.config.scaling_factor * init_latents
|
||||||
|
```
|
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_canny.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_depth.png
Normal file
After Width: | Height: | Size: 49 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_lineart.png
Normal file
After Width: | Height: | Size: 124 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_normals.png
Normal file
After Width: | Height: | Size: 142 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_guide_sam.png
Normal file
After Width: | Height: | Size: 7.2 KiB |
BIN
tests/e2e/test_diffusion_ref/cutecat_init.png
Normal file
After Width: | Height: | Size: 387 KiB |
BIN
tests/e2e/test_diffusion_ref/cyberpunk_guide.png
Normal file
After Width: | Height: | Size: 563 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_canny.png
Normal file
After Width: | Height: | Size: 416 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_depth.png
Normal file
After Width: | Height: | Size: 409 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_lineart.png
Normal file
After Width: | Height: | Size: 468 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_normals.png
Normal file
After Width: | Height: | Size: 447 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_sam.png
Normal file
After Width: | Height: | Size: 386 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_inpainting_refonly.png
Normal file
After Width: | Height: | Size: 461 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_lora_pokemon.png
Normal file
After Width: | Height: | Size: 336 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_refonly.png
Normal file
After Width: | Height: | Size: 700 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_std_init_image.png
Normal file
After Width: | Height: | Size: 393 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_std_inpainting.png
Normal file
After Width: | Height: | Size: 316 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_std_random_init.png
Normal file
After Width: | Height: | Size: 491 KiB |
BIN
tests/e2e/test_diffusion_ref/inpainting-mask.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
tests/e2e/test_diffusion_ref/inpainting-scene.png
Normal file
After Width: | Height: | Size: 476 KiB |
BIN
tests/e2e/test_diffusion_ref/inpainting-target.png
Normal file
After Width: | Height: | Size: 225 KiB |
BIN
tests/e2e/test_diffusion_ref/kitchen_dog.png
Normal file
After Width: | Height: | Size: 379 KiB |
BIN
tests/e2e/test_diffusion_ref/kitchen_dog_mask.png
Normal file
After Width: | Height: | Size: 10 KiB |