use float32 reference for textual inversion (fixes tests on CPU)

This commit is contained in:
Pierre Chapuis 2023-09-11 14:13:40 +02:00
parent e5425e2968
commit dd0cca5855

View file

@ -45,7 +45,7 @@ def our_encoder_with_new_concepts(
def ref_sd15_with_new_concepts(
runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device
):
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path, torch_dtype=torch.float16).to(test_device) # type: ignore
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
return pipe