mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
Add support for learned concepts e.g. via textual inversion
This commit is contained in:
parent
8b1719b1f9
commit
3680f9d196
121
src/refiners/foundationals/clip/concepts.py
Normal file
121
src/refiners/foundationals/clip/concepts.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
from refiners.adapters.adapter import Adapter
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from torch import Tensor, cat, zeros
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class ConceptExtender:
|
||||||
|
"""
|
||||||
|
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
import torch
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
encoder = CLIPTextEncoderL(device="cuda")
|
||||||
|
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
|
||||||
|
encoder.load_state_dict(tensors)
|
||||||
|
|
||||||
|
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
|
||||||
|
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
|
||||||
|
|
||||||
|
extender = ConceptExtender(encoder)
|
||||||
|
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
|
||||||
|
extender.inject()
|
||||||
|
# New concepts can be added at any time
|
||||||
|
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
|
||||||
|
|
||||||
|
# Now the encoder can be used with the new concepts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target: CLIPTextEncoder) -> None:
|
||||||
|
try:
|
||||||
|
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("TokenEncoder not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Tokenizer not found.")
|
||||||
|
|
||||||
|
self.embedding_extender = EmbeddingExtender(token_encoder)
|
||||||
|
self.token_extender = TokenExtender(clip_tokenizer)
|
||||||
|
|
||||||
|
def add_concept(self, token: str, embedding: Tensor) -> None:
|
||||||
|
self.embedding_extender.add_embedding(embedding)
|
||||||
|
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||||
|
|
||||||
|
def inject(self) -> None:
|
||||||
|
self.embedding_extender.inject(self.token_encoder_parent)
|
||||||
|
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.embedding_extender.eject()
|
||||||
|
self.token_extender.eject()
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
|
old_weight: Parameter
|
||||||
|
new_weight: Parameter
|
||||||
|
weight: Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: TokenEncoder,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(fl.Lambda(func=self.lookup))
|
||||||
|
self.old_weight = cast(Parameter, target.weight)
|
||||||
|
self.new_weight = Parameter(
|
||||||
|
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
|
) # requires_grad=True by default
|
||||||
|
self.weight = cat([self.old_weight, self.new_weight])
|
||||||
|
|
||||||
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
||||||
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
|
self.new_weight = Parameter(
|
||||||
|
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
||||||
|
)
|
||||||
|
self.weight = cat([self.old_weight, self.new_weight])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_embeddings(self) -> int:
|
||||||
|
return self.weight.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
|
def __init__(self, target: CLIPTokenizer) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
CLIPTokenizer(
|
||||||
|
vocabulary_path=target.vocabulary_path,
|
||||||
|
sequence_length=target.sequence_length,
|
||||||
|
start_of_text_token_id=target.start_of_text_token_id,
|
||||||
|
end_of_text_token_id=target.end_of_text_token_id,
|
||||||
|
pad_token_id=target.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_token(self, token: str, token_id: int) -> None:
|
||||||
|
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None, "Tokenizer not found."
|
||||||
|
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||||
|
tokenizer.token_to_id_mapping[token] = token_id
|
||||||
|
current_pattern = tokenizer.token_pattern.pattern
|
||||||
|
new_pattern = token + "|" + current_pattern
|
||||||
|
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||||
|
# Define the keyword as its own smallest subtoken
|
||||||
|
tokenizer.byte_pair_encoding_cache[token] = token
|
|
@ -23,3 +23,7 @@ def test_weights_path() -> Path:
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def test_e2e_path() -> Path:
|
def test_e2e_path() -> Path:
|
||||||
return PARENT_PATH / "e2e"
|
return PARENT_PATH / "e2e"
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_textual_inversion_path() -> Path:
|
||||||
|
return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref"
|
||||||
|
|
|
@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
@ -118,6 +119,16 @@ def condition_image_refonly(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
|
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def text_encoder_weights(test_weights_path: Path) -> Path:
|
def text_encoder_weights(test_weights_path: Path) -> Path:
|
||||||
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||||
|
@ -689,3 +700,41 @@ def test_diffusion_inpainting_refonly(
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_textual_inversion_random_init(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
expected_image_textual_inversion_random_init: Image.Image,
|
||||||
|
text_embedding_textual_inversion: torch.Tensor,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
|
||||||
|
conceptExtender = ConceptExtender(sd15.clip_text_encoder)
|
||||||
|
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
|
||||||
|
conceptExtender.inject()
|
||||||
|
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat on a <gta5-artwork>"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from diffusers import StableDiffusionPipeline
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-v1-5",
|
"runwayml/stable-diffusion-v1-5",
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
).to("cuda)
|
).to("cuda")
|
||||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
@ -80,3 +80,34 @@ Special cases:
|
||||||
|
|
||||||
init_latents = self.vae.config.scaling_factor * init_latents
|
init_latents = self.vae.config.scaling_factor * init_latents
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Textual Inversion
|
||||||
|
|
||||||
|
- `expected_textual_inversion_random_init.png` has been generated with StableDiffusionPipeline, e.g.:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DPMSolverMultistepScheduler
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
).to("cuda")
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
pipe.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||||
|
|
||||||
|
prompt = "a cute cat on a <gta5-artwork>"
|
||||||
|
negative_prompt = ""
|
||||||
|
|
||||||
|
torch.manual_seed(2)
|
||||||
|
output = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=30,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
output.images[0].save("expected_textual_inversion_random_init.png")
|
||||||
|
```
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 450 KiB |
113
tests/foundationals/clip/test_concepts.py
Normal file
113
tests/foundationals/clip/test_concepts.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from warnings import warn
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionPipeline # type: ignore
|
||||||
|
import transformers # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
"a cute cat", # a simple prompt
|
||||||
|
"This artwork is inspired by <gta5-artwork> and uses a <cat-toy> as a prop", # prompt with two added concepts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def our_encoder_with_new_concepts(
|
||||||
|
test_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
cat_embedding_textual_inversion: torch.Tensor,
|
||||||
|
gta5_artwork_embedding_textual_inversion: torch.Tensor,
|
||||||
|
) -> CLIPTextEncoderL:
|
||||||
|
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
encoder = CLIPTextEncoderL(device=test_device)
|
||||||
|
tensors = load_from_safetensors(weights)
|
||||||
|
encoder.load_state_dict(tensors)
|
||||||
|
concept_extender = ConceptExtender(encoder)
|
||||||
|
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
|
||||||
|
concept_extender.add_concept("<gta5-artwork>", gta5_artwork_embedding_textual_inversion)
|
||||||
|
concept_extender.inject()
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_sd15_with_new_concepts(runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device):
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path, torch_dtype=torch.float16).to(test_device) # type: ignore
|
||||||
|
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
|
||||||
|
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def runwayml_weights_path(test_weights_path: Path):
|
||||||
|
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
|
||||||
|
if not r.is_dir():
|
||||||
|
warn(f"could not find RunwayML weights at {r}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
|
||||||
|
return ref_sd15_with_new_concepts.tokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_encoder_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTextModel:
|
||||||
|
return ref_sd15_with_new_concepts.text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=PROMPTS)
|
||||||
|
def prompt(request: pytest.FixtureRequest):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
|
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
|
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder(
|
||||||
|
prompt: str,
|
||||||
|
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
|
||||||
|
ref_encoder_with_new_concepts: transformers.CLIPTextModel,
|
||||||
|
our_encoder_with_new_concepts: CLIPTextEncoderL,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
ref_tokens = ref_tokenizer_with_new_concepts( # type: ignore
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=ref_tokenizer_with_new_concepts.model_max_length, # type: ignore
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids
|
||||||
|
assert isinstance(ref_tokens, torch.Tensor)
|
||||||
|
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None
|
||||||
|
our_tokens = tokenizer(prompt)
|
||||||
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
|
||||||
|
our_embeddings = our_encoder_with_new_concepts(prompt)
|
||||||
|
|
||||||
|
assert ref_embeddings.shape == (1, 77, 768)
|
||||||
|
assert our_embeddings.shape == (1, 77, 768)
|
||||||
|
|
||||||
|
# See `test_encoder` in test_text_encoder.py for details about the tolerance (0.04)
|
||||||
|
assert (our_embeddings - ref_embeddings).abs().max() < 0.04
|
6
tests/foundationals/clip/test_concepts_ref/README.md
Normal file
6
tests/foundationals/clip/test_concepts_ref/README.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# Note about this data
|
||||||
|
|
||||||
|
## Textual Inversion Concepts
|
||||||
|
|
||||||
|
- `textual_inversion/cat-toy` comes from https://huggingface.co/sd-concepts-library/cat-toy
|
||||||
|
- `textual_inversion/gta5-artwork` comes from https://huggingface.co/sd-concepts-library/gta5-artwork
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue