support disabling CFG in LatentDiffusionModel

This commit is contained in:
Pierre Chapuis 2024-01-18 18:33:34 +01:00
parent 446967859d
commit 4a619e84f0
2 changed files with 23 additions and 12 deletions

View file

@ -19,6 +19,7 @@ class LatentDiffusionModel(fl.Module, ABC):
lda: LatentDiffusionAutoencoder, lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Chain, clip_text_encoder: fl.Chain,
solver: Solver, solver: Solver,
classifier_free_guidance: bool = True,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
@ -29,6 +30,7 @@ class LatentDiffusionModel(fl.Module, ABC):
self.lda = lda.to(device=self.device, dtype=self.dtype) self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.solver = solver.to(device=self.device, dtype=self.dtype) self.solver = solver.to(device=self.device, dtype=self.dtype)
self.classifier_free_guidance = classifier_free_guidance
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
@ -80,24 +82,33 @@ class LatentDiffusionModel(fl.Module, ABC):
def forward( def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor: ) -> Tensor:
if self.classifier_free_guidance:
assert clip_text_embedding.shape[0] % 2 == 0, f"invalid batch size: {clip_text_embedding.shape[0]}"
timestep = self.solver.timesteps[step].unsqueeze(dim=0) timestep = self.solver.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance latents = torch.cat(tensors=(x, x)) if self.classifier_free_guidance else x
# scale latents for solvers that need it # scale latents for solvers that need it
latents = self.solver.scale_model_input(latents, step=step) latents = self.solver.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance if self.classifier_free_guidance:
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
predicted_noise = unconditional_prediction + condition_scale * ( predicted_noise = unconditional_prediction + condition_scale * (
conditional_prediction - unconditional_prediction conditional_prediction - unconditional_prediction
) )
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
if self.has_self_attention_guidance(): if self.has_self_attention_guidance():
predicted_noise += self.compute_self_attention_guidance( predicted_noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs x=x,
noise=unconditional_prediction,
step=step,
clip_text_embedding=clip_text_embedding,
**kwargs,
) )
else:
predicted_noise = self.unet(latents)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
return self.solver(x, predicted_noise=predicted_noise, step=step) return self.solver(x, predicted_noise=predicted_noise, step=step)

View file

@ -93,7 +93,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
# [original_height, original_width, crop_top, crop_left, target_height, target_width] # [original_height, original_width, crop_top, crop_left, target_height, target_width]
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning # See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device) time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
return time_ids.repeat(2, 1) return time_ids.repeat(2 if self.classifier_free_guidance else 1, 1)
def set_unet_context( def set_unet_context(
self, self,