cosmetic changes

This commit is contained in:
limiteinductive 2023-08-16 13:18:23 +02:00 committed by Benjamin Trom
parent 17dc75421b
commit efe923a272
2 changed files with 24 additions and 21 deletions

View file

@ -25,7 +25,7 @@ class PositionalTokenEncoder(Sum):
positional_embedding_dim: int, positional_embedding_dim: int,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
): ) -> None:
self.vocabulary_size = vocabulary_size self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim self.positional_embedding_dim = positional_embedding_dim
super().__init__( super().__init__(
@ -36,7 +36,7 @@ class PositionalTokenEncoder(Sum):
dtype=dtype, dtype=dtype,
), ),
Chain( Chain(
Lambda(self.get_position_ids), Lambda(func=self.get_position_ids),
Embedding( Embedding(
num_embeddings=positional_embedding_dim, num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
@ -48,7 +48,7 @@ class PositionalTokenEncoder(Sum):
@property @property
def position_ids(self) -> Tensor: def position_ids(self) -> Tensor:
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1) return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor: def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]] return self.position_ids[:, : x.shape[1]]
@ -145,7 +145,7 @@ class CLIPTextEncoder(Chain):
layer_norm_eps: float = 1e-5, layer_norm_eps: float = 1e-5,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
): ) -> None:
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.positional_embedding_dim = positional_embedding_dim self.positional_embedding_dim = positional_embedding_dim
self.vocabulary_size = vocabulary_size self.vocabulary_size = vocabulary_size
@ -177,12 +177,12 @@ class CLIPTextEncoder(Chain):
) )
def encode(self, text: str) -> Tensor: def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device) tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
return self(tokens) return self(tokens)
@property @property
def unconditional_text_embedding(self) -> Tensor: def unconditional_text_embedding(self) -> Tensor:
return self.encode("") return self.encode(text="")
class CLIPTextEncoderL(CLIPTextEncoder): class CLIPTextEncoderL(CLIPTextEncoder):
@ -206,7 +206,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)): for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU()) parent.replace(old_module=gelu, new_module=ApproximateGeLU())

View file

@ -11,13 +11,16 @@ class CLIPTokenizer:
def __init__( def __init__(
self, self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
): ) -> None:
self.vocabulary_path = vocabulary_path self.vocabulary_path = vocabulary_path
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()} self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [ merge_tuples = [
tuple(merge.split()) tuple(merge.split())
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1] for merge in gzip.open(filename=vocabulary_path)
.read()
.decode(encoding="utf-8")
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
] ]
vocabulary = ( vocabulary = (
list(self.byte_to_unicode_mapping.values()) list(self.byte_to_unicode_mapping.values())
@ -25,25 +28,25 @@ class CLIPTokenizer:
+ ["".join(merge) for merge in merge_tuples] + ["".join(merge) for merge in merge_tuples]
+ ["", ""] + ["", ""]
) )
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)} self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)} self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
self.byte_pair_encoding_cache = {"": ""} self.byte_pair_encoding_cache = {"": ""}
# Note: this regular expression does not support Unicode. It was changed so # Note: this regular expression does not support Unicode. It was changed so
# to get rid of the dependence on the `regex` module. Unicode support could # to get rid of the dependence on the `regex` module. Unicode support could
# potentially be added back by leveraging the `\w` character class. # potentially be added back by leveraging the `\w` character class.
self.token_pattern = re.compile( self.token_pattern = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""", pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
re.IGNORECASE, flags=re.IGNORECASE,
) )
self.start_of_text_token_id: int = 49406 self.start_of_text_token_id: int = 49406
self.end_of_text_token_id: int = 49407 self.end_of_text_token_id: int = 49407
def __call__(self, text: str, sequence_length: int) -> Tensor: def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0) tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
assert ( assert (
tokens.shape[1] <= sequence_length tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id) return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
@lru_cache() @lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]: def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
@ -83,13 +86,13 @@ class CLIPTokenizer:
return recursive_bpe(tuple(new_word)) return recursive_bpe(tuple(new_word))
word = tuple(token[:-1]) + (token[-1] + "</w>",) word = tuple(token[:-1]) + (token[-1] + "</w>",)
result = " ".join(recursive_bpe(word)) result = " ".join(recursive_bpe(word=word))
self.byte_pair_encoding_cache[token] = result self.byte_pair_encoding_cache[token] = result
return result return result
def encode(self, text: str, max_length: int | None = None) -> Tensor: def encode(self, text: str, max_length: int | None = None) -> Tensor:
text = re.sub(r"\s+", " ", text.lower()) text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
tokens = re.findall(self.token_pattern, text) tokens = re.findall(pattern=self.token_pattern, string=text)
upper_bound = None upper_bound = None
if max_length: if max_length:
assert max_length >= 2 assert max_length >= 2
@ -99,10 +102,10 @@ class CLIPTokenizer:
self.token_to_id_mapping[subtoken] self.token_to_id_mapping[subtoken]
for token in tokens for token in tokens
for subtoken in self.byte_pair_encoding( for subtoken in self.byte_pair_encoding(
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8")) token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
).split(" ") ).split(sep=" ")
), ),
0, 0,
upper_bound, upper_bound,
) )
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id]) return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])