add InformativeDrawings

https://github.com/carolineec/informative-drawings

This is the preprocessor for the Lineart ControlNet.
This commit is contained in:
Pierre Chapuis 2023-08-08 12:17:16 +02:00
parent e10f761a84
commit 97b162d9a0
9 changed files with 303 additions and 31 deletions

View file

@ -0,0 +1,60 @@
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
# Code is at https://github.com/carolineec/informative-drawings
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
# and imports if you want, we only need `Generator`.
import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from model import Generator
@torch.no_grad()
def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]:
src_model = Generator(3, 1, 3)
src_model.load_state_dict(torch.load(checkpoint, map_location=device))
src_model.eval()
dst_model = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)
mapping = create_state_dict_mapping(src_model, dst_model, [x])
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
return {k: v.half() for k, v in state_dict.items()}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
dest="source",
required=False,
default="model2.pth",
help="Source model",
)
parser.add_argument(
"--output-file",
type=str,
required=False,
default="informative-drawings.safetensors",
help="Path for the output file",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
tensors = convert(args.source, device)
save_file(tensors, args.output_file)
if __name__ == "__main__":
main()

View file

@ -1,5 +1,5 @@
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d, InstanceNorm2d
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
from refiners.fluxion.layers.basics import (
Identity,
@ -28,9 +28,10 @@ from refiners.fluxion.layers.chain import (
Breakpoint,
Concatenate,
)
from refiners.fluxion.layers.conv import Conv2d
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
from refiners.fluxion.layers.linear import Linear, MultiLinear
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
from refiners.fluxion.layers.padding import ReflectionPad2d
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding
@ -39,6 +40,7 @@ __all__ = [
"LayerNorm",
"GroupNorm",
"LayerNorm2d",
"InstanceNorm2d",
"GeLU",
"GLU",
"SiLU",
@ -72,6 +74,7 @@ __all__ = [
"Breakpoint",
"Concatenate",
"Conv2d",
"ConvTranspose2d",
"Linear",
"MultiLinear",
"Downsample",
@ -80,4 +83,5 @@ __all__ = [
"WeightedModule",
"ContextModule",
"Interpolate",
"ReflectionPad2d",
]

View file

@ -1,19 +1,18 @@
from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d
from torch import device as Device, dtype as DType
from torch import nn, device as Device, dtype as DType
from refiners.fluxion.layers.module import WeightedModule
class Conv2d(_Conv2d, WeightedModule):
class Conv2d(nn.Conv2d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] | str = 0,
dilation: int | tuple[int, ...] = 1,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = 1,
padding: int | tuple[int, int] | str = 0,
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int, int] = 1,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
@ -31,26 +30,19 @@ class Conv2d(_Conv2d, WeightedModule):
device,
dtype,
)
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = (padding,) if isinstance(padding, int) else padding
self.dilation = (dilation,) if isinstance(dilation, int) else dilation
self.groups = groups
self.use_bias = use_bias
self.padding_mode = padding_mode
class Conv1d(_Conv1d, WeightedModule):
class Conv1d(nn.Conv1d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] | str = 0,
dilation: int | tuple[int, ...] = 1,
kernel_size: int | tuple[int],
stride: int | tuple[int] = 1,
padding: int | tuple[int] | str = 0,
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int] = 1,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
@ -68,6 +60,35 @@ class Conv1d(_Conv1d, WeightedModule):
device,
dtype,
)
self.in_channels = in_channels
self.out_channels = out_channels
self.use_bias = use_bias
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = 1,
padding: int | tuple[int, int] = 0,
output_padding: int | tuple[int, int] = 0,
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int, int] = 1,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
bias=use_bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)

View file

@ -1,10 +1,9 @@
from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType
from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm
from torch import nn, ones, zeros, Tensor, sqrt, device as Device, dtype as DType
from jaxtyping import Float
from refiners.fluxion.layers.module import WeightedModule
from refiners.fluxion.layers.module import Module, WeightedModule
class LayerNorm(_LayerNorm, WeightedModule):
class LayerNorm(nn.LayerNorm, WeightedModule):
def __init__(
self,
normalized_shape: int | list[int],
@ -21,7 +20,7 @@ class LayerNorm(_LayerNorm, WeightedModule):
)
class GroupNorm(_GroupNorm, WeightedModule):
class GroupNorm(nn.GroupNorm, WeightedModule):
def __init__(
self,
channels: int,
@ -60,8 +59,8 @@ class LayerNorm2d(WeightedModule):
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = Parameter(ones(channels, device=device, dtype=dtype))
self.bias = Parameter(zeros(channels, device=device, dtype=dtype))
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
self.eps = eps
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
@ -70,3 +69,19 @@ class LayerNorm2d(WeightedModule):
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out
class InstanceNorm2d(nn.InstanceNorm2d, Module):
def __init__(
self,
num_features: int,
eps: float = 1e-05,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
num_features=num_features,
eps=eps,
device=device,
dtype=dtype,
)

View file

@ -0,0 +1,7 @@
from torch import nn
from refiners.fluxion.layers.module import Module
class ReflectionPad2d(nn.ReflectionPad2d, Module):
def __init__(self, padding: int) -> None:
super().__init__(padding=padding)

View file

@ -82,6 +82,9 @@ BASIC_LAYERS: list[str] = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"Linear",
"BatchNorm1d",
"BatchNorm2d",

View file

@ -0,0 +1,106 @@
# Adapted from https://github.com/carolineec/informative-drawings, MIT License
from torch import device as Device, dtype as DType
import refiners.fluxion.layers as fl
class InformativeDrawings(fl.Chain):
"""Model typically used as the preprocessor for the Lineart ControlNet.
Implements the paper "Learning to generate line drawings that convey
geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand
and Phillip Isola - https://arxiv.org/abs/2203.12691
For use as a preprocessor it is recommended to use the weights for "Style 2".
"""
def __init__(
self,
in_channels: int = 3, # RGB
out_channels: int = 1, # Grayscale
n_residual_blocks: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Chain( # Initial convolution
fl.ReflectionPad2d(3),
fl.Conv2d(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(64, device=device, dtype=dtype),
fl.ReLU(),
),
*( # Downsampling
fl.Chain(
fl.Conv2d(
in_channels=64 * (2**i),
out_channels=128 * (2**i),
kernel_size=3,
stride=2,
padding=1,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype),
fl.ReLU(),
)
for i in range(2)
),
*( # Residual blocks
fl.Residual(
fl.ReflectionPad2d(1),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(256, device=device, dtype=dtype),
fl.ReLU(),
fl.ReflectionPad2d(1),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(256, device=device, dtype=dtype),
)
for _ in range(n_residual_blocks)
),
*( # Upsampling
fl.Chain(
fl.ConvTranspose2d(
in_channels=128 * (2**i),
out_channels=64 * (2**i),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype),
fl.ReLU(),
)
for i in reversed(range(2))
),
fl.Chain( # Output layer
fl.ReflectionPad2d(3),
fl.Conv2d(
in_channels=64,
out_channels=out_channels,
kernel_size=7,
device=device,
dtype=dtype,
),
fl.Sigmoid(),
)
)

View file

@ -0,0 +1,56 @@
import torch
import pytest
from warnings import warn
from PIL import Image
from pathlib import Path
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from tests.utils import ensure_similar_images
@pytest.fixture(scope="module")
def diffusion_ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_diffusion_ref"
@pytest.fixture(scope="module")
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
@pytest.fixture
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
@pytest.fixture(scope="module")
def informative_drawings_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "informative-drawings.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
model = InformativeDrawings(device=test_device)
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
return model
@torch.no_grad()
def test_preprocessor_informative_drawing(
informative_drawings_model: InformativeDrawings,
cutecat_init: Image.Image,
expected_image_informative_drawings: Image.Image,
test_device: torch.device,
):
in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device)
out_tensor = informative_drawings_model(in_tensor)
rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB
image = tensor_to_image(rgb_tensor)
ensure_similar_images(image, expected_image_informative_drawings)