mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from refiners.fluxion import manual_seed
|
|
from refiners.fluxion.utils import no_grad
|
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def refiners_sd15_unet(
|
|
test_device: torch.device,
|
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
|
) -> SD1UNet:
|
|
return SD1UNet(
|
|
in_channels=4,
|
|
device=test_device,
|
|
dtype=test_dtype_fp32_bf16_fp16,
|
|
)
|
|
|
|
|
|
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
|
|
manual_seed(0)
|
|
text_embedding = torch.randn(1, 77, 768, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
|
|
timestep = torch.randint(0, 999, size=(1, 1), device=refiners_sd15_unet.device)
|
|
x = torch.randn(1, 4, 32, 32, device=refiners_sd15_unet.device, dtype=refiners_sd15_unet.dtype)
|
|
|
|
refiners_sd15_unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
|
|
|
with no_grad():
|
|
refiners_sd15_unet.set_timestep(timestep=timestep)
|
|
y_1 = refiners_sd15_unet(x.clone())
|
|
|
|
with no_grad():
|
|
refiners_sd15_unet.set_timestep(timestep=timestep)
|
|
y_2 = refiners_sd15_unet(x.clone())
|
|
|
|
assert torch.equal(y_1, y_2)
|