mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
cosmetic changes
This commit is contained in:
parent
17dc75421b
commit
efe923a272
|
@ -25,7 +25,7 @@ class PositionalTokenEncoder(Sum):
|
|||
positional_embedding_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.vocabulary_size = vocabulary_size
|
||||
self.positional_embedding_dim = positional_embedding_dim
|
||||
super().__init__(
|
||||
|
@ -36,7 +36,7 @@ class PositionalTokenEncoder(Sum):
|
|||
dtype=dtype,
|
||||
),
|
||||
Chain(
|
||||
Lambda(self.get_position_ids),
|
||||
Lambda(func=self.get_position_ids),
|
||||
Embedding(
|
||||
num_embeddings=positional_embedding_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
|
@ -48,7 +48,7 @@ class PositionalTokenEncoder(Sum):
|
|||
|
||||
@property
|
||||
def position_ids(self) -> Tensor:
|
||||
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
||||
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
||||
|
||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||
return self.position_ids[:, : x.shape[1]]
|
||||
|
@ -145,7 +145,7 @@ class CLIPTextEncoder(Chain):
|
|||
layer_norm_eps: float = 1e-5,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.positional_embedding_dim = positional_embedding_dim
|
||||
self.vocabulary_size = vocabulary_size
|
||||
|
@ -177,12 +177,12 @@ class CLIPTextEncoder(Chain):
|
|||
)
|
||||
|
||||
def encode(self, text: str) -> Tensor:
|
||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device)
|
||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||
return self(tokens)
|
||||
|
||||
@property
|
||||
def unconditional_text_embedding(self) -> Tensor:
|
||||
return self.encode("")
|
||||
return self.encode(text="")
|
||||
|
||||
|
||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||
|
@ -206,7 +206,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)):
|
||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
||||
|
||||
|
||||
|
|
|
@ -11,13 +11,16 @@ class CLIPTokenizer:
|
|||
def __init__(
|
||||
self,
|
||||
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||
):
|
||||
) -> None:
|
||||
self.vocabulary_path = vocabulary_path
|
||||
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||
merge_tuples = [
|
||||
tuple(merge.split())
|
||||
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
for merge in gzip.open(filename=vocabulary_path)
|
||||
.read()
|
||||
.decode(encoding="utf-8")
|
||||
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
|
||||
]
|
||||
vocabulary = (
|
||||
list(self.byte_to_unicode_mapping.values())
|
||||
|
@ -25,25 +28,25 @@ class CLIPTokenizer:
|
|||
+ ["".join(merge) for merge in merge_tuples]
|
||||
+ ["", ""]
|
||||
)
|
||||
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)}
|
||||
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)}
|
||||
self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
|
||||
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
|
||||
self.byte_pair_encoding_cache = {"": ""}
|
||||
# Note: this regular expression does not support Unicode. It was changed so
|
||||
# to get rid of the dependence on the `regex` module. Unicode support could
|
||||
# potentially be added back by leveraging the `\w` character class.
|
||||
self.token_pattern = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||
re.IGNORECASE,
|
||||
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
self.start_of_text_token_id: int = 49406
|
||||
self.end_of_text_token_id: int = 49407
|
||||
|
||||
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0)
|
||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
||||
assert (
|
||||
tokens.shape[1] <= sequence_length
|
||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
||||
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
||||
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
||||
|
||||
@lru_cache()
|
||||
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||
|
@ -83,13 +86,13 @@ class CLIPTokenizer:
|
|||
return recursive_bpe(tuple(new_word))
|
||||
|
||||
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||
result = " ".join(recursive_bpe(word))
|
||||
result = " ".join(recursive_bpe(word=word))
|
||||
self.byte_pair_encoding_cache[token] = result
|
||||
return result
|
||||
|
||||
def encode(self, text: str, max_length: int | None = None) -> Tensor:
|
||||
text = re.sub(r"\s+", " ", text.lower())
|
||||
tokens = re.findall(self.token_pattern, text)
|
||||
text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
|
||||
tokens = re.findall(pattern=self.token_pattern, string=text)
|
||||
upper_bound = None
|
||||
if max_length:
|
||||
assert max_length >= 2
|
||||
|
@ -99,10 +102,10 @@ class CLIPTokenizer:
|
|||
self.token_to_id_mapping[subtoken]
|
||||
for token in tokens
|
||||
for subtoken in self.byte_pair_encoding(
|
||||
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
||||
).split(" ")
|
||||
token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
||||
).split(sep=" ")
|
||||
),
|
||||
0,
|
||||
upper_bound,
|
||||
)
|
||||
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
||||
return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
||||
|
|
Loading…
Reference in a new issue