2023-08-04 13:28:41 +00:00
|
|
|
from pathlib import Path
|
2023-12-11 10:46:38 +00:00
|
|
|
from typing import Any
|
2023-08-04 13:28:41 +00:00
|
|
|
from warnings import warn
|
2023-12-11 10:46:38 +00:00
|
|
|
|
2023-08-04 13:28:41 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2023-08-29 12:38:58 +00:00
|
|
|
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.fluxion.utils import manual_seed
|
|
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
2023-09-06 17:32:44 +00:00
|
|
|
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
|
2023-08-04 13:28:41 +00:00
|
|
|
if not r.is_dir():
|
|
|
|
warn(f"could not find Stability SDXL base weights at {r}, skipping")
|
|
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
|
|
|
from diffusers import DiffusionPipeline # type: ignore
|
|
|
|
|
|
|
|
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
|
|
|
|
return diffusers_sdxl.unet
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
2023-08-29 12:38:58 +00:00
|
|
|
def refiners_sdxl_unet() -> SDXLUNet:
|
2023-08-04 13:28:41 +00:00
|
|
|
unet = SDXLUNet(in_channels=4)
|
|
|
|
return unet
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
2023-08-24 00:26:37 +00:00
|
|
|
source = diffusers_sdxl_unet
|
|
|
|
target = refiners_sdxl_unet
|
|
|
|
|
|
|
|
manual_seed(seed=0)
|
2023-08-04 13:28:41 +00:00
|
|
|
x = torch.randn(1, 4, 32, 32)
|
|
|
|
timestep = torch.tensor(data=[0])
|
|
|
|
clip_text_embeddings = torch.randn(1, 77, 2048)
|
|
|
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
|
|
|
|
|
|
|
target_args = (x,)
|
2023-08-24 00:26:37 +00:00
|
|
|
source_args = {
|
|
|
|
"positional": (x, timestep, clip_text_embeddings),
|
|
|
|
"keyword": {"added_cond_kwargs": added_cond_kwargs},
|
|
|
|
}
|
|
|
|
|
2023-08-29 12:38:58 +00:00
|
|
|
old_forward = target.forward
|
|
|
|
|
|
|
|
def forward_with_context(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
target.set_timestep(timestep=timestep)
|
|
|
|
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
|
|
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
|
|
|
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
|
|
|
return old_forward(self, *args, **kwargs)
|
|
|
|
|
|
|
|
target.forward = forward_with_context
|
|
|
|
|
|
|
|
converter = ModelConverter(source_model=source, target_model=target, verbose=True, threshold=1e-2)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
2023-08-24 00:26:37 +00:00
|
|
|
assert converter.run(
|
2023-08-04 13:28:41 +00:00
|
|
|
source_args=source_args,
|
|
|
|
target_args=target_args,
|
|
|
|
)
|
2023-08-29 12:38:58 +00:00
|
|
|
assert converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|