mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add InformativeDrawings
https://github.com/carolineec/informative-drawings This is the preprocessor for the Lineart ControlNet.
This commit is contained in:
parent
e10f761a84
commit
97b162d9a0
60
scripts/convert-informative-drawings-weights.py
Normal file
60
scripts/convert-informative-drawings-weights.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
|
||||||
|
# Code is at https://github.com/carolineec/informative-drawings
|
||||||
|
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
|
||||||
|
# and imports if you want, we only need `Generator`.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from refiners.fluxion.utils import (
|
||||||
|
create_state_dict_mapping,
|
||||||
|
convert_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
|
from model import Generator
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]:
|
||||||
|
src_model = Generator(3, 1, 3)
|
||||||
|
src_model.load_state_dict(torch.load(checkpoint, map_location=device))
|
||||||
|
src_model.eval()
|
||||||
|
|
||||||
|
dst_model = InformativeDrawings()
|
||||||
|
|
||||||
|
x = torch.randn(1, 3, 512, 512)
|
||||||
|
|
||||||
|
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||||||
|
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||||
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source",
|
||||||
|
required=False,
|
||||||
|
default="model2.pth",
|
||||||
|
help="Source model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="informative-drawings.safetensors",
|
||||||
|
help="Path for the output file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
tensors = convert(args.source, device)
|
||||||
|
save_file(tensors, args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,5 +1,5 @@
|
||||||
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid
|
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid
|
||||||
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d
|
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d, InstanceNorm2d
|
||||||
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
||||||
from refiners.fluxion.layers.basics import (
|
from refiners.fluxion.layers.basics import (
|
||||||
Identity,
|
Identity,
|
||||||
|
@ -28,9 +28,10 @@ from refiners.fluxion.layers.chain import (
|
||||||
Breakpoint,
|
Breakpoint,
|
||||||
Concatenate,
|
Concatenate,
|
||||||
)
|
)
|
||||||
from refiners.fluxion.layers.conv import Conv2d
|
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||||
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
||||||
|
from refiners.fluxion.layers.padding import ReflectionPad2d
|
||||||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||||
from refiners.fluxion.layers.embedding import Embedding
|
from refiners.fluxion.layers.embedding import Embedding
|
||||||
|
|
||||||
|
@ -39,6 +40,7 @@ __all__ = [
|
||||||
"LayerNorm",
|
"LayerNorm",
|
||||||
"GroupNorm",
|
"GroupNorm",
|
||||||
"LayerNorm2d",
|
"LayerNorm2d",
|
||||||
|
"InstanceNorm2d",
|
||||||
"GeLU",
|
"GeLU",
|
||||||
"GLU",
|
"GLU",
|
||||||
"SiLU",
|
"SiLU",
|
||||||
|
@ -72,6 +74,7 @@ __all__ = [
|
||||||
"Breakpoint",
|
"Breakpoint",
|
||||||
"Concatenate",
|
"Concatenate",
|
||||||
"Conv2d",
|
"Conv2d",
|
||||||
|
"ConvTranspose2d",
|
||||||
"Linear",
|
"Linear",
|
||||||
"MultiLinear",
|
"MultiLinear",
|
||||||
"Downsample",
|
"Downsample",
|
||||||
|
@ -80,4 +83,5 @@ __all__ = [
|
||||||
"WeightedModule",
|
"WeightedModule",
|
||||||
"ContextModule",
|
"ContextModule",
|
||||||
"Interpolate",
|
"Interpolate",
|
||||||
|
"ReflectionPad2d",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,19 +1,18 @@
|
||||||
from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d
|
from torch import nn, device as Device, dtype as DType
|
||||||
from torch import device as Device, dtype as DType
|
|
||||||
from refiners.fluxion.layers.module import WeightedModule
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(_Conv2d, WeightedModule):
|
class Conv2d(nn.Conv2d, WeightedModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: int | tuple[int, ...],
|
kernel_size: int | tuple[int, int],
|
||||||
stride: int | tuple[int, ...] = 1,
|
stride: int | tuple[int, int] = 1,
|
||||||
padding: int | tuple[int, ...] | str = 0,
|
padding: int | tuple[int, int] | str = 0,
|
||||||
dilation: int | tuple[int, ...] = 1,
|
|
||||||
groups: int = 1,
|
groups: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int, int] = 1,
|
||||||
padding_mode: str = "zeros",
|
padding_mode: str = "zeros",
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
|
@ -31,26 +30,19 @@ class Conv2d(_Conv2d, WeightedModule):
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.padding = (padding,) if isinstance(padding, int) else padding
|
|
||||||
self.dilation = (dilation,) if isinstance(dilation, int) else dilation
|
|
||||||
self.groups = groups
|
|
||||||
self.use_bias = use_bias
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(_Conv1d, WeightedModule):
|
class Conv1d(nn.Conv1d, WeightedModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: int | tuple[int, ...],
|
kernel_size: int | tuple[int],
|
||||||
stride: int | tuple[int, ...] = 1,
|
stride: int | tuple[int] = 1,
|
||||||
padding: int | tuple[int, ...] | str = 0,
|
padding: int | tuple[int] | str = 0,
|
||||||
dilation: int | tuple[int, ...] = 1,
|
|
||||||
groups: int = 1,
|
groups: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int] = 1,
|
||||||
padding_mode: str = "zeros",
|
padding_mode: str = "zeros",
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
|
@ -68,6 +60,35 @@ class Conv1d(_Conv1d, WeightedModule):
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.use_bias = use_bias
|
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
stride: int | tuple[int, int] = 1,
|
||||||
|
padding: int | tuple[int, int] = 0,
|
||||||
|
output_padding: int | tuple[int, int] = 0,
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int, int] = 1,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=use_bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType
|
from torch import nn, ones, zeros, Tensor, sqrt, device as Device, dtype as DType
|
||||||
from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from refiners.fluxion.layers.module import WeightedModule
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(_LayerNorm, WeightedModule):
|
class LayerNorm(nn.LayerNorm, WeightedModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape: int | list[int],
|
normalized_shape: int | list[int],
|
||||||
|
@ -21,7 +20,7 @@ class LayerNorm(_LayerNorm, WeightedModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(_GroupNorm, WeightedModule):
|
class GroupNorm(nn.GroupNorm, WeightedModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
@ -60,8 +59,8 @@ class LayerNorm2d(WeightedModule):
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = Parameter(ones(channels, device=device, dtype=dtype))
|
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
|
||||||
self.bias = Parameter(zeros(channels, device=device, dtype=dtype))
|
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
@ -70,3 +69,19 @@ class LayerNorm2d(WeightedModule):
|
||||||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||||
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceNorm2d(nn.InstanceNorm2d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-05,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
num_features=num_features,
|
||||||
|
eps=eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
7
src/refiners/fluxion/layers/padding.py
Normal file
7
src/refiners/fluxion/layers/padding.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from torch import nn
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectionPad2d(nn.ReflectionPad2d, Module):
|
||||||
|
def __init__(self, padding: int) -> None:
|
||||||
|
super().__init__(padding=padding)
|
|
@ -82,6 +82,9 @@ BASIC_LAYERS: list[str] = [
|
||||||
"Conv1d",
|
"Conv1d",
|
||||||
"Conv2d",
|
"Conv2d",
|
||||||
"Conv3d",
|
"Conv3d",
|
||||||
|
"ConvTranspose1d",
|
||||||
|
"ConvTranspose2d",
|
||||||
|
"ConvTranspose3d",
|
||||||
"Linear",
|
"Linear",
|
||||||
"BatchNorm1d",
|
"BatchNorm1d",
|
||||||
"BatchNorm2d",
|
"BatchNorm2d",
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Adapted from https://github.com/carolineec/informative-drawings, MIT License
|
||||||
|
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class InformativeDrawings(fl.Chain):
|
||||||
|
"""Model typically used as the preprocessor for the Lineart ControlNet.
|
||||||
|
|
||||||
|
Implements the paper "Learning to generate line drawings that convey
|
||||||
|
geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand
|
||||||
|
and Phillip Isola - https://arxiv.org/abs/2203.12691
|
||||||
|
|
||||||
|
For use as a preprocessor it is recommended to use the weights for "Style 2".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3, # RGB
|
||||||
|
out_channels: int = 1, # Grayscale
|
||||||
|
n_residual_blocks: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain( # Initial convolution
|
||||||
|
fl.ReflectionPad2d(3),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=7,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(64, device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
),
|
||||||
|
*( # Downsampling
|
||||||
|
fl.Chain(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=64 * (2**i),
|
||||||
|
out_channels=128 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
),
|
||||||
|
*( # Residual blocks
|
||||||
|
fl.Residual(
|
||||||
|
fl.ReflectionPad2d(1),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.ReflectionPad2d(1),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
for _ in range(n_residual_blocks)
|
||||||
|
),
|
||||||
|
*( # Upsampling
|
||||||
|
fl.Chain(
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=128 * (2**i),
|
||||||
|
out_channels=64 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
)
|
||||||
|
for i in reversed(range(2))
|
||||||
|
),
|
||||||
|
fl.Chain( # Output layer
|
||||||
|
fl.ReflectionPad2d(3),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Sigmoid(),
|
||||||
|
)
|
||||||
|
)
|
56
tests/e2e/test_preprocessors.py
Normal file
56
tests/e2e/test_preprocessors.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from warnings import warn
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
|
||||||
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
|
|
||||||
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
return test_e2e_path / "test_diffusion_ref"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def informative_drawings_weights(test_weights_path: Path) -> Path:
|
||||||
|
weights = test_weights_path / "informative-drawings.safetensors"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {test_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
||||||
|
model = InformativeDrawings(device=test_device)
|
||||||
|
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_preprocessor_informative_drawing(
|
||||||
|
informative_drawings_model: InformativeDrawings,
|
||||||
|
cutecat_init: Image.Image,
|
||||||
|
expected_image_informative_drawings: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device)
|
||||||
|
out_tensor = informative_drawings_model(in_tensor)
|
||||||
|
rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB
|
||||||
|
image = tensor_to_image(rgb_tensor)
|
||||||
|
ensure_similar_images(image, expected_image_informative_drawings)
|
Loading…
Reference in a new issue