diff --git a/src/refiners/foundationals/clip/tokenizer.py b/src/refiners/foundationals/clip/tokenizer.py index 0213cfa..a160ccc 100644 --- a/src/refiners/foundationals/clip/tokenizer.py +++ b/src/refiners/foundationals/clip/tokenizer.py @@ -11,6 +11,9 @@ class CLIPTokenizer: def __init__( self, vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", + start_of_text_token_id: int = 49406, + end_of_text_token_id: int = 49407, + pad_token_id: int = 49407, ) -> None: self.vocabulary_path = vocabulary_path self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() @@ -38,15 +41,16 @@ class CLIPTokenizer: pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""", flags=re.IGNORECASE, ) - self.start_of_text_token_id: int = 49406 - self.end_of_text_token_id: int = 49407 + self.start_of_text_token_id: int = start_of_text_token_id + self.end_of_text_token_id: int = end_of_text_token_id + self.pad_token_id: int = pad_token_id def __call__(self, text: str, sequence_length: int) -> Tensor: tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0) assert ( tokens.shape[1] <= sequence_length ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" - return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id) + return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id) @lru_cache() def get_bytes_to_unicode_mapping(self) -> dict[int, str]: