mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
brush up main README
This commit is contained in:
parent
4e80d295bd
commit
17b085341a
298
README.md
298
README.md
|
@ -6,7 +6,9 @@
|
||||||
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
|
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
|
||||||
</picture>
|
</picture>
|
||||||
|
|
||||||
**The simplest way to train and run adapters on top of foundation models** ([dive in!](https://blog.finegrain.ai/posts/simplifying-ai-code/))
|
**The simplest way to train and run adapters on top of foundation models**
|
||||||
|
|
||||||
|
[**Manifesto**](https://refine.rs/home/why/) | [**Documentation**](https://refine.rs) | [**Guides**](https://refine.rs/guides/adapting_sdxl/) | [**Discord**](https://discord.gg/mCmjNUVV7d)
|
||||||
|
|
||||||
______________________________________________________________________
|
______________________________________________________________________
|
||||||
|
|
||||||
|
@ -31,305 +33,31 @@ ______________________________________________________________________
|
||||||
- Added [SDXL 1.0](https://github.com/Stability-AI/generative-models) to foundation models
|
- Added [SDXL 1.0](https://github.com/Stability-AI/generative-models) to foundation models
|
||||||
- Made possible to add new concepts to the CLIP text encoder, e.g. via [Textual Inversion](https://arxiv.org/abs/2208.01618)
|
- Made possible to add new concepts to the CLIP text encoder, e.g. via [Textual Inversion](https://arxiv.org/abs/2208.01618)
|
||||||
|
|
||||||
## Getting Started
|
## Installation
|
||||||
|
|
||||||
### Install
|
The current recommended way to install Refiners is from source using [Rye](https://rye-up.com/):
|
||||||
|
|
||||||
Refiners is still an early stage project, and we do not release minor versions yet. We recommend
|
|
||||||
installing the latest version via a git install:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/finegrain-ai/refiners.git
|
git clone "git@github.com:finegrain-ai/refiners.git"
|
||||||
|
cd refiners
|
||||||
|
rye sync --all-features
|
||||||
```
|
```
|
||||||
|
|
||||||
To include the training utils, use:
|
## Documentation
|
||||||
|
|
||||||
```bash
|
Refiners comes with a MkDocs-based documentation website available at https://refine.rs. You will find there a [quick start guide](https://refine.rs/getting-started/recommended/), a description of the [key concepts](https://refine.rs/concepts/chain/), as well as in-depth foundation model adaptation [guides](https://refine.rs/guides/adapting_sdxl/).
|
||||||
pip install 'refiners[training] @ git+https://github.com/finegrain-ai/refiners.git'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Hello World
|
|
||||||
|
|
||||||
Goal: turn Refiners' mascot into a [Dragon Quest Slime](https://en.wikipedia.org/wiki/Slime_(Dragon_Quest)) plush in a one-shot manner thanks to a powerful combo of adapters:
|
|
||||||
- IP-Adapter: to capture the Slime plush visual appearance into an image prompt (no prompt engineering needed)
|
|
||||||
- T2I-Adapter: to guide the generation with the mascot's geometry
|
|
||||||
- Self-Attention-Guidance (SAG): to increase the sharpness
|
|
||||||
|
|
||||||
![hello world overview](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/hello_world_overview.png)
|
|
||||||
|
|
||||||
**Step 1**: convert SDXL weights to the Refiners' format:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/conversion/convert_transformers_clip_text_model.py --from "stabilityai/stable-diffusion-xl-base-1.0" --subfolder2 text_encoder_2 --to clip_text_xl.safetensors --half
|
|
||||||
python scripts/conversion/convert_diffusers_unet.py --from "stabilityai/stable-diffusion-xl-base-1.0" --to unet_xl.safetensors --half
|
|
||||||
python scripts/conversion/convert_diffusers_autoencoder_kl.py --from "madebyollin/sdxl-vae-fp16-fix" --subfolder "" --to lda_xl.safetensors --half
|
|
||||||
```
|
|
||||||
|
|
||||||
> Note: this will download the original weights from https://huggingface.co/ which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stabilityai/stable-diffusion-xl-base-1.0` option instead.
|
|
||||||
|
|
||||||
And then convert IP-Adapter and T2I-Adapter weights (note: SAG is parameter-free):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/conversion/convert_diffusers_t2i_adapter.py --from "TencentARC/t2i-adapter-canny-sdxl-1.0" --to t2i_canny_xl.safetensors --half
|
|
||||||
python scripts/conversion/convert_transformers_clip_image_model.py --from "stabilityai/stable-diffusion-2-1-unclip" --to clip_image.safetensors --half
|
|
||||||
curl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin
|
|
||||||
python scripts/conversion/convert_diffusers_ip_adapter.py --from ip-adapter_sdxl_vit-h.bin --half
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2**: download input images:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -O https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_logo.png
|
|
||||||
curl -O https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_canny.png
|
|
||||||
curl -O https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dragon_quest_slime.jpg
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3**: generate an image using the GPU:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
|
|
||||||
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
|
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors
|
|
||||||
|
|
||||||
# Load inputs
|
|
||||||
init_image = Image.open("dropy_logo.png")
|
|
||||||
image_prompt = Image.open("dragon_quest_slime.jpg")
|
|
||||||
condition_image = Image.open("dropy_canny.png")
|
|
||||||
|
|
||||||
# Load SDXL
|
|
||||||
sdxl = StableDiffusion_XL(device="cuda", dtype=torch.float16)
|
|
||||||
sdxl.clip_text_encoder.load_from_safetensors("clip_text_xl.safetensors")
|
|
||||||
sdxl.lda.load_from_safetensors("lda_xl.safetensors")
|
|
||||||
sdxl.unet.load_from_safetensors("unet_xl.safetensors")
|
|
||||||
|
|
||||||
# Load and inject adapters
|
|
||||||
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors("ip-adapter_sdxl_vit-h.safetensors"))
|
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors("clip_image.safetensors")
|
|
||||||
ip_adapter.inject()
|
|
||||||
|
|
||||||
t2i_adapter = SDXLT2IAdapter(
|
|
||||||
target=sdxl.unet, name="canny", weights=load_from_safetensors("t2i_canny_xl.safetensors")
|
|
||||||
).inject()
|
|
||||||
|
|
||||||
# Tune parameters
|
|
||||||
seed = 9752
|
|
||||||
ip_adapter.set_scale(0.85)
|
|
||||||
t2i_adapter.set_scale(0.8)
|
|
||||||
sdxl.set_inference_steps(50, first_step=1)
|
|
||||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
|
||||||
|
|
||||||
with no_grad():
|
|
||||||
# Note: default text prompts for IP-Adapter
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
|
||||||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
|
||||||
)
|
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
|
||||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
|
||||||
time_ids = sdxl.default_time_ids
|
|
||||||
|
|
||||||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
|
||||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
|
||||||
|
|
||||||
manual_seed(seed=seed)
|
|
||||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(
|
|
||||||
device=sdxl.device, dtype=sdxl.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
for step in sdxl.steps:
|
|
||||||
x = sdxl(
|
|
||||||
x,
|
|
||||||
step=step,
|
|
||||||
clip_text_embedding=clip_text_embedding,
|
|
||||||
pooled_text_embedding=pooled_text_embedding,
|
|
||||||
time_ids=time_ids,
|
|
||||||
)
|
|
||||||
predicted_image = sdxl.lda.decode_latents(x=x)
|
|
||||||
|
|
||||||
predicted_image.save("output.png")
|
|
||||||
print("done: see output.png")
|
|
||||||
```
|
|
||||||
|
|
||||||
You should get:
|
|
||||||
|
|
||||||
![dropy slime output](https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy_slime_9752.png)
|
|
||||||
|
|
||||||
## Adapter Zoo
|
|
||||||
|
|
||||||
For now, given [finegrain](https://finegrain.ai)'s mission, we are focusing on image edition tasks. We support:
|
|
||||||
|
|
||||||
| Adapter | Foundation Model |
|
|
||||||
| ----------------- | ------- |
|
|
||||||
| LoRA | `SD15` `SDXL` |
|
|
||||||
| ControlNets | `SD15` |
|
|
||||||
| Ref Only Control | `SD15` |
|
|
||||||
| IP-Adapter | `SD15` `SDXL` |
|
|
||||||
| T2I-Adapter | `SD15` `SDXL` |
|
|
||||||
|
|
||||||
## Motivation
|
|
||||||
|
|
||||||
At [Finegrain](https://finegrain.ai), we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.
|
|
||||||
|
|
||||||
That's why we're building Refiners.
|
|
||||||
|
|
||||||
It's a framework to easily bridge the last mile quality gap of foundation models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.
|
|
||||||
|
|
||||||
We decided to build Refiners in the open.
|
|
||||||
|
|
||||||
It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.
|
|
||||||
|
|
||||||
## Design Pillars
|
|
||||||
|
|
||||||
We are huge fans of PyTorch (we actually were core committers to [Torch](http://torch.ch/) in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.
|
|
||||||
|
|
||||||
Instead, we needed:
|
|
||||||
|
|
||||||
- A model structure that's human readable so that you know what models do and how they work right here, right now
|
|
||||||
- A mechanism to easily inject parameters in some target layers, or between them
|
|
||||||
- A way to easily pass data (like a conditioning input) between layers even when deeply nested
|
|
||||||
- Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on [Civitai](http://civitai.com/) and the likes)
|
|
||||||
|
|
||||||
Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.
|
|
||||||
|
|
||||||
## Key Concepts
|
|
||||||
|
|
||||||
### The Chain class
|
|
||||||
|
|
||||||
The `Chain` class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a **declarative way**.
|
|
||||||
|
|
||||||
E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
|
|
||||||
class ViT(fl.Chain):
|
|
||||||
# The Vision Transformer model structure is entirely defined in the constructor. It is
|
|
||||||
# ready-to-use right after i.e. no need to implement any forward function or add extra logic
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int = 512,
|
|
||||||
patch_size: int = 16,
|
|
||||||
image_size: int = 384,
|
|
||||||
num_layers: int = 12,
|
|
||||||
num_heads: int = 8,
|
|
||||||
):
|
|
||||||
num_patches = (image_size // patch_size)
|
|
||||||
super().__init__(
|
|
||||||
fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
|
|
||||||
fl.Reshape(num_patches**2, embedding_dim),
|
|
||||||
# The Residual layer implements the so-called skip-connection, i.e. x + F(x).
|
|
||||||
# Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
|
|
||||||
# weights are stored in the Parameter layer (note: there is no extra classification
|
|
||||||
# token in this toy example)
|
|
||||||
fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
|
|
||||||
# These are the transformer encoders:
|
|
||||||
*(
|
|
||||||
fl.Chain(
|
|
||||||
fl.LayerNorm(embedding_dim),
|
|
||||||
fl.Residual(
|
|
||||||
# The Parallel layer is used to pass multiple inputs to a downstream
|
|
||||||
# layer, here multiheaded self-attention
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Identity(),
|
|
||||||
fl.Identity(),
|
|
||||||
fl.Identity()
|
|
||||||
),
|
|
||||||
fl.Attention(
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
key_embedding_dim=embedding_dim,
|
|
||||||
value_embedding_dim=embedding_dim,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
fl.LayerNorm(embedding_dim),
|
|
||||||
fl.Residual(
|
|
||||||
fl.Linear(embedding_dim, embedding_dim * 4),
|
|
||||||
fl.GeLU(),
|
|
||||||
fl.Linear(embedding_dim * 4, embedding_dim),
|
|
||||||
),
|
|
||||||
fl.Chain(
|
|
||||||
fl.Linear(embedding_dim, embedding_dim * 4),
|
|
||||||
fl.GeLU(),
|
|
||||||
fl.Linear(embedding_dim * 4, embedding_dim),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
),
|
|
||||||
fl.Reshape(embedding_dim, num_patches, num_patches),
|
|
||||||
)
|
|
||||||
|
|
||||||
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
|
|
||||||
x = torch.randn(2, 3, 224, 224)
|
|
||||||
y = vit(x)
|
|
||||||
```
|
|
||||||
|
|
||||||
### The Context API
|
|
||||||
|
|
||||||
The `Chain` class has a context provider that allows you to **pass data to layers even when deeply nested**.
|
|
||||||
|
|
||||||
E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:
|
|
||||||
|
|
||||||
|
|
||||||
```diff
|
|
||||||
@@ -21,8 +21,8 @@
|
|
||||||
fl.Residual(
|
|
||||||
fl.Parallel(
|
|
||||||
fl.Identity(),
|
|
||||||
- fl.Identity(),
|
|
||||||
- fl.Identity()
|
|
||||||
+ fl.UseContext(context="cross_attention", key="my_embed"),
|
|
||||||
+ fl.UseContext(context="cross_attention", key="my_embed"),
|
|
||||||
), # used to pass multiple inputs to a layer
|
|
||||||
fl.Attention(
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
@@ -49,5 +49,6 @@
|
|
||||||
)
|
|
||||||
|
|
||||||
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
|
|
||||||
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
|
|
||||||
x = torch.randn(2, 3, 224, 224)
|
|
||||||
y = vit(x)
|
|
||||||
```
|
|
||||||
|
|
||||||
### The Adapter API
|
|
||||||
|
|
||||||
The `Adapter` API lets you **easily patch models** by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.
|
|
||||||
|
|
||||||
E.g. to inject LoRA layers in all attention's linear layers:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from refiners.fluxion.adapters.lora import SingleLoraAdapter
|
|
||||||
|
|
||||||
for layer in vit.layers(fl.Attention):
|
|
||||||
for linear, parent in layer.walk(fl.Linear):
|
|
||||||
SingleLoraAdapter(target=linear, rank=64).inject(parent)
|
|
||||||
|
|
||||||
# ... and load existing weights if the LoRAs are pretrained ...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Awesome Adaptation Papers
|
## Awesome Adaptation Papers
|
||||||
|
|
||||||
If you're interested in understanding the diversity of use cases for foundation model adaptation (potentially beyond the specific adapters supported by Refiners), we suggest you take a look at these outstanding papers:
|
If you're interested in understanding the diversity of use cases for foundation model adaptation (potentially beyond the specific adapters supported by Refiners), we suggest you take a look at these outstanding papers:
|
||||||
|
|
||||||
### SAM
|
- [ControlNet](https://arxiv.org/abs/2302.05543)
|
||||||
|
- [T2I-Adapter](https://arxiv.org/abs/2302.08453)
|
||||||
|
- [IP-Adapter](https://arxiv.org/abs/2308.06721)
|
||||||
- [Medical SAM Adapter](https://arxiv.org/abs/2304.12620)
|
- [Medical SAM Adapter](https://arxiv.org/abs/2304.12620)
|
||||||
- [3DSAM-adapter](https://arxiv.org/abs/2306.13465)
|
- [3DSAM-adapter](https://arxiv.org/abs/2306.13465)
|
||||||
- [SAM-adapter](https://arxiv.org/abs/2304.09148)
|
- [SAM-adapter](https://arxiv.org/abs/2304.09148)
|
||||||
- [Cross Modality Attention Adapter](https://arxiv.org/abs/2307.01124)
|
- [Cross Modality Attention Adapter](https://arxiv.org/abs/2307.01124)
|
||||||
|
|
||||||
### SD
|
|
||||||
|
|
||||||
- [ControlNet](https://arxiv.org/abs/2302.05543)
|
|
||||||
- [T2I-Adapter](https://arxiv.org/abs/2302.08453)
|
|
||||||
- [IP-Adapter](https://arxiv.org/abs/2308.06721)
|
|
||||||
|
|
||||||
### BLIP
|
|
||||||
|
|
||||||
- [UniAdapter](https://arxiv.org/abs/2302.06605)
|
- [UniAdapter](https://arxiv.org/abs/2302.06605)
|
||||||
|
|
||||||
## Projects using Refiners
|
## Projects using Refiners
|
||||||
|
|
Loading…
Reference in a new issue