diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 8caef45..172820f 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -7,7 +7,6 @@ from torch import Tensor from refiners.fluxion.utils import manual_seed import refiners.fluxion.layers as fl -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder @@ -41,6 +40,15 @@ def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path: return r +@pytest.fixture(scope="module") +def double_text_encoder_weights(test_weights_path: Path) -> Path: + text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors" + if not text_encoder_weights.is_file(): + warn(f"could not find weights at {text_encoder_weights}, skipping") + pytest.skip(allow_module_level=True) + return text_encoder_weights + + @pytest.fixture(scope="module") def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any: from diffusers import DiffusionPipeline # type: ignore @@ -49,31 +57,9 @@ def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any: @pytest.fixture(scope="module") -def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder: - text_encoder_l = CLIPTextEncoderL() - text_encoder_g_with_projection = CLIPTextEncoderG() - text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False)) - - text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors" - text_encoder_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors" - - if not text_encoder_l_path.is_file(): - warn(f"could not find weights at {text_encoder_l_path}, skipping") - pytest.skip(allow_module_level=True) - - if not text_encoder_g_path.is_file(): - warn(f"could not find weights at {text_encoder_g_path}, skipping") - pytest.skip(allow_module_level=True) - - text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path) - text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encoder_g_path) - - linear = text_encoder_g_with_projection.pop(index=-1) - assert isinstance(linear, fl.Linear) - - double_text_encoder = DoubleTextEncoder( - text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear - ) +def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder: + double_text_encoder = DoubleTextEncoder() + double_text_encoder.load_from_safetensors(double_text_encoder_weights) return double_text_encoder