mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
test_sdxl_double_encoder: use proper weights
This commit is contained in:
parent
cc3b20320d
commit
32cba1afd8
|
@ -7,7 +7,6 @@ from torch import Tensor
|
|||
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
|
@ -41,6 +40,15 @@ def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
|||
return r
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def double_text_encoder_weights(test_weights_path: Path) -> Path:
|
||||
text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
||||
if not text_encoder_weights.is_file():
|
||||
warn(f"could not find weights at {text_encoder_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return text_encoder_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
||||
from diffusers import DiffusionPipeline # type: ignore
|
||||
|
@ -49,31 +57,9 @@ def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
|
||||
text_encoder_l = CLIPTextEncoderL()
|
||||
text_encoder_g_with_projection = CLIPTextEncoderG()
|
||||
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
||||
|
||||
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||
text_encoder_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
|
||||
|
||||
if not text_encoder_l_path.is_file():
|
||||
warn(f"could not find weights at {text_encoder_l_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
if not text_encoder_g_path.is_file():
|
||||
warn(f"could not find weights at {text_encoder_g_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
|
||||
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encoder_g_path)
|
||||
|
||||
linear = text_encoder_g_with_projection.pop(index=-1)
|
||||
assert isinstance(linear, fl.Linear)
|
||||
|
||||
double_text_encoder = DoubleTextEncoder(
|
||||
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
|
||||
)
|
||||
def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
|
||||
double_text_encoder = DoubleTextEncoder()
|
||||
double_text_encoder.load_from_safetensors(double_text_encoder_weights)
|
||||
|
||||
return double_text_encoder
|
||||
|
||||
|
|
Loading…
Reference in a new issue