rename test_unet.py to test_sd15_unet.py + use test_device fixture
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
Laurent 2024-09-09 15:53:48 +00:00 committed by Laureηt
parent 4eceda810c
commit 2c0174f50e
2 changed files with 31 additions and 25 deletions

View file

@ -0,0 +1,31 @@
import pytest
import torch
from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet
@pytest.fixture(scope="module")
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
unet = SD1UNet(in_channels=4, device=test_device)
return unet
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
manual_seed(0)
text_embedding = torch.randn(1, 77, 768, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
timestep = torch.randint(0, 999, size=(1, 1), device=refiners_sd15_unet.device)
x = torch.randn(1, 4, 32, 32, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
refiners_sd15_unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with no_grad():
refiners_sd15_unet.set_timestep(timestep=timestep)
y_1 = refiners_sd15_unet(x.clone())
with no_grad():
refiners_sd15_unet.set_timestep(timestep=timestep)
y_2 = refiners_sd15_unet(x.clone())
assert torch.equal(y_1, y_2)

View file

@ -1,25 +0,0 @@
import torch
from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet
def test_unet_context_flush():
manual_seed(0)
text_embedding = torch.randn(1, 77, 768)
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with no_grad():
unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone())
with no_grad():
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())
assert torch.equal(y_1, y_2)