A microframework on top of PyTorch with first-class citizen APIs for foundation model adaptation https://refine.rs/
Find a file
2023-10-10 14:19:47 +02:00
.github/workflows use extras instead of groups 2023-09-13 17:02:47 +02:00
assets initial commit 2023-08-04 15:28:41 +02:00
configs Add concepts learning via textual inversion 2023-08-31 16:07:53 +02:00
docs add documentation about Adapters 2023-08-04 18:49:11 +02:00
scripts add test weights conversion script 2023-10-09 14:18:40 +02:00
src/refiners add support for self-attention guidance 2023-10-09 17:33:15 +02:00
tests add support for self-attention guidance 2023-10-09 17:33:15 +02:00
.gitignore add .env to .gitignore 2023-08-25 16:37:50 +02:00
LICENSE initial commit 2023-08-04 15:28:41 +02:00
poetry.lock poetry add torch@^2.1.0 2023-10-10 14:19:47 +02:00
pyproject.toml poetry add torch@^2.1.0 2023-10-10 14:19:47 +02:00
README.md add test weights conversion script 2023-10-09 14:18:40 +02:00

Finegrain Refiners Library

The simplest way to train and run adapters on top of foundational models


PyPI - Python Version PyPI Status license

Design Pillars

We are huge fans of PyTorch (we actually were core committers to Torch in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.

Instead, we needed:

  • A model structure that's human readable so that you know what models do and how they work right here, right now
  • A mechanism to easily inject parameters in some target layers, or between them
  • A way to easily pass data (like a conditioning input) between layers even when deeply nested
  • Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on Civitai and the likes)

Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.

Key Concepts

The Chain class

The Chain class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a declarative way.

E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:

import torch
import refiners.fluxion.layers as fl

class ViT(fl.Chain):
    # The Vision Transformer model structure is entirely defined in the constructor. It is
    # ready-to-use right after i.e. no need to implement any forward function or add extra logic
    def __init__(
        self,
        embedding_dim: int = 512,
        patch_size: int = 16,
        image_size: int = 384,
        num_layers: int = 12,
        num_heads: int = 8,
    ):
        num_patches = (image_size // patch_size)
        super().__init__(
            fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
            fl.Reshape(num_patches**2, embedding_dim),
            # The Residual layer implements the so-called skip-connection, i.e. x + F(x).
            # Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
            # weights are stored in the Parameter layer (note: there is no extra classification
            # token in this toy example)
            fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
            # These are the transformer encoders:
            *(
                fl.Chain(
                    fl.LayerNorm(embedding_dim),
                    fl.Residual(
                        # The Parallel layer is used to pass multiple inputs to a downstream
                        # layer, here multiheaded self-attention
                        fl.Parallel(
                            fl.Identity(),
                            fl.Identity(),
                            fl.Identity()
                        ),
                        fl.Attention(
                            embedding_dim=embedding_dim,
                            num_heads=num_heads,
                            key_embedding_dim=embedding_dim,
                            value_embedding_dim=embedding_dim,
                        ),
                    ),
                    fl.LayerNorm(embedding_dim),
                    fl.Residual(
                        fl.Linear(embedding_dim, embedding_dim * 4),
                        fl.GeLU(),
                        fl.Linear(embedding_dim * 4, embedding_dim),
                    ),
                    fl.Chain(
                        fl.Linear(embedding_dim, embedding_dim * 4),
                        fl.GeLU(),
                        fl.Linear(embedding_dim * 4, embedding_dim),
                    ),
                )
                for _ in range(num_layers)
            ),
            fl.Reshape(embedding_dim, num_patches, num_patches),
        )

vit = ViT(embedding_dim=768, image_size=224, num_heads=12)  # ~ViT-B/16 like
x = torch.randn(2, 3, 224, 224)
y = vit(x)

The Context API

The Chain class has a context provider that allows you to pass data to layers even when deeply nested.

E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:

@@ -21,8 +21,8 @@
                     fl.Residual(
                         fl.Parallel(
                             fl.Identity(),
-                            fl.Identity(),
-                            fl.Identity()
+                            fl.UseContext(context="cross_attention", key="my_embed"),
+                            fl.UseContext(context="cross_attention", key="my_embed"),
                         ),  # used to pass multiple inputs to a layer
                         fl.Attention(
                             embedding_dim=embedding_dim,
@@ -49,5 +49,6 @@
         )

 vit = ViT(embedding_dim=768, image_size=224, num_heads=12)  # ~ViT-B/16 like
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
 x = torch.randn(2, 3, 224, 224)
 y = vit(x)

The Adapter API

The Adapter API lets you easily patch models by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.

E.g. to inject LoRA layers in all attention's linear layers:

from refiners.fluxion.adapters.lora import SingleLoraAdapter

for layer in vit.layers(fl.Attention):
    for linear, parent in layer.walk(fl.Linear):
        SingleLoraAdapter(target=linear, rank=64).inject(parent)

# ... and load existing weights if the LoRAs are pretrained ...

Adapter Zoo

For now, given finegrain's mission, we are focusing on image edition tasks. We support:

Adapter Foundation Model
LoRA SD15 SDXL
ControlNets SD15
Ref Only Control SD15
IP-Adapter SD15 SDXL
T2I-Adapter SD15 SDXL

Getting Started

Install

Refiners is still an early stage project so we recommend using the main branch directly with Poetry.

If you just want to use Refiners directly, clone the repository and run:

poetry install --all-extras

There is currently a bug with PyTorch 2.0.1 and Poetry, to work around it run:

poetry run pip install --upgrade torch torchvision

If you want to depend on Refiners in your project which uses Poetry, you can do so:

poetry add git+ssh://git@github.com:finegrain-ai/refiners.git#main

If you want to run tests, we provide a script to download and convert all the necessary weights first. Be aware that this will use around 50 GB of disk space.

poetry shell
./scripts/prepare-test-weights.sh
pytest

Hello World

Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 foundational model patched with a Pokemon LoRA:

Step 1: prepare the model weights in refiners' format:

python scripts/conversion/convert_transformers_clip_text_model.py --to clip.safetensors
python scripts/conversion/convert_diffusers_autoencoder_kl.py --to lda.safetensors
python scripts/conversion/convert_diffusers_unet.py --to unet.safetensors

Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the --from /path/to/stable-diffusion-v1-5 option instead.

Step 2: download and convert a community Pokemon LoRA, e.g. this one

curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
python scripts/conversion/convert_diffusers_lora.py \
  --from pytorch_lora_weights.bin \
  --to pokemon_lora.safetensors

Step 3: run inference using the GPU:

from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.fluxion.utils import load_from_safetensors, manual_seed
import torch


sd15 = StableDiffusion_1(device="cuda")
sd15.clip_text_encoder.load_from_safetensors("clip.safetensors")
sd15.lda.load_from_safetensors("lda.safetensors")
sd15.unet.load_from_safetensors("unet.safetensors")

SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()

prompt = "a cute cat"

with torch.no_grad():
    clip_text_embedding = sd15.compute_clip_text_embedding(prompt)

sd15.set_num_inference_steps(30)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=sd15.device)

with torch.no_grad():
    for step in sd15.steps:
        x = sd15(
            x,
            step=step,
            clip_text_embedding=clip_text_embedding,
            condition_scale=7.5,
        )
    predicted_image = sd15.lda.decode_latents(x)
    predicted_image.save("pokemon_cat.png")

You should get:

pokemon cat output

Training

Refiners has a built-in training utils library and provides scripts that can be used as a starting point.

E.g. to train a LoRA on top of Stable Diffusion, copy and edit configs/finetune-lora.toml to suit your needs and launch the training as follows:

python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml

Motivation

At Finegrain, we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.

That's why we're building Refiners.

It's a framework to easily bridge the last mile quality gap of foundational models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.

We decided to build Refiners in the open.

It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.

Awesome Adaptation Papers

If you're interested in understanding the diversity of use cases for foundation model adaptation (potentially beyond the specific adapters supported by Refiners), we suggest you take a look at these outstanding papers:

SAM

SD

BLIP

Credits

We took inspiration from these great projects:

Citation

@misc{the-finegrain-team-2023-refiners,
  author = {Benjamin Trom and Pierre Chapuis and Cédric Deltheil},
  title = {Refiners: The simplest way to train and run adapters on top of foundational models},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/finegrain-ai/refiners}}
}