mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
write StyleAligned
e2e test
This commit is contained in:
parent
60c0780fe7
commit
da3c3602fb
|
@ -30,6 +30,7 @@ from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSc
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,6 +151,11 @@ def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_style_aligned(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
def controlnet_data(
|
def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
@ -2140,3 +2146,79 @@ def test_hello_world(
|
||||||
predicted_image = sdxl.lda.latents_to_image(x)
|
predicted_image = sdxl.lda.latents_to_image(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image)
|
ensure_similar_images(predicted_image, expected_image)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_style_aligned(
|
||||||
|
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
|
||||||
|
expected_style_aligned: Image.Image,
|
||||||
|
):
|
||||||
|
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
|
||||||
|
sdxl.dtype = torch.float16 # FIXME: should not be necessary
|
||||||
|
|
||||||
|
style_aligned_adapter = StyleAlignedAdapter(sdxl.unet)
|
||||||
|
style_aligned_adapter.inject()
|
||||||
|
|
||||||
|
set_of_prompts = [
|
||||||
|
"a toy train. macro photo. 3d game asset",
|
||||||
|
"a toy airplane. macro photo. 3d game asset",
|
||||||
|
"a toy bicycle. macro photo. 3d game asset",
|
||||||
|
"a toy car. macro photo. 3d game asset",
|
||||||
|
"a toy boat. macro photo. 3d game asset",
|
||||||
|
]
|
||||||
|
|
||||||
|
# create (context) embeddings from prompts
|
||||||
|
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
|
||||||
|
unconds: list[torch.Tensor] = []
|
||||||
|
conds: list[torch.Tensor] = []
|
||||||
|
pooled_unconds: list[torch.Tensor] = []
|
||||||
|
pooled_conds: list[torch.Tensor] = []
|
||||||
|
for prompt in set_of_prompts:
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)
|
||||||
|
|
||||||
|
uncond, cond = clip_text_embedding.chunk(2)
|
||||||
|
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
|
||||||
|
|
||||||
|
unconds.append(uncond)
|
||||||
|
conds.append(cond)
|
||||||
|
pooled_unconds.append(pooled_uncond)
|
||||||
|
pooled_conds.append(pooled_cond)
|
||||||
|
|
||||||
|
uncond = torch.cat(unconds, dim=0)
|
||||||
|
cond = torch.cat(conds, dim=0)
|
||||||
|
pooled_uncond = torch.cat(pooled_unconds, dim=0)
|
||||||
|
pooled_cond = torch.cat(pooled_conds, dim=0)
|
||||||
|
|
||||||
|
clip_text_embedding = torch.cat((uncond, cond), dim=0)
|
||||||
|
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
|
||||||
|
|
||||||
|
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
|
||||||
|
|
||||||
|
# initialize latents
|
||||||
|
manual_seed(seed=2)
|
||||||
|
x = torch.randn(
|
||||||
|
(len(set_of_prompts), 4, 128, 128),
|
||||||
|
device=sdxl.device,
|
||||||
|
dtype=sdxl.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# denoise
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode latents
|
||||||
|
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]
|
||||||
|
|
||||||
|
# tile all images horizontally
|
||||||
|
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
|
||||||
|
for i in range(len(predicted_images)):
|
||||||
|
merged_image.paste(predicted_images[i], (i * 1024, 0))
|
||||||
|
|
||||||
|
# compare against reference image
|
||||||
|
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
|
@ -56,6 +56,7 @@ Special cases:
|
||||||
- `expected_controllora_PyraCanny.png`
|
- `expected_controllora_PyraCanny.png`
|
||||||
- `expected_controllora_PyraCanny+CPDS.png`
|
- `expected_controllora_PyraCanny+CPDS.png`
|
||||||
- `expected_controllora_disabled.png`
|
- `expected_controllora_disabled.png`
|
||||||
|
- `expected_style_aligned.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_style_aligned.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_style_aligned.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.4 MiB |
Loading…
Reference in a new issue