training_utils: fix naming issue timestep->step

This commit is contained in:
Cédric Deltheil 2023-12-05 09:57:08 +01:00 committed by Cédric Deltheil
parent 0dc3a17fbf
commit 46b4b4b462

View file

@ -29,8 +29,8 @@ from torch.nn import Module
class LatentDiffusionConfig(BaseModel): class LatentDiffusionConfig(BaseModel):
unconditional_sampling_probability: float = 0.2 unconditional_sampling_probability: float = 0.2
offset_noise: float = 0.1 offset_noise: float = 0.1
min_timestep: int = 0 min_step: int = 0
max_timestep: int = 999 max_step: int = 999
class TestDiffusionConfig(BaseModel): class TestDiffusionConfig(BaseModel):
@ -154,9 +154,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
).to(device=self.device) ).to(device=self.device)
def sample_timestep(self) -> Tensor: def sample_timestep(self) -> Tensor:
random_step = random.randint( random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep
)
self.current_step = random_step self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0) return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)