From a5d3c2971b84f6faa4762b1cf5a07f4f812bb1f5 Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 21 Aug 2024 10:20:11 +0000 Subject: [PATCH] update ic_light adapter, bugfix + improve docstrings --- .../stable_diffusion_1/__init__.py | 2 + .../stable_diffusion_1/ic_light.py | 182 +++++++++--------- 2 files changed, 98 insertions(+), 86 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py index f8cd4c7..2914ff3 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py @@ -1,4 +1,5 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( SD1Autoencoder, @@ -16,4 +17,5 @@ __all__ = [ "SD1ControlnetAdapter", "SD1IPAdapter", "SD1T2IAdapter", + "ICLight", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py index 0f31dbe..8a9a0bb 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py @@ -3,7 +3,7 @@ from PIL import Image from torch.nn.init import zeros_ as zero_init from refiners.fluxion import layers as fl -from refiners.fluxion.utils import image_to_tensor, no_grad +from refiners.fluxion.utils import no_grad from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1 @@ -11,79 +11,82 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import Down class ICLight(StableDiffusion_1): - """ - IC-Light is a Stable Diffusion model that can be used to relight a reference image. + """IC-Light is a Stable Diffusion model that can be used to relight a reference image. - At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now. + At initialization, the UNet will be patched to accept four additional input channels. + Only the text-conditioned relighting model is supported for now. - ```example - import torch - from huggingface_hub import hf_hub_download - from PIL import Image - from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad - from refiners.foundationals.clip import CLIPTextEncoderL - from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet - from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight + Example: + ```py + import torch + from huggingface_hub import hf_hub_download + from PIL import Image - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.float32 - no_grad().__enter__() - manual_seed(42) + from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad + from refiners.foundationals.clip import CLIPTextEncoderL + from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet + from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight - sd = ICLight( - patch_weights=load_from_safetensors( - path=hf_hub_download( - repo_id="refiners/ic_light.sd1_5.fc", - filename="model.safetensors", + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + no_grad().__enter__() + manual_seed(42) + + sd = ICLight( + patch_weights=load_from_safetensors( + path=hf_hub_download( + repo_id="refiners/ic_light.sd1_5.fc", + filename="model.safetensors", + ), + device=device, + ), + unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.unet", + filename="model.safetensors", + ) + ), + clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder", + filename="model.safetensors", + ) + ), + lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder", + filename="model.safetensors", + ) ), device=device, - ), - unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors( - tensors_path=hf_hub_download( - repo_id="refiners/realistic_vision.v5_1.sd1_5.unet", - filename="model.safetensors", - ) - ), - clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors( - tensors_path=hf_hub_download( - repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder", - filename="model.safetensors", - ) - ), - lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors( - tensors_path=hf_hub_download( - repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder", - filename="model.safetensors", - ) - ), - device=device, - dtype=dtype, - ) - - prompt = "soft lighting, high-quality professional image" - negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - - image = Image.open("reference-image.png").resize((512, 512)) - sd.set_ic_light_condition(image) - - x = torch.randn( - size=(1, 4, 64, 64), - device=device, - dtype=dtype, - ) - - for step in sd.steps: - x = sd( - x=x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=1.5, + dtype=dtype, ) - predicted_image = sd.lda.latents_to_image(x) - predicted_image.save("ic-light-output.png") + prompt = "soft lighting, high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + image = Image.open("reference-image.png").resize((512, 512)) + sd.set_ic_light_condition(image) + + x = torch.randn( + size=(1, 4, 64, 64), + device=device, + dtype=dtype, + ) + + for step in sd.steps: + x = sd( + x=x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=1.5, + ) + predicted_image = sd.lda.latents_to_image(x) + + predicted_image.save("ic-light-output.png") + ``` """ def __init__( @@ -109,9 +112,7 @@ class ICLight(StableDiffusion_1): @no_grad() def _extend_conv_in(self) -> None: - """ - Extend to 8 the input channels of the first convolutional layer of the UNet. - """ + """Extend to 8 the input channels of the first convolutional layer of the UNet.""" down_blocks = self.unet.ensure_find(DownBlocks) first_block = down_blocks.layer(0, fl.Chain) conv_in = first_block.ensure_find(fl.Conv2d) @@ -129,9 +130,7 @@ class ICLight(StableDiffusion_1): first_block.replace(old_module=conv_in, new_module=new_conv_in) def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None: - """ - Apply the patch weights to the UNet, modifying inplace the state dict. - """ + """Apply the weights patch to the UNet, modifying inplace the state dict.""" current_state_dict = self.unet.state_dict() new_state_dict = { key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items() @@ -139,9 +138,17 @@ class ICLight(StableDiffusion_1): self.unet.load_state_dict(new_state_dict) @staticmethod - def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image: - """ - Compute a grayscale composite of an image and a mask. + def compute_gray_composite( + image: Image.Image, + mask: Image.Image, + ) -> Image.Image: + """Compute a grayscale composite of an image and a mask. + + IC-Light will recreate the image + + Args: + image: The image to composite. + mask: The mask to use for the composite. """ assert mask.mode == "L", "Mask must be a grayscale image" assert image.size == mask.size, "Image and mask must have the same size" @@ -149,28 +156,31 @@ class ICLight(StableDiffusion_1): return Image.composite(image, background, mask) def set_ic_light_condition( - self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False + self, + image: Image.Image, + mask: Image.Image | None = None, ) -> None: - """ - Set the IC light condition. + """Set the IC light condition. + + Args: + image: The reference image. + mask: The mask to use for the reference image. If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, the image will be used as is, but note that IC-Light requires a 127-valued gray background to work. - - `use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the - Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results. - see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265 """ if mask is not None: image = self.compute_gray_composite(image=image, mask=mask) - image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype) - if use_rescaled_image: - image_tensor = 2 * image_tensor - 1 - latents = self.lda.encode(image_tensor) + latents = self.lda.image_to_latents(image) self._ic_light_condition = latents def __call__( - self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0 + self, + x: torch.Tensor, + step: int, + *, + clip_text_embedding: torch.Tensor, + condition_scale: float = 2.0, ) -> torch.Tensor: assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first" x = torch.cat((x, self._ic_light_condition), dim=1)