mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
correctly scale init latents for Euler scheduler
This commit is contained in:
parent
bf0ba58541
commit
7e4e0f0650
|
@ -48,14 +48,18 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
height // 8,
|
||||
width // 8,
|
||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||
|
||||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
|
||||
return self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
step=self.solver.first_inference_step,
|
||||
)
|
||||
x = noise
|
||||
else:
|
||||
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
|
||||
x = self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
step=self.solver.first_inference_step,
|
||||
)
|
||||
|
||||
return self.solver.scale_model_input(x, step=-1)
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
|
|
|
@ -63,11 +63,15 @@ class Euler(Solver):
|
|||
|
||||
Args:
|
||||
x: The model input.
|
||||
step: The current step.
|
||||
step: The current step. This method is called with `step=-1` in `init_latents`.
|
||||
|
||||
Returns:
|
||||
The scaled model input.
|
||||
"""
|
||||
|
||||
if step == -1:
|
||||
return x * self.init_noise_sigma
|
||||
|
||||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
|
|
|
@ -826,8 +826,7 @@ def test_diffusion_std_random_init_euler(
|
|||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
x = x * euler_solver.init_noise_sigma
|
||||
x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
|
@ -1997,11 +1996,7 @@ def test_diffusion_sdxl_euler_deterministic(
|
|||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_inference_steps(30)
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
||||
# init latents must be scaled for Euler
|
||||
# TODO make init_latents work
|
||||
x = x * sdxl.solver.init_noise_sigma
|
||||
x = sdxl.init_latents((1024, 1024)).to(sdxl.device, sdxl.dtype)
|
||||
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
|
|
Loading…
Reference in a new issue