From 46b4b4b4623f16efae7b3195382cfec4c245488e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Tue, 5 Dec 2023 09:57:08 +0100 Subject: [PATCH] training_utils: fix naming issue timestep->step --- src/refiners/training_utils/latent_diffusion.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 9c95701..187a1fb 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -29,8 +29,8 @@ from torch.nn import Module class LatentDiffusionConfig(BaseModel): unconditional_sampling_probability: float = 0.2 offset_noise: float = 0.1 - min_timestep: int = 0 - max_timestep: int = 999 + min_step: int = 0 + max_step: int = 999 class TestDiffusionConfig(BaseModel): @@ -154,9 +154,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): ).to(device=self.device) def sample_timestep(self) -> Tensor: - random_step = random.randint( - a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep - ) + random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) self.current_step = random_step return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)