mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
cosmetic changes
This commit is contained in:
parent
17dc75421b
commit
efe923a272
|
@ -25,7 +25,7 @@ class PositionalTokenEncoder(Sum):
|
||||||
positional_embedding_dim: int,
|
positional_embedding_dim: int,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
):
|
) -> None:
|
||||||
self.vocabulary_size = vocabulary_size
|
self.vocabulary_size = vocabulary_size
|
||||||
self.positional_embedding_dim = positional_embedding_dim
|
self.positional_embedding_dim = positional_embedding_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -36,7 +36,7 @@ class PositionalTokenEncoder(Sum):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Chain(
|
Chain(
|
||||||
Lambda(self.get_position_ids),
|
Lambda(func=self.get_position_ids),
|
||||||
Embedding(
|
Embedding(
|
||||||
num_embeddings=positional_embedding_dim,
|
num_embeddings=positional_embedding_dim,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
@ -48,7 +48,7 @@ class PositionalTokenEncoder(Sum):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position_ids(self) -> Tensor:
|
def position_ids(self) -> Tensor:
|
||||||
return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
return self.position_ids[:, : x.shape[1]]
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
@ -145,7 +145,7 @@ class CLIPTextEncoder(Chain):
|
||||||
layer_norm_eps: float = 1e-5,
|
layer_norm_eps: float = 1e-5,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
):
|
) -> None:
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.positional_embedding_dim = positional_embedding_dim
|
self.positional_embedding_dim = positional_embedding_dim
|
||||||
self.vocabulary_size = vocabulary_size
|
self.vocabulary_size = vocabulary_size
|
||||||
|
@ -177,12 +177,12 @@ class CLIPTextEncoder(Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, text: str) -> Tensor:
|
def encode(self, text: str) -> Tensor:
|
||||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device)
|
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unconditional_text_embedding(self) -> Tensor:
|
def unconditional_text_embedding(self) -> Tensor:
|
||||||
return self.encode("")
|
return self.encode(text="")
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
|
@ -206,7 +206,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)):
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
||||||
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,13 +11,16 @@ class CLIPTokenizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||||
):
|
) -> None:
|
||||||
self.vocabulary_path = vocabulary_path
|
self.vocabulary_path = vocabulary_path
|
||||||
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||||
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||||
merge_tuples = [
|
merge_tuples = [
|
||||||
tuple(merge.split())
|
tuple(merge.split())
|
||||||
for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1]
|
for merge in gzip.open(filename=vocabulary_path)
|
||||||
|
.read()
|
||||||
|
.decode(encoding="utf-8")
|
||||||
|
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
]
|
]
|
||||||
vocabulary = (
|
vocabulary = (
|
||||||
list(self.byte_to_unicode_mapping.values())
|
list(self.byte_to_unicode_mapping.values())
|
||||||
|
@ -25,25 +28,25 @@ class CLIPTokenizer:
|
||||||
+ ["".join(merge) for merge in merge_tuples]
|
+ ["".join(merge) for merge in merge_tuples]
|
||||||
+ ["", ""]
|
+ ["", ""]
|
||||||
)
|
)
|
||||||
self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)}
|
self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
|
||||||
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)}
|
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
|
||||||
self.byte_pair_encoding_cache = {"": ""}
|
self.byte_pair_encoding_cache = {"": ""}
|
||||||
# Note: this regular expression does not support Unicode. It was changed so
|
# Note: this regular expression does not support Unicode. It was changed so
|
||||||
# to get rid of the dependence on the `regex` module. Unicode support could
|
# to get rid of the dependence on the `regex` module. Unicode support could
|
||||||
# potentially be added back by leveraging the `\w` character class.
|
# potentially be added back by leveraging the `\w` character class.
|
||||||
self.token_pattern = re.compile(
|
self.token_pattern = re.compile(
|
||||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||||
re.IGNORECASE,
|
flags=re.IGNORECASE,
|
||||||
)
|
)
|
||||||
self.start_of_text_token_id: int = 49406
|
self.start_of_text_token_id: int = 49406
|
||||||
self.end_of_text_token_id: int = 49407
|
self.end_of_text_token_id: int = 49407
|
||||||
|
|
||||||
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
||||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0)
|
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
||||||
assert (
|
assert (
|
||||||
tokens.shape[1] <= sequence_length
|
tokens.shape[1] <= sequence_length
|
||||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
||||||
return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||||
|
@ -83,13 +86,13 @@ class CLIPTokenizer:
|
||||||
return recursive_bpe(tuple(new_word))
|
return recursive_bpe(tuple(new_word))
|
||||||
|
|
||||||
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||||
result = " ".join(recursive_bpe(word))
|
result = " ".join(recursive_bpe(word=word))
|
||||||
self.byte_pair_encoding_cache[token] = result
|
self.byte_pair_encoding_cache[token] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def encode(self, text: str, max_length: int | None = None) -> Tensor:
|
def encode(self, text: str, max_length: int | None = None) -> Tensor:
|
||||||
text = re.sub(r"\s+", " ", text.lower())
|
text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
|
||||||
tokens = re.findall(self.token_pattern, text)
|
tokens = re.findall(pattern=self.token_pattern, string=text)
|
||||||
upper_bound = None
|
upper_bound = None
|
||||||
if max_length:
|
if max_length:
|
||||||
assert max_length >= 2
|
assert max_length >= 2
|
||||||
|
@ -99,10 +102,10 @@ class CLIPTokenizer:
|
||||||
self.token_to_id_mapping[subtoken]
|
self.token_to_id_mapping[subtoken]
|
||||||
for token in tokens
|
for token in tokens
|
||||||
for subtoken in self.byte_pair_encoding(
|
for subtoken in self.byte_pair_encoding(
|
||||||
"".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
||||||
).split(" ")
|
).split(sep=" ")
|
||||||
),
|
),
|
||||||
0,
|
0,
|
||||||
upper_bound,
|
upper_bound,
|
||||||
)
|
)
|
||||||
return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
||||||
|
|
Loading…
Reference in a new issue