mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Add concepts learning via textual inversion
This commit is contained in:
parent
0f476ea18b
commit
9f6733de8e
63
configs/finetune-textual-inversion.toml
Normal file
63
configs/finetune-textual-inversion.toml
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
script = "finetune-ldm-textual_inversion.py" # not used for now
|
||||||
|
|
||||||
|
[wandb]
|
||||||
|
mode = "offline" # "online", "offline", "disabled"
|
||||||
|
entity = "acme"
|
||||||
|
project = "test-textual-inversion"
|
||||||
|
|
||||||
|
[models]
|
||||||
|
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
||||||
|
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
||||||
|
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
||||||
|
|
||||||
|
[latent_diffusion]
|
||||||
|
unconditional_sampling_probability = 0.05
|
||||||
|
offset_noise = 0.1
|
||||||
|
|
||||||
|
[textual_inversion]
|
||||||
|
placeholder_token = "<cat-toy>"
|
||||||
|
initializer_token = "toy"
|
||||||
|
# style_mode = true
|
||||||
|
|
||||||
|
[training]
|
||||||
|
duration = "2000:step"
|
||||||
|
seed = 0
|
||||||
|
gpu_index = 0
|
||||||
|
batch_size = 4
|
||||||
|
gradient_accumulation = "1:step"
|
||||||
|
evaluation_interval = "250:step"
|
||||||
|
evaluation_seed = 1
|
||||||
|
|
||||||
|
[optimizer]
|
||||||
|
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
||||||
|
learning_rate = 5e-4
|
||||||
|
betas = [0.9, 0.999]
|
||||||
|
eps = 1e-8
|
||||||
|
weight_decay = 1e-2
|
||||||
|
|
||||||
|
[scheduler]
|
||||||
|
scheduler_type = "ConstantLR"
|
||||||
|
update_interval = "1:step"
|
||||||
|
|
||||||
|
[dropout]
|
||||||
|
dropout_probability = 0
|
||||||
|
use_gyro_dropout = false
|
||||||
|
|
||||||
|
[dataset]
|
||||||
|
hf_repo = "acme/cat-toy"
|
||||||
|
revision = "main"
|
||||||
|
horizontal_flip = true
|
||||||
|
random_crop = true
|
||||||
|
resize_image_max_size = 512
|
||||||
|
|
||||||
|
[checkpointing]
|
||||||
|
# save_folder = "/path/to/ckpts"
|
||||||
|
save_interval = "250:step"
|
||||||
|
|
||||||
|
[test_diffusion]
|
||||||
|
num_inference_steps = 30
|
||||||
|
use_short_prompts = false
|
||||||
|
prompts = [
|
||||||
|
"<cat-toy>",
|
||||||
|
# "green grass, <cat-toy>"
|
||||||
|
]
|
164
scripts/training/finetune-ldm-textual-inversion.py
Normal file
164
scripts/training/finetune-ldm-textual-inversion.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from loguru import logger
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch import randn, Tensor
|
||||||
|
import random
|
||||||
|
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.training_utils.latent_diffusion import (
|
||||||
|
FinetuneLatentDiffusionConfig,
|
||||||
|
TextEmbeddingLatentsBatch,
|
||||||
|
LatentDiffusionTrainer,
|
||||||
|
LatentDiffusionConfig,
|
||||||
|
TextEmbeddingLatentsDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGENET_TEMPLATES_SMALL = [
|
||||||
|
"a photo of a {}",
|
||||||
|
"a rendering of a {}",
|
||||||
|
"a cropped photo of the {}",
|
||||||
|
"the photo of a {}",
|
||||||
|
"a photo of a clean {}",
|
||||||
|
"a photo of a dirty {}",
|
||||||
|
"a dark photo of the {}",
|
||||||
|
"a photo of my {}",
|
||||||
|
"a photo of the cool {}",
|
||||||
|
"a close-up photo of a {}",
|
||||||
|
"a bright photo of the {}",
|
||||||
|
"a cropped photo of a {}",
|
||||||
|
"a photo of the {}",
|
||||||
|
"a good photo of the {}",
|
||||||
|
"a photo of one {}",
|
||||||
|
"a close-up photo of the {}",
|
||||||
|
"a rendition of the {}",
|
||||||
|
"a photo of the clean {}",
|
||||||
|
"a rendition of a {}",
|
||||||
|
"a photo of a nice {}",
|
||||||
|
"a good photo of a {}",
|
||||||
|
"a photo of the nice {}",
|
||||||
|
"a photo of the small {}",
|
||||||
|
"a photo of the weird {}",
|
||||||
|
"a photo of the large {}",
|
||||||
|
"a photo of a cool {}",
|
||||||
|
"a photo of a small {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
IMAGENET_STYLE_TEMPLATES_SMALL = [
|
||||||
|
"a painting in the style of {}",
|
||||||
|
"a rendering in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"the painting in the style of {}",
|
||||||
|
"a clean painting in the style of {}",
|
||||||
|
"a dirty painting in the style of {}",
|
||||||
|
"a dark painting in the style of {}",
|
||||||
|
"a picture in the style of {}",
|
||||||
|
"a cool painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a bright painting in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"a good painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a rendition in the style of {}",
|
||||||
|
"a nice painting in the style of {}",
|
||||||
|
"a small painting in the style of {}",
|
||||||
|
"a weird painting in the style of {}",
|
||||||
|
"a large painting in the style of {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionDataset(TextEmbeddingLatentsDataset):
|
||||||
|
templates: list[str] = []
|
||||||
|
placeholder_token: str = ""
|
||||||
|
|
||||||
|
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
|
||||||
|
super().__init__(trainer)
|
||||||
|
self.templates = (
|
||||||
|
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
|
||||||
|
)
|
||||||
|
self.placeholder_token = self.config.textual_inversion.placeholder_token
|
||||||
|
|
||||||
|
def get_caption(self, index: int) -> str:
|
||||||
|
# Ignore the dataset caption, if any: use a template instead
|
||||||
|
return random.choice(self.templates).format(self.placeholder_token)
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionConfig(BaseModel):
|
||||||
|
# The new token to be learned
|
||||||
|
placeholder_token: str = "*"
|
||||||
|
# The token to be used as initializer; if None, a random vector is used
|
||||||
|
initializer_token: str | None = None
|
||||||
|
style_mode: bool = False
|
||||||
|
|
||||||
|
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
||||||
|
adapter = ConceptExtender(target=text_encoder)
|
||||||
|
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None, "Tokenizer not found in text encoder."
|
||||||
|
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||||
|
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||||
|
if self.initializer_token is not None:
|
||||||
|
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
||||||
|
assert " " not in bpe, "This initializer_token is not a single token."
|
||||||
|
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
||||||
|
init_embedding = token_encoder(token).squeeze(0)
|
||||||
|
else:
|
||||||
|
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||||
|
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||||
|
init_embedding = randn(token_encoder.embedding_dim)
|
||||||
|
adapter.add_concept(self.placeholder_token, init_embedding)
|
||||||
|
adapter.inject()
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||||
|
latent_diffusion: LatentDiffusionConfig
|
||||||
|
textual_inversion: TextualInversionConfig
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
# Pydantic v2 does post init differently, so we need to override this method too.
|
||||||
|
logger.info("Freezing models to train only the new embedding.")
|
||||||
|
self.models["unet"].train = False
|
||||||
|
self.models["text_encoder"].train = False
|
||||||
|
self.models["lda"].train = False
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TextualInversionLatentDiffusionConfig,
|
||||||
|
callbacks: "list[Callback[Any]] | None" = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config=config, callbacks=callbacks)
|
||||||
|
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
|
||||||
|
|
||||||
|
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||||
|
return TextualInversionDataset(trainer=self)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||||
|
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||||
|
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||||
|
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||||
|
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
|
||||||
|
assert embedding_extender is not None, "Embedding extender not found in text encoder."
|
||||||
|
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
||||||
|
|
||||||
|
save_to_safetensors(
|
||||||
|
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
|
||||||
|
trainer = TextualInversionLatentDiffusionTrainer(config=config)
|
||||||
|
trainer.train()
|
|
@ -71,7 +71,6 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
old_weight: Parameter
|
old_weight: Parameter
|
||||||
new_weight: Parameter
|
new_weight: Parameter
|
||||||
weight: Tensor
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -83,22 +82,21 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
self.new_weight = Parameter(
|
self.new_weight = Parameter(
|
||||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
) # requires_grad=True by default
|
) # requires_grad=True by default
|
||||||
self.weight = cat([self.old_weight, self.new_weight])
|
|
||||||
|
|
||||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
def lookup(self, x: Tensor) -> Tensor:
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
# Concatenate old and new weights for dynamic embedding updates during training
|
||||||
|
return F.embedding(x, cat([self.old_weight, self.new_weight]))
|
||||||
|
|
||||||
def add_embedding(self, embedding: Tensor) -> None:
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
assert embedding.shape == (self.old_weight.shape[1],)
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
self.new_weight = Parameter(
|
self.new_weight = Parameter(
|
||||||
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
||||||
)
|
)
|
||||||
self.weight = cat([self.old_weight, self.new_weight])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_embeddings(self) -> int:
|
def num_embeddings(self) -> int:
|
||||||
return self.weight.shape[0]
|
return self.old_weight.shape[0] + self.new_weight.shape[0]
|
||||||
|
|
||||||
|
|
||||||
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
|
@ -115,12 +113,13 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_token(self, token: str, token_id: int) -> None:
|
def add_token(self, token: str, token_id: int) -> None:
|
||||||
|
token = token.lower()
|
||||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Tokenizer not found."
|
assert tokenizer is not None, "Tokenizer not found."
|
||||||
assert token_id not in tokenizer.token_to_id_mapping.values()
|
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||||
tokenizer.token_to_id_mapping[token] = token_id
|
tokenizer.token_to_id_mapping[token] = token_id
|
||||||
current_pattern = tokenizer.token_pattern.pattern
|
current_pattern = tokenizer.token_pattern.pattern
|
||||||
new_pattern = token + "|" + current_pattern
|
new_pattern = re.escape(token) + "|" + current_pattern
|
||||||
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||||
# Define the keyword as its own smallest subtoken
|
# Define the keyword as its own smallest subtoken
|
||||||
tokenizer.byte_pair_encoding_cache[token] = token
|
tokenizer.byte_pair_encoding_cache[token] = token
|
||||||
|
|
|
@ -95,9 +95,15 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
def process_caption(self, caption: str) -> str:
|
def process_caption(self, caption: str) -> str:
|
||||||
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
|
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
|
||||||
|
|
||||||
|
def get_caption(self, index: int) -> str:
|
||||||
|
return self.dataset[index]["caption"]
|
||||||
|
|
||||||
|
def get_image(self, index: int) -> Image.Image:
|
||||||
|
return self.dataset[index]["image"]
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
|
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
|
||||||
item = self.dataset[index]
|
caption = self.get_caption(index=index)
|
||||||
caption, image = item["caption"], item["image"]
|
image = self.get_image(index=index)
|
||||||
resized_image = self.resize_image(
|
resized_image = self.resize_image(
|
||||||
image=image,
|
image=image,
|
||||||
min_size=self.config.dataset.resize_image_min_size,
|
min_size=self.config.dataset.resize_image_min_size,
|
||||||
|
|
|
@ -4,10 +4,11 @@ import pytest
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
from diffusers import StableDiffusionPipeline # type: ignore
|
from diffusers import StableDiffusionPipeline # type: ignore
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
|
@ -84,6 +85,28 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
|
||||||
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer_with_special_character():
|
||||||
|
clip_tokenizer = fl.Chain(CLIPTokenizer())
|
||||||
|
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer)
|
||||||
|
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42
|
||||||
|
token_extender.add_token("*", new_token_id)
|
||||||
|
token_extender.inject(clip_tokenizer)
|
||||||
|
|
||||||
|
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
|
||||||
|
assert adapted_clip_tokenizer is not None
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
adapted_clip_tokenizer.encode("*"),
|
||||||
|
torch.Tensor(
|
||||||
|
[
|
||||||
|
adapted_clip_tokenizer.start_of_text_token_id,
|
||||||
|
new_token_id,
|
||||||
|
adapted_clip_tokenizer.end_of_text_token_id,
|
||||||
|
]
|
||||||
|
).to(torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_encoder(
|
def test_encoder(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
|
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
|
||||||
|
|
Loading…
Reference in a new issue