mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
Add concepts learning via textual inversion
This commit is contained in:
parent
0f476ea18b
commit
9f6733de8e
63
configs/finetune-textual-inversion.toml
Normal file
63
configs/finetune-textual-inversion.toml
Normal file
|
@ -0,0 +1,63 @@
|
|||
script = "finetune-ldm-textual_inversion.py" # not used for now
|
||||
|
||||
[wandb]
|
||||
mode = "offline" # "online", "offline", "disabled"
|
||||
entity = "acme"
|
||||
project = "test-textual-inversion"
|
||||
|
||||
[models]
|
||||
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
||||
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
||||
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
||||
|
||||
[latent_diffusion]
|
||||
unconditional_sampling_probability = 0.05
|
||||
offset_noise = 0.1
|
||||
|
||||
[textual_inversion]
|
||||
placeholder_token = "<cat-toy>"
|
||||
initializer_token = "toy"
|
||||
# style_mode = true
|
||||
|
||||
[training]
|
||||
duration = "2000:step"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "1:step"
|
||||
evaluation_interval = "250:step"
|
||||
evaluation_seed = 1
|
||||
|
||||
[optimizer]
|
||||
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
||||
learning_rate = 5e-4
|
||||
betas = [0.9, 0.999]
|
||||
eps = 1e-8
|
||||
weight_decay = 1e-2
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0
|
||||
use_gyro_dropout = false
|
||||
|
||||
[dataset]
|
||||
hf_repo = "acme/cat-toy"
|
||||
revision = "main"
|
||||
horizontal_flip = true
|
||||
random_crop = true
|
||||
resize_image_max_size = 512
|
||||
|
||||
[checkpointing]
|
||||
# save_folder = "/path/to/ckpts"
|
||||
save_interval = "250:step"
|
||||
|
||||
[test_diffusion]
|
||||
num_inference_steps = 30
|
||||
use_short_prompts = false
|
||||
prompts = [
|
||||
"<cat-toy>",
|
||||
# "green grass, <cat-toy>"
|
||||
]
|
164
scripts/training/finetune-ldm-textual-inversion.py
Normal file
164
scripts/training/finetune-ldm-textual-inversion.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
from torch import randn, Tensor
|
||||
import random
|
||||
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.latent_diffusion import (
|
||||
FinetuneLatentDiffusionConfig,
|
||||
TextEmbeddingLatentsBatch,
|
||||
LatentDiffusionTrainer,
|
||||
LatentDiffusionConfig,
|
||||
TextEmbeddingLatentsDataset,
|
||||
)
|
||||
|
||||
|
||||
IMAGENET_TEMPLATES_SMALL = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
IMAGENET_STYLE_TEMPLATES_SMALL = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(TextEmbeddingLatentsDataset):
|
||||
templates: list[str] = []
|
||||
placeholder_token: str = ""
|
||||
|
||||
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
|
||||
super().__init__(trainer)
|
||||
self.templates = (
|
||||
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
|
||||
)
|
||||
self.placeholder_token = self.config.textual_inversion.placeholder_token
|
||||
|
||||
def get_caption(self, index: int) -> str:
|
||||
# Ignore the dataset caption, if any: use a template instead
|
||||
return random.choice(self.templates).format(self.placeholder_token)
|
||||
|
||||
|
||||
class TextualInversionConfig(BaseModel):
|
||||
# The new token to be learned
|
||||
placeholder_token: str = "*"
|
||||
# The token to be used as initializer; if None, a random vector is used
|
||||
initializer_token: str | None = None
|
||||
style_mode: bool = False
|
||||
|
||||
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
||||
adapter = ConceptExtender(target=text_encoder)
|
||||
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found in text encoder."
|
||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||
if self.initializer_token is not None:
|
||||
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
||||
assert " " not in bpe, "This initializer_token is not a single token."
|
||||
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
||||
init_embedding = token_encoder(token).squeeze(0)
|
||||
else:
|
||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||
init_embedding = randn(token_encoder.embedding_dim)
|
||||
adapter.add_concept(self.placeholder_token, init_embedding)
|
||||
adapter.inject()
|
||||
|
||||
|
||||
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
textual_inversion: TextualInversionConfig
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
# Pydantic v2 does post init differently, so we need to override this method too.
|
||||
logger.info("Freezing models to train only the new embedding.")
|
||||
self.models["unet"].train = False
|
||||
self.models["text_encoder"].train = False
|
||||
self.models["lda"].train = False
|
||||
|
||||
|
||||
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
config: TextualInversionLatentDiffusionConfig,
|
||||
callbacks: "list[Callback[Any]] | None" = None,
|
||||
) -> None:
|
||||
super().__init__(config=config, callbacks=callbacks)
|
||||
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
|
||||
|
||||
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||
return TextualInversionDataset(trainer=self)
|
||||
|
||||
|
||||
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
|
||||
|
||||
|
||||
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
|
||||
assert embedding_extender is not None, "Embedding extender not found in text encoder."
|
||||
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
||||
|
||||
save_to_safetensors(
|
||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
|
||||
trainer = TextualInversionLatentDiffusionTrainer(config=config)
|
||||
trainer.train()
|
|
@ -71,7 +71,6 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
|||
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||
old_weight: Parameter
|
||||
new_weight: Parameter
|
||||
weight: Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -83,22 +82,21 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
|||
self.new_weight = Parameter(
|
||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||
) # requires_grad=True by default
|
||||
self.weight = cat([self.old_weight, self.new_weight])
|
||||
|
||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||
def lookup(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
# Concatenate old and new weights for dynamic embedding updates during training
|
||||
return F.embedding(x, cat([self.old_weight, self.new_weight]))
|
||||
|
||||
def add_embedding(self, embedding: Tensor) -> None:
|
||||
assert embedding.shape == (self.old_weight.shape[1],)
|
||||
self.new_weight = Parameter(
|
||||
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
||||
)
|
||||
self.weight = cat([self.old_weight, self.new_weight])
|
||||
|
||||
@property
|
||||
def num_embeddings(self) -> int:
|
||||
return self.weight.shape[0]
|
||||
return self.old_weight.shape[0] + self.new_weight.shape[0]
|
||||
|
||||
|
||||
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||
|
@ -115,12 +113,13 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
|||
)
|
||||
|
||||
def add_token(self, token: str, token_id: int) -> None:
|
||||
token = token.lower()
|
||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||
tokenizer.token_to_id_mapping[token] = token_id
|
||||
current_pattern = tokenizer.token_pattern.pattern
|
||||
new_pattern = token + "|" + current_pattern
|
||||
new_pattern = re.escape(token) + "|" + current_pattern
|
||||
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||
# Define the keyword as its own smallest subtoken
|
||||
tokenizer.byte_pair_encoding_cache[token] = token
|
||||
|
|
|
@ -95,9 +95,15 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
def process_caption(self, caption: str) -> str:
|
||||
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
|
||||
|
||||
def get_caption(self, index: int) -> str:
|
||||
return self.dataset[index]["caption"]
|
||||
|
||||
def get_image(self, index: int) -> Image.Image:
|
||||
return self.dataset[index]["image"]
|
||||
|
||||
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
|
||||
item = self.dataset[index]
|
||||
caption, image = item["caption"], item["image"]
|
||||
caption = self.get_caption(index=index)
|
||||
image = self.get_image(index=index)
|
||||
resized_image = self.resize_image(
|
||||
image=image,
|
||||
min_size=self.config.dataset.resize_image_min_size,
|
||||
|
|
|
@ -4,10 +4,11 @@ import pytest
|
|||
from warnings import warn
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
from diffusers import StableDiffusionPipeline # type: ignore
|
||||
import transformers # type: ignore
|
||||
|
@ -84,6 +85,28 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
|
|||
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
||||
|
||||
|
||||
def test_tokenizer_with_special_character():
|
||||
clip_tokenizer = fl.Chain(CLIPTokenizer())
|
||||
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer)
|
||||
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42
|
||||
token_extender.add_token("*", new_token_id)
|
||||
token_extender.inject(clip_tokenizer)
|
||||
|
||||
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
|
||||
assert adapted_clip_tokenizer is not None
|
||||
|
||||
assert torch.allclose(
|
||||
adapted_clip_tokenizer.encode("*"),
|
||||
torch.Tensor(
|
||||
[
|
||||
adapted_clip_tokenizer.start_of_text_token_id,
|
||||
new_token_id,
|
||||
adapted_clip_tokenizer.end_of_text_token_id,
|
||||
]
|
||||
).to(torch.int64),
|
||||
)
|
||||
|
||||
|
||||
def test_encoder(
|
||||
prompt: str,
|
||||
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
|
||||
|
|
Loading…
Reference in a new issue