diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 188df0b..a3b551b 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -145,6 +145,7 @@ class CLIPTextEncoder(Chain): feedforward_dim: int = 3072, layer_norm_eps: float = 1e-5, use_quick_gelu: bool = False, + tokenizer: CLIPTokenizer | None = None, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -156,7 +157,7 @@ class CLIPTextEncoder(Chain): self.feedforward_dim = feedforward_dim self.layer_norm_eps = layer_norm_eps self.use_quick_gelu = use_quick_gelu - self.tokenizer = CLIPTokenizer() + self.tokenizer = tokenizer or CLIPTokenizer() super().__init__( PositionalTokenEncoder( vocabulary_size=vocabulary_size,