Add support for learned concepts e.g. via textual inversion

This commit is contained in:
Doryan Kaced 2023-08-25 20:27:29 +02:00 committed by Doryan Kaced
parent 8b1719b1f9
commit 3680f9d196
9 changed files with 325 additions and 1 deletions

View file

@ -0,0 +1,121 @@
from refiners.adapters.adapter import Adapter
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.fluxion.layers as fl
from typing import cast
from torch import Tensor, cat, zeros
import torch.nn.functional as F
from torch.nn import Parameter
import re
class ConceptExtender:
"""
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
Example:
import torch
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors
encoder = CLIPTextEncoderL(device="cuda")
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
encoder.load_state_dict(tensors)
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
extender = ConceptExtender(encoder)
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
extender.inject()
# New concepts can be added at any time
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
# Now the encoder can be used with the new concepts
"""
def __init__(self, target: CLIPTextEncoder) -> None:
try:
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
except StopIteration:
raise RuntimeError("TokenEncoder not found.")
try:
clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
except StopIteration:
raise RuntimeError("Tokenizer not found.")
self.embedding_extender = EmbeddingExtender(token_encoder)
self.token_extender = TokenExtender(clip_tokenizer)
def add_concept(self, token: str, embedding: Tensor) -> None:
self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
def inject(self) -> None:
self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent)
def eject(self) -> None:
self.embedding_extender.eject()
self.token_extender.eject()
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
old_weight: Parameter
new_weight: Parameter
weight: Tensor
def __init__(
self,
target: TokenEncoder,
) -> None:
with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup))
self.old_weight = cast(Parameter, target.weight)
self.new_weight = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.weight = cat([self.old_weight, self.new_weight])
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)
def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
self.new_weight = Parameter(
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
)
self.weight = cat([self.old_weight, self.new_weight])
@property
def num_embeddings(self) -> int:
return self.weight.shape[0]
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
def __init__(self, target: CLIPTokenizer) -> None:
with self.setup_adapter(target):
super().__init__(
CLIPTokenizer(
vocabulary_path=target.vocabulary_path,
sequence_length=target.sequence_length,
start_of_text_token_id=target.start_of_text_token_id,
end_of_text_token_id=target.end_of_text_token_id,
pad_token_id=target.pad_token_id,
)
)
def add_token(self, token: str, token_id: int) -> None:
tokenizer = self.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
assert token_id not in tokenizer.token_to_id_mapping.values()
tokenizer.token_to_id_mapping[token] = token_id
current_pattern = tokenizer.token_pattern.pattern
new_pattern = token + "|" + current_pattern
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
# Define the keyword as its own smallest subtoken
tokenizer.byte_pair_encoding_cache[token] = token

View file

@ -23,3 +23,7 @@ def test_weights_path() -> Path:
@fixture(scope="session")
def test_e2e_path() -> Path:
return PARENT_PATH / "e2e"
@fixture(scope="session")
def test_textual_inversion_path() -> Path:
return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref"

View file

@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
from refiners.foundationals.clip.concepts import ConceptExtender
from tests.utils import ensure_similar_images
@ -118,6 +119,16 @@ def condition_image_refonly(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
@pytest.fixture
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
@pytest.fixture(scope="module")
def text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
@ -689,3 +700,41 @@ def test_diffusion_inpainting_refonly(
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad()
def test_diffusion_textual_inversion_random_init(
sd15_std: StableDiffusion_1,
expected_image_textual_inversion_random_init: Image.Image,
text_embedding_textual_inversion: torch.Tensor,
test_device: torch.device,
):
sd15 = sd15_std
conceptExtender = ConceptExtender(sd15.clip_text_encoder)
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
conceptExtender.inject()
n_steps = 30
prompt = "a cute cat on a <gta5-artwork>"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)

View file

@ -15,7 +15,7 @@ from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float32,
).to("cuda)
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
prompt = "a cute cat, detailed high-quality professional image"
@ -80,3 +80,34 @@ Special cases:
init_latents = self.vae.config.scaling_factor * init_latents
```
## Textual Inversion
- `expected_textual_inversion_random_init.png` has been generated with StableDiffusionPipeline, e.g.:
```py
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float32,
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.load_textual_inversion("sd-concepts-library/gta5-artwork")
prompt = "a cute cat on a <gta5-artwork>"
negative_prompt = ""
torch.manual_seed(2)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=30,
guidance_scale=7.5,
)
output.images[0].save("expected_textual_inversion_random_init.png")
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 450 KiB

View file

@ -0,0 +1,113 @@
import torch
import pytest
from warnings import warn
from pathlib import Path
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.fluxion.utils import load_from_safetensors
from diffusers import StableDiffusionPipeline # type: ignore
import transformers # type: ignore
PROMPTS = [
"a cute cat", # a simple prompt
"This artwork is inspired by <gta5-artwork> and uses a <cat-toy> as a prop", # prompt with two added concepts
]
@pytest.fixture(scope="module")
def our_encoder_with_new_concepts(
test_weights_path: Path,
test_device: torch.device,
cat_embedding_textual_inversion: torch.Tensor,
gta5_artwork_embedding_textual_inversion: torch.Tensor,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights)
encoder.load_state_dict(tensors)
concept_extender = ConceptExtender(encoder)
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
concept_extender.add_concept("<gta5-artwork>", gta5_artwork_embedding_textual_inversion)
concept_extender.inject()
return encoder
@pytest.fixture(scope="module")
def ref_sd15_with_new_concepts(runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device):
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path, torch_dtype=torch.float16).to(test_device) # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
return pipe
@pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path):
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
if not r.is_dir():
warn(f"could not find RunwayML weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
return ref_sd15_with_new_concepts.tokenizer # type: ignore
@pytest.fixture(scope="module")
def ref_encoder_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTextModel:
return ref_sd15_with_new_concepts.text_encoder
@pytest.fixture(params=PROMPTS)
def prompt(request: pytest.FixtureRequest):
return request.param
@pytest.fixture(scope="module")
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
@pytest.fixture(scope="module")
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
def test_encoder(
prompt: str,
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
ref_encoder_with_new_concepts: transformers.CLIPTextModel,
our_encoder_with_new_concepts: CLIPTextEncoderL,
test_device: torch.device,
):
ref_tokens = ref_tokenizer_with_new_concepts( # type: ignore
prompt,
padding="max_length",
max_length=ref_tokenizer_with_new_concepts.model_max_length, # type: ignore
truncation=True,
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
assert tokenizer is not None
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad():
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder_with_new_concepts(prompt)
assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768)
# See `test_encoder` in test_text_encoder.py for details about the tolerance (0.04)
assert (our_embeddings - ref_embeddings).abs().max() < 0.04

View file

@ -0,0 +1,6 @@
# Note about this data
## Textual Inversion Concepts
- `textual_inversion/cat-toy` comes from https://huggingface.co/sd-concepts-library/cat-toy
- `textual_inversion/gta5-artwork` comes from https://huggingface.co/sd-concepts-library/gta5-artwork