mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
implement foreground conditioned ic light
This commit is contained in:
parent
928da1ee1c
commit
51dcd7772c
|
@ -19,6 +19,7 @@ class Args(argparse.Namespace):
|
||||||
half: bool
|
half: bool
|
||||||
verbose: bool
|
verbose: bool
|
||||||
skip_init_check: bool
|
skip_init_check: bool
|
||||||
|
override_weights: str | None
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
|
89
scripts/conversion/convert_ic_light.py
Normal file
89
scripts/conversion/convert_ic_light.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter
|
||||||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
class Args(argparse.Namespace):
|
||||||
|
source_path: str
|
||||||
|
output_path: str | None
|
||||||
|
subfolder: str
|
||||||
|
half: bool
|
||||||
|
verbose: bool
|
||||||
|
reference_unet_path: str
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners")
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
default="lllyasviel/ic-light",
|
||||||
|
help=(
|
||||||
|
"Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:"
|
||||||
|
" lllyasviel/ic-light"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||||||
|
" source path."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Prints additional information during conversion. Default: False",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-unet-path",
|
||||||
|
type=str,
|
||||||
|
dest="reference_unet_path",
|
||||||
|
default="runwayml/stable-diffusion-v1-5",
|
||||||
|
help="Path to the reference UNet weights.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(namespace=Args())
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.filename).stem}-refiners.safetensors"
|
||||||
|
|
||||||
|
patch_file = (
|
||||||
|
Path(args.source_path)
|
||||||
|
if args.source_path.endswith(".safetensors")
|
||||||
|
else Path(
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=args.source_path,
|
||||||
|
filename=args.filename,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_weights = load_from_safetensors(patch_file)
|
||||||
|
|
||||||
|
unet_args = UNetArgs(
|
||||||
|
source_path=args.reference_unet_path,
|
||||||
|
subfolder="unet",
|
||||||
|
half=False,
|
||||||
|
verbose=False,
|
||||||
|
skip_init_check=True,
|
||||||
|
override_weights=None,
|
||||||
|
)
|
||||||
|
converter = setup_unet_converter(args=unet_args)
|
||||||
|
result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||||
|
source_state_dict=patch_weights,
|
||||||
|
target_state_dict=converter.target_model.state_dict(),
|
||||||
|
state_dict_mapping=converter.get_mapping(),
|
||||||
|
)
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -438,6 +438,14 @@ def download_sdxl_lightning_lora():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_ic_light():
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors",
|
||||||
|
dest_folder=test_weights_dir,
|
||||||
|
expected_hash="bce70123",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def printg(msg: str):
|
def printg(msg: str):
|
||||||
"""print in green color"""
|
"""print in green color"""
|
||||||
print("\033[92m" + msg + "\033[0m")
|
print("\033[92m" + msg + "\033[0m")
|
||||||
|
@ -790,6 +798,16 @@ def convert_sdxl_lightning_base():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ic_light():
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_ic_light.py",
|
||||||
|
"tests/weights/iclight_sd15_fc.safetensors",
|
||||||
|
"tests/weights/iclight_sd15_fc-refiners.safetensors",
|
||||||
|
half=False,
|
||||||
|
expected_hash="be315c1f",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_all():
|
def download_all():
|
||||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||||
|
@ -811,6 +829,7 @@ def download_all():
|
||||||
download_lcm_lora()
|
download_lcm_lora()
|
||||||
download_sdxl_lightning_base()
|
download_sdxl_lightning_base()
|
||||||
download_sdxl_lightning_lora()
|
download_sdxl_lightning_lora()
|
||||||
|
download_ic_light()
|
||||||
|
|
||||||
|
|
||||||
def convert_all():
|
def convert_all():
|
||||||
|
@ -830,6 +849,7 @@ def convert_all():
|
||||||
convert_control_lora_fooocus()
|
convert_control_lora_fooocus()
|
||||||
convert_lcm_base()
|
convert_lcm_base()
|
||||||
convert_sdxl_lightning_base()
|
convert_sdxl_lightning_base()
|
||||||
|
convert_ic_light()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.nn.init import zeros_ as zero_init
|
||||||
|
|
||||||
|
from refiners.fluxion import layers as fl
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, no_grad
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks, SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class ICLight(StableDiffusion_1):
|
||||||
|
"""
|
||||||
|
IC-Light is a Stable Diffusion model that can be used to relight a reference image.
|
||||||
|
|
||||||
|
At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.
|
||||||
|
|
||||||
|
```example
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||||
|
from refiners.foundationals.clip import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
no_grad().__enter__()
|
||||||
|
manual_seed(42)
|
||||||
|
|
||||||
|
sd = ICLight(
|
||||||
|
patch_weights=load_from_safetensors(
|
||||||
|
path=hf_hub_download(
|
||||||
|
repo_id="refiners/ic_light.sd1_5.fc",
|
||||||
|
filename="model.safetensors",
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
|
||||||
|
tensors_path=hf_hub_download(
|
||||||
|
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
|
||||||
|
filename="model.safetensors",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "soft lighting, high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
image = Image.open("reference-image.png").resize((512, 512))
|
||||||
|
sd.set_ic_light_condition(image)
|
||||||
|
|
||||||
|
x = torch.randn(
|
||||||
|
size=(1, 4, 64, 64),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in sd.steps:
|
||||||
|
x = sd(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=1.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd.lda.latents_to_image(x)
|
||||||
|
|
||||||
|
predicted_image.save("ic-light-output.png")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_weights: dict[str, torch.Tensor],
|
||||||
|
unet: SD1UNet,
|
||||||
|
lda: SD1Autoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
solver: Solver | None = None,
|
||||||
|
device: torch.device | str = "cpu",
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
solver=solver,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self._extend_conv_in()
|
||||||
|
self._apply_patch(weights=patch_weights)
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _extend_conv_in(self) -> None:
|
||||||
|
"""
|
||||||
|
Extend to 8 the input channels of the first convolutional layer of the UNet.
|
||||||
|
"""
|
||||||
|
down_blocks = self.unet.ensure_find(DownBlocks)
|
||||||
|
first_block = down_blocks.layer(0, fl.Chain)
|
||||||
|
conv_in = first_block.ensure_find(fl.Conv2d)
|
||||||
|
new_conv_in = fl.Conv2d(
|
||||||
|
in_channels=conv_in.in_channels + 4,
|
||||||
|
out_channels=conv_in.out_channels,
|
||||||
|
kernel_size=(conv_in.kernel_size[0], conv_in.kernel_size[1]),
|
||||||
|
padding=(int(conv_in.padding[0]), int(conv_in.padding[1])),
|
||||||
|
device=conv_in.device,
|
||||||
|
dtype=conv_in.dtype,
|
||||||
|
)
|
||||||
|
zero_init(new_conv_in.weight)
|
||||||
|
new_conv_in.bias = conv_in.bias
|
||||||
|
new_conv_in.weight[:, :4, :, :] = conv_in.weight
|
||||||
|
first_block.replace(old_module=conv_in, new_module=new_conv_in)
|
||||||
|
|
||||||
|
def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
"""
|
||||||
|
Apply the patch weights to the UNet, modifying inplace the state dict.
|
||||||
|
"""
|
||||||
|
current_state_dict = self.unet.state_dict()
|
||||||
|
new_state_dict = {
|
||||||
|
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
|
||||||
|
}
|
||||||
|
self.unet.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Compute a grayscale composite of an image and a mask.
|
||||||
|
"""
|
||||||
|
assert mask.mode == "L", "Mask must be a grayscale image"
|
||||||
|
assert image.size == mask.size, "Image and mask must have the same size"
|
||||||
|
background = Image.new("RGB", image.size, (127, 127, 127))
|
||||||
|
return Image.composite(image, background, mask)
|
||||||
|
|
||||||
|
def set_ic_light_condition(
|
||||||
|
self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the IC light condition.
|
||||||
|
|
||||||
|
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
|
||||||
|
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
|
||||||
|
|
||||||
|
`use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
|
||||||
|
Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
|
||||||
|
see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
|
||||||
|
"""
|
||||||
|
if mask is not None:
|
||||||
|
image = self.compute_gray_composite(image=image, mask=mask)
|
||||||
|
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||||
|
if use_rescaled_image:
|
||||||
|
image_tensor = 2 * image_tensor - 1
|
||||||
|
latents = self.lda.encode(image_tensor)
|
||||||
|
self._ic_light_condition = latents
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
|
||||||
|
x = torch.cat((x, self._ic_light_condition), dim=1)
|
||||||
|
return super().__call__(
|
||||||
|
x,
|
||||||
|
step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
)
|
|
@ -12,6 +12,7 @@ from tests.utils import ensure_similar_images
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
ControlLoraAdapter,
|
ControlLoraAdapter,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
@ -30,6 +31,8 @@ from refiners.foundationals.latent_diffusion.reference_only_control import Refer
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import (
|
||||||
SD1DiffusionTarget,
|
SD1DiffusionTarget,
|
||||||
SD1MultiDiffusion,
|
SD1MultiDiffusion,
|
||||||
|
@ -2564,3 +2567,58 @@ def test_multi_upscaler(
|
||||||
) -> None:
|
) -> None:
|
||||||
predicted_image = multi_upscaler.upscale(clarity_example)
|
predicted_image = multi_upscaler.upscale(clarity_example)
|
||||||
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def expected_ic_light(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ic_light_sd15_fc_weights(test_weights_path: Path) -> Path:
|
||||||
|
return test_weights_path / "iclight_sd15_fc-refiners.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ic_light_sd15_fc(
|
||||||
|
ic_light_sd15_fc_weights: Path,
|
||||||
|
unet_weights_std: Path,
|
||||||
|
lda_weights: Path,
|
||||||
|
text_encoder_weights: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> ICLight:
|
||||||
|
return ICLight(
|
||||||
|
patch_weights=load_from_safetensors(ic_light_sd15_fc_weights),
|
||||||
|
unet=SD1UNet(in_channels=4).load_from_safetensors(unet_weights_std),
|
||||||
|
lda=SD1Autoencoder().load_from_safetensors(lda_weights),
|
||||||
|
clip_text_encoder=CLIPTextEncoderL().load_from_safetensors(text_encoder_weights),
|
||||||
|
device=test_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_ic_light(
|
||||||
|
kitchen_dog: Image.Image,
|
||||||
|
kitchen_dog_mask: Image.Image,
|
||||||
|
ic_light_sd15_fc: ICLight,
|
||||||
|
expected_ic_light: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
sd = ic_light_sd15_fc
|
||||||
|
manual_seed(2)
|
||||||
|
clip_text_embedding = sd.compute_clip_text_embedding(
|
||||||
|
text="a photo of dog, purple neon lighting",
|
||||||
|
negative_text="lowres, bad anatomy, bad hands, cropped, worst quality",
|
||||||
|
)
|
||||||
|
ic_light_condition = sd.compute_gray_composite(image=kitchen_dog, mask=kitchen_dog_mask.convert("L"))
|
||||||
|
sd.set_ic_light_condition(ic_light_condition)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
for step in sd.steps:
|
||||||
|
x = sd(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=2.0,
|
||||||
|
)
|
||||||
|
predicted_image = sd.lda.latents_to_image(x)
|
||||||
|
ensure_similar_images(predicted_image, expected_ic_light, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
|
@ -60,6 +60,7 @@ Special cases:
|
||||||
- `expected_controlnet_canny_scale_decay.png`
|
- `expected_controlnet_canny_scale_decay.png`
|
||||||
- `expected_multi_diffusion_dpm.png`
|
- `expected_multi_diffusion_dpm.png`
|
||||||
- `expected_multi_upscaler.png`
|
- `expected_multi_upscaler.png`
|
||||||
|
- `expected_ic_light.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_ic_light.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_ic_light.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
Loading…
Reference in a new issue