mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
Add support for learned concepts e.g. via textual inversion
This commit is contained in:
parent
8b1719b1f9
commit
3680f9d196
121
src/refiners/foundationals/clip/concepts.py
Normal file
121
src/refiners/foundationals/clip/concepts.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
from refiners.adapters.adapter import Adapter
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
import refiners.fluxion.layers as fl
|
||||
from typing import cast
|
||||
|
||||
from torch import Tensor, cat, zeros
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
import re
|
||||
|
||||
|
||||
class ConceptExtender:
|
||||
"""
|
||||
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
||||
|
||||
Example:
|
||||
import torch
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
|
||||
encoder = CLIPTextEncoderL(device="cuda")
|
||||
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
|
||||
encoder.load_state_dict(tensors)
|
||||
|
||||
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
|
||||
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
|
||||
|
||||
extender = ConceptExtender(encoder)
|
||||
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
|
||||
extender.inject()
|
||||
# New concepts can be added at any time
|
||||
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
|
||||
|
||||
# Now the encoder can be used with the new concepts
|
||||
"""
|
||||
|
||||
def __init__(self, target: CLIPTextEncoder) -> None:
|
||||
try:
|
||||
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
|
||||
except StopIteration:
|
||||
raise RuntimeError("TokenEncoder not found.")
|
||||
|
||||
try:
|
||||
clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Tokenizer not found.")
|
||||
|
||||
self.embedding_extender = EmbeddingExtender(token_encoder)
|
||||
self.token_extender = TokenExtender(clip_tokenizer)
|
||||
|
||||
def add_concept(self, token: str, embedding: Tensor) -> None:
|
||||
self.embedding_extender.add_embedding(embedding)
|
||||
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||
|
||||
def inject(self) -> None:
|
||||
self.embedding_extender.inject(self.token_encoder_parent)
|
||||
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.embedding_extender.eject()
|
||||
self.token_extender.eject()
|
||||
|
||||
|
||||
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||
old_weight: Parameter
|
||||
new_weight: Parameter
|
||||
weight: Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: TokenEncoder,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(fl.Lambda(func=self.lookup))
|
||||
self.old_weight = cast(Parameter, target.weight)
|
||||
self.new_weight = Parameter(
|
||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||
) # requires_grad=True by default
|
||||
self.weight = cat([self.old_weight, self.new_weight])
|
||||
|
||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||
def lookup(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
||||
def add_embedding(self, embedding: Tensor) -> None:
|
||||
assert embedding.shape == (self.old_weight.shape[1],)
|
||||
self.new_weight = Parameter(
|
||||
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
||||
)
|
||||
self.weight = cat([self.old_weight, self.new_weight])
|
||||
|
||||
@property
|
||||
def num_embeddings(self) -> int:
|
||||
return self.weight.shape[0]
|
||||
|
||||
|
||||
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||
def __init__(self, target: CLIPTokenizer) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
CLIPTokenizer(
|
||||
vocabulary_path=target.vocabulary_path,
|
||||
sequence_length=target.sequence_length,
|
||||
start_of_text_token_id=target.start_of_text_token_id,
|
||||
end_of_text_token_id=target.end_of_text_token_id,
|
||||
pad_token_id=target.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
def add_token(self, token: str, token_id: int) -> None:
|
||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||
tokenizer.token_to_id_mapping[token] = token_id
|
||||
current_pattern = tokenizer.token_pattern.pattern
|
||||
new_pattern = token + "|" + current_pattern
|
||||
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||
# Define the keyword as its own smallest subtoken
|
||||
tokenizer.byte_pair_encoding_cache[token] = token
|
|
@ -23,3 +23,7 @@ def test_weights_path() -> Path:
|
|||
@fixture(scope="session")
|
||||
def test_e2e_path() -> Path:
|
||||
return PARENT_PATH / "e2e"
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_textual_inversion_path() -> Path:
|
||||
return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref"
|
||||
|
|
|
@ -14,6 +14,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
|
|||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
@ -118,6 +119,16 @@ def condition_image_refonly(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def text_encoder_weights(test_weights_path: Path) -> Path:
|
||||
text_encoder_weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||
|
@ -689,3 +700,41 @@ def test_diffusion_inpainting_refonly(
|
|||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_textual_inversion_random_init(
|
||||
sd15_std: StableDiffusion_1,
|
||||
expected_image_textual_inversion_random_init: Image.Image,
|
||||
text_embedding_textual_inversion: torch.Tensor,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
|
||||
conceptExtender = ConceptExtender(sd15.clip_text_encoder)
|
||||
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
|
||||
conceptExtender.inject()
|
||||
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat on a <gta5-artwork>"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
|
|
@ -15,7 +15,7 @@ from diffusers import StableDiffusionPipeline
|
|||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float32,
|
||||
).to("cuda)
|
||||
).to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
|
@ -80,3 +80,34 @@ Special cases:
|
|||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
```
|
||||
|
||||
## Textual Inversion
|
||||
|
||||
- `expected_textual_inversion_random_init.png` has been generated with StableDiffusionPipeline, e.g.:
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float32,
|
||||
).to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||
|
||||
prompt = "a cute cat on a <gta5-artwork>"
|
||||
negative_prompt = ""
|
||||
|
||||
torch.manual_seed(2)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=7.5,
|
||||
)
|
||||
|
||||
output.images[0].save("expected_textual_inversion_random_init.png")
|
||||
```
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 450 KiB |
113
tests/foundationals/clip/test_concepts.py
Normal file
113
tests/foundationals/clip/test_concepts.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
from warnings import warn
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
|
||||
from diffusers import StableDiffusionPipeline # type: ignore
|
||||
import transformers # type: ignore
|
||||
|
||||
|
||||
PROMPTS = [
|
||||
"a cute cat", # a simple prompt
|
||||
"This artwork is inspired by <gta5-artwork> and uses a <cat-toy> as a prop", # prompt with two added concepts
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def our_encoder_with_new_concepts(
|
||||
test_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
cat_embedding_textual_inversion: torch.Tensor,
|
||||
gta5_artwork_embedding_textual_inversion: torch.Tensor,
|
||||
) -> CLIPTextEncoderL:
|
||||
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
encoder = CLIPTextEncoderL(device=test_device)
|
||||
tensors = load_from_safetensors(weights)
|
||||
encoder.load_state_dict(tensors)
|
||||
concept_extender = ConceptExtender(encoder)
|
||||
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
|
||||
concept_extender.add_concept("<gta5-artwork>", gta5_artwork_embedding_textual_inversion)
|
||||
concept_extender.inject()
|
||||
return encoder
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_sd15_with_new_concepts(runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path, torch_dtype=torch.float16).to(test_device) # type: ignore
|
||||
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
|
||||
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def runwayml_weights_path(test_weights_path: Path):
|
||||
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
|
||||
if not r.is_dir():
|
||||
warn(f"could not find RunwayML weights at {r}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
|
||||
return ref_sd15_with_new_concepts.tokenizer # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_encoder_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTextModel:
|
||||
return ref_sd15_with_new_concepts.text_encoder
|
||||
|
||||
|
||||
@pytest.fixture(params=PROMPTS)
|
||||
def prompt(request: pytest.FixtureRequest):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
||||
|
||||
|
||||
def test_encoder(
|
||||
prompt: str,
|
||||
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
|
||||
ref_encoder_with_new_concepts: transformers.CLIPTextModel,
|
||||
our_encoder_with_new_concepts: CLIPTextEncoderL,
|
||||
test_device: torch.device,
|
||||
):
|
||||
ref_tokens = ref_tokenizer_with_new_concepts( # type: ignore
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=ref_tokenizer_with_new_concepts.model_max_length, # type: ignore
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
assert isinstance(ref_tokens, torch.Tensor)
|
||||
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None
|
||||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
|
||||
our_embeddings = our_encoder_with_new_concepts(prompt)
|
||||
|
||||
assert ref_embeddings.shape == (1, 77, 768)
|
||||
assert our_embeddings.shape == (1, 77, 768)
|
||||
|
||||
# See `test_encoder` in test_text_encoder.py for details about the tolerance (0.04)
|
||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.04
|
6
tests/foundationals/clip/test_concepts_ref/README.md
Normal file
6
tests/foundationals/clip/test_concepts_ref/README.md
Normal file
|
@ -0,0 +1,6 @@
|
|||
# Note about this data
|
||||
|
||||
## Textual Inversion Concepts
|
||||
|
||||
- `textual_inversion/cat-toy` comes from https://huggingface.co/sd-concepts-library/cat-toy
|
||||
- `textual_inversion/gta5-artwork` comes from https://huggingface.co/sd-concepts-library/gta5-artwork
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue