test_sdxl_double_encoder: use proper weights

This commit is contained in:
Cédric Deltheil 2023-09-11 16:07:51 +02:00 committed by Cédric Deltheil
parent cc3b20320d
commit 32cba1afd8

View file

@ -7,7 +7,6 @@ from torch import Tensor
from refiners.fluxion.utils import manual_seed
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@ -41,6 +40,15 @@ def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
return r
@pytest.fixture(scope="module")
def double_text_encoder_weights(test_weights_path: Path) -> Path:
text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not text_encoder_weights.is_file():
warn(f"could not find weights at {text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return text_encoder_weights
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
@ -49,31 +57,9 @@ def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
@pytest.fixture(scope="module")
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
text_encoder_l = CLIPTextEncoderL()
text_encoder_g_with_projection = CLIPTextEncoderG()
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
text_encoder_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
if not text_encoder_l_path.is_file():
warn(f"could not find weights at {text_encoder_l_path}, skipping")
pytest.skip(allow_module_level=True)
if not text_encoder_g_path.is_file():
warn(f"could not find weights at {text_encoder_g_path}, skipping")
pytest.skip(allow_module_level=True)
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encoder_g_path)
linear = text_encoder_g_with_projection.pop(index=-1)
assert isinstance(linear, fl.Linear)
double_text_encoder = DoubleTextEncoder(
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
)
def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
double_text_encoder = DoubleTextEncoder()
double_text_encoder.load_from_safetensors(double_text_encoder_weights)
return double_text_encoder