correctly scale init latents for Euler scheduler

This commit is contained in:
Pierre Chapuis 2024-02-23 16:45:21 +01:00
parent bf0ba58541
commit 7e4e0f0650
3 changed files with 18 additions and 15 deletions

View file

@ -48,15 +48,19 @@ class LatentDiffusionModel(fl.Module, ABC):
height // 8, height // 8,
width // 8, width // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}" ], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None: if init_image is None:
return noise x = noise
else:
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
return self.solver.add_noise( x = self.solver.add_noise(
x=encoded_image, x=encoded_image,
noise=noise, noise=noise,
step=self.solver.first_inference_step, step=self.solver.first_inference_step,
) )
return self.solver.scale_model_input(x, step=-1)
@property @property
def steps(self) -> list[int]: def steps(self) -> list[int]:
return self.solver.inference_steps return self.solver.inference_steps

View file

@ -63,11 +63,15 @@ class Euler(Solver):
Args: Args:
x: The model input. x: The model input.
step: The current step. step: The current step. This method is called with `step=-1` in `init_latents`.
Returns: Returns:
The scaled model input. The scaled model input.
""" """
if step == -1:
return x * self.init_noise_sigma
sigma = self.sigmas[step] sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5) return x / ((sigma**2 + 1) ** 0.5)

View file

@ -826,8 +826,7 @@ def test_diffusion_std_random_init_euler(
sd15.set_inference_steps(30) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)
x = x * euler_solver.init_noise_sigma
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
@ -1997,11 +1996,7 @@ def test_diffusion_sdxl_euler_deterministic(
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_inference_steps(30) sdxl.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
# init latents must be scaled for Euler
# TODO make init_latents work
x = x * sdxl.solver.init_noise_sigma
for step in sdxl.steps: for step in sdxl.steps:
x = sdxl( x = sdxl(