From 6594502c1117859882b5fe2a46176c9c8be210e4 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 16 Aug 2023 13:44:44 +0200 Subject: [PATCH] parametrize tokenizer for text_encoder --- src/refiners/foundationals/clip/text_encoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 188df0b..a3b551b 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -145,6 +145,7 @@ class CLIPTextEncoder(Chain): feedforward_dim: int = 3072, layer_norm_eps: float = 1e-5, use_quick_gelu: bool = False, + tokenizer: CLIPTokenizer | None = None, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -156,7 +157,7 @@ class CLIPTextEncoder(Chain): self.feedforward_dim = feedforward_dim self.layer_norm_eps = layer_norm_eps self.use_quick_gelu = use_quick_gelu - self.tokenizer = CLIPTokenizer() + self.tokenizer = tokenizer or CLIPTokenizer() super().__init__( PositionalTokenEncoder( vocabulary_size=vocabulary_size,