mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 01:28:14 +00:00
140 lines
4.4 KiB
Python
140 lines
4.4 KiB
Python
|
from pathlib import Path
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||
|
|
||
|
from refiners.fluxion.utils import load_from_safetensors
|
||
|
from refiners.foundationals.latent_diffusion import (
|
||
|
CLIPTextEncoderL,
|
||
|
DoubleTextEncoder,
|
||
|
SD1Autoencoder,
|
||
|
SD1UNet,
|
||
|
SDXLAutoencoder,
|
||
|
SDXLUNet,
|
||
|
StableDiffusion_1,
|
||
|
StableDiffusion_XL,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sd15_autoencoder(sd15_autoencoder_weights_path: Path) -> SD1Autoencoder:
|
||
|
autoencoder = SD1Autoencoder()
|
||
|
tensors = load_from_safetensors(sd15_autoencoder_weights_path)
|
||
|
autoencoder.load_state_dict(tensors)
|
||
|
return autoencoder
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sd15_unet(sd15_unet_weights_path: Path) -> SD1UNet:
|
||
|
unet = SD1UNet(in_channels=4)
|
||
|
tensors = load_from_safetensors(sd15_unet_weights_path)
|
||
|
unet.load_state_dict(tensors)
|
||
|
return unet
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sd15_text_encoder(sd15_text_encoder_weights_path: Path) -> CLIPTextEncoderL:
|
||
|
text_encoder = CLIPTextEncoderL()
|
||
|
tensors = load_from_safetensors(sd15_text_encoder_weights_path)
|
||
|
text_encoder.load_state_dict(tensors)
|
||
|
return text_encoder
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sd15(
|
||
|
refiners_sd15_autoencoder: SD1Autoencoder,
|
||
|
refiners_sd15_unet: SD1UNet,
|
||
|
refiners_sd15_text_encoder: CLIPTextEncoderL,
|
||
|
) -> StableDiffusion_1:
|
||
|
return StableDiffusion_1(
|
||
|
lda=refiners_sd15_autoencoder,
|
||
|
unet=refiners_sd15_unet,
|
||
|
clip_text_encoder=refiners_sd15_text_encoder,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sdxl_autoencoder(sdxl_autoencoder_weights_path: Path) -> SDXLAutoencoder:
|
||
|
autoencoder = SDXLAutoencoder()
|
||
|
tensors = load_from_safetensors(sdxl_autoencoder_weights_path)
|
||
|
autoencoder.load_state_dict(tensors)
|
||
|
return autoencoder
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sdxl_unet(sdxl_unet_weights_path: Path) -> SDXLUNet:
|
||
|
unet = SDXLUNet(in_channels=4)
|
||
|
tensors = load_from_safetensors(sdxl_unet_weights_path)
|
||
|
unet.load_state_dict(tensors)
|
||
|
return unet
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sdxl_text_encoder(sdxl_text_encoder_weights_path: Path) -> DoubleTextEncoder:
|
||
|
text_encoder = DoubleTextEncoder()
|
||
|
tensors = load_from_safetensors(sdxl_text_encoder_weights_path)
|
||
|
text_encoder.load_state_dict(tensors)
|
||
|
return text_encoder
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def refiners_sdxl(
|
||
|
refiners_sdxl_autoencoder: SDXLAutoencoder,
|
||
|
refiners_sdxl_unet: SDXLUNet,
|
||
|
refiners_sd15_text_encoder: DoubleTextEncoder,
|
||
|
) -> StableDiffusion_XL:
|
||
|
return StableDiffusion_XL(
|
||
|
lda=refiners_sdxl_autoencoder,
|
||
|
unet=refiners_sdxl_unet,
|
||
|
clip_text_encoder=refiners_sd15_text_encoder,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
|
||
|
def refiners_autoencoder(
|
||
|
request: pytest.FixtureRequest,
|
||
|
refiners_sd15_autoencoder: SD1Autoencoder,
|
||
|
refiners_sdxl_autoencoder: SDXLAutoencoder,
|
||
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||
|
) -> SD1Autoencoder | SDXLAutoencoder:
|
||
|
model_version = request.param
|
||
|
match (model_version, test_dtype_fp32_bf16_fp16):
|
||
|
case ("SD1.5", _):
|
||
|
return refiners_sd15_autoencoder
|
||
|
case ("SDXL", torch.float16):
|
||
|
return refiners_sdxl_autoencoder
|
||
|
case ("SDXL", _):
|
||
|
return refiners_sdxl_autoencoder
|
||
|
case _:
|
||
|
raise ValueError(f"Unknown model version: {model_version}")
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def diffusers_sd15_pipeline(
|
||
|
sd15_diffusers_runwayml_path: str,
|
||
|
use_local_weights: bool,
|
||
|
) -> StableDiffusionPipeline:
|
||
|
return StableDiffusionPipeline.from_pretrained( # type: ignore
|
||
|
sd15_diffusers_runwayml_path,
|
||
|
local_files_only=use_local_weights,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def diffusers_sdxl_pipeline(
|
||
|
sdxl_diffusers_stabilityai_path: str,
|
||
|
use_local_weights: bool,
|
||
|
) -> StableDiffusionXLPipeline:
|
||
|
return StableDiffusionXLPipeline.from_pretrained( # type: ignore
|
||
|
sdxl_diffusers_stabilityai_path,
|
||
|
local_files_only=use_local_weights,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def diffusers_sdxl_unet(diffusers_sdxl_pipeline: StableDiffusionXLPipeline) -> UNet2DConditionModel:
|
||
|
return diffusers_sdxl_pipeline.unet # type: ignore
|