mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix typo + skip test if weights are not available
This commit is contained in:
parent
cf43cb191f
commit
78e69c7da0
|
@ -55,10 +55,18 @@ def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
|
||||||
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||||
|
|
||||||
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
|
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||||
text_encdoer_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
|
text_encoder_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
|
||||||
|
|
||||||
|
if not text_encoder_l_path.is_file():
|
||||||
|
warn(f"could not find weights at {text_encoder_l_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
if not text_encoder_g_path.is_file():
|
||||||
|
warn(f"could not find weights at {text_encoder_g_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
|
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
|
||||||
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encdoer_g_path)
|
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encoder_g_path)
|
||||||
|
|
||||||
linear = text_encoder_g_with_projection.pop(index=-1)
|
linear = text_encoder_g_with_projection.pop(index=-1)
|
||||||
assert isinstance(linear, fl.Linear)
|
assert isinstance(linear, fl.Linear)
|
||||||
|
|
Loading…
Reference in a new issue