mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
turn CLIPTokenizer into a fl.Module
This commit is contained in:
parent
1ad4e1a35a
commit
e7c1db50e0
|
@ -6,13 +6,16 @@ from diffusers import DiffusionPipeline # type: ignore
|
|||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||
dst_model = CLIPTextEncoderL()
|
||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Could not find tokenizer"
|
||||
tokens = tokenizer("Nice cat")
|
||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||
assert mapping is not None, "Model conversion failed"
|
||||
state_dict = convert_state_dict(
|
||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||
|
|
|
@ -4,6 +4,7 @@ from refiners.fluxion.utils import (
|
|||
save_to_safetensors,
|
||||
)
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||
from refiners.fluxion.layers.module import Module
|
||||
|
@ -33,9 +34,10 @@ def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dic
|
|||
|
||||
@torch.no_grad()
|
||||
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||
|
||||
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Could not find tokenizer"
|
||||
tokens = tokenizer("Nice cat")
|
||||
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
@ -6,6 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|||
from diffusers import DiffusionPipeline # type: ignore
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
@ -15,8 +16,10 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
|||
dst_model = CLIPTextEncoderG()
|
||||
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
|
||||
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Could not find tokenizer"
|
||||
tokens = tokenizer("Nice cat")
|
||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||
if mapping is None:
|
||||
raise RuntimeError("Could not create state dict mapping")
|
||||
state_dict = convert_state_dict(
|
||||
|
|
|
@ -131,7 +131,6 @@ class CLIPTextEncoder(fl.Chain):
|
|||
"feedforward_dim",
|
||||
"layer_norm_eps",
|
||||
"use_quick_gelu",
|
||||
"tokenizer",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
@ -156,8 +155,9 @@ class CLIPTextEncoder(fl.Chain):
|
|||
self.feedforward_dim = feedforward_dim
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
self.tokenizer = tokenizer or CLIPTokenizer()
|
||||
super().__init__(
|
||||
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
|
||||
fl.Converter(set_dtype=False),
|
||||
fl.Sum(
|
||||
TokenEncoder(
|
||||
vocabulary_size=vocabulary_size,
|
||||
|
@ -189,13 +189,9 @@ class CLIPTextEncoder(fl.Chain):
|
|||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||
|
||||
def encode(self, text: str) -> Tensor:
|
||||
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
|
||||
return self(tokens)
|
||||
|
||||
@property
|
||||
def unconditional_text_embedding(self) -> Tensor:
|
||||
return self.encode(text="")
|
||||
return self("")
|
||||
|
||||
|
||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||
|
|
|
@ -5,17 +5,21 @@ from itertools import islice
|
|||
import re
|
||||
from torch import Tensor, tensor
|
||||
from refiners.fluxion import pad
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class CLIPTokenizer:
|
||||
class CLIPTokenizer(fl.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||
sequence_length: int = 77,
|
||||
start_of_text_token_id: int = 49406,
|
||||
end_of_text_token_id: int = 49407,
|
||||
pad_token_id: int = 49407,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocabulary_path = vocabulary_path
|
||||
self.sequence_length = sequence_length
|
||||
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||
merge_tuples = [
|
||||
|
@ -45,12 +49,12 @@ class CLIPTokenizer:
|
|||
self.end_of_text_token_id: int = end_of_text_token_id
|
||||
self.pad_token_id: int = pad_token_id
|
||||
|
||||
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
||||
def forward(self, text: str) -> Tensor:
|
||||
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
|
||||
assert (
|
||||
tokens.shape[1] <= sequence_length
|
||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
||||
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||
tokens.shape[1] <= self.sequence_length
|
||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
|
||||
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||
|
||||
@lru_cache()
|
||||
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||
|
|
|
@ -88,7 +88,7 @@ class LatentDiffusionModel(Module):
|
|||
return self.clip_text_encoder.unconditional_text_embedding
|
||||
|
||||
def compute_text_embedding(self, text: str) -> Tensor:
|
||||
return self.clip_text_encoder.encode(text)
|
||||
return self.clip_text_encoder(text)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -89,7 +89,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
processed_image: Image.Image = self.process_image(resized_image)
|
||||
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||
processed_caption = self.process_caption(caption=caption)
|
||||
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device)
|
||||
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
||||
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
|
||||
|
||||
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
|
||||
|
|
|
@ -8,6 +8,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|||
from refiners.fluxion.utils import load_from_safetensors
|
||||
|
||||
import transformers # type: ignore
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
||||
|
||||
long_prompt = """
|
||||
|
@ -86,12 +87,14 @@ def test_encoder(
|
|||
return_tensors="pt",
|
||||
).input_ids
|
||||
assert isinstance(ref_tokens, torch.Tensor)
|
||||
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
|
||||
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None
|
||||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
||||
our_embeddings = our_encoder(our_tokens.to(test_device))
|
||||
our_embeddings = our_encoder(prompt)
|
||||
|
||||
assert ref_embeddings.shape == (1, 77, 768)
|
||||
assert our_embeddings.shape == (1, 77, 768)
|
||||
|
|
Loading…
Reference in a new issue