modify some foundational tests to also test in float16 and bfloat16

This commit is contained in:
Laurent 2024-10-03 08:47:37 +00:00 committed by Laureηt
parent b20474f8f5
commit f3d2b6c325
6 changed files with 170 additions and 45 deletions

View file

@ -10,12 +10,16 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH: def our_encoder(
test_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPImageEncoderH:
weights = test_weights_path / "CLIPImageEncoderH.safetensors" weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not weights.is_file(): if not weights.is_file():
warn(f"could not find weights at {weights}, skipping") warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
encoder = CLIPImageEncoderH(device=test_device) encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
tensors = load_from_safetensors(weights) tensors = load_from_safetensors(weights)
encoder.load_state_dict(tensors) encoder.load_state_dict(tensors)
return encoder return encoder
@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection: def ref_encoder(
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore stabilityai_unclip_weights_path: Path,
test_device # type: ignore test_device: torch.device,
) test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
stabilityai_unclip_weights_path,
subfolder="image_encoder",
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
@no_grad()
@pytest.mark.flaky(reruns=3)
def test_encoder( def test_encoder(
ref_encoder: CLIPVisionModelWithProjection, ref_encoder: CLIPVisionModelWithProjection,
our_encoder: CLIPImageEncoderH, our_encoder: CLIPImageEncoderH,
test_device: torch.device,
): ):
x = torch.randn(1, 3, 224, 224).to(test_device) assert ref_encoder.dtype == our_encoder.dtype
assert ref_encoder.device == our_encoder.device
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)
with no_grad(): ref_embeddings = ref_encoder(x).image_embeds
ref_embeddings = ref_encoder(x).image_embeds our_embeddings = our_encoder(x)
our_embeddings = our_encoder(x)
assert ref_embeddings.shape == (1, 1024) assert ref_embeddings.shape == (1, 1024)
assert our_embeddings.shape == (1, 1024) assert our_embeddings.shape == (1, 1024)
assert (our_embeddings - ref_embeddings).abs().max() < 0.01 assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)

View file

@ -30,13 +30,17 @@ PROMPTS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL: def our_encoder(
test_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors" weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file(): if not weights.is_file():
warn(f"could not find weights at {weights}, skipping") warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights) tensors = load_from_safetensors(weights)
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
encoder.load_state_dict(tensors) encoder.load_state_dict(tensors)
return encoder return encoder
@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel: def ref_encoder(
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore runwayml_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
) -> transformers.CLIPTextModel:
return transformers.CLIPTextModel.from_pretrained( # type: ignore
runwayml_weights_path,
subfolder="text_encoder",
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL): def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
return long_prompt if request.param == "<long prompt>" else request.param return long_prompt if request.param == "<long prompt>" else request.param
@no_grad()
def test_encoder( def test_encoder(
prompt: str, prompt: str,
ref_tokenizer: transformers.CLIPTokenizer, ref_tokenizer: transformers.CLIPTokenizer,
ref_encoder: transformers.CLIPTextModel, ref_encoder: transformers.CLIPTextModel,
our_encoder: CLIPTextEncoderL, our_encoder: CLIPTextEncoderL,
test_device: torch.device,
): ):
ref_tokens = ref_tokenizer( # type: ignore ref_tokens = ref_tokenizer( # type: ignore
prompt, prompt,
@ -89,18 +100,16 @@ def test_encoder(
our_tokens = tokenizer(prompt) our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)
with no_grad(): ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0]
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] our_embeddings = our_encoder(prompt)
our_embeddings = our_encoder(prompt)
assert ref_embeddings.shape == (1, 77, 768) assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768) assert our_embeddings.shape == (1, 77, 768)
# FG-336 - Not strictly equal because we do not use the same implementation # FG-336 - Not strictly equal because we do not use the same implementation
# of self-attention. We use `scaled_dot_product_attention` which can have # of self-attention. We use `scaled_dot_product_attention` which can have
# numerical differences depending on the backend. # numerical differences depending on the backend. Also we use FP16 weights.
# Also we use FP16 weights. torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0)
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
def test_list_string_tokenizer( def test_list_string_tokenizer(

View file

@ -109,7 +109,7 @@ def test_dinov2_facebook_weights(
) -> None: ) -> None:
manual_seed(2) manual_seed(2)
input_data = torch.randn( input_data = torch.randn(
(1, 3, resolution, resolution), size=(1, 3, resolution, resolution),
device=test_device, device=test_device,
) )
@ -129,27 +129,28 @@ def test_dinov2_facebook_weights(
@no_grad() @no_grad()
def test_dinov2_float16( def test_dinov2(
resolution: int, resolution: int,
test_dtype_fp32_bf16_fp16: torch.dtype,
test_device: torch.device, test_device: torch.device,
) -> None: ) -> None:
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
model = DINOv2_small(device=test_device, dtype=torch.float16) model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
manual_seed(2) manual_seed(2)
input_data = torch.randn( input_data = torch.randn(
(1, 3, resolution, resolution), size=(1, 3, resolution, resolution),
device=test_device, device=test_device,
dtype=torch.float16, dtype=test_dtype_fp32_bf16_fp16,
) )
output = model(input_data) output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1 sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim) assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16 assert output.dtype == test_dtype_fp32_bf16_fp16
@no_grad() @no_grad()
@ -162,7 +163,7 @@ def test_dinov2_batch_size(
batch_size = 4 batch_size = 4
manual_seed(2) manual_seed(2)
input_data = torch.randn( input_data = torch.randn(
(batch_size, 3, resolution, resolution), size=(batch_size, 3, resolution, resolution),
device=test_device, device=test_device,
) )

View file

@ -6,8 +6,8 @@ import torch
from PIL import Image from PIL import Image
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -15,16 +15,37 @@ def ref_path() -> Path:
return Path(__file__).parent / "test_auto_encoder_ref" return Path(__file__).parent / "test_auto_encoder_ref"
@pytest.fixture(scope="module") @pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: def lda(
lda_weights = test_weights_path / "lda.safetensors" request: pytest.FixtureRequest,
if not lda_weights.is_file(): test_weights_path: Path,
warn(f"could not find weights at {lda_weights}, skipping") test_dtype_fp32_bf16_fp16: torch.dtype,
pytest.skip(allow_module_level=True) test_device: torch.device,
encoder = LatentDiffusionAutoencoder(device=test_device) ) -> LatentDiffusionAutoencoder:
tensors = load_from_safetensors(lda_weights) model_version = request.param
encoder.load_state_dict(tensors) match (model_version, test_dtype_fp32_bf16_fp16):
return encoder case ("SD1.5", _):
weight_path = test_weights_path / "lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SD1Autoencoder().load_from_safetensors(weight_path)
case ("SDXL", torch.float16):
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case ("SDXL", _):
weight_path = test_weights_path / "sdxl-lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case _:
raise ValueError(f"Unknown model version: {model_version}")
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
return model
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View file

@ -0,0 +1,77 @@
import torch
from PIL import Image
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
@no_grad()
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
manual_seed(2)
latents_0 = LatentDiffusionModel.sample_noise(
size=(1, 4, 64, 64),
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)
manual_seed(2)
latents_1 = LatentDiffusionModel.sample_noise(
size=(1, 4, 64, 64),
offset_noise=0.0, # should be no-op
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
@no_grad()
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
# prepare inputs
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
text_embedding = sd.compute_clip_text_embedding("")
# run the pipeline of models, for a single step
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
assert output.shape == (1, 4, 64, 64)
@no_grad()
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
# prepare inputs
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
target_image = Image.new("RGB", (512, 512))
mask = Image.new("L", (512, 512))
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
text_embedding = sd.compute_clip_text_embedding("")
# run the pipeline of models, for a single step
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
assert output.shape == (1, 4, 64, 64)
@no_grad()
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
# prepare inputs
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("")
time_ids = sd.default_time_ids
# run the pipeline of models, for a single step
output = sd(
latent_noise,
step=0,
clip_text_embedding=text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
assert output.shape == (1, 4, 128, 128)

View file

@ -7,9 +7,15 @@ from refiners.foundationals.latent_diffusion import SD1UNet
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet: def refiners_sd15_unet(
unet = SD1UNet(in_channels=4, device=test_device) test_device: torch.device,
return unet test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1UNet:
return SD1UNet(
in_channels=4,
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)
def test_unet_context_flush(refiners_sd15_unet: SD1UNet): def test_unet_context_flush(refiners_sd15_unet: SD1UNet):