diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index 610ec88..694f10f 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -19,6 +19,7 @@ class Args(argparse.Namespace): half: bool verbose: bool skip_init_check: bool + override_weights: str | None def setup_converter(args: Args) -> ModelConverter: diff --git a/scripts/conversion/convert_ic_light.py b/scripts/conversion/convert_ic_light.py new file mode 100644 index 0000000..ddc367e --- /dev/null +++ b/scripts/conversion/convert_ic_light.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter +from huggingface_hub import hf_hub_download # type: ignore + +from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors + + +class Args(argparse.Namespace): + source_path: str + output_path: str | None + subfolder: str + half: bool + verbose: bool + reference_unet_path: str + + +def main() -> None: + parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners") + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="lllyasviel/ic-light", + help=( + "Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:" + " lllyasviel/ic-light" + ), + ) + parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.") + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Output path (.safetensors) for converted model. If not provided, the output path will be the same as the" + " source path." + ), + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Prints additional information during conversion. Default: False", + ) + parser.add_argument( + "--reference-unet-path", + type=str, + dest="reference_unet_path", + default="runwayml/stable-diffusion-v1-5", + help="Path to the reference UNet weights.", + ) + args = parser.parse_args(namespace=Args()) + if args.output_path is None: + args.output_path = f"{Path(args.filename).stem}-refiners.safetensors" + + patch_file = ( + Path(args.source_path) + if args.source_path.endswith(".safetensors") + else Path( + hf_hub_download( + repo_id=args.source_path, + filename=args.filename, + ) + ) + ) + patch_weights = load_from_safetensors(patch_file) + + unet_args = UNetArgs( + source_path=args.reference_unet_path, + subfolder="unet", + half=False, + verbose=False, + skip_init_check=True, + override_weights=None, + ) + converter = setup_unet_converter(args=unet_args) + result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] + source_state_dict=patch_weights, + target_state_dict=converter.target_model.state_dict(), + state_dict_mapping=converter.get_mapping(), + ) + save_to_safetensors(path=args.output_path, tensors=result) + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index f30f87a..c20447c 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -438,6 +438,14 @@ def download_sdxl_lightning_lora(): ) +def download_ic_light(): + download_file( + "https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors", + dest_folder=test_weights_dir, + expected_hash="bce70123", + ) + + def printg(msg: str): """print in green color""" print("\033[92m" + msg + "\033[0m") @@ -790,6 +798,16 @@ def convert_sdxl_lightning_base(): ) +def convert_ic_light(): + run_conversion_script( + "convert_ic_light.py", + "tests/weights/iclight_sd15_fc.safetensors", + "tests/weights/iclight_sd15_fc-refiners.safetensors", + half=False, + expected_hash="be315c1f", + ) + + def download_all(): print(f"\nAll weights will be downloaded to {test_weights_dir}\n") download_sd15("runwayml/stable-diffusion-v1-5") @@ -811,6 +829,7 @@ def download_all(): download_lcm_lora() download_sdxl_lightning_base() download_sdxl_lightning_lora() + download_ic_light() def convert_all(): @@ -830,6 +849,7 @@ def convert_all(): convert_control_lora_fooocus() convert_lcm_base() convert_sdxl_lightning_base() + convert_ic_light() def main(): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py new file mode 100644 index 0000000..0f31dbe --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py @@ -0,0 +1,182 @@ +import torch +from PIL import Image +from torch.nn.init import zeros_ as zero_init + +from refiners.fluxion import layers as fl +from refiners.fluxion.utils import image_to_tensor, no_grad +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.latent_diffusion.solvers.solver import Solver +from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1 +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks, SD1UNet + + +class ICLight(StableDiffusion_1): + """ + IC-Light is a Stable Diffusion model that can be used to relight a reference image. + + At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now. + + ```example + import torch + from huggingface_hub import hf_hub_download + from PIL import Image + + from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad + from refiners.foundationals.clip import CLIPTextEncoderL + from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet + from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + no_grad().__enter__() + manual_seed(42) + + sd = ICLight( + patch_weights=load_from_safetensors( + path=hf_hub_download( + repo_id="refiners/ic_light.sd1_5.fc", + filename="model.safetensors", + ), + device=device, + ), + unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.unet", + filename="model.safetensors", + ) + ), + clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder", + filename="model.safetensors", + ) + ), + lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors( + tensors_path=hf_hub_download( + repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder", + filename="model.safetensors", + ) + ), + device=device, + dtype=dtype, + ) + + prompt = "soft lighting, high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + image = Image.open("reference-image.png").resize((512, 512)) + sd.set_ic_light_condition(image) + + x = torch.randn( + size=(1, 4, 64, 64), + device=device, + dtype=dtype, + ) + + for step in sd.steps: + x = sd( + x=x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=1.5, + ) + predicted_image = sd.lda.latents_to_image(x) + + predicted_image.save("ic-light-output.png") + """ + + def __init__( + self, + patch_weights: dict[str, torch.Tensor], + unet: SD1UNet, + lda: SD1Autoencoder | None = None, + clip_text_encoder: CLIPTextEncoderL | None = None, + solver: Solver | None = None, + device: torch.device | str = "cpu", + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__( + unet=unet, + lda=lda, + clip_text_encoder=clip_text_encoder, + solver=solver, + device=device, + dtype=dtype, + ) + self._extend_conv_in() + self._apply_patch(weights=patch_weights) + + @no_grad() + def _extend_conv_in(self) -> None: + """ + Extend to 8 the input channels of the first convolutional layer of the UNet. + """ + down_blocks = self.unet.ensure_find(DownBlocks) + first_block = down_blocks.layer(0, fl.Chain) + conv_in = first_block.ensure_find(fl.Conv2d) + new_conv_in = fl.Conv2d( + in_channels=conv_in.in_channels + 4, + out_channels=conv_in.out_channels, + kernel_size=(conv_in.kernel_size[0], conv_in.kernel_size[1]), + padding=(int(conv_in.padding[0]), int(conv_in.padding[1])), + device=conv_in.device, + dtype=conv_in.dtype, + ) + zero_init(new_conv_in.weight) + new_conv_in.bias = conv_in.bias + new_conv_in.weight[:, :4, :, :] = conv_in.weight + first_block.replace(old_module=conv_in, new_module=new_conv_in) + + def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None: + """ + Apply the patch weights to the UNet, modifying inplace the state dict. + """ + current_state_dict = self.unet.state_dict() + new_state_dict = { + key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items() + } + self.unet.load_state_dict(new_state_dict) + + @staticmethod + def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image: + """ + Compute a grayscale composite of an image and a mask. + """ + assert mask.mode == "L", "Mask must be a grayscale image" + assert image.size == mask.size, "Image and mask must have the same size" + background = Image.new("RGB", image.size, (127, 127, 127)) + return Image.composite(image, background, mask) + + def set_ic_light_condition( + self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False + ) -> None: + """ + Set the IC light condition. + + If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, + the image will be used as is, but note that IC-Light requires a 127-valued gray background to work. + + `use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the + Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results. + see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265 + """ + if mask is not None: + image = self.compute_gray_composite(image=image, mask=mask) + image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype) + if use_rescaled_image: + image_tensor = 2 * image_tensor - 1 + latents = self.lda.encode(image_tensor) + self._ic_light_condition = latents + + def __call__( + self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0 + ) -> torch.Tensor: + assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first" + x = torch.cat((x, self._ic_light_condition), dim=1) + return super().__call__( + x, + step, + clip_text_embedding=clip_text_embedding, + condition_scale=condition_scale, + ) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index d4e4865..a8f5396 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -12,6 +12,7 @@ from tests.utils import ensure_similar_images from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion import ( ControlLoraAdapter, SD1ControlnetAdapter, @@ -30,6 +31,8 @@ from refiners.foundationals.latent_diffusion.reference_only_control import Refer from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver +from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight +from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import ( SD1DiffusionTarget, SD1MultiDiffusion, @@ -2564,3 +2567,58 @@ def test_multi_upscaler( ) -> None: predicted_image = multi_upscaler.upscale(clarity_example) ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99) + + +@pytest.fixture(scope="module") +def expected_ic_light(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_ic_light.png").convert("RGB") + + +@pytest.fixture(scope="module") +def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path: + return test_weights_path / "iclight_sd15_fc-refiners.safetensors" + + +@pytest.fixture(scope="module") +def ic_light_sd15_fc( + ic_light_sd15_fc_weights: Path, + unet_weights_std: Path, + lda_weights: Path, + text_encoder_weights: Path, + test_device: torch.device, +) -> ICLight: + return ICLight( + patch_weights=load_from_safetensors(ic_light_sd15_fc_weights), + unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std), + lda=SD1Autoencoder().load_from_safetensors(lda_weights), + clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights), + device=test_device, + ) + + +@no_grad() +def test_ic_light( + kitchen_dog: Image.Image, + kitchen_dog_mask: Image.Image, + ic_light_sd15_fc: ICLight, + expected_ic_light: Image.Image, + test_device: torch.device, +) -> None: + sd = ic_light_sd15_fc + manual_seed(2) + clip_text_embedding = sd.compute_clip_text_embedding( + text="a photo of dog, purple neon lighting", + negative_text="lowres, bad anatomy, bad hands, cropped, worst quality", + ) + ic_light_condition = sd.compute_gray_composite(image=kitchen_dog, mask=kitchen_dog_mask.convert("L")) + sd.set_ic_light_condition(ic_light_condition) + x = torch.randn(1, 4, 64, 64, device=test_device) + for step in sd.steps: + x = sd( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=2.0, + ) + predicted_image = sd.lda.latents_to_image(x) + ensure_similar_images(predicted_image, expected_ic_light, min_psnr=35, min_ssim=0.99) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 926127f..a0ef337 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -60,6 +60,7 @@ Special cases: - `expected_controlnet_canny_scale_decay.png` - `expected_multi_diffusion_dpm.png` - `expected_multi_upscaler.png` + - `expected_ic_light.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_ic_light.png b/tests/e2e/test_diffusion_ref/expected_ic_light.png new file mode 100644 index 0000000..677e2a1 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_ic_light.png differ