mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
implement SDXL + e2e test on random init
This commit is contained in:
parent
02af8e9f0b
commit
2786117469
|
@ -1,8 +1,9 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
|
"StableDiffusion_XL",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
import torch
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
from torch import device as Device, dtype as DType, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
||||||
|
encoder_scale: float = 0.13025
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
|
unet: SDXLUNet
|
||||||
|
clip_text_encoder: DoubleTextEncoder
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SDXLUNet | None = None,
|
||||||
|
lda: SDXLAutoencoder | None = None,
|
||||||
|
clip_text_encoder: DoubleTextEncoder | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
unet = unet or SDXLUNet(in_channels=4)
|
||||||
|
lda = lda or SDXLAutoencoder()
|
||||||
|
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||||
|
scheduler = scheduler or DDIM(num_inference_steps=30)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
|
||||||
|
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
||||||
|
if text == negative_text:
|
||||||
|
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
|
||||||
|
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: when negative_text is None, use zero tensor?
|
||||||
|
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
|
||||||
|
|
||||||
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
|
||||||
|
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_time_ids(self) -> Tensor:
|
||||||
|
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||||
|
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||||
|
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||||
|
return time_ids.repeat(2, 1)
|
||||||
|
|
||||||
|
def set_unet_context(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
timestep: Tensor,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
**_: Tensor
|
||||||
|
) -> None:
|
||||||
|
self.unet.set_timestep(timestep=timestep)
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
|
||||||
|
self.unet.set_time_ids(time_ids=time_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
condition_scale: float = 5.0,
|
||||||
|
**kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
return super().forward(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
**kwargs
|
||||||
|
)
|
|
@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
@ -73,6 +74,11 @@ def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
def controlnet_data(
|
def controlnet_data(
|
||||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
@ -326,6 +332,51 @@ def sd15_ddim_lda_ft_mse(
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sdxl_lda_weights(test_weights_path: Path) -> Path:
|
||||||
|
sdxl_lda_weights = test_weights_path / "sdxl-lda.safetensors"
|
||||||
|
if not sdxl_lda_weights.is_file():
|
||||||
|
warn(message=f"could not find weights at {sdxl_lda_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return sdxl_lda_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sdxl_unet_weights(test_weights_path: Path) -> Path:
|
||||||
|
sdxl_unet_weights = test_weights_path / "sdxl-unet.safetensors"
|
||||||
|
if not sdxl_unet_weights.is_file():
|
||||||
|
warn(message=f"could not find weights at {sdxl_unet_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return sdxl_unet_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sdxl_text_encoder_weights(test_weights_path: Path) -> Path:
|
||||||
|
sdxl_double_text_encoder_weights = test_weights_path / "DoubleCLIPTextEncoder.safetensors"
|
||||||
|
if not sdxl_double_text_encoder_weights.is_file():
|
||||||
|
warn(message=f"could not find weights at {sdxl_double_text_encoder_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return sdxl_double_text_encoder_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sdxl_ddim(
|
||||||
|
sdxl_text_encoder_weights: Path, sdxl_lda_weights: Path, sdxl_unet_weights: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_XL:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn(message="not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
scheduler = DDIM(num_inference_steps=30)
|
||||||
|
sdxl = StableDiffusion_XL(scheduler=scheduler, device=test_device)
|
||||||
|
|
||||||
|
sdxl.clip_text_encoder.load_from_safetensors(tensors_path=sdxl_text_encoder_weights)
|
||||||
|
sdxl.lda.load_from_safetensors(tensors_path=sdxl_lda_weights)
|
||||||
|
sdxl.unet.load_from_safetensors(tensors_path=sdxl_unet_weights)
|
||||||
|
|
||||||
|
return sdxl
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
@ -957,3 +1008,37 @@ def test_diffusion_ip_adapter(
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sdxl_random_init(
|
||||||
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||||
|
) -> None:
|
||||||
|
sdxl = sdxl_ddim
|
||||||
|
expected_image = expected_sdxl_ddim_random_init
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
|
||||||
|
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
|
||||||
|
|
||||||
|
manual_seed(seed=2)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=test_device)
|
||||||
|
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=5,
|
||||||
|
)
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x=x)
|
||||||
|
|
||||||
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue