mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
add InformativeDrawings
https://github.com/carolineec/informative-drawings This is the preprocessor for the Lineart ControlNet.
This commit is contained in:
parent
e10f761a84
commit
97b162d9a0
60
scripts/convert-informative-drawings-weights.py
Normal file
60
scripts/convert-informative-drawings-weights.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
|
||||
# Code is at https://github.com/carolineec/informative-drawings
|
||||
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
|
||||
# and imports if you want, we only need `Generator`.
|
||||
|
||||
import torch
|
||||
|
||||
from safetensors.torch import save_file
|
||||
from refiners.fluxion.utils import (
|
||||
create_state_dict_mapping,
|
||||
convert_state_dict,
|
||||
)
|
||||
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
from model import Generator
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]:
|
||||
src_model = Generator(3, 1, 3)
|
||||
src_model.load_state_dict(torch.load(checkpoint, map_location=device))
|
||||
src_model.eval()
|
||||
|
||||
dst_model = InformativeDrawings()
|
||||
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
|
||||
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||
return {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source",
|
||||
required=False,
|
||||
default="model2.pth",
|
||||
help="Source model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
required=False,
|
||||
default="informative-drawings.safetensors",
|
||||
help="Path for the output file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
tensors = convert(args.source, device)
|
||||
save_file(tensors, args.output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,5 +1,5 @@
|
|||
from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid
|
||||
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d
|
||||
from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d, InstanceNorm2d
|
||||
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
||||
from refiners.fluxion.layers.basics import (
|
||||
Identity,
|
||||
|
@ -28,9 +28,10 @@ from refiners.fluxion.layers.chain import (
|
|||
Breakpoint,
|
||||
Concatenate,
|
||||
)
|
||||
from refiners.fluxion.layers.conv import Conv2d
|
||||
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
||||
from refiners.fluxion.layers.padding import ReflectionPad2d
|
||||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||
from refiners.fluxion.layers.embedding import Embedding
|
||||
|
||||
|
@ -39,6 +40,7 @@ __all__ = [
|
|||
"LayerNorm",
|
||||
"GroupNorm",
|
||||
"LayerNorm2d",
|
||||
"InstanceNorm2d",
|
||||
"GeLU",
|
||||
"GLU",
|
||||
"SiLU",
|
||||
|
@ -72,6 +74,7 @@ __all__ = [
|
|||
"Breakpoint",
|
||||
"Concatenate",
|
||||
"Conv2d",
|
||||
"ConvTranspose2d",
|
||||
"Linear",
|
||||
"MultiLinear",
|
||||
"Downsample",
|
||||
|
@ -80,4 +83,5 @@ __all__ = [
|
|||
"WeightedModule",
|
||||
"ContextModule",
|
||||
"Interpolate",
|
||||
"ReflectionPad2d",
|
||||
]
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
from torch.nn import Conv2d as _Conv2d, Conv1d as _Conv1d
|
||||
from torch import device as Device, dtype as DType
|
||||
from torch import nn, device as Device, dtype as DType
|
||||
from refiners.fluxion.layers.module import WeightedModule
|
||||
|
||||
|
||||
class Conv2d(_Conv2d, WeightedModule):
|
||||
class Conv2d(nn.Conv2d, WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | str = 0,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int] = 1,
|
||||
padding: int | tuple[int, int] | str = 0,
|
||||
groups: int = 1,
|
||||
use_bias: bool = True,
|
||||
dilation: int | tuple[int, int] = 1,
|
||||
padding_mode: str = "zeros",
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
|
@ -31,26 +30,19 @@ class Conv2d(_Conv2d, WeightedModule):
|
|||
device,
|
||||
dtype,
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.padding = (padding,) if isinstance(padding, int) else padding
|
||||
self.dilation = (dilation,) if isinstance(dilation, int) else dilation
|
||||
self.groups = groups
|
||||
self.use_bias = use_bias
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
|
||||
class Conv1d(_Conv1d, WeightedModule):
|
||||
class Conv1d(nn.Conv1d, WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | str = 0,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
kernel_size: int | tuple[int],
|
||||
stride: int | tuple[int] = 1,
|
||||
padding: int | tuple[int] | str = 0,
|
||||
groups: int = 1,
|
||||
use_bias: bool = True,
|
||||
dilation: int | tuple[int] = 1,
|
||||
padding_mode: str = "zeros",
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
|
@ -68,6 +60,35 @@ class Conv1d(_Conv1d, WeightedModule):
|
|||
device,
|
||||
dtype,
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_bias = use_bias
|
||||
|
||||
|
||||
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int] = 1,
|
||||
padding: int | tuple[int, int] = 0,
|
||||
output_padding: int | tuple[int, int] = 0,
|
||||
groups: int = 1,
|
||||
use_bias: bool = True,
|
||||
dilation: int | tuple[int, int] = 1,
|
||||
padding_mode: str = "zeros",
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__( # type: ignore
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=use_bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from torch import ones, zeros, Tensor, sqrt, device as Device, dtype as DType
|
||||
from torch.nn import GroupNorm as _GroupNorm, Parameter, LayerNorm as _LayerNorm
|
||||
from torch import nn, ones, zeros, Tensor, sqrt, device as Device, dtype as DType
|
||||
from jaxtyping import Float
|
||||
from refiners.fluxion.layers.module import WeightedModule
|
||||
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||
|
||||
|
||||
class LayerNorm(_LayerNorm, WeightedModule):
|
||||
class LayerNorm(nn.LayerNorm, WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int | list[int],
|
||||
|
@ -21,7 +20,7 @@ class LayerNorm(_LayerNorm, WeightedModule):
|
|||
)
|
||||
|
||||
|
||||
class GroupNorm(_GroupNorm, WeightedModule):
|
||||
class GroupNorm(nn.GroupNorm, WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
@ -60,8 +59,8 @@ class LayerNorm2d(WeightedModule):
|
|||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = Parameter(ones(channels, device=device, dtype=dtype))
|
||||
self.bias = Parameter(zeros(channels, device=device, dtype=dtype))
|
||||
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
||||
|
@ -70,3 +69,19 @@ class LayerNorm2d(WeightedModule):
|
|||
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
return x_out
|
||||
|
||||
|
||||
class InstanceNorm2d(nn.InstanceNorm2d, Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
eps: float = 1e-05,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__( # type: ignore
|
||||
num_features=num_features,
|
||||
eps=eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
7
src/refiners/fluxion/layers/padding.py
Normal file
7
src/refiners/fluxion/layers/padding.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from torch import nn
|
||||
from refiners.fluxion.layers.module import Module
|
||||
|
||||
|
||||
class ReflectionPad2d(nn.ReflectionPad2d, Module):
|
||||
def __init__(self, padding: int) -> None:
|
||||
super().__init__(padding=padding)
|
|
@ -82,6 +82,9 @@ BASIC_LAYERS: list[str] = [
|
|||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"ConvTranspose1d",
|
||||
"ConvTranspose2d",
|
||||
"ConvTranspose3d",
|
||||
"Linear",
|
||||
"BatchNorm1d",
|
||||
"BatchNorm2d",
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# Adapted from https://github.com/carolineec/informative-drawings, MIT License
|
||||
|
||||
from torch import device as Device, dtype as DType
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class InformativeDrawings(fl.Chain):
|
||||
"""Model typically used as the preprocessor for the Lineart ControlNet.
|
||||
|
||||
Implements the paper "Learning to generate line drawings that convey
|
||||
geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand
|
||||
and Phillip Isola - https://arxiv.org/abs/2203.12691
|
||||
|
||||
For use as a preprocessor it is recommended to use the weights for "Style 2".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3, # RGB
|
||||
out_channels: int = 1, # Grayscale
|
||||
n_residual_blocks: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Chain( # Initial convolution
|
||||
fl.ReflectionPad2d(3),
|
||||
fl.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.InstanceNorm2d(64, device=device, dtype=dtype),
|
||||
fl.ReLU(),
|
||||
),
|
||||
*( # Downsampling
|
||||
fl.Chain(
|
||||
fl.Conv2d(
|
||||
in_channels=64 * (2**i),
|
||||
out_channels=128 * (2**i),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype),
|
||||
fl.ReLU(),
|
||||
)
|
||||
for i in range(2)
|
||||
),
|
||||
*( # Residual blocks
|
||||
fl.Residual(
|
||||
fl.ReflectionPad2d(1),
|
||||
fl.Conv2d(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||
fl.ReLU(),
|
||||
fl.ReflectionPad2d(1),
|
||||
fl.Conv2d(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(n_residual_blocks)
|
||||
),
|
||||
*( # Upsampling
|
||||
fl.Chain(
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=128 * (2**i),
|
||||
out_channels=64 * (2**i),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype),
|
||||
fl.ReLU(),
|
||||
)
|
||||
for i in reversed(range(2))
|
||||
),
|
||||
fl.Chain( # Output layer
|
||||
fl.ReflectionPad2d(3),
|
||||
fl.Conv2d(
|
||||
in_channels=64,
|
||||
out_channels=out_channels,
|
||||
kernel_size=7,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Sigmoid(),
|
||||
)
|
||||
)
|
56
tests/e2e/test_preprocessors.py
Normal file
56
tests/e2e/test_preprocessors.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
from warnings import warn
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def diffusion_ref_path(test_e2e_path: Path) -> Path:
|
||||
return test_e2e_path / "test_diffusion_ref"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cutecat_init(diffusion_ref_path: Path) -> Image.Image:
|
||||
return Image.open(diffusion_ref_path / "cutecat_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_informative_drawings(diffusion_ref_path: Path) -> Image.Image:
|
||||
return Image.open(diffusion_ref_path / "cutecat_guide_lineart.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def informative_drawings_weights(test_weights_path: Path) -> Path:
|
||||
weights = test_weights_path / "informative-drawings.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
||||
model = InformativeDrawings(device=test_device)
|
||||
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_preprocessor_informative_drawing(
|
||||
informative_drawings_model: InformativeDrawings,
|
||||
cutecat_init: Image.Image,
|
||||
expected_image_informative_drawings: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
in_tensor = image_to_tensor(cutecat_init.convert("RGB"), device=test_device)
|
||||
out_tensor = informative_drawings_model(in_tensor)
|
||||
rgb_tensor = out_tensor.repeat(1, 3, 1, 1) # grayscale to RGB
|
||||
image = tensor_to_image(rgb_tensor)
|
||||
ensure_similar_images(image, expected_image_informative_drawings)
|
Loading…
Reference in a new issue