parametrize tokenizer for text_encoder

This commit is contained in:
limiteinductive 2023-08-16 13:44:44 +02:00 committed by Benjamin Trom
parent 4575e3dd91
commit 6594502c11

View file

@ -145,6 +145,7 @@ class CLIPTextEncoder(Chain):
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
use_quick_gelu: bool = False,
tokenizer: CLIPTokenizer | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
@ -156,7 +157,7 @@ class CLIPTextEncoder(Chain):
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
self.tokenizer = CLIPTokenizer()
self.tokenizer = tokenizer or CLIPTokenizer()
super().__init__(
PositionalTokenEncoder(
vocabulary_size=vocabulary_size,