From e7c1db50e0764b60853973cc539e12b366803714 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 17 Aug 2023 11:00:47 +0200 Subject: [PATCH] turn CLIPTokenizer into a fl.Module --- scripts/convert-clip-weights.py | 7 +++++-- scripts/convert-loras-to-sdwebui.py | 8 +++++--- scripts/convert-sdxl-text-encoder-2.py | 7 +++++-- src/refiners/foundationals/clip/text_encoder.py | 10 +++------- src/refiners/foundationals/clip/tokenizer.py | 16 ++++++++++------ .../foundationals/latent_diffusion/__init__.py | 2 +- src/refiners/training_utils/latent_diffusion.py | 2 +- tests/foundationals/clip/test_text_encoder.py | 9 ++++++--- 8 files changed, 36 insertions(+), 25 deletions(-) diff --git a/scripts/convert-clip-weights.py b/scripts/convert-clip-weights.py index 72eb760..3e52e4a 100644 --- a/scripts/convert-clip-weights.py +++ b/scripts/convert-clip-weights.py @@ -6,13 +6,16 @@ from diffusers import DiffusionPipeline # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.clip.tokenizer import CLIPTokenizer @torch.no_grad() def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: dst_model = CLIPTextEncoderL() - x = dst_model.tokenizer("Nice cat", sequence_length=77) - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + tokenizer = dst_model.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Could not find tokenizer" + tokens = tokenizer("Nice cat") + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore assert mapping is not None, "Model conversion failed" state_dict = convert_state_dict( source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping diff --git a/scripts/convert-loras-to-sdwebui.py b/scripts/convert-loras-to-sdwebui.py index 931c10b..0b9a3f7 100644 --- a/scripts/convert-loras-to-sdwebui.py +++ b/scripts/convert-loras-to-sdwebui.py @@ -4,6 +4,7 @@ from refiners.fluxion.utils import ( save_to_safetensors, ) from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget from refiners.fluxion.layers.module import Module @@ -33,9 +34,10 @@ def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dic @torch.no_grad() def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None: - x = dst_model.tokenizer("Nice cat", sequence_length=77) - - return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + tokenizer = dst_model.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Could not find tokenizer" + tokens = tokenizer("Nice cat") + return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore def main() -> None: diff --git a/scripts/convert-sdxl-text-encoder-2.py b/scripts/convert-sdxl-text-encoder-2.py index ad67020..c1cee00 100644 --- a/scripts/convert-sdxl-text-encoder-2.py +++ b/scripts/convert-sdxl-text-encoder-2.py @@ -6,6 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict from diffusers import DiffusionPipeline # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore +from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG import refiners.fluxion.layers as fl @@ -15,8 +16,10 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: dst_model = CLIPTextEncoderG() # Extra projection layer (see CLIPTextModelWithProjection in transformers) dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False)) - x = dst_model.tokenizer("Nice cat", sequence_length=77) - mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + tokenizer = dst_model.find(layer_type=CLIPTokenizer) + assert tokenizer is not None, "Could not find tokenizer" + tokens = tokenizer("Nice cat") + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore if mapping is None: raise RuntimeError("Could not create state dict mapping") state_dict = convert_state_dict( diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 6683959..6a85101 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -131,7 +131,6 @@ class CLIPTextEncoder(fl.Chain): "feedforward_dim", "layer_norm_eps", "use_quick_gelu", - "tokenizer", ] def __init__( @@ -156,8 +155,9 @@ class CLIPTextEncoder(fl.Chain): self.feedforward_dim = feedforward_dim self.layer_norm_eps = layer_norm_eps self.use_quick_gelu = use_quick_gelu - self.tokenizer = tokenizer or CLIPTokenizer() super().__init__( + tokenizer or CLIPTokenizer(sequence_length=max_sequence_length), + fl.Converter(set_dtype=False), fl.Sum( TokenEncoder( vocabulary_size=vocabulary_size, @@ -189,13 +189,9 @@ class CLIPTextEncoder(fl.Chain): for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)): parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU()) - def encode(self, text: str) -> Tensor: - tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device) - return self(tokens) - @property def unconditional_text_embedding(self) -> Tensor: - return self.encode(text="") + return self("") class CLIPTextEncoderL(CLIPTextEncoder): diff --git a/src/refiners/foundationals/clip/tokenizer.py b/src/refiners/foundationals/clip/tokenizer.py index a160ccc..6f06531 100644 --- a/src/refiners/foundationals/clip/tokenizer.py +++ b/src/refiners/foundationals/clip/tokenizer.py @@ -5,17 +5,21 @@ from itertools import islice import re from torch import Tensor, tensor from refiners.fluxion import pad +import refiners.fluxion.layers as fl -class CLIPTokenizer: +class CLIPTokenizer(fl.Module): def __init__( self, vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", + sequence_length: int = 77, start_of_text_token_id: int = 49406, end_of_text_token_id: int = 49407, pad_token_id: int = 49407, ) -> None: + super().__init__() self.vocabulary_path = vocabulary_path + self.sequence_length = sequence_length self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()} merge_tuples = [ @@ -45,12 +49,12 @@ class CLIPTokenizer: self.end_of_text_token_id: int = end_of_text_token_id self.pad_token_id: int = pad_token_id - def __call__(self, text: str, sequence_length: int) -> Tensor: - tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0) + def forward(self, text: str) -> Tensor: + tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0) assert ( - tokens.shape[1] <= sequence_length - ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" - return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id) + tokens.shape[1] <= self.sequence_length + ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}" + return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id) @lru_cache() def get_bytes_to_unicode_mapping(self) -> dict[int, str]: diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index 30e19fa..b46c094 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -88,7 +88,7 @@ class LatentDiffusionModel(Module): return self.clip_text_encoder.unconditional_text_embedding def compute_text_embedding(self, text: str) -> Tensor: - return self.clip_text_encoder.encode(text) + return self.clip_text_encoder(text) def forward( self, diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 253ff3e..8967d16 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -89,7 +89,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): processed_image: Image.Image = self.process_image(resized_image) latents = self.lda.encode_image(image=processed_image).to(device=self.device) processed_caption = self.process_caption(caption=caption) - clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device) + clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch: diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 8cb6abc..75d2fe4 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -7,7 +7,8 @@ from pathlib import Path from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.fluxion.utils import load_from_safetensors -import transformers # type: ignore +import transformers # type: ignore +from refiners.foundationals.clip.tokenizer import CLIPTokenizer long_prompt = """ @@ -86,12 +87,14 @@ def test_encoder( return_tensors="pt", ).input_ids assert isinstance(ref_tokens, torch.Tensor) - our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length) + tokenizer = our_encoder.find(layer_type=CLIPTokenizer) + assert tokenizer is not None + our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) with torch.no_grad(): ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] - our_embeddings = our_encoder(our_tokens.to(test_device)) + our_embeddings = our_encoder(prompt) assert ref_embeddings.shape == (1, 77, 768) assert our_embeddings.shape == (1, 77, 768)