diff --git a/scripts/convert-informative-drawings-weights.py b/scripts/convert-informative-drawings-weights.py new file mode 100644 index 0000000..adac5f0 --- /dev/null +++ b/scripts/convert-informative-drawings-weights.py @@ -0,0 +1,60 @@ +# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings +# Code is at https://github.com/carolineec/informative-drawings +# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code +# and imports if you want, we only need `Generator`. + +import torch + +from safetensors.torch import save_file +from refiners.fluxion.utils import ( + create_state_dict_mapping, + convert_state_dict, +) + +from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings +from model import Generator + + +@torch.no_grad() +def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]: + src_model = Generator(3, 1, 3) + src_model.load_state_dict(torch.load(checkpoint, map_location=device)) + src_model.eval() + + dst_model = InformativeDrawings() + + x = torch.randn(1, 3, 512, 512) + + mapping = create_state_dict_mapping(src_model, dst_model, [x]) + state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + return {k: v.half() for k, v in state_dict.items()} + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--from", + type=str, + dest="source", + required=False, + default="model2.pth", + help="Source model", + ) + parser.add_argument( + "--output-file", + type=str, + required=False, + default="informative-drawings.safetensors", + help="Path for the output file", + ) + args = parser.parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + + tensors = convert(args.source, device) + save_file(tensors, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index 170bac6..f0b59a5 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -1,5 +1,5 @@ from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid -from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d +from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d, InstanceNorm2d from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d from refiners.fluxion.layers.basics import ( Identity, @@ -28,9 +28,10 @@ from refiners.fluxion.layers.chain import ( Breakpoint, Concatenate, ) -from refiners.fluxion.layers.conv import Conv2d +from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d from refiners.fluxion.layers.linear import Linear, MultiLinear from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule +from refiners.fluxion.layers.padding import ReflectionPad2d from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.embedding import Embedding @@ -39,6 +40,7 @@ __all__ = [ "LayerNorm", "GroupNorm", "LayerNorm2d", + "InstanceNorm2d", "GeLU", "GLU", "SiLU", @@ -72,6 +74,7 @@ __all__ = [ "Breakpoint", "Concatenate", "Conv2d", + "ConvTranspose2d", "Linear", "MultiLinear", "Downsample", @@ -80,4 +83,5 @@ __all__ = [ "WeightedModule", "ContextModule", "Interpolate", + "ReflectionPad2d", ] diff --git a/src/refiners/fluxion/layers/conv.py b/src/refiners/fluxion/layers/conv.py index 1fc5124..fd24308 100644 --- a/src/refiners/fluxion/layers/conv.py +++ b/src/refiners/fluxion/layers/conv.py @@ -1,19 +1,18 @@ -from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d -from torch import device as Device, dtype as DType +from torch import nn, device as Device, dtype as DType from refiners.fluxion.layers.module import WeightedModule -class Conv2d(_Conv2d, WeightedModule): +class Conv2d(nn.Conv2d, WeightedModule): def __init__( self, in_channels: int, out_channels: int, - kernel_size: int | tuple[int, ...], - stride: int | tuple[int, ...] = 1, - padding: int | tuple[int, ...] | str = 0, - dilation: int | tuple[int, ...] = 1, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] | str = 0, groups: int = 1, use_bias: bool = True, + dilation: int | tuple[int, int] = 1, padding_mode: str = "zeros", device: Device | str | None = None, dtype: DType | None = None, @@ -31,26 +30,19 @@ class Conv2d(_Conv2d, WeightedModule): device, dtype, ) - self.in_channels = in_channels - self.out_channels = out_channels - self.padding = (padding,) if isinstance(padding, int) else padding - self.dilation = (dilation,) if isinstance(dilation, int) else dilation - self.groups = groups - self.use_bias = use_bias - self.padding_mode = padding_mode -class Conv1d(_Conv1d, WeightedModule): +class Conv1d(nn.Conv1d, WeightedModule): def __init__( self, in_channels: int, out_channels: int, - kernel_size: int | tuple[int, ...], - stride: int | tuple[int, ...] = 1, - padding: int | tuple[int, ...] | str = 0, - dilation: int | tuple[int, ...] = 1, + kernel_size: int | tuple[int], + stride: int | tuple[int] = 1, + padding: int | tuple[int] | str = 0, groups: int = 1, use_bias: bool = True, + dilation: int | tuple[int] = 1, padding_mode: str = "zeros", device: Device | str | None = None, dtype: DType | None = None, @@ -68,6 +60,35 @@ class Conv1d(_Conv1d, WeightedModule): device, dtype, ) - self.in_channels = in_channels - self.out_channels = out_channels - self.use_bias = use_bias + + +class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + output_padding: int | tuple[int, int] = 0, + groups: int = 1, + use_bias: bool = True, + dilation: int | tuple[int, int] = 1, + padding_mode: str = "zeros", + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( # type: ignore + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + bias=use_bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index 11d7d3a..97016cb 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -1,10 +1,9 @@ -from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType -from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm +from torch import nn, ones, zeros, Tensor, sqrt, device as Device, dtype as DType from jaxtyping import Float -from refiners.fluxion.layers.module import WeightedModule +from refiners.fluxion.layers.module import Module, WeightedModule -class LayerNorm(_LayerNorm, WeightedModule): +class LayerNorm(nn.LayerNorm, WeightedModule): def __init__( self, normalized_shape: int | list[int], @@ -21,7 +20,7 @@ class LayerNorm(_LayerNorm, WeightedModule): ) -class GroupNorm(_GroupNorm, WeightedModule): +class GroupNorm(nn.GroupNorm, WeightedModule): def __init__( self, channels: int, @@ -60,8 +59,8 @@ class LayerNorm2d(WeightedModule): dtype: DType | None = None, ) -> None: super().__init__() - self.weight = Parameter(ones(channels, device=device, dtype=dtype)) - self.bias = Parameter(zeros(channels, device=device, dtype=dtype)) + self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype)) + self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype)) self.eps = eps def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]: @@ -70,3 +69,19 @@ class LayerNorm2d(WeightedModule): x_norm = (x - x_mean) / sqrt(x_var + self.eps) x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1) return x_out + + +class InstanceNorm2d(nn.InstanceNorm2d, Module): + def __init__( + self, + num_features: int, + eps: float = 1e-05, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( # type: ignore + num_features=num_features, + eps=eps, + device=device, + dtype=dtype, + ) diff --git a/src/refiners/fluxion/layers/padding.py b/src/refiners/fluxion/layers/padding.py new file mode 100644 index 0000000..d8d7377 --- /dev/null +++ b/src/refiners/fluxion/layers/padding.py @@ -0,0 +1,7 @@ +from torch import nn +from refiners.fluxion.layers.module import Module + + +class ReflectionPad2d(nn.ReflectionPad2d, Module): + def __init__(self, padding: int) -> None: + super().__init__(padding=padding) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index fe64817..8b9b198 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -82,6 +82,9 @@ BASIC_LAYERS: list[str] = [ "Conv1d", "Conv2d", "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", "Linear", "BatchNorm1d", "BatchNorm2d", diff --git a/src/refiners/foundationals/latent_diffusion/preprocessors/__init__.py b/src/refiners/foundationals/latent_diffusion/preprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py b/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py new file mode 100644 index 0000000..148db60 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py @@ -0,0 +1,106 @@ +# Adapted from https://github.com/carolineec/informative-drawings, MIT License + +from torch import device as Device, dtype as DType +import refiners.fluxion.layers as fl + + +class InformativeDrawings(fl.Chain): + """Model typically used as the preprocessor for the Lineart ControlNet. + + Implements the paper "Learning to generate line drawings that convey + geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand + and Phillip Isola - https://arxiv.org/abs/2203.12691 + + For use as a preprocessor it is recommended to use the weights for "Style 2". + """ + + def __init__( + self, + in_channels: int = 3, # RGB + out_channels: int = 1, # Grayscale + n_residual_blocks: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Chain( # Initial convolution + fl.ReflectionPad2d(3), + fl.Conv2d( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + device=device, + dtype=dtype, + ), + fl.InstanceNorm2d(64, device=device, dtype=dtype), + fl.ReLU(), + ), + *( # Downsampling + fl.Chain( + fl.Conv2d( + in_channels=64 * (2**i), + out_channels=128 * (2**i), + kernel_size=3, + stride=2, + padding=1, + device=device, + dtype=dtype, + ), + fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype), + fl.ReLU(), + ) + for i in range(2) + ), + *( # Residual blocks + fl.Residual( + fl.ReflectionPad2d(1), + fl.Conv2d( + in_channels=256, + out_channels=256, + kernel_size=3, + device=device, + dtype=dtype, + ), + fl.InstanceNorm2d(256, device=device, dtype=dtype), + fl.ReLU(), + fl.ReflectionPad2d(1), + fl.Conv2d( + in_channels=256, + out_channels=256, + kernel_size=3, + device=device, + dtype=dtype, + ), + fl.InstanceNorm2d(256, device=device, dtype=dtype), + ) + for _ in range(n_residual_blocks) + ), + *( # Upsampling + fl.Chain( + fl.ConvTranspose2d( + in_channels=128 * (2**i), + out_channels=64 * (2**i), + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + device=device, + dtype=dtype, + ), + fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype), + fl.ReLU(), + ) + for i in reversed(range(2)) + ), + fl.Chain( # Output layer + fl.ReflectionPad2d(3), + fl.Conv2d( + in_channels=64, + out_channels=out_channels, + kernel_size=7, + device=device, + dtype=dtype, + ), + fl.Sigmoid(), + ) + ) diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py new file mode 100644 index 0000000..21b9355 --- /dev/null +++ b/tests/e2e/test_preprocessors.py @@ -0,0 +1,56 @@ +import torch +import pytest + +from warnings import warn +from PIL import Image +from pathlib import Path + +from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image +from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings + +from tests.utils import ensure_similar_images + + +@pytest.fixture(scope="module") +def diffusion_ref_path(test_e2e_path: Path) -> Path: + return test_e2e_path / "test_diffusion_ref" + + +@pytest.fixture(scope="module") +def cutecat_init(diffusion_ref_path: Path) -> Image.Image: + return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB") + + +@pytest.fixture +def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image: + return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB") + + +@pytest.fixture(scope="module") +def informative_drawings_weights(test_weights_path: Path) -> Path: + weights = test_weights_path / "informative-drawings.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {test_weights_path}, skipping") + pytest.skip(allow_module_level=True) + return weights + + +@pytest.fixture +def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings: + model = InformativeDrawings(device=test_device) + model.load_state_dict(load_from_safetensors(informative_drawings_weights)) + return model + + +@torch.no_grad() +def test_preprocessor_informative_drawing( + informative_drawings_model: InformativeDrawings, + cutecat_init: Image.Image, + expected_image_informative_drawings: Image.Image, + test_device: torch.device, +): + in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device) + out_tensor = informative_drawings_model(in_tensor) + rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB + image = tensor_to_image(rgb_tensor) + ensure_similar_images(image, expected_image_informative_drawings)