diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 96a79cc..90d35ba 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -25,7 +25,7 @@ class PositionalTokenEncoder(Sum): positional_embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None, - ): + ) -> None: self.vocabulary_size = vocabulary_size self.positional_embedding_dim = positional_embedding_dim super().__init__( @@ -36,7 +36,7 @@ class PositionalTokenEncoder(Sum): dtype=dtype, ), Chain( - Lambda(self.get_position_ids), + Lambda(func=self.get_position_ids), Embedding( num_embeddings=positional_embedding_dim, embedding_dim=embedding_dim, @@ -48,7 +48,7 @@ class PositionalTokenEncoder(Sum): @property def position_ids(self) -> Tensor: - return arange(self.positional_embedding_dim, device=self.device).reshape(1, -1) + return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1) def get_position_ids(self, x: Tensor) -> Tensor: return self.position_ids[:, : x.shape[1]] @@ -145,7 +145,7 @@ class CLIPTextEncoder(Chain): layer_norm_eps: float = 1e-5, device: Device | str | None = None, dtype: DType | None = None, - ): + ) -> None: self.embedding_dim = embedding_dim self.positional_embedding_dim = positional_embedding_dim self.vocabulary_size = vocabulary_size @@ -177,12 +177,12 @@ class CLIPTextEncoder(Chain): ) def encode(self, text: str) -> Tensor: - tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(self.device) + tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device) return self(tokens) @property def unconditional_text_embedding(self) -> Tensor: - return self.encode("") + return self.encode(text="") class CLIPTextEncoderL(CLIPTextEncoder): @@ -206,7 +206,7 @@ class CLIPTextEncoderL(CLIPTextEncoder): device=device, dtype=dtype, ) - for gelu, parent in self.walk(lambda m, _: isinstance(m, GeLU)): + for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)): parent.replace(old_module=gelu, new_module=ApproximateGeLU()) diff --git a/src/refiners/foundationals/clip/tokenizer.py b/src/refiners/foundationals/clip/tokenizer.py index 4bfdcae..0213cfa 100644 --- a/src/refiners/foundationals/clip/tokenizer.py +++ b/src/refiners/foundationals/clip/tokenizer.py @@ -11,13 +11,16 @@ class CLIPTokenizer: def __init__( self, vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", - ): + ) -> None: self.vocabulary_path = vocabulary_path self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()} merge_tuples = [ tuple(merge.split()) - for merge in gzip.open(vocabulary_path).read().decode("utf-8").split("\n")[1 : 49152 - 256 - 2 + 1] + for merge in gzip.open(filename=vocabulary_path) + .read() + .decode(encoding="utf-8") + .split(sep="\n")[1 : 49152 - 256 - 2 + 1] ] vocabulary = ( list(self.byte_to_unicode_mapping.values()) @@ -25,25 +28,25 @@ class CLIPTokenizer: + ["".join(merge) for merge in merge_tuples] + ["", ""] ) - self.token_to_id_mapping = {token: i for i, token in enumerate(vocabulary)} - self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(merge_tuples)} + self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)} + self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)} self.byte_pair_encoding_cache = {"": ""} # Note: this regular expression does not support Unicode. It was changed so # to get rid of the dependence on the `regex` module. Unicode support could # potentially be added back by leveraging the `\w` character class. self.token_pattern = re.compile( - r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""", - re.IGNORECASE, + pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""", + flags=re.IGNORECASE, ) self.start_of_text_token_id: int = 49406 self.end_of_text_token_id: int = 49407 def __call__(self, text: str, sequence_length: int) -> Tensor: - tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(0) + tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0) assert ( tokens.shape[1] <= sequence_length ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" - return pad(tokens, (0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id) + return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id) @lru_cache() def get_bytes_to_unicode_mapping(self) -> dict[int, str]: @@ -83,13 +86,13 @@ class CLIPTokenizer: return recursive_bpe(tuple(new_word)) word = tuple(token[:-1]) + (token[-1] + "",) - result = " ".join(recursive_bpe(word)) + result = " ".join(recursive_bpe(word=word)) self.byte_pair_encoding_cache[token] = result return result def encode(self, text: str, max_length: int | None = None) -> Tensor: - text = re.sub(r"\s+", " ", text.lower()) - tokens = re.findall(self.token_pattern, text) + text = re.sub(pattern=r"\s+", repl=" ", string=text.lower()) + tokens = re.findall(pattern=self.token_pattern, string=text) upper_bound = None if max_length: assert max_length >= 2 @@ -99,10 +102,10 @@ class CLIPTokenizer: self.token_to_id_mapping[subtoken] for token in tokens for subtoken in self.byte_pair_encoding( - "".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8")) - ).split(" ") + token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8")) + ).split(sep=" ") ), 0, upper_bound, ) - return tensor([self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id]) + return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])