mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
turn CLIPTokenizer into a fl.Module
This commit is contained in:
parent
1ad4e1a35a
commit
e7c1db50e0
|
@ -6,13 +6,16 @@ from diffusers import DiffusionPipeline # type: ignore
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = CLIPTextEncoderL()
|
dst_model = CLIPTextEncoderL()
|
||||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
assert tokenizer is not None, "Could not find tokenizer"
|
||||||
|
tokens = tokenizer("Nice cat")
|
||||||
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||||
assert mapping is not None, "Model conversion failed"
|
assert mapping is not None, "Model conversion failed"
|
||||||
state_dict = convert_state_dict(
|
state_dict = convert_state_dict(
|
||||||
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
|
|
@ -4,6 +4,7 @@ from refiners.fluxion.utils import (
|
||||||
save_to_safetensors,
|
save_to_safetensors,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
from refiners.fluxion.layers.module import Module
|
from refiners.fluxion.layers.module import Module
|
||||||
|
@ -33,9 +34,10 @@ def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dic
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
||||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None, "Could not find tokenizer"
|
||||||
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
tokens = tokenizer("Nice cat")
|
||||||
|
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|
|
@ -6,6 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
@ -15,8 +16,10 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = CLIPTextEncoderG()
|
dst_model = CLIPTextEncoderG()
|
||||||
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
|
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
|
||||||
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
|
||||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
assert tokenizer is not None, "Could not find tokenizer"
|
||||||
|
tokens = tokenizer("Nice cat")
|
||||||
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
|
||||||
if mapping is None:
|
if mapping is None:
|
||||||
raise RuntimeError("Could not create state dict mapping")
|
raise RuntimeError("Could not create state dict mapping")
|
||||||
state_dict = convert_state_dict(
|
state_dict = convert_state_dict(
|
||||||
|
|
|
@ -131,7 +131,6 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
"feedforward_dim",
|
"feedforward_dim",
|
||||||
"layer_norm_eps",
|
"layer_norm_eps",
|
||||||
"use_quick_gelu",
|
"use_quick_gelu",
|
||||||
"tokenizer",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -156,8 +155,9 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.use_quick_gelu = use_quick_gelu
|
self.use_quick_gelu = use_quick_gelu
|
||||||
self.tokenizer = tokenizer or CLIPTokenizer()
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
|
||||||
|
fl.Converter(set_dtype=False),
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
TokenEncoder(
|
TokenEncoder(
|
||||||
vocabulary_size=vocabulary_size,
|
vocabulary_size=vocabulary_size,
|
||||||
|
@ -189,13 +189,9 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||||
|
|
||||||
def encode(self, text: str) -> Tensor:
|
|
||||||
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
|
|
||||||
return self(tokens)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unconditional_text_embedding(self) -> Tensor:
|
def unconditional_text_embedding(self) -> Tensor:
|
||||||
return self.encode(text="")
|
return self("")
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
|
|
|
@ -5,17 +5,21 @@ from itertools import islice
|
||||||
import re
|
import re
|
||||||
from torch import Tensor, tensor
|
from torch import Tensor, tensor
|
||||||
from refiners.fluxion import pad
|
from refiners.fluxion import pad
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
class CLIPTokenizer:
|
class CLIPTokenizer(fl.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
sequence_length: int = 77,
|
||||||
start_of_text_token_id: int = 49406,
|
start_of_text_token_id: int = 49406,
|
||||||
end_of_text_token_id: int = 49407,
|
end_of_text_token_id: int = 49407,
|
||||||
pad_token_id: int = 49407,
|
pad_token_id: int = 49407,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.vocabulary_path = vocabulary_path
|
self.vocabulary_path = vocabulary_path
|
||||||
|
self.sequence_length = sequence_length
|
||||||
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||||
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||||
merge_tuples = [
|
merge_tuples = [
|
||||||
|
@ -45,12 +49,12 @@ class CLIPTokenizer:
|
||||||
self.end_of_text_token_id: int = end_of_text_token_id
|
self.end_of_text_token_id: int = end_of_text_token_id
|
||||||
self.pad_token_id: int = pad_token_id
|
self.pad_token_id: int = pad_token_id
|
||||||
|
|
||||||
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
def forward(self, text: str) -> Tensor:
|
||||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
|
||||||
assert (
|
assert (
|
||||||
tokens.shape[1] <= sequence_length
|
tokens.shape[1] <= self.sequence_length
|
||||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
|
||||||
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||||
|
|
|
@ -88,7 +88,7 @@ class LatentDiffusionModel(Module):
|
||||||
return self.clip_text_encoder.unconditional_text_embedding
|
return self.clip_text_encoder.unconditional_text_embedding
|
||||||
|
|
||||||
def compute_text_embedding(self, text: str) -> Tensor:
|
def compute_text_embedding(self, text: str) -> Tensor:
|
||||||
return self.clip_text_encoder.encode(text)
|
return self.clip_text_encoder(text)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -89,7 +89,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
processed_image: Image.Image = self.process_image(resized_image)
|
processed_image: Image.Image = self.process_image(resized_image)
|
||||||
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||||
processed_caption = self.process_caption(caption=caption)
|
processed_caption = self.process_caption(caption=caption)
|
||||||
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device)
|
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
||||||
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
|
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
|
||||||
|
|
||||||
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
|
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
long_prompt = """
|
long_prompt = """
|
||||||
|
@ -86,12 +87,14 @@ def test_encoder(
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
assert isinstance(ref_tokens, torch.Tensor)
|
assert isinstance(ref_tokens, torch.Tensor)
|
||||||
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
|
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
|
||||||
|
assert tokenizer is not None
|
||||||
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
||||||
our_embeddings = our_encoder(our_tokens.to(test_device))
|
our_embeddings = our_encoder(prompt)
|
||||||
|
|
||||||
assert ref_embeddings.shape == (1, 77, 768)
|
assert ref_embeddings.shape == (1, 77, 768)
|
||||||
assert our_embeddings.shape == (1, 77, 768)
|
assert our_embeddings.shape == (1, 77, 768)
|
||||||
|
|
Loading…
Reference in a new issue