mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
modify some foundational tests to also test in float16 and bfloat16
This commit is contained in:
parent
b20474f8f5
commit
f3d2b6c325
|
@ -10,12 +10,16 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH:
|
||||
def our_encoder(
|
||||
test_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||
) -> CLIPImageEncoderH:
|
||||
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
encoder = CLIPImageEncoderH(device=test_device)
|
||||
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
tensors = load_from_safetensors(weights)
|
||||
encoder.load_state_dict(tensors)
|
||||
return encoder
|
||||
|
@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
|
||||
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
|
||||
test_device # type: ignore
|
||||
)
|
||||
def ref_encoder(
|
||||
stabilityai_unclip_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||
) -> CLIPVisionModelWithProjection:
|
||||
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||
stabilityai_unclip_weights_path,
|
||||
subfolder="image_encoder",
|
||||
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
|
||||
|
||||
@no_grad()
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
def test_encoder(
|
||||
ref_encoder: CLIPVisionModelWithProjection,
|
||||
our_encoder: CLIPImageEncoderH,
|
||||
test_device: torch.device,
|
||||
):
|
||||
x = torch.randn(1, 3, 224, 224).to(test_device)
|
||||
assert ref_encoder.dtype == our_encoder.dtype
|
||||
assert ref_encoder.device == our_encoder.device
|
||||
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)
|
||||
|
||||
with no_grad():
|
||||
ref_embeddings = ref_encoder(x).image_embeds
|
||||
our_embeddings = our_encoder(x)
|
||||
ref_embeddings = ref_encoder(x).image_embeds
|
||||
our_embeddings = our_encoder(x)
|
||||
|
||||
assert ref_embeddings.shape == (1, 1024)
|
||||
assert our_embeddings.shape == (1, 1024)
|
||||
|
||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|
||||
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)
|
||||
|
|
|
@ -30,13 +30,17 @@ PROMPTS = [
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL:
|
||||
def our_encoder(
|
||||
test_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
test_dtype_fp32_fp16: torch.dtype,
|
||||
) -> CLIPTextEncoderL:
|
||||
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
encoder = CLIPTextEncoderL(device=test_device)
|
||||
tensors = load_from_safetensors(weights)
|
||||
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
|
||||
encoder.load_state_dict(tensors)
|
||||
return encoder
|
||||
|
||||
|
@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel:
|
||||
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore
|
||||
def ref_encoder(
|
||||
runwayml_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
test_dtype_fp32_fp16: torch.dtype,
|
||||
) -> transformers.CLIPTextModel:
|
||||
return transformers.CLIPTextModel.from_pretrained( # type: ignore
|
||||
runwayml_weights_path,
|
||||
subfolder="text_encoder",
|
||||
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
|
||||
|
||||
|
||||
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
||||
|
@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
|
|||
return long_prompt if request.param == "<long prompt>" else request.param
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_encoder(
|
||||
prompt: str,
|
||||
ref_tokenizer: transformers.CLIPTokenizer,
|
||||
ref_encoder: transformers.CLIPTextModel,
|
||||
our_encoder: CLIPTextEncoderL,
|
||||
test_device: torch.device,
|
||||
):
|
||||
ref_tokens = ref_tokenizer( # type: ignore
|
||||
prompt,
|
||||
|
@ -89,18 +100,16 @@ def test_encoder(
|
|||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with no_grad():
|
||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
||||
our_embeddings = our_encoder(prompt)
|
||||
ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0]
|
||||
our_embeddings = our_encoder(prompt)
|
||||
|
||||
assert ref_embeddings.shape == (1, 77, 768)
|
||||
assert our_embeddings.shape == (1, 77, 768)
|
||||
|
||||
# FG-336 - Not strictly equal because we do not use the same implementation
|
||||
# of self-attention. We use `scaled_dot_product_attention` which can have
|
||||
# numerical differences depending on the backend.
|
||||
# Also we use FP16 weights.
|
||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|
||||
# numerical differences depending on the backend. Also we use FP16 weights.
|
||||
torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0)
|
||||
|
||||
|
||||
def test_list_string_tokenizer(
|
||||
|
|
|
@ -109,7 +109,7 @@ def test_dinov2_facebook_weights(
|
|||
) -> None:
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(1, 3, resolution, resolution),
|
||||
size=(1, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
)
|
||||
|
||||
|
@ -129,27 +129,28 @@ def test_dinov2_facebook_weights(
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_dinov2_float16(
|
||||
def test_dinov2(
|
||||
resolution: int,
|
||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||
test_device: torch.device,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
||||
model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(1, 3, resolution, resolution),
|
||||
size=(1, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
dtype=torch.float16,
|
||||
dtype=test_dtype_fp32_bf16_fp16,
|
||||
)
|
||||
|
||||
output = model(input_data)
|
||||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||
assert output.dtype == torch.float16
|
||||
assert output.dtype == test_dtype_fp32_bf16_fp16
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -162,7 +163,7 @@ def test_dinov2_batch_size(
|
|||
batch_size = 4
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(batch_size, 3, resolution, resolution),
|
||||
size=(batch_size, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
)
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
|||
from PIL import Image
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -15,16 +15,37 @@ def ref_path() -> Path:
|
|||
return Path(__file__).parent / "test_auto_encoder_ref"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
||||
lda_weights = test_weights_path / "lda.safetensors"
|
||||
if not lda_weights.is_file():
|
||||
warn(f"could not find weights at {lda_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
encoder = LatentDiffusionAutoencoder(device=test_device)
|
||||
tensors = load_from_safetensors(lda_weights)
|
||||
encoder.load_state_dict(tensors)
|
||||
return encoder
|
||||
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
|
||||
def lda(
|
||||
request: pytest.FixtureRequest,
|
||||
test_weights_path: Path,
|
||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||
test_device: torch.device,
|
||||
) -> LatentDiffusionAutoencoder:
|
||||
model_version = request.param
|
||||
match (model_version, test_dtype_fp32_bf16_fp16):
|
||||
case ("SD1.5", _):
|
||||
weight_path = test_weights_path / "lda.safetensors"
|
||||
if not weight_path.is_file():
|
||||
warn(f"could not find weights at {weight_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
model = SD1Autoencoder().load_from_safetensors(weight_path)
|
||||
case ("SDXL", torch.float16):
|
||||
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
||||
if not weight_path.is_file():
|
||||
warn(f"could not find weights at {weight_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
||||
case ("SDXL", _):
|
||||
weight_path = test_weights_path / "sdxl-lda.safetensors"
|
||||
if not weight_path.is_file():
|
||||
warn(f"could not find weights at {weight_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model version: {model_version}")
|
||||
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
77
tests/foundationals/latent_diffusion/test_models.py
Normal file
77
tests/foundationals/latent_diffusion/test_models.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||
manual_seed(2)
|
||||
latents_0 = LatentDiffusionModel.sample_noise(
|
||||
size=(1, 4, 64, 64),
|
||||
device=test_device,
|
||||
dtype=test_dtype_fp32_bf16_fp16,
|
||||
)
|
||||
manual_seed(2)
|
||||
latents_1 = LatentDiffusionModel.sample_noise(
|
||||
size=(1, 4, 64, 64),
|
||||
offset_noise=0.0, # should be no-op
|
||||
device=test_device,
|
||||
dtype=test_dtype_fp32_bf16_fp16,
|
||||
)
|
||||
|
||||
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
|
||||
# prepare inputs
|
||||
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
text_embedding = sd.compute_clip_text_embedding("")
|
||||
|
||||
# run the pipeline of models, for a single step
|
||||
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
||||
|
||||
assert output.shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
|
||||
# prepare inputs
|
||||
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
target_image = Image.new("RGB", (512, 512))
|
||||
mask = Image.new("L", (512, 512))
|
||||
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
|
||||
text_embedding = sd.compute_clip_text_embedding("")
|
||||
|
||||
# run the pipeline of models, for a single step
|
||||
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
||||
|
||||
assert output.shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
|
||||
# prepare inputs
|
||||
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("")
|
||||
time_ids = sd.default_time_ids
|
||||
|
||||
# run the pipeline of models, for a single step
|
||||
output = sd(
|
||||
latent_noise,
|
||||
step=0,
|
||||
clip_text_embedding=text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
|
||||
assert output.shape == (1, 4, 128, 128)
|
|
@ -7,9 +7,15 @@ from refiners.foundationals.latent_diffusion import SD1UNet
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
|
||||
unet = SD1UNet(in_channels=4, device=test_device)
|
||||
return unet
|
||||
def refiners_sd15_unet(
|
||||
test_device: torch.device,
|
||||
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||
) -> SD1UNet:
|
||||
return SD1UNet(
|
||||
in_channels=4,
|
||||
device=test_device,
|
||||
dtype=test_dtype_fp32_bf16_fp16,
|
||||
)
|
||||
|
||||
|
||||
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
|
||||
|
|
Loading…
Reference in a new issue