mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
Add T2I-Adapter subsection to SDXL Adaptation guide
This commit is contained in:
parent
fbb1fcb8ff
commit
75830e2179
|
@ -155,7 +155,8 @@ predicted_image.save("vanilla_sdxl.png")
|
|||
It's time to execute your code. The resulting image should look like this:
|
||||
|
||||
<figure markdown>
|
||||
<img src="vanilla_sdxl.webp" alt="Image title" width="400">
|
||||
<img src="vanilla_sdxl.webp" alt="Generated image of a castle using default SDXL weights" width="400">
|
||||
<figcaption>Generated image of a castle using default SDXL weights.</figcaption>
|
||||
</figure>
|
||||
|
||||
It is not really what we prompted the model for, unfortunately. To get a more futuristic-looking castle, you can either go for tedious prompt engineering, or use a pretrainered LoRA tailored to our use case, like the [Sci-fi Environments](https://civitai.com/models/105945?modelVersionId=140624) LoRA available on Civitai. We'll now show you how the LoRA option works with Refiners.
|
||||
|
@ -235,7 +236,8 @@ manager.add_loras("scifi-lora", tensors=scifi_lora_weights)
|
|||
You should get something like this - pretty neat, isn't it?
|
||||
|
||||
<figure markdown>
|
||||
<img src="scifi_sdxl.webp" alt="Image title" width="400">
|
||||
<img src="scifi_sdxl.webp" alt="Sci-fi castle" width="400">
|
||||
<figcaption>Generated image of a castle in sci-fi style.</figcaption>
|
||||
</figure>
|
||||
|
||||
## Multiple LoRAs
|
||||
|
@ -256,7 +258,7 @@ manager.add_multiple_loras(
|
|||
)
|
||||
```
|
||||
|
||||
Adapters such as LoRAs also have a [scale](https://github.com/finegrain-ai/refiners/blob/fd01ba910efb764b4521254cded2530b6c31cbd4/src/refiners/fluxion/adapters/lora.py#L17) (roughly) quantifying the effect of this Adapter.
|
||||
Adapters such as LoRAs also have a [scale][refiners.fluxion.adapters.Lora.scale] (roughly) quantifying the effect of this Adapter.
|
||||
Refiners allows setting different scales for each Adapter, allowing the user to balance the effect of each Adapter:
|
||||
|
||||
```py
|
||||
|
@ -334,7 +336,8 @@ manager.add_multiple_loras(
|
|||
The results are looking great:
|
||||
|
||||
<figure markdown>
|
||||
<img src="scifi_pixel_sdxl.webp" alt="Image title" width="400">
|
||||
<img src="scifi_pixel_sdxl.webp" alt="Sci-fi Pixel Art castle" width="400">
|
||||
<figcaption>Generated image of a castle in sci-fi, pixel art style.</figcaption>
|
||||
</figure>
|
||||
|
||||
## Multiple LoRAs + IP-Adapter
|
||||
|
@ -346,7 +349,7 @@ For instance, IP-Adapter (covered in [a previous blog post](https://blog.finegra
|
|||
In our example, consider this image of the [Neuschwanstein Castle](https://en.wikipedia.org/wiki/Neuschwanstein_Castle):
|
||||
|
||||
<figure markdown>
|
||||
<img src="german-castle.jpg" alt="Image title" width="400">
|
||||
<img src="german-castle.jpg" alt="Castle Image" width="400">
|
||||
<figcaption>Credits: Bayerische Schlösserverwaltung, Anton Brandl</figcaption>
|
||||
</figure>
|
||||
|
||||
|
@ -470,10 +473,152 @@ with torch.no_grad():
|
|||
The result looks convincing: we do get a *pixel-art, futuristic-looking Neuschwanstein castle*!
|
||||
|
||||
<figure markdown>
|
||||
<img src="scifi_pixel_IP_sdxl.webp" alt="Image title" width="400">
|
||||
<img src="scifi_pixel_IP_sdxl.webp" alt="Generated image in sci-fi, pixel art style, using IP-Adapter." width="400">
|
||||
<figcaption>Generated image in sci-fi, pixel art style, using IP-Adapter.</figcaption>
|
||||
</figure>
|
||||
|
||||
|
||||
## Everything else + T2I-Adapter
|
||||
|
||||
T2I-Adapters[^1] are a powerful class of Adapters aiming at controlling the Text-to-Image (T2I) diffusion process with external control signals, such as canny edges or pose estimations inputs.
|
||||
In this section, we will compose our previous example with the [Depth-Zoe Adapter](https://huggingface.co/TencentARC/t2i-adapter-depth-zoe-sdxl-1.0), providing a depth condition to the diffusion process using the following depth map as input signal:
|
||||
|
||||
<figure markdown>
|
||||
<img src="zoe-depth-map-german-castle.png" alt="Input depth map of the initial castle image" width="400">
|
||||
<figcaption>Input depth map of the initial castle image.</figcaption>
|
||||
</figure>
|
||||
|
||||
First, download the image as well as the weights of T2I-Depth-Zoe-Adapter by calling the following commands:
|
||||
|
||||
```bash
|
||||
curl -O https://refine.rs/guides/adapting_sdxl/zoe-depth-map-german-castle.png
|
||||
python scripts/conversion/convert_diffusers_t2i_adapter.py --from "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0" --to t2i_depth_zoe_xl.safetensors --half
|
||||
```
|
||||
|
||||
Then, just inject it as usual:
|
||||
|
||||
```py
|
||||
# Load T2I-Adapter
|
||||
t2i_adapter = SDXLT2IAdapter(
|
||||
target=sdxl.unet,
|
||||
name="zoe-depth",
|
||||
weights=load_from_safetensors("t2i_depth_zoe_xl.safetensors"),
|
||||
scale=0.72,
|
||||
).inject()
|
||||
```
|
||||
|
||||
Finally, at runtime, compute the embedding of the input condition through the `t2i_adapter` object, and set its embedding calling `.set_condition_features()`:
|
||||
|
||||
```py
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||
|
||||
image_depth_condition = Image.open("zoe-depth-map-german-castle.png")
|
||||
|
||||
with torch.no_grad():
|
||||
condition = image_to_tensor(image_depth_condition.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# Spatial dimensions should be divisible by default downscale factor (=16 for T2IAdapter ConditionEncoder)
|
||||
condition = interpolate(condition, torch.Size((1024, 1024)))
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
```
|
||||
|
||||
??? example "Expand to see the entire end-to-end code"
|
||||
|
||||
```py
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad, image_to_tensor, interpolate
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL, SDXLT2IAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||
|
||||
# Load SDXL
|
||||
sdxl = StableDiffusion_XL(device="cuda", dtype=torch.float16)
|
||||
sdxl.clip_text_encoder.load_from_safetensors("DoubleCLIPTextEncoder.safetensors")
|
||||
sdxl.unet.load_from_safetensors("sdxl-unet.safetensors")
|
||||
sdxl.lda.load_from_safetensors("sdxl-lda.safetensors")
|
||||
|
||||
# Load LoRAs weights from disk and inject them into target
|
||||
manager = SDLoraManager(sdxl)
|
||||
scifi_lora_weights = load_from_safetensors("Sci-fi_Environments_sdxl.safetensors")
|
||||
pixel_art_lora_weights = load_from_safetensors("pixel-art-xl-v1.1.safetensors")
|
||||
manager.add_multiple_loras(
|
||||
tensors={"scifi-lora": scifi_lora_weights, "pixel-art-lora": pixel_art_lora_weights},
|
||||
scale={"scifi-lora": 1.5, "pixel-art-lora": 1.55},
|
||||
)
|
||||
|
||||
# Load IP-Adapter
|
||||
ip_adapter = SDXLIPAdapter(
|
||||
target=sdxl.unet,
|
||||
weights=load_from_safetensors("ip-adapter-plus_sdxl_vit-h.safetensors"),
|
||||
scale=1.0,
|
||||
fine_grained=True, # Use fine-grained IP-Adapter (IP-Adapter Plus)
|
||||
)
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors("CLIPImageEncoderH.safetensors")
|
||||
ip_adapter.inject()
|
||||
|
||||
# Load T2I-Adapter
|
||||
t2i_adapter = SDXLT2IAdapter(
|
||||
target=sdxl.unet,
|
||||
name="zoe-depth",
|
||||
weights=load_from_safetensors("t2i_depth_zoe_xl.safetensors"),
|
||||
scale=0.72,
|
||||
).inject()
|
||||
|
||||
# Hyperparameters
|
||||
prompt = "a futuristic castle surrounded by a forest, mountains in the background"
|
||||
image_prompt = Image.open("german-castle.jpg")
|
||||
image_depth_condition = Image.open("zoe-depth-map-german-castle.png")
|
||||
seed = 42
|
||||
sdxl.set_inference_steps(50, first_step=0)
|
||||
sdxl.set_self_attention_guidance(
|
||||
enable=True, scale=0.75
|
||||
) # Enable self-attention guidance to enhance the quality of the generated images
|
||||
|
||||
with no_grad():
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt + ", best quality, high quality",
|
||||
negative_text="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
condition = image_to_tensor(image_depth_condition.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
# Spatial dimensions should be divisible by default downscale factor (=16 for T2IAdapter ConditionEncoder)
|
||||
condition = interpolate(condition, torch.Size((1024, 1024)))
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
manual_seed(seed=seed)
|
||||
x = torch.randn(size=(1, 4, 128, 128), device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
||||
# Diffusion process
|
||||
for step in sdxl.steps:
|
||||
if step % 10 == 0:
|
||||
print(f"Step {step}")
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
|
||||
predicted_image.save("scifi_pixel_IP_T2I_sdxl.png")
|
||||
|
||||
```
|
||||
|
||||
The results look convincing: the depth and proportions of the initial castle are more faithful, while preserving our *futuristic, pixel-art style*!
|
||||
<figure markdown>
|
||||
<img src="scifi_pixel_IP_T2I_sdxl.webp" alt="Generated image in sci-fi, pixel art style, using IP and T2I Adapters" width="400">
|
||||
<figcaption>Generated image in sci-fi, pixel art style, using IP and T2I Adapters.</figcaption>
|
||||
</figure>
|
||||
|
||||
## Wrap up
|
||||
|
||||
As you can see in this guide, composing Adapters on top of foundation models is pretty seamless in Refiners, allowing practitioners to quickly test out different combinations of Adapters for their needs. We encourage you to try out different ones, and even train some yourselves!
|
||||
|
||||
[^1]: Mou, C., Wang, X., Xie, L., Zhang, J., Qi, Z., Shan, Y., & Qie, X. (2023). T2i-adapter: Learning adapters to dig out more controllable ability for text-to-image diffusion models.
|
||||
|
|
BIN
docs/guides/adapting_sdxl/scifi_pixel_IP_T2I_sdxl.webp
Normal file
BIN
docs/guides/adapting_sdxl/scifi_pixel_IP_T2I_sdxl.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 132 KiB |
BIN
docs/guides/adapting_sdxl/zoe-depth-map-german-castle.png
Normal file
BIN
docs/guides/adapting_sdxl/zoe-depth-map-german-castle.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
Loading…
Reference in a new issue