mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix broken self-attention guidance with ip-adapter
The #168 and #177 refactorings caused this regression. A new end-to-end test has been added for proper coverage. (This fix will be revisited at some point)
This commit is contained in:
parent
d9ae7ca6a5
commit
2b977bc69e
|
@ -209,6 +209,16 @@ def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
|
|||
download_sd_tokenizer(hf_repo_id, "tokenizer_2")
|
||||
|
||||
|
||||
def download_vae_fp16_fix():
|
||||
download_files(
|
||||
urls=[
|
||||
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json",
|
||||
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"),
|
||||
)
|
||||
|
||||
|
||||
def download_vae_ft_mse():
|
||||
download_files(
|
||||
urls=[
|
||||
|
@ -433,6 +443,17 @@ def convert_vae_ft_mse():
|
|||
)
|
||||
|
||||
|
||||
def convert_vae_fp16_fix():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/madebyollin/sdxl-vae-fp16-fix",
|
||||
"tests/weights/sdxl-lda-fp16-fix.safetensors",
|
||||
additional_args=["--subfolder", "''"],
|
||||
half=True,
|
||||
expected_hash="98c7e998",
|
||||
)
|
||||
|
||||
|
||||
def convert_lora():
|
||||
os.makedirs("tests/weights/loras", exist_ok=True)
|
||||
run_conversion_script(
|
||||
|
@ -610,6 +631,7 @@ def download_all():
|
|||
download_sd15("runwayml/stable-diffusion-inpainting")
|
||||
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
download_vae_ft_mse()
|
||||
download_vae_fp16_fix()
|
||||
download_lora()
|
||||
download_preprocessors()
|
||||
download_controlnet()
|
||||
|
@ -624,6 +646,7 @@ def convert_all():
|
|||
convert_sd15()
|
||||
convert_sdxl()
|
||||
convert_vae_ft_mse()
|
||||
convert_vae_fp16_fix()
|
||||
convert_lora()
|
||||
convert_preprocessors()
|
||||
convert_controlnet()
|
||||
|
|
|
@ -89,10 +89,18 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
||||
|
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
step=step,
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
x = torch.cat(
|
||||
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
||||
dim=1,
|
||||
)
|
||||
degraded_noise = self.unet(x)
|
||||
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(x)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(x)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
|
|
@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
negative_text_embedding, _ = clip_text_embedding.chunk(2)
|
||||
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
time_ids, _ = time_ids.chunk(2)
|
||||
|
||||
self.set_unet_context(
|
||||
timestep=timestep,
|
||||
clip_text_embedding=negative_embedding,
|
||||
clip_text_embedding=negative_text_embedding,
|
||||
pooled_text_embedding=negative_pooled_embedding,
|
||||
time_ids=time_ids,
|
||||
**kwargs,
|
||||
)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
|
|
@ -242,6 +242,20 @@ def expected_freeu(ref_path: Path) -> Image.Image:
|
|||
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
|
||||
assets = Path(__file__).parent.parent.parent / "assets"
|
||||
dropy = assets / "dropy_logo.png"
|
||||
image_prompt = assets / "dragon_quest_slime.jpg"
|
||||
condition_image = assets / "dropy_canny.png"
|
||||
return (
|
||||
Image.open(fp=dropy).convert(mode="RGB"),
|
||||
Image.open(fp=image_prompt).convert(mode="RGB"),
|
||||
Image.open(fp=condition_image).convert(mode="RGB"),
|
||||
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
|
@ -488,6 +502,15 @@ def sdxl_lda_weights(test_weights_path: Path) -> Path:
|
|||
return sdxl_lda_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
|
||||
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
||||
if not sdxl_lda_weights.is_file():
|
||||
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sdxl_lda_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
||||
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
|
||||
|
@ -524,6 +547,24 @@ def sdxl_ddim(
|
|||
return sdxl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_ddim_lda_fp16_fix(
|
||||
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
|
||||
) -> StableDiffusion_XL:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
scheduler = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
|
||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
||||
|
||||
return sdxl
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init(
|
||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
|
@ -1702,3 +1743,62 @@ def test_freeu(
|
|||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_freeu)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_hello_world(
|
||||
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
|
||||
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||
sdxl_ip_adapter_weights: Path,
|
||||
image_encoder_weights: Path,
|
||||
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
|
||||
) -> None:
|
||||
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
|
||||
sdxl.dtype = torch.float16 # FIXME: should not be necessary
|
||||
|
||||
name, _, _, weights_path = t2i_adapter_xl_data_canny
|
||||
init_image, image_prompt, condition_image, expected_image = hello_world_assets
|
||||
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
|
||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
|
||||
ip_adapter.set_clip_image_embedding(image_embedding)
|
||||
|
||||
# Note: default text prompts for IP-Adapter
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
|
||||
|
||||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
first_step = 1
|
||||
ip_adapter.set_scale(0.85)
|
||||
t2i_adapter.set_scale(0.8)
|
||||
sdxl.set_num_inference_steps(50)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manual_seed(9752)
|
||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
|
||||
device=sdxl.device, dtype=sdxl.dtype
|
||||
)
|
||||
for step in sdxl.steps[first_step:]:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
|
|
@ -47,6 +47,7 @@ Special cases:
|
|||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||
- `expected_restart.png`
|
||||
- `expected_freeu.png`
|
||||
- `expected_dropy_slime_9752.png`
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_dropy_slime_9752.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_dropy_slime_9752.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
Loading…
Reference in a new issue