add start, end and pad tokens as parameter

This commit is contained in:
limiteinductive 2023-08-15 16:34:25 +02:00 committed by Benjamin Trom
parent 63fda2bfd8
commit 4575e3dd91

View file

@ -11,6 +11,9 @@ class CLIPTokenizer:
def __init__(
self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
start_of_text_token_id: int = 49406,
end_of_text_token_id: int = 49407,
pad_token_id: int = 49407,
) -> None:
self.vocabulary_path = vocabulary_path
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
@ -38,15 +41,16 @@ class CLIPTokenizer:
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
flags=re.IGNORECASE,
)
self.start_of_text_token_id: int = 49406
self.end_of_text_token_id: int = 49407
self.start_of_text_token_id: int = start_of_text_token_id
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id
def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
assert (
tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]: