diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 9c95701..187a1fb 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -29,8 +29,8 @@ from torch.nn import Module class LatentDiffusionConfig(BaseModel): unconditional_sampling_probability: float = 0.2 offset_noise: float = 0.1 - min_timestep: int = 0 - max_timestep: int = 999 + min_step: int = 0 + max_step: int = 999 class TestDiffusionConfig(BaseModel): @@ -154,9 +154,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): ).to(device=self.device) def sample_timestep(self) -> Tensor: - random_step = random.randint( - a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep - ) + random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) self.current_step = random_step return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)