mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
training_utils: fix naming issue timestep->step
This commit is contained in:
parent
0dc3a17fbf
commit
46b4b4b462
|
@ -29,8 +29,8 @@ from torch.nn import Module
|
||||||
class LatentDiffusionConfig(BaseModel):
|
class LatentDiffusionConfig(BaseModel):
|
||||||
unconditional_sampling_probability: float = 0.2
|
unconditional_sampling_probability: float = 0.2
|
||||||
offset_noise: float = 0.1
|
offset_noise: float = 0.1
|
||||||
min_timestep: int = 0
|
min_step: int = 0
|
||||||
max_timestep: int = 999
|
max_step: int = 999
|
||||||
|
|
||||||
|
|
||||||
class TestDiffusionConfig(BaseModel):
|
class TestDiffusionConfig(BaseModel):
|
||||||
|
@ -154,9 +154,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
).to(device=self.device)
|
).to(device=self.device)
|
||||||
|
|
||||||
def sample_timestep(self) -> Tensor:
|
def sample_timestep(self) -> Tensor:
|
||||||
random_step = random.randint(
|
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
|
||||||
a=self.config.latent_diffusion.min_timestep, b=self.config.latent_diffusion.max_timestep
|
|
||||||
)
|
|
||||||
self.current_step = random_step
|
self.current_step = random_step
|
||||||
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue