mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-13 00:28:14 +00:00
add start, end and pad tokens as parameter
This commit is contained in:
parent
63fda2bfd8
commit
4575e3dd91
|
@ -11,6 +11,9 @@ class CLIPTokenizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
start_of_text_token_id: int = 49406,
|
||||||
|
end_of_text_token_id: int = 49407,
|
||||||
|
pad_token_id: int = 49407,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.vocabulary_path = vocabulary_path
|
self.vocabulary_path = vocabulary_path
|
||||||
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||||
|
@ -38,15 +41,16 @@ class CLIPTokenizer:
|
||||||
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||||
flags=re.IGNORECASE,
|
flags=re.IGNORECASE,
|
||||||
)
|
)
|
||||||
self.start_of_text_token_id: int = 49406
|
self.start_of_text_token_id: int = start_of_text_token_id
|
||||||
self.end_of_text_token_id: int = 49407
|
self.end_of_text_token_id: int = end_of_text_token_id
|
||||||
|
self.pad_token_id: int = pad_token_id
|
||||||
|
|
||||||
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
def __call__(self, text: str, sequence_length: int) -> Tensor:
|
||||||
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
|
||||||
assert (
|
assert (
|
||||||
tokens.shape[1] <= sequence_length
|
tokens.shape[1] <= sequence_length
|
||||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
|
||||||
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.end_of_text_token_id)
|
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||||
|
|
Loading…
Reference in a new issue