2023-08-04 13:28:41 +00:00
|
|
|
import torch
|
|
|
|
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.fluxion import manual_seed
|
2023-12-29 09:59:51 +00:00
|
|
|
from refiners.fluxion.utils import no_grad
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
|
|
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
def test_unet_context_flush():
|
|
|
|
manual_seed(0)
|
|
|
|
text_embedding = torch.randn(1, 77, 768)
|
|
|
|
timestep = torch.randint(0, 999, size=(1, 1))
|
|
|
|
x = torch.randn(1, 4, 32, 32)
|
|
|
|
|
2023-08-31 15:22:57 +00:00
|
|
|
unet = SD1UNet(in_channels=4)
|
2023-08-04 13:28:41 +00:00
|
|
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
with no_grad():
|
2023-08-04 13:28:41 +00:00
|
|
|
unet.set_timestep(timestep=timestep)
|
|
|
|
y_1 = unet(x.clone())
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
with no_grad():
|
2023-08-04 13:28:41 +00:00
|
|
|
unet.set_timestep(timestep=timestep)
|
|
|
|
y_2 = unet(x.clone())
|
|
|
|
|
|
|
|
assert torch.equal(y_1, y_2)
|