fix broken self-attention guidance with ip-adapter

The #168 and #177 refactorings caused this regression. A new end-to-end
test has been added for proper coverage.

(This fix will be revisited at some point)
This commit is contained in:
limiteinductive 2024-01-16 16:13:40 +01:00 committed by Cédric Deltheil
parent d9ae7ca6a5
commit 2b977bc69e
6 changed files with 160 additions and 11 deletions

View file

@ -209,6 +209,16 @@ def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
download_sd_tokenizer(hf_repo_id, "tokenizer_2") download_sd_tokenizer(hf_repo_id, "tokenizer_2")
def download_vae_fp16_fix():
download_files(
urls=[
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/raw/main/config.json",
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/diffusion_pytorch_model.safetensors",
],
dest_folder=os.path.join(test_weights_dir, "madebyollin", "sdxl-vae-fp16-fix"),
)
def download_vae_ft_mse(): def download_vae_ft_mse():
download_files( download_files(
urls=[ urls=[
@ -433,6 +443,17 @@ def convert_vae_ft_mse():
) )
def convert_vae_fp16_fix():
run_conversion_script(
"convert_diffusers_autoencoder_kl.py",
"tests/weights/madebyollin/sdxl-vae-fp16-fix",
"tests/weights/sdxl-lda-fp16-fix.safetensors",
additional_args=["--subfolder", "''"],
half=True,
expected_hash="98c7e998",
)
def convert_lora(): def convert_lora():
os.makedirs("tests/weights/loras", exist_ok=True) os.makedirs("tests/weights/loras", exist_ok=True)
run_conversion_script( run_conversion_script(
@ -610,6 +631,7 @@ def download_all():
download_sd15("runwayml/stable-diffusion-inpainting") download_sd15("runwayml/stable-diffusion-inpainting")
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0") download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
download_vae_ft_mse() download_vae_ft_mse()
download_vae_fp16_fix()
download_lora() download_lora()
download_preprocessors() download_preprocessors()
download_controlnet() download_controlnet()
@ -624,6 +646,7 @@ def convert_all():
convert_sd15() convert_sd15()
convert_sdxl() convert_sdxl()
convert_vae_ft_mse() convert_vae_ft_mse()
convert_vae_fp16_fix()
convert_lora() convert_lora()
convert_preprocessors() convert_preprocessors()
convert_controlnet() convert_controlnet()

View file

@ -89,9 +89,17 @@ class StableDiffusion_1(LatentDiffusionModel):
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents) degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
step=step, step=step,
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat( x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents), tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1, dim=1,
) )
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x) degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)

View file

@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2) negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2) time_ids, _ = time_ids.chunk(2)
self.set_unet_context( self.set_unet_context(
timestep=timestep, timestep=timestep,
clip_text_embedding=negative_embedding, clip_text_embedding=negative_text_embedding,
pooled_text_embedding=negative_pooled_embedding, pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids, time_ids=time_ids,
**kwargs,
) )
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents) degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)

View file

@ -242,6 +242,20 @@ def expected_freeu(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
@pytest.fixture
def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]:
assets = Path(__file__).parent.parent.parent / "assets"
dropy = assets / "dropy_logo.png"
image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png"
return (
Image.open(fp=dropy).convert(mode="RGB"),
Image.open(fp=image_prompt).convert(mode="RGB"),
Image.open(fp=condition_image).convert(mode="RGB"),
Image.open(fp=ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
)
@pytest.fixture @pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
@ -488,6 +502,15 @@ def sdxl_lda_weights(test_weights_path: Path) -> Path:
return sdxl_lda_weights return sdxl_lda_weights
@pytest.fixture
def sdxl_lda_fp16_fix_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture @pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path: def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors" sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
@ -524,6 +547,24 @@ def sdxl_ddim(
return sdxl return sdxl
@pytest.fixture
def sdxl_ddim_lda_fp16_fix(
sdxl_text_encoder_weights: Path, sdxl_lda_fp16_fix_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_fp16_fix_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@no_grad() @no_grad()
def test_diffusion_std_random_init( def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
@ -1702,3 +1743,62 @@ def test_freeu(
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_freeu) ensure_similar_images(predicted_image, expected_freeu)
@no_grad()
def test_hello_world(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
sdxl_ip_adapter_weights: Path,
image_encoder_weights: Path,
hello_world_assets: tuple[Image.Image, Image.Image, Image.Image, Image.Image],
) -> None:
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
name, _, _, weights_path = t2i_adapter_xl_data_canny
init_image, image_prompt, condition_image, expected_image = hello_world_assets
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
ip_adapter = SDXLIPAdapter(target=sdxl.unet, weights=load_from_safetensors(sdxl_ip_adapter_weights))
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))
ip_adapter.set_clip_image_embedding(image_embedding)
# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
time_ids = sdxl.default_time_ids
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
first_step = 1
ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(9752)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
device=sdxl.device, dtype=sdxl.dtype
)
for step in sdxl.steps[first_step:]:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)

View file

@ -47,6 +47,7 @@ Special cases:
- `expected_cutecat_sdxl_ddim_random_init_sag.png` - `expected_cutecat_sdxl_ddim_random_init_sag.png`
- `expected_restart.png` - `expected_restart.png`
- `expected_freeu.png` - `expected_freeu.png`
- `expected_dropy_slime_9752.png`
## Other images ## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB