cosmetic changes

This commit is contained in:
limiteinductive 2023-08-16 13:18:23 +02:00 committed by Benjamin Trom
parent 17dc75421b
commit efe923a272
2 changed files with 24 additions and 21 deletions

View file

@ -25,7 +25,7 @@ class PositionalTokenEncoder(Sum):
positional_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
) -> None:
self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim
super().__init__(
@ -36,7 +36,7 @@ class PositionalTokenEncoder(Sum):
dtype=dtype,
),
Chain(
Lambda(self.get_position_ids),
Lambda(func=self.get_position_ids),
Embedding(
num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim,
@ -48,7 +48,7 @@ class PositionalTokenEncoder(Sum):
@property
def position_ids(self) -> Tensor:
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1)
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
@ -145,7 +145,7 @@ class CLIPTextEncoder(Chain):
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
):
) -> None:
self.embedding_dim = embedding_dim
self.positional_embedding_dim = positional_embedding_dim
self.vocabulary_size = vocabulary_size
@ -177,12 +177,12 @@ class CLIPTextEncoder(Chain):
)
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device)
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
return self(tokens)
@property
def unconditional_text_embedding(self) -> Tensor:
return self.encode("")
return self.encode(text="")
class CLIPTextEncoderL(CLIPTextEncoder):
@ -206,7 +206,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
device=device,
dtype=dtype,
)
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)):
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())

View file

@ -11,13 +11,16 @@ class CLIPTokenizer:
def __init__(
self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
):
) -> None:
self.vocabulary_path = vocabulary_path
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [
tuple(merge.split())
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1]
for merge in gzip.open(filename=vocabulary_path)
.read()
.decode(encoding="utf-8")
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
]
vocabulary = (
list(self.byte_to_unicode_mapping.values())
@ -25,25 +28,25 @@ class CLIPTokenizer:
+ ["".join(merge) for merge in merge_tuples]
+ ["", ""]
)
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)}
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)}
self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
self.byte_pair_encoding_cache = {"": ""}
# Note: this regular expression does not support Unicode. It was changed so
# to get rid of the dependence on the `regex` module. Unicode support could
# potentially be added back by leveraging the `\w` character class.
self.token_pattern = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
re.IGNORECASE,
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
flags=re.IGNORECASE,
)
self.start_of_text_token_id: int = 49406
self.end_of_text_token_id: int = 49407
def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0)
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
assert (
tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
@lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
@ -83,13 +86,13 @@ class CLIPTokenizer:
return recursive_bpe(tuple(new_word))
word = tuple(token[:-1]) + (token[-1] + "</w>",)
result = " ".join(recursive_bpe(word))
result = " ".join(recursive_bpe(word=word))
self.byte_pair_encoding_cache[token] = result
return result
def encode(self, text: str, max_length: int | None = None) -> Tensor:
text = re.sub(r"\s+", " ", text.lower())
tokens = re.findall(self.token_pattern, text)
text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
tokens = re.findall(pattern=self.token_pattern, string=text)
upper_bound = None
if max_length:
assert max_length >= 2
@ -99,10 +102,10 @@ class CLIPTokenizer:
self.token_to_id_mapping[subtoken]
for token in tokens
for subtoken in self.byte_pair_encoding(
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
).split(" ")
token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
).split(sep=" ")
),
0,
upper_bound,
)
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])