add start, end and pad tokens as parameter

This commit is contained in:
limiteinductive 2023-08-15 16:34:25 +02:00 committed by Benjamin Trom
parent 63fda2bfd8
commit 4575e3dd91

View file

@ -11,6 +11,9 @@ class CLIPTokenizer:
def __init__( def __init__(
self, self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
start_of_text_token_id: int = 49406,
end_of_text_token_id: int = 49407,
pad_token_id: int = 49407,
) -> None: ) -> None:
self.vocabulary_path = vocabulary_path self.vocabulary_path = vocabulary_path
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
@ -38,15 +41,16 @@ class CLIPTokenizer:
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""", pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
flags=re.IGNORECASE, flags=re.IGNORECASE,
) )
self.start_of_text_token_id: int = 49406 self.start_of_text_token_id: int = start_of_text_token_id
self.end_of_text_token_id: int = 49407 self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id
def __call__(self, text: str, sequence_length: int) -> Tensor: def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0) tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
assert ( assert (
tokens.shape[1] <= sequence_length tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id) return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache() @lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]: def get_bytes_to_unicode_mapping(self) -> dict[int, str]: