refiners/tests/foundationals/latent_diffusion/conftest.py
2024-10-14 15:12:59 +02:00

140 lines
4.4 KiB
Python

from pathlib import Path
import pytest
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from refiners.fluxion.utils import load_from_safetensors
from refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
DoubleTextEncoder,
SD1Autoencoder,
SD1UNet,
SDXLAutoencoder,
SDXLUNet,
StableDiffusion_1,
StableDiffusion_XL,
)
@pytest.fixture(scope="module")
def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder:
autoencoder = SD1Autoencoder()
tensors = load_from_safetensors(sd15_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet:
unet = SD1UNet(in_channels=4)
tensors = load_from_safetensors(sd15_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL:
text_encoder = CLIPTextEncoderL()
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sd15(
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sd15_unet: SD1UNet,
refiners_sd15_text_encoder: CLIPTextEncoderL,
) -> StableDiffusion_1:
return StableDiffusion_1(
lda=refiners_sd15_autoencoder,
unet=refiners_sd15_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module")
def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder:
autoencoder = SDXLAutoencoder()
tensors = load_from_safetensors(sdxl_autoencoder_weights_path)
autoencoder.load_state_dict(tensors)
return autoencoder
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
tensors = load_from_safetensors(sdxl_unet_weights_path)
unet.load_state_dict(tensors)
return unet
@pytest.fixture(scope="module")
def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder:
text_encoder = DoubleTextEncoder()
tensors = load_from_safetensors(sdxl_text_encoder_weights_path)
text_encoder.load_state_dict(tensors)
return text_encoder
@pytest.fixture(scope="module")
def refiners_sdxl(
refiners_sdxl_autoencoder: SDXLAutoencoder,
refiners_sdxl_unet: SDXLUNet,
refiners_sd15_text_encoder: DoubleTextEncoder,
) -> StableDiffusion_XL:
return StableDiffusion_XL(
lda=refiners_sdxl_autoencoder,
unet=refiners_sdxl_unet,
clip_text_encoder=refiners_sd15_text_encoder,
)
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def refiners_autoencoder(
request: pytest.FixtureRequest,
refiners_sd15_autoencoder: SD1Autoencoder,
refiners_sdxl_autoencoder: SDXLAutoencoder,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1Autoencoder | SDXLAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
return refiners_sd15_autoencoder
case ("SDXL", torch.float16):
return refiners_sdxl_autoencoder
case ("SDXL", _):
return refiners_sdxl_autoencoder
case _:
raise ValueError(f"Unknown model version: {model_version}")
@pytest.fixture(scope="module")
def diffusers_sd15_pipeline(
sd15_diffusers_runwayml_path: str,
use_local_weights: bool,
) -> StableDiffusionPipeline:
return StableDiffusionPipeline.from_pretrained( # type: ignore
sd15_diffusers_runwayml_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_pipeline(
sdxl_diffusers_stabilityai_path: str,
use_local_weights: bool,
) -> StableDiffusionXLPipeline:
return StableDiffusionXLPipeline.from_pretrained( # type: ignore
sdxl_diffusers_stabilityai_path,
local_files_only=use_local_weights,
)
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel:
return diffusers_sdxl_pipeline.unet # type: ignore