mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
implement SDXL + e2e test on random init
This commit is contained in:
parent
02af8e9f0b
commit
2786117469
|
@ -1,8 +1,9 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
|
||||
__all__ = [
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"StableDiffusion_XL",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
from torch import device as Device, dtype as DType, Tensor
|
||||
|
||||
|
||||
class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
||||
encoder_scale: float = 0.13025
|
||||
|
||||
|
||||
class StableDiffusion_XL(LatentDiffusionModel):
|
||||
unet: SDXLUNet
|
||||
clip_text_encoder: DoubleTextEncoder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: SDXLUNet | None = None,
|
||||
lda: SDXLAutoencoder | None = None,
|
||||
clip_text_encoder: DoubleTextEncoder | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SDXLUNet(in_channels=4)
|
||||
lda = lda or SDXLAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||
scheduler = scheduler or DDIM(num_inference_steps=30)
|
||||
|
||||
super().__init__(
|
||||
unet=unet,
|
||||
lda=lda,
|
||||
clip_text_encoder=clip_text_encoder,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
|
||||
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
||||
if text == negative_text:
|
||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
|
||||
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||
)
|
||||
|
||||
# TODO: when negative_text is None, use zero tensor?
|
||||
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
|
||||
|
||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
|
||||
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||
)
|
||||
|
||||
@property
|
||||
def default_time_ids(self) -> Tensor:
|
||||
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||
return time_ids.repeat(2, 1)
|
||||
|
||||
def set_unet_context(
|
||||
self,
|
||||
*,
|
||||
timestep: Tensor,
|
||||
clip_text_embedding: Tensor,
|
||||
pooled_text_embedding: Tensor,
|
||||
time_ids: Tensor,
|
||||
**_: Tensor
|
||||
) -> None:
|
||||
self.unet.set_timestep(timestep=timestep)
|
||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
|
||||
self.unet.set_time_ids(time_ids=time_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
*,
|
||||
clip_text_embedding: Tensor,
|
||||
pooled_text_embedding: Tensor,
|
||||
time_ids: Tensor,
|
||||
condition_scale: float = 5.0,
|
||||
**kwargs: Tensor
|
||||
) -> Tensor:
|
||||
return super().forward(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=condition_scale,
|
||||
**kwargs
|
||||
)
|
|
@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
|||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
@ -73,6 +74,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||
def controlnet_data(
|
||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||
|
@ -326,6 +332,51 @@ def sd15_ddim_lda_ft_mse(
|
|||
return sd15
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_lda_weights(test_weights_path: Path) -> Path:
|
||||
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
|
||||
if not sdxl_lda_weights.is_file():
|
||||
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sdxl_lda_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
||||
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
|
||||
if not sdxl_unet_weights.is_file():
|
||||
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sdxl_unet_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
||||
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
||||
if not sdxl_double_text_encoder_weights.is_file():
|
||||
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sdxl_double_text_encoder_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_ddim(
|
||||
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
|
||||
) -> StableDiffusion_XL:
|
||||
if test_device.type == "cpu":
|
||||
warn(message="not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
scheduler = DDIM(num_inference_steps=30)
|
||||
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
||||
|
||||
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
||||
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
||||
|
||||
return sdxl
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_std_random_init(
|
||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
|
@ -957,3 +1008,37 @@ def test_diffusion_ip_adapter(
|
|||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||
|
||||
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
expected_image = expected_sdxl_ddim_random_init
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
|
||||
|
||||
manual_seed(seed=2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device)
|
||||
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x=x)
|
||||
|
||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue