add IP-Adapter support for SD 1.5

Official repo: https://github.com/tencent-ailab/IP-Adapter
This commit is contained in:
Cédric Deltheil 2023-09-06 12:23:53 +02:00 committed by Cédric Deltheil
parent d4dd45fd4d
commit c55917e293
8 changed files with 579 additions and 1 deletions

View file

@ -0,0 +1,111 @@
from pathlib import Path
from typing import Any
import argparse
import torch
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1IPAdapter
from refiners.fluxion.utils import save_to_safetensors
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a IP-Adapter diffusers model to refiners.")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="ip-adapter_sd15.bin",
help="Path to the source model. (default: 'ip-adapter_sd15.bin').",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="ip-adapter_sd15.safetensors",
help="Path to save the converted model. (default: 'ip-adapter_sd15.safetensors').",
)
parser.add_argument("--verbose", action="store_true", dest="verbose")
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
unet = SD1UNet(in_channels=4)
ip_adapter = SD1IPAdapter(target=unet)
# Manual conversion to avoid any runtime dependency on IP-Adapter[1] custom classes
# [1]: https://github.com/tencent-ailab/IP-Adapter
state_dict: dict[str, torch.Tensor] = {}
image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor] = {
"Linear.weight": image_proj_weights["proj.weight"],
"Linear.bias": image_proj_weights["proj.bias"],
"LayerNorm.weight": image_proj_weights["norm.weight"],
"LayerNorm.bias": image_proj_weights["norm.bias"],
}
ip_adapter.image_proj.load_state_dict(state_dict=image_proj_state_dict)
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2
# Running:
#
# from diffusers import UNet2DConditionModel
# unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# for k in unet.attn_processors.keys():
# print(k)
#
# Gives:
#
# down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor
# down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor
# ...
# down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
# up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor
# ...
# up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn2.processor
#
# With attn1=self-attention and attn2=cross-attention, and middle block in last position. So in terms of increasing
# indices:
#
# DownBlocks -> [1, 3, 5, 7, 9, 11]
# MiddleBlock -> [31]
# UpBlocks -> [13, 15, 17, 19, 21, 23, 25, 27, 29]
cross_attn_mapping: list[int] = [1, 3, 5, 7, 9, 11, 31, 13, 15, 17, 19, 21, 23, 25, 27, 29]
for i, cross_attn in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight"
# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights
cross_attn_state_dict: dict[str, Any] = {
cross_attn.get_parameter_name("wk_prime"): ip_adapter_weights[k_ip],
cross_attn.get_parameter_name("wv_prime"): ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)
for k, v in cross_attn_state_dict.items():
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
StableDiffusion_1_Inpainting,
SD1UNet,
SD1ControlnetAdapter,
SD1IPAdapter,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet,
@ -22,6 +23,7 @@ __all__ = [
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"DPMSolver",

View file

@ -4,10 +4,12 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
]

View file

@ -0,0 +1,348 @@
from enum import IntEnum
from pathlib import Path
from functools import partial
from torch import Tensor, as_tensor, cat, zeros_like, device as Device, dtype as DType
from PIL import Image
from refiners.fluxion.adapters.adapter import Adapter
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors
import refiners.fluxion.layers as fl
class ImageProjection(fl.Chain):
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
def __init__(
self,
clip_image_embedding_dim: int = 1024,
clip_text_embedding_dim: int = 768,
sequence_length: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.clip_image_embedding_dim = clip_image_embedding_dim
self.clip_text_embedding_dim = clip_text_embedding_dim
self.sequence_length = sequence_length
super().__init__(
fl.Linear(
in_features=clip_image_embedding_dim,
out_features=clip_text_embedding_dim * sequence_length,
device=device,
dtype=dtype,
),
fl.Reshape(sequence_length, clip_text_embedding_dim),
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
)
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
# Fluxion's Attention layer drop-in replacement implementing Decoupled Cross-Attention
class IPAttention(fl.Chain):
structural_attrs = [
"embedding_dim",
"text_sequence_length",
"image_sequence_length",
"scale",
"num_heads",
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"inner_dim",
"use_bias",
"is_causal",
]
def __init__(
self,
embedding_dim: int,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True,
is_causal: bool | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
self.embedding_dim = embedding_dim
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
self.num_heads = num_heads
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
fl.Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
fl.Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
fl.Linear(
in_features=self.inner_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
with self.setup_adapter(target):
super().__init__(
IPAttention(
embedding_dim=target.embedding_dim,
text_sequence_length=text_sequence_length,
image_sequence_length=image_sequence_length,
scale=scale,
num_heads=target.num_heads,
key_embedding_dim=target.key_embedding_dim,
value_embedding_dim=target.value_embedding_dim,
inner_dim=target.inner_dim,
use_bias=target.use_bias,
is_causal=target.is_causal,
device=target.device,
dtype=target.dtype,
)
)
def get_parameter_name(self, matrix: str, bias: bool = False) -> str:
match matrix:
case "wq":
index = 0
case "wk":
index = 1
case "wk_prime":
index = 2
case "wv":
index = 3
case "wv_prime":
index = 4
case "proj":
index = 5
case _:
raise ValueError(f"Unexpected matrix name {matrix}")
linear = list(self.IPAttention.layers(fl.Linear))[index]
param = getattr(linear, "bias" if bias else "weight")
name = next((n for n, p in self.named_parameters() if id(p) == id(param)), None)
assert name is not None
return name
class SD1IPAdapter(fl.Chain, Adapter[SD1UNet]):
def __init__(
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
self.image_proj = ImageProjection(device=target.device, dtype=target.dtype)
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
if weights is not None:
image_proj_state_dict: dict[str, Tensor] = {
k.removeprefix("image_proj."): v for k, v in weights.items() if k.startswith("image_proj.")
}
self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
# Retrieve original (frozen) cross-attention weights
# Note: this assumes the target UNet has already loaded weights
cross_attn_linears = list(cross_attn.target.layers(fl.Linear))
assert len(cross_attn_linears) == 4 # Wq, Wk, Wv and Proj
cross_attn_state_dict[cross_attn.get_parameter_name("wq")] = cross_attn_linears[0].weight
cross_attn_state_dict[cross_attn.get_parameter_name("wk")] = cross_attn_linears[1].weight
cross_attn_state_dict[cross_attn.get_parameter_name("wv")] = cross_attn_linears[2].weight
cross_attn_state_dict[cross_attn.get_parameter_name("proj")] = cross_attn_linears[3].weight
cross_attn_state_dict[cross_attn.get_parameter_name("proj", bias=True)] = cross_attn_linears[3].bias
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
@classmethod
def from_safetensors(
cls,
target: SD1UNet,
checkpoint_path: Path | str,
clip_image_encoder: CLIPImageEncoderH | None = None,
scale: float = 1.0,
):
weights = load_from_safetensors(checkpoint_path, device=target.device if target.device is not None else "cpu")
return cls(
target=target,
clip_image_encoder=clip_image_encoder,
scale=scale,
weights=weights,
)
def inject(self: "SD1IPAdapter", parent: fl.Chain | None = None) -> "SD1IPAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor | None) -> Tensor:
clip_embedding = self.clip_image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
negative_embedding = self.image_proj(zeros_like(clip_embedding))
return cat((negative_embedding, conditional_embedding))
def preprocess_image(
self,
image: Image.Image,
size: tuple[int, int] = (224, 224),
mean: list[float] | None = None,
std: list[float] | None = None,
) -> Tensor:
# Default mean and std are parameters from https://github.com/openai/CLIP
return self._normalize(
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
)
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
@staticmethod
def _normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
assert tensor.is_floating_point()
assert tensor.ndim >= 3
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean_tensor = as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std_tensor = as_tensor(std, dtype=tensor.dtype, device=tensor.device)
if (std_tensor == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean_tensor.ndim == 1:
mean_tensor = mean_tensor.view(-1, 1, 1)
if std_tensor.ndim == 1:
std_tensor = std_tensor.view(-1, 1, 1)
return tensor.sub_(mean_tensor).div_(std_tensor)

View file

@ -13,6 +13,7 @@ from refiners.foundationals.latent_diffusion import (
StableDiffusion_1_Inpainting,
SD1UNet,
SD1ControlnetAdapter,
SD1IPAdapter,
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM
@ -42,6 +43,11 @@ def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")
@pytest.fixture(scope="module")
def woman_image(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "woman.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@ -62,6 +68,11 @@ def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -182,6 +193,33 @@ def unet_weights_inpainting(test_weights_path: Path) -> Path:
return unet_weights_inpainting
@pytest.fixture(scope="module")
def lda_ft_mse_weights(test_weights_path: Path) -> Path:
lda_weights = test_weights_path / "lda_ft_mse.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return lda_weights
@pytest.fixture(scope="module")
def ip_adapter_weights(test_weights_path: Path) -> Path:
ip_adapter_weights = test_weights_path / "ip-adapter_sd15.safetensors"
if not ip_adapter_weights.is_file():
warn(f"could not find weights at {ip_adapter_weights}, skipping")
pytest.skip(allow_module_level=True)
return ip_adapter_weights
@pytest.fixture(scope="module")
def image_encoder_weights(test_weights_path: Path) -> Path:
image_encoder_weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not image_encoder_weights.is_file():
warn(f"could not find weights at {image_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return image_encoder_weights
@pytest.fixture
def sd15_std(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
@ -270,6 +308,24 @@ def sd15_ddim(
return sd15
@pytest.fixture
def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_ft_mse_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
return sd15
@torch.no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
@ -844,3 +900,60 @@ def test_diffusion_textual_inversion_random_init(
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
image_encoder_weights: Path,
woman_image: Image.Image,
expected_image_ip_adapter_woman: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
# See tencent-ailab/IP-Adapter best practices section:
#
# If you only use the image prompt, you can set the scale=1.0 and text_prompt="" (or some generic text
# prompts, e.g. "best quality", you can also use any negative text prompt).
#
# The prompts below are the ones used by default by IPAdapter's generate method if none are specified
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
ip_adapter = SD1IPAdapter.from_safetensors(target=sd15.unet, checkpoint_path=ip_adapter_weights)
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
sd15.set_num_inference_steps(n_steps)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)

View file

@ -35,7 +35,7 @@ output.images[0].save("std_random_init_expected.png")
Special cases:
- `expected_refonly.png` has been generated [with Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
- `expected_inpainting_refonly.png` has been generated with refiners itself (and inspected so that it looks reasonable).
- `expected_inpainting_refonly.png`, `expected_image_ip_adapter_woman.png` have been generated with refiners itself (and inspected so that it looks reasonable).
## Other images
@ -59,6 +59,8 @@ Special cases:
- `inpainting-scene.png`: cropped-to-square-and-resized version of https://unsplash.com/photos/RCz6eSVPGYU by @jannerboy62
- `inpainting-target.png`: computed with `convert <(convert -size 512x512 xc:white png:-) kitchen_dog.png <(convert inpainting-mask.png -negate png:-) -compose Over -composite inpainting-target.png`
- `woman.png` [comes from tencent-ailab/IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/blob/8b96670cc5c8ef00278b42c0c7b62fe8a74510b9/assets/images/woman.png).
## VAE without randomness
```diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB