update ic_light adapter, bugfix + improve docstrings

This commit is contained in:
Laurent 2024-08-21 10:20:11 +00:00 committed by Laureηt
parent 2cb0f06119
commit d7bbf9ad5d
2 changed files with 98 additions and 86 deletions

View file

@ -1,4 +1,5 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder, SD1Autoencoder,
@ -16,4 +17,5 @@ __all__ = [
"SD1ControlnetAdapter", "SD1ControlnetAdapter",
"SD1IPAdapter", "SD1IPAdapter",
"SD1T2IAdapter", "SD1T2IAdapter",
"ICLight",
] ]

View file

@ -3,7 +3,7 @@ from PIL import Image
from torch.nn.init import zeros_ as zero_init from torch.nn.init import zeros_ as zero_init
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import image_to_tensor, no_grad from refiners.fluxion.utils import no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1 from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
@ -11,12 +11,14 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import Down
class ICLight(StableDiffusion_1): class ICLight(StableDiffusion_1):
""" """IC-Light is a Stable Diffusion model that can be used to relight a reference image.
IC-Light is a Stable Diffusion model that can be used to relight a reference image.
At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now. At initialization, the UNet will be patched to accept four additional input channels.
Only the text-conditioned relighting model is supported for now.
```example
Example:
```py
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
@ -84,6 +86,7 @@ class ICLight(StableDiffusion_1):
predicted_image = sd.lda.latents_to_image(x) predicted_image = sd.lda.latents_to_image(x)
predicted_image.save("ic-light-output.png") predicted_image.save("ic-light-output.png")
```
""" """
def __init__( def __init__(
@ -109,9 +112,7 @@ class ICLight(StableDiffusion_1):
@no_grad() @no_grad()
def _extend_conv_in(self) -> None: def _extend_conv_in(self) -> None:
""" """Extend to 8 the input channels of the first convolutional layer of the UNet."""
Extend to 8 the input channels of the first convolutional layer of the UNet.
"""
down_blocks = self.unet.ensure_find(DownBlocks) down_blocks = self.unet.ensure_find(DownBlocks)
first_block = down_blocks.layer(0, fl.Chain) first_block = down_blocks.layer(0, fl.Chain)
conv_in = first_block.ensure_find(fl.Conv2d) conv_in = first_block.ensure_find(fl.Conv2d)
@ -129,9 +130,7 @@ class ICLight(StableDiffusion_1):
first_block.replace(old_module=conv_in, new_module=new_conv_in) first_block.replace(old_module=conv_in, new_module=new_conv_in)
def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None: def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
""" """Apply the weights patch to the UNet, modifying inplace the state dict."""
Apply the patch weights to the UNet, modifying inplace the state dict.
"""
current_state_dict = self.unet.state_dict() current_state_dict = self.unet.state_dict()
new_state_dict = { new_state_dict = {
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items() key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
@ -139,9 +138,17 @@ class ICLight(StableDiffusion_1):
self.unet.load_state_dict(new_state_dict) self.unet.load_state_dict(new_state_dict)
@staticmethod @staticmethod
def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image: def compute_gray_composite(
""" image: Image.Image,
Compute a grayscale composite of an image and a mask. mask: Image.Image,
) -> Image.Image:
"""Compute a grayscale composite of an image and a mask.
IC-Light will recreate the image
Args:
image: The image to composite.
mask: The mask to use for the composite.
""" """
assert mask.mode == "L", "Mask must be a grayscale image" assert mask.mode == "L", "Mask must be a grayscale image"
assert image.size == mask.size, "Image and mask must have the same size" assert image.size == mask.size, "Image and mask must have the same size"
@ -149,28 +156,31 @@ class ICLight(StableDiffusion_1):
return Image.composite(image, background, mask) return Image.composite(image, background, mask)
def set_ic_light_condition( def set_ic_light_condition(
self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False self,
image: Image.Image,
mask: Image.Image | None = None,
) -> None: ) -> None:
""" """Set the IC light condition.
Set the IC light condition.
Args:
image: The reference image.
mask: The mask to use for the reference image.
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise, If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work. the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
`use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
""" """
if mask is not None: if mask is not None:
image = self.compute_gray_composite(image=image, mask=mask) image = self.compute_gray_composite(image=image, mask=mask)
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype) latents = self.lda.image_to_latents(image)
if use_rescaled_image:
image_tensor = 2 * image_tensor - 1
latents = self.lda.encode(image_tensor)
self._ic_light_condition = latents self._ic_light_condition = latents
def __call__( def __call__(
self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0 self,
x: torch.Tensor,
step: int,
*,
clip_text_embedding: torch.Tensor,
condition_scale: float = 2.0,
) -> torch.Tensor: ) -> torch.Tensor:
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first" assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
x = torch.cat((x, self._ic_light_condition), dim=1) x = torch.cat((x, self._ic_light_condition), dim=1)