diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py index 4f11099..e3500c6 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py @@ -1,8 +1,9 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder - +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL __all__ = [ "SDXLUNet", "DoubleTextEncoder", + "StableDiffusion_XL", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py new file mode 100644 index 0000000..0e4c783 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -0,0 +1,96 @@ +import torch +from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel +from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder +from torch import device as Device, dtype as DType, Tensor + + +class SDXLAutoencoder(LatentDiffusionAutoencoder): + encoder_scale: float = 0.13025 + + +class StableDiffusion_XL(LatentDiffusionModel): + unet: SDXLUNet + clip_text_encoder: DoubleTextEncoder + + def __init__( + self, + unet: SDXLUNet | None = None, + lda: SDXLAutoencoder | None = None, + clip_text_encoder: DoubleTextEncoder | None = None, + scheduler: Scheduler | None = None, + device: Device | str = "cpu", + dtype: DType = torch.float32, + ) -> None: + unet = unet or SDXLUNet(in_channels=4) + lda = lda or SDXLAutoencoder() + clip_text_encoder = clip_text_encoder or DoubleTextEncoder() + scheduler = scheduler or DDIM(num_inference_steps=30) + + super().__init__( + unet=unet, + lda=lda, + clip_text_encoder=clip_text_encoder, + scheduler=scheduler, + device=device, + dtype=dtype, + ) + + def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]: + conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) + if text == negative_text: + return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat( + tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0 + ) + + # TODO: when negative_text is None, use zero tensor? + negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "") + + return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat( + tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0 + ) + + @property + def default_time_ids(self) -> Tensor: + # [original_height, original_width, crop_top, crop_left, target_height, target_width] + # See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning + time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device) + return time_ids.repeat(2, 1) + + def set_unet_context( + self, + *, + timestep: Tensor, + clip_text_embedding: Tensor, + pooled_text_embedding: Tensor, + time_ids: Tensor, + **_: Tensor + ) -> None: + self.unet.set_timestep(timestep=timestep) + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) + self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding) + self.unet.set_time_ids(time_ids=time_ids) + + def forward( + self, + x: Tensor, + step: int, + *, + clip_text_embedding: Tensor, + pooled_text_embedding: Tensor, + time_ids: Tensor, + condition_scale: float = 5.0, + **kwargs: Tensor + ) -> Tensor: + return super().forward( + x=x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=condition_scale, + **kwargs + ) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index a2098be..13b3ca3 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from tests.utils import ensure_similar_images @@ -73,6 +74,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB") +@pytest.fixture +def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB") + + @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) def controlnet_data( ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest @@ -326,6 +332,51 @@ def sd15_ddim_lda_ft_mse( return sd15 +@pytest.fixture +def sdxl_lda_weights(test_weights_path: Path) -> Path: + sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors" + if not sdxl_lda_weights.is_file(): + warn(message=f"could not find weights at {sdxl_lda_weights}, skipping") + pytest.skip(allow_module_level=True) + return sdxl_lda_weights + + +@pytest.fixture +def sdxl_unet_weights(test_weights_path: Path) -> Path: + sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors" + if not sdxl_unet_weights.is_file(): + warn(message=f"could not find weights at {sdxl_unet_weights}, skipping") + pytest.skip(allow_module_level=True) + return sdxl_unet_weights + + +@pytest.fixture +def sdxl_text_encoder_weights(test_weights_path: Path) -> Path: + sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors" + if not sdxl_double_text_encoder_weights.is_file(): + warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping") + pytest.skip(allow_module_level=True) + return sdxl_double_text_encoder_weights + + +@pytest.fixture +def sdxl_ddim( + sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device +) -> StableDiffusion_XL: + if test_device.type == "cpu": + warn(message="not running on CPU, skipping") + pytest.skip() + + scheduler = DDIM(num_inference_steps=30) + sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device) + + sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights) + sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights) + sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights) + + return sdxl + + @torch.no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device @@ -957,3 +1008,37 @@ def test_diffusion_ip_adapter( predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) + + +def test_sdxl_random_init( + sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device +) -> None: + sdxl = sdxl_ddim + expected_image = expected_sdxl_ddim_random_init + n_steps = 30 + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + time_ids = sdxl.default_time_ids + + sdxl.set_num_inference_steps(num_inference_steps=n_steps) + + manual_seed(seed=2) + x = torch.randn(1, 4, 128, 128, device=test_device) + + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=5, + ) + predicted_image = sdxl.lda.decode_latents(x=x) + + ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) diff --git a/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_ddim_random_init.png b/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_ddim_random_init.png new file mode 100644 index 0000000..236bef5 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_cutecat_sdxl_ddim_random_init.png differ