diff --git a/configs/finetune-textual-inversion.toml b/configs/finetune-textual-inversion.toml new file mode 100644 index 0000000..90597a8 --- /dev/null +++ b/configs/finetune-textual-inversion.toml @@ -0,0 +1,63 @@ +script = "finetune-ldm-textual_inversion.py" # not used for now + +[wandb] +mode = "offline" # "online", "offline", "disabled" +entity = "acme" +project = "test-textual-inversion" + +[models] +unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"} +text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} +lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"} + +[latent_diffusion] +unconditional_sampling_probability = 0.05 +offset_noise = 0.1 + +[textual_inversion] +placeholder_token = "" +initializer_token = "toy" +# style_mode = true + +[training] +duration = "2000:step" +seed = 0 +gpu_index = 0 +batch_size = 4 +gradient_accumulation = "1:step" +evaluation_interval = "250:step" +evaluation_seed = 1 + +[optimizer] +optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" +learning_rate = 5e-4 +betas = [0.9, 0.999] +eps = 1e-8 +weight_decay = 1e-2 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout_probability = 0 +use_gyro_dropout = false + +[dataset] +hf_repo = "acme/cat-toy" +revision = "main" +horizontal_flip = true +random_crop = true +resize_image_max_size = 512 + +[checkpointing] +# save_folder = "/path/to/ckpts" +save_interval = "250:step" + +[test_diffusion] +num_inference_steps = 30 +use_short_prompts = false +prompts = [ + "", + # "green grass, " +] diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py new file mode 100644 index 0000000..cf7d3f5 --- /dev/null +++ b/scripts/training/finetune-ldm-textual-inversion.py @@ -0,0 +1,164 @@ +from typing import Any +from pydantic import BaseModel +from loguru import logger +from torch.utils.data import Dataset +from torch import randn, Tensor +import random + +from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender +from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder +from refiners.foundationals.clip.tokenizer import CLIPTokenizer +from refiners.fluxion.utils import save_to_safetensors +from refiners.training_utils.callback import Callback +from refiners.training_utils.latent_diffusion import ( + FinetuneLatentDiffusionConfig, + TextEmbeddingLatentsBatch, + LatentDiffusionTrainer, + LatentDiffusionConfig, + TextEmbeddingLatentsDataset, +) + + +IMAGENET_TEMPLATES_SMALL = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +IMAGENET_STYLE_TEMPLATES_SMALL = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(TextEmbeddingLatentsDataset): + templates: list[str] = [] + placeholder_token: str = "" + + def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: + super().__init__(trainer) + self.templates = ( + IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL + ) + self.placeholder_token = self.config.textual_inversion.placeholder_token + + def get_caption(self, index: int) -> str: + # Ignore the dataset caption, if any: use a template instead + return random.choice(self.templates).format(self.placeholder_token) + + +class TextualInversionConfig(BaseModel): + # The new token to be learned + placeholder_token: str = "*" + # The token to be used as initializer; if None, a random vector is used + initializer_token: str | None = None + style_mode: bool = False + + def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None: + adapter = ConceptExtender(target=text_encoder) + tokenizer = text_encoder.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Tokenizer not found in text encoder." + token_encoder = text_encoder.find(layer_type=TokenEncoder) + assert token_encoder is not None, "Token encoder not found in text encoder." + if self.initializer_token is not None: + bpe = tokenizer.byte_pair_encoding(token=self.initializer_token) + assert " " not in bpe, "This initializer_token is not a single token." + token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device) + init_embedding = token_encoder(token).squeeze(0) + else: + token_encoder = text_encoder.find(layer_type=TokenEncoder) + assert token_encoder is not None, "Token encoder not found in text encoder." + init_embedding = randn(token_encoder.embedding_dim) + adapter.add_concept(self.placeholder_token, init_embedding) + adapter.inject() + + +class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig): + latent_diffusion: LatentDiffusionConfig + textual_inversion: TextualInversionConfig + + def model_post_init(self, __context: Any) -> None: + # Pydantic v2 does post init differently, so we need to override this method too. + logger.info("Freezing models to train only the new embedding.") + self.models["unet"].train = False + self.models["text_encoder"].train = False + self.models["lda"].train = False + + +class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]): + def __init__( + self, + config: TextualInversionLatentDiffusionConfig, + callbacks: "list[Callback[Any]] | None" = None, + ) -> None: + super().__init__(config=config, callbacks=callbacks) + self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion())) + + def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: + return TextualInversionDataset(trainer=self) + + +class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): + def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: + trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder) + + +class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): + def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: + embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender) + assert embedding_extender is not None, "Embedding extender not found in text encoder." + tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)} + + save_to_safetensors( + path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors + ) + + +if __name__ == "__main__": + import sys + + config_path = sys.argv[1] + config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path) + trainer = TextualInversionLatentDiffusionTrainer(config=config) + trainer.train() diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index c178c7d..b33b789 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -71,7 +71,6 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): old_weight: Parameter new_weight: Parameter - weight: Tensor def __init__( self, @@ -83,22 +82,21 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): self.new_weight = Parameter( zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) ) # requires_grad=True by default - self.weight = cat([self.old_weight, self.new_weight]) # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings def lookup(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) + # Concatenate old and new weights for dynamic embedding updates during training + return F.embedding(x, cat([self.old_weight, self.new_weight])) def add_embedding(self, embedding: Tensor) -> None: assert embedding.shape == (self.old_weight.shape[1],) self.new_weight = Parameter( cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]) ) - self.weight = cat([self.old_weight, self.new_weight]) @property def num_embeddings(self) -> int: - return self.weight.shape[0] + return self.old_weight.shape[0] + self.new_weight.shape[0] class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]): @@ -115,12 +113,13 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]): ) def add_token(self, token: str, token_id: int) -> None: + token = token.lower() tokenizer = self.find(layer_type=CLIPTokenizer) assert tokenizer is not None, "Tokenizer not found." assert token_id not in tokenizer.token_to_id_mapping.values() tokenizer.token_to_id_mapping[token] = token_id current_pattern = tokenizer.token_pattern.pattern - new_pattern = token + "|" + current_pattern + new_pattern = re.escape(token) + "|" + current_pattern tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE) # Define the keyword as its own smallest subtoken tokenizer.byte_pair_encoding_cache[token] = token diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 2195e7c..2caca6f 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -95,9 +95,15 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): def process_caption(self, caption: str) -> str: return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else "" + def get_caption(self, index: int) -> str: + return self.dataset[index]["caption"] + + def get_image(self, index: int) -> Image.Image: + return self.dataset[index]["image"] + def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: - item = self.dataset[index] - caption, image = item["caption"], item["image"] + caption = self.get_caption(index=index) + image = self.get_image(index=index) resized_image = self.resize_image( image=image, min_size=self.config.dataset.resize_image_min_size, diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index aed634d..0ed4f06 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -4,10 +4,11 @@ import pytest from warnings import warn from pathlib import Path -from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.fluxion.utils import load_from_safetensors +import refiners.fluxion.layers as fl from diffusers import StableDiffusionPipeline # type: ignore import transformers # type: ignore @@ -84,6 +85,28 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch. return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")[""] # type: ignore +def test_tokenizer_with_special_character(): + clip_tokenizer = fl.Chain(CLIPTokenizer()) + token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer) + new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42 + token_extender.add_token("*", new_token_id) + token_extender.inject(clip_tokenizer) + + adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer) + assert adapted_clip_tokenizer is not None + + assert torch.allclose( + adapted_clip_tokenizer.encode("*"), + torch.Tensor( + [ + adapted_clip_tokenizer.start_of_text_token_id, + new_token_id, + adapted_clip_tokenizer.end_of_text_token_id, + ] + ).to(torch.int64), + ) + + def test_encoder( prompt: str, ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,