mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
modify some foundational tests to also test in float16 and bfloat16
This commit is contained in:
parent
b20474f8f5
commit
f3d2b6c325
|
@ -10,12 +10,16 @@ from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH:
|
def our_encoder(
|
||||||
|
test_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
) -> CLIPImageEncoderH:
|
||||||
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
|
||||||
if not weights.is_file():
|
if not weights.is_file():
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
encoder = CLIPImageEncoderH(device=test_device)
|
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(weights)
|
||||||
encoder.load_state_dict(tensors)
|
encoder.load_state_dict(tensors)
|
||||||
return encoder
|
return encoder
|
||||||
|
@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
|
def ref_encoder(
|
||||||
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
|
stabilityai_unclip_weights_path: Path,
|
||||||
test_device # type: ignore
|
test_device: torch.device,
|
||||||
)
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
) -> CLIPVisionModelWithProjection:
|
||||||
|
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||||
|
stabilityai_unclip_weights_path,
|
||||||
|
subfolder="image_encoder",
|
||||||
|
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
@pytest.mark.flaky(reruns=3)
|
||||||
def test_encoder(
|
def test_encoder(
|
||||||
ref_encoder: CLIPVisionModelWithProjection,
|
ref_encoder: CLIPVisionModelWithProjection,
|
||||||
our_encoder: CLIPImageEncoderH,
|
our_encoder: CLIPImageEncoderH,
|
||||||
test_device: torch.device,
|
|
||||||
):
|
):
|
||||||
x = torch.randn(1, 3, 224, 224).to(test_device)
|
assert ref_encoder.dtype == our_encoder.dtype
|
||||||
|
assert ref_encoder.device == our_encoder.device
|
||||||
|
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)
|
||||||
|
|
||||||
with no_grad():
|
|
||||||
ref_embeddings = ref_encoder(x).image_embeds
|
ref_embeddings = ref_encoder(x).image_embeds
|
||||||
our_embeddings = our_encoder(x)
|
our_embeddings = our_encoder(x)
|
||||||
|
|
||||||
assert ref_embeddings.shape == (1, 1024)
|
assert ref_embeddings.shape == (1, 1024)
|
||||||
assert our_embeddings.shape == (1, 1024)
|
assert our_embeddings.shape == (1, 1024)
|
||||||
|
|
||||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)
|
||||||
|
|
|
@ -30,13 +30,17 @@ PROMPTS = [
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL:
|
def our_encoder(
|
||||||
|
test_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
test_dtype_fp32_fp16: torch.dtype,
|
||||||
|
) -> CLIPTextEncoderL:
|
||||||
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
||||||
if not weights.is_file():
|
if not weights.is_file():
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
encoder = CLIPTextEncoderL(device=test_device)
|
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(weights)
|
||||||
|
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
|
||||||
encoder.load_state_dict(tensors)
|
encoder.load_state_dict(tensors)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel:
|
def ref_encoder(
|
||||||
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore
|
runwayml_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
test_dtype_fp32_fp16: torch.dtype,
|
||||||
|
) -> transformers.CLIPTextModel:
|
||||||
|
return transformers.CLIPTextModel.from_pretrained( # type: ignore
|
||||||
|
runwayml_weights_path,
|
||||||
|
subfolder="text_encoder",
|
||||||
|
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
||||||
|
@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
|
||||||
return long_prompt if request.param == "<long prompt>" else request.param
|
return long_prompt if request.param == "<long prompt>" else request.param
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
def test_encoder(
|
def test_encoder(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
ref_tokenizer: transformers.CLIPTokenizer,
|
ref_tokenizer: transformers.CLIPTokenizer,
|
||||||
ref_encoder: transformers.CLIPTextModel,
|
ref_encoder: transformers.CLIPTextModel,
|
||||||
our_encoder: CLIPTextEncoderL,
|
our_encoder: CLIPTextEncoderL,
|
||||||
test_device: torch.device,
|
|
||||||
):
|
):
|
||||||
ref_tokens = ref_tokenizer( # type: ignore
|
ref_tokens = ref_tokenizer( # type: ignore
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -89,8 +100,7 @@ def test_encoder(
|
||||||
our_tokens = tokenizer(prompt)
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
with no_grad():
|
ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0]
|
||||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
|
||||||
our_embeddings = our_encoder(prompt)
|
our_embeddings = our_encoder(prompt)
|
||||||
|
|
||||||
assert ref_embeddings.shape == (1, 77, 768)
|
assert ref_embeddings.shape == (1, 77, 768)
|
||||||
|
@ -98,9 +108,8 @@ def test_encoder(
|
||||||
|
|
||||||
# FG-336 - Not strictly equal because we do not use the same implementation
|
# FG-336 - Not strictly equal because we do not use the same implementation
|
||||||
# of self-attention. We use `scaled_dot_product_attention` which can have
|
# of self-attention. We use `scaled_dot_product_attention` which can have
|
||||||
# numerical differences depending on the backend.
|
# numerical differences depending on the backend. Also we use FP16 weights.
|
||||||
# Also we use FP16 weights.
|
torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0)
|
||||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_string_tokenizer(
|
def test_list_string_tokenizer(
|
||||||
|
|
|
@ -109,7 +109,7 @@ def test_dinov2_facebook_weights(
|
||||||
) -> None:
|
) -> None:
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
input_data = torch.randn(
|
input_data = torch.randn(
|
||||||
(1, 3, resolution, resolution),
|
size=(1, 3, resolution, resolution),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -129,27 +129,28 @@ def test_dinov2_facebook_weights(
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_dinov2_float16(
|
def test_dinov2(
|
||||||
resolution: int,
|
resolution: int,
|
||||||
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
input_data = torch.randn(
|
input_data = torch.randn(
|
||||||
(1, 3, resolution, resolution),
|
size=(1, 3, resolution, resolution),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
dtype=torch.float16,
|
dtype=test_dtype_fp32_bf16_fp16,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = model(input_data)
|
output = model(input_data)
|
||||||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||||
assert output.shape == (1, sequence_length, model.embedding_dim)
|
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||||
assert output.dtype == torch.float16
|
assert output.dtype == test_dtype_fp32_bf16_fp16
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
@ -162,7 +163,7 @@ def test_dinov2_batch_size(
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
input_data = torch.randn(
|
input_data = torch.randn(
|
||||||
(batch_size, 3, resolution, resolution),
|
size=(batch_size, 3, resolution, resolution),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -15,16 +15,37 @@ def ref_path() -> Path:
|
||||||
return Path(__file__).parent / "test_auto_encoder_ref"
|
return Path(__file__).parent / "test_auto_encoder_ref"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
|
||||||
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
def lda(
|
||||||
lda_weights = test_weights_path / "lda.safetensors"
|
request: pytest.FixtureRequest,
|
||||||
if not lda_weights.is_file():
|
test_weights_path: Path,
|
||||||
warn(f"could not find weights at {lda_weights}, skipping")
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> LatentDiffusionAutoencoder:
|
||||||
|
model_version = request.param
|
||||||
|
match (model_version, test_dtype_fp32_bf16_fp16):
|
||||||
|
case ("SD1.5", _):
|
||||||
|
weight_path = test_weights_path / "lda.safetensors"
|
||||||
|
if not weight_path.is_file():
|
||||||
|
warn(f"could not find weights at {weight_path}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
encoder = LatentDiffusionAutoencoder(device=test_device)
|
model = SD1Autoencoder().load_from_safetensors(weight_path)
|
||||||
tensors = load_from_safetensors(lda_weights)
|
case ("SDXL", torch.float16):
|
||||||
encoder.load_state_dict(tensors)
|
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
|
||||||
return encoder
|
if not weight_path.is_file():
|
||||||
|
warn(f"could not find weights at {weight_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
||||||
|
case ("SDXL", _):
|
||||||
|
weight_path = test_weights_path / "sdxl-lda.safetensors"
|
||||||
|
if not weight_path.is_file():
|
||||||
|
warn(f"could not find weights at {weight_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
model = SDXLAutoencoder().load_from_safetensors(weight_path)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown model version: {model_version}")
|
||||||
|
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
77
tests/foundationals/latent_diffusion/test_models.py
Normal file
77
tests/foundationals/latent_diffusion/test_models.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||||
|
manual_seed(2)
|
||||||
|
latents_0 = LatentDiffusionModel.sample_noise(
|
||||||
|
size=(1, 4, 64, 64),
|
||||||
|
device=test_device,
|
||||||
|
dtype=test_dtype_fp32_bf16_fp16,
|
||||||
|
)
|
||||||
|
manual_seed(2)
|
||||||
|
latents_1 = LatentDiffusionModel.sample_noise(
|
||||||
|
size=(1, 4, 64, 64),
|
||||||
|
offset_noise=0.0, # should be no-op
|
||||||
|
device=test_device,
|
||||||
|
dtype=test_dtype_fp32_bf16_fp16,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||||
|
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
text_embedding = sd.compute_clip_text_embedding("")
|
||||||
|
|
||||||
|
# run the pipeline of models, for a single step
|
||||||
|
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
||||||
|
|
||||||
|
assert output.shape == (1, 4, 64, 64)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||||
|
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
target_image = Image.new("RGB", (512, 512))
|
||||||
|
mask = Image.new("L", (512, 512))
|
||||||
|
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
|
||||||
|
text_embedding = sd.compute_clip_text_embedding("")
|
||||||
|
|
||||||
|
# run the pipeline of models, for a single step
|
||||||
|
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
||||||
|
|
||||||
|
assert output.shape == (1, 4, 64, 64)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
|
||||||
|
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
|
||||||
|
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("")
|
||||||
|
time_ids = sd.default_time_ids
|
||||||
|
|
||||||
|
# run the pipeline of models, for a single step
|
||||||
|
output = sd(
|
||||||
|
latent_noise,
|
||||||
|
step=0,
|
||||||
|
clip_text_embedding=text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.shape == (1, 4, 128, 128)
|
|
@ -7,9 +7,15 @@ from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
|
def refiners_sd15_unet(
|
||||||
unet = SD1UNet(in_channels=4, device=test_device)
|
test_device: torch.device,
|
||||||
return unet
|
test_dtype_fp32_bf16_fp16: torch.dtype,
|
||||||
|
) -> SD1UNet:
|
||||||
|
return SD1UNet(
|
||||||
|
in_channels=4,
|
||||||
|
device=test_device,
|
||||||
|
dtype=test_dtype_fp32_bf16_fp16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
|
def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
|
||||||
|
|
Loading…
Reference in a new issue