training_utils: fix naming issue timestep->step

This commit is contained in:
Cédric Deltheil 2023-12-05 09:57:08 +01:00 committed by Cédric Deltheil
parent 0dc3a17fbf
commit 46b4b4b462

View file

@ -29,8 +29,8 @@ from torch.nn import Module
class LatentDiffusionConfig(BaseModel):
unconditional_sampling_probability: float = 0.2
offset_noise: float = 0.1
min_timestep: int = 0
max_timestep: int = 999
min_step: int = 0
max_step: int = 999
class TestDiffusionConfig(BaseModel):
@ -154,9 +154,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
).to(device=self.device)
def sample_timestep(self) -> Tensor:
random_step = random.randint(
a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep
)
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)