mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
parametrize tokenizer for text_encoder
This commit is contained in:
parent
4575e3dd91
commit
6594502c11
|
@ -145,6 +145,7 @@ class CLIPTextEncoder(Chain):
|
||||||
feedforward_dim: int = 3072,
|
feedforward_dim: int = 3072,
|
||||||
layer_norm_eps: float = 1e-5,
|
layer_norm_eps: float = 1e-5,
|
||||||
use_quick_gelu: bool = False,
|
use_quick_gelu: bool = False,
|
||||||
|
tokenizer: CLIPTokenizer | None = None,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -156,7 +157,7 @@ class CLIPTextEncoder(Chain):
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.use_quick_gelu = use_quick_gelu
|
self.use_quick_gelu = use_quick_gelu
|
||||||
self.tokenizer = CLIPTokenizer()
|
self.tokenizer = tokenizer or CLIPTokenizer()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
PositionalTokenEncoder(
|
PositionalTokenEncoder(
|
||||||
vocabulary_size=vocabulary_size,
|
vocabulary_size=vocabulary_size,
|
||||||
|
|
Loading…
Reference in a new issue