mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
update ic_light adapter, bugfix + improve docstrings
This commit is contained in:
parent
2cb0f06119
commit
d7bbf9ad5d
|
@ -1,4 +1,5 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
SD1Autoencoder,
|
SD1Autoencoder,
|
||||||
|
@ -16,4 +17,5 @@ __all__ = [
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
"SD1T2IAdapter",
|
"SD1T2IAdapter",
|
||||||
|
"ICLight",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,7 @@ from PIL import Image
|
||||||
from torch.nn.init import zeros_ as zero_init
|
from torch.nn.init import zeros_ as zero_init
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import image_to_tensor, no_grad
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
|
||||||
|
@ -11,79 +11,82 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import Down
|
||||||
|
|
||||||
|
|
||||||
class ICLight(StableDiffusion_1):
|
class ICLight(StableDiffusion_1):
|
||||||
"""
|
"""IC-Light is a Stable Diffusion model that can be used to relight a reference image.
|
||||||
IC-Light is a Stable Diffusion model that can be used to relight a reference image.
|
|
||||||
|
|
||||||
At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.
|
At initialization, the UNet will be patched to accept four additional input channels.
|
||||||
|
Only the text-conditioned relighting model is supported for now.
|
||||||
|
|
||||||
```example
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
Example:
|
||||||
from refiners.foundationals.clip import CLIPTextEncoderL
|
```py
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
|
import torch
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||||
dtype = torch.float32
|
from refiners.foundationals.clip import CLIPTextEncoderL
|
||||||
no_grad().__enter__()
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
|
||||||
manual_seed(42)
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
||||||
|
|
||||||
sd = ICLight(
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
patch_weights=load_from_safetensors(
|
dtype = torch.float32
|
||||||
path=hf_hub_download(
|
no_grad().__enter__()
|
||||||
repo_id="refiners/ic_light.sd1_5.fc",
|
manual_seed(42)
|
||||||
filename="model.safetensors",
|
|
||||||
|
sd = ICLight(
|
||||||
|
patch_weights=load_from_safetensors(
|
||||||
|
path=hf_hub_download(
|
||||||
|
repo_id="refiners/ic_light.sd1_5.fc",
|
||||||
|
filename="model.safetensors",
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
),
|
),
|
||||||
device=device,
|
device=device,
|
||||||
),
|
dtype=dtype,
|
||||||
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
|
|
||||||
tensors_path=hf_hub_download(
|
|
||||||
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
|
|
||||||
filename="model.safetensors",
|
|
||||||
)
|
|
||||||
),
|
|
||||||
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
|
|
||||||
tensors_path=hf_hub_download(
|
|
||||||
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
|
|
||||||
filename="model.safetensors",
|
|
||||||
)
|
|
||||||
),
|
|
||||||
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
|
|
||||||
tensors_path=hf_hub_download(
|
|
||||||
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
|
|
||||||
filename="model.safetensors",
|
|
||||||
)
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = "soft lighting, high-quality professional image"
|
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
|
||||||
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
image = Image.open("reference-image.png").resize((512, 512))
|
|
||||||
sd.set_ic_light_condition(image)
|
|
||||||
|
|
||||||
x = torch.randn(
|
|
||||||
size=(1, 4, 64, 64),
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
for step in sd.steps:
|
|
||||||
x = sd(
|
|
||||||
x=x,
|
|
||||||
step=step,
|
|
||||||
clip_text_embedding=clip_text_embedding,
|
|
||||||
condition_scale=1.5,
|
|
||||||
)
|
)
|
||||||
predicted_image = sd.lda.latents_to_image(x)
|
|
||||||
|
|
||||||
predicted_image.save("ic-light-output.png")
|
prompt = "soft lighting, high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
image = Image.open("reference-image.png").resize((512, 512))
|
||||||
|
sd.set_ic_light_condition(image)
|
||||||
|
|
||||||
|
x = torch.randn(
|
||||||
|
size=(1, 4, 64, 64),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in sd.steps:
|
||||||
|
x = sd(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=1.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd.lda.latents_to_image(x)
|
||||||
|
|
||||||
|
predicted_image.save("ic-light-output.png")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -109,9 +112,7 @@ class ICLight(StableDiffusion_1):
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def _extend_conv_in(self) -> None:
|
def _extend_conv_in(self) -> None:
|
||||||
"""
|
"""Extend to 8 the input channels of the first convolutional layer of the UNet."""
|
||||||
Extend to 8 the input channels of the first convolutional layer of the UNet.
|
|
||||||
"""
|
|
||||||
down_blocks = self.unet.ensure_find(DownBlocks)
|
down_blocks = self.unet.ensure_find(DownBlocks)
|
||||||
first_block = down_blocks.layer(0, fl.Chain)
|
first_block = down_blocks.layer(0, fl.Chain)
|
||||||
conv_in = first_block.ensure_find(fl.Conv2d)
|
conv_in = first_block.ensure_find(fl.Conv2d)
|
||||||
|
@ -129,9 +130,7 @@ class ICLight(StableDiffusion_1):
|
||||||
first_block.replace(old_module=conv_in, new_module=new_conv_in)
|
first_block.replace(old_module=conv_in, new_module=new_conv_in)
|
||||||
|
|
||||||
def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
|
def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
|
||||||
"""
|
"""Apply the weights patch to the UNet, modifying inplace the state dict."""
|
||||||
Apply the patch weights to the UNet, modifying inplace the state dict.
|
|
||||||
"""
|
|
||||||
current_state_dict = self.unet.state_dict()
|
current_state_dict = self.unet.state_dict()
|
||||||
new_state_dict = {
|
new_state_dict = {
|
||||||
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
|
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
|
||||||
|
@ -139,9 +138,17 @@ class ICLight(StableDiffusion_1):
|
||||||
self.unet.load_state_dict(new_state_dict)
|
self.unet.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image:
|
def compute_gray_composite(
|
||||||
"""
|
image: Image.Image,
|
||||||
Compute a grayscale composite of an image and a mask.
|
mask: Image.Image,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Compute a grayscale composite of an image and a mask.
|
||||||
|
|
||||||
|
IC-Light will recreate the image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to composite.
|
||||||
|
mask: The mask to use for the composite.
|
||||||
"""
|
"""
|
||||||
assert mask.mode == "L", "Mask must be a grayscale image"
|
assert mask.mode == "L", "Mask must be a grayscale image"
|
||||||
assert image.size == mask.size, "Image and mask must have the same size"
|
assert image.size == mask.size, "Image and mask must have the same size"
|
||||||
|
@ -149,28 +156,31 @@ class ICLight(StableDiffusion_1):
|
||||||
return Image.composite(image, background, mask)
|
return Image.composite(image, background, mask)
|
||||||
|
|
||||||
def set_ic_light_condition(
|
def set_ic_light_condition(
|
||||||
self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
mask: Image.Image | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Set the IC light condition.
|
||||||
Set the IC light condition.
|
|
||||||
|
Args:
|
||||||
|
image: The reference image.
|
||||||
|
mask: The mask to use for the reference image.
|
||||||
|
|
||||||
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
|
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
|
||||||
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
|
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
|
||||||
|
|
||||||
`use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
|
|
||||||
Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
|
|
||||||
see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
|
|
||||||
"""
|
"""
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
image = self.compute_gray_composite(image=image, mask=mask)
|
image = self.compute_gray_composite(image=image, mask=mask)
|
||||||
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
latents = self.lda.image_to_latents(image)
|
||||||
if use_rescaled_image:
|
|
||||||
image_tensor = 2 * image_tensor - 1
|
|
||||||
latents = self.lda.encode(image_tensor)
|
|
||||||
self._ic_light_condition = latents
|
self._ic_light_condition = latents
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: torch.Tensor,
|
||||||
|
condition_scale: float = 2.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
|
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
|
||||||
x = torch.cat((x, self._ic_light_condition), dim=1)
|
x = torch.cat((x, self._ic_light_condition), dim=1)
|
||||||
|
|
Loading…
Reference in a new issue