implement SDXL + e2e test on random init

This commit is contained in:
limiteinductive 2023-09-06 18:43:02 +02:00 committed by Benjamin Trom
parent 02af8e9f0b
commit 2786117469
4 changed files with 183 additions and 1 deletions

View file

@ -1,8 +1,9 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
__all__ = [ __all__ = [
"SDXLUNet", "SDXLUNet",
"DoubleTextEncoder", "DoubleTextEncoder",
"StableDiffusion_XL",
] ]

View file

@ -0,0 +1,96 @@
import torch
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from torch import device as Device, dtype as DType, Tensor
class SDXLAutoencoder(LatentDiffusionAutoencoder):
encoder_scale: float = 0.13025
class StableDiffusion_XL(LatentDiffusionModel):
unet: SDXLUNet
clip_text_encoder: DoubleTextEncoder
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
)
# TODO: when negative_text is None, use zero tensor?
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
)
@property
def default_time_ids(self) -> Tensor:
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
return time_ids.repeat(2, 1)
def set_unet_context(
self,
*,
timestep: Tensor,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**_: Tensor
) -> None:
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
self.unet.set_time_ids(time_ids=time_ids)
def forward(
self,
x: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
condition_scale: float = 5.0,
**kwargs: Tensor
) -> Tensor:
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
**kwargs
)

View file

@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
@ -73,6 +74,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
@pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data( def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -326,6 +332,51 @@ def sd15_ddim_lda_ft_mse(
return sd15 return sd15
@pytest.fixture
def sdxl_lda_weights(test_weights_path: Path) -> Path:
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
if not sdxl_lda_weights.is_file():
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_lda_weights
@pytest.fixture
def sdxl_unet_weights(test_weights_path: Path) -> Path:
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
if not sdxl_unet_weights.is_file():
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_unet_weights
@pytest.fixture
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
if not sdxl_double_text_encoder_weights.is_file():
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
pytest.skip(allow_module_level=True)
return sdxl_double_text_encoder_weights
@pytest.fixture
def sdxl_ddim(
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
) -> StableDiffusion_XL:
if test_device.type == "cpu":
warn(message="not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30)
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
return sdxl
@torch.no_grad() @torch.no_grad()
def test_diffusion_std_random_init( def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
@ -957,3 +1008,37 @@ def test_diffusion_ip_adapter(
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB