refiners/tests/foundationals/latent_diffusion/test_unet.py
Cédric Deltheil b933fabf31 unet: get rid of clip_embedding attribute for SD1
It is implicitly defined by the underlying cross-attention layer. This
also makes it consistent with SDXL.
2023-09-01 19:23:33 +02:00

24 lines
670 B
Python

from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.fluxion import manual_seed
import torch
def test_unet_context_flush():
manual_seed(0)
text_embedding = torch.randn(1, 77, 768)
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad():
unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone())
with torch.no_grad():
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())
assert torch.equal(y_1, y_2)