diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 55ca1d6..79e3272 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -48,14 +48,18 @@ class LatentDiffusionModel(fl.Module, ABC): height // 8, width // 8, ], f"noise shape is not compatible: {noise.shape}, with size: {size}" + if init_image is None: - return noise - encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) - return self.solver.add_noise( - x=encoded_image, - noise=noise, - step=self.solver.first_inference_step, - ) + x = noise + else: + encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) + x = self.solver.add_noise( + x=encoded_image, + noise=noise, + step=self.solver.first_inference_step, + ) + + return self.solver.scale_model_input(x, step=-1) @property def steps(self) -> list[int]: diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 3c062ec..e88c48c 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -63,11 +63,15 @@ class Euler(Solver): Args: x: The model input. - step: The current step. + step: The current step. This method is called with `step=-1` in `init_latents`. Returns: The scaled model input. """ + + if step == -1: + return x * self.init_noise_sigma + sigma = self.sigmas[step] return x / ((sigma**2 + 1) ** 0.5) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 3b1f9f5..8fdab23 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -826,8 +826,7 @@ def test_diffusion_std_random_init_euler( sd15.set_inference_steps(30) manual_seed(2) - x = torch.randn(1, 4, 64, 64, device=test_device) - x = x * euler_solver.init_noise_sigma + x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype) for step in sd15.steps: x = sd15( @@ -1997,11 +1996,7 @@ def test_diffusion_sdxl_euler_deterministic( time_ids = sdxl.default_time_ids sdxl.set_inference_steps(30) manual_seed(2) - x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) - - # init latents must be scaled for Euler - # TODO make init_latents work - x = x * sdxl.solver.init_noise_sigma + x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype) for step in sdxl.steps: x = sdxl(