mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
support disabling CFG in LatentDiffusionModel
This commit is contained in:
parent
446967859d
commit
4a619e84f0
|
@ -19,6 +19,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
lda: LatentDiffusionAutoencoder,
|
||||
clip_text_encoder: fl.Chain,
|
||||
solver: Solver,
|
||||
classifier_free_guidance: bool = True,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
|
@ -29,6 +30,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.solver = solver.to(device=self.device, dtype=self.dtype)
|
||||
self.classifier_free_guidance = classifier_free_guidance
|
||||
|
||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||
|
@ -80,24 +82,33 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
def forward(
|
||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||
) -> Tensor:
|
||||
if self.classifier_free_guidance:
|
||||
assert clip_text_embedding.shape[0] % 2 == 0, f"invalid batch size: {clip_text_embedding.shape[0]}"
|
||||
|
||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
latents = torch.cat(tensors=(x, x)) if self.classifier_free_guidance else x
|
||||
# scale latents for solvers that need it
|
||||
latents = self.solver.scale_model_input(latents, step=step)
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
if self.classifier_free_guidance:
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
predicted_noise = unconditional_prediction + condition_scale * (
|
||||
conditional_prediction - unconditional_prediction
|
||||
)
|
||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||
|
||||
if self.has_self_attention_guidance():
|
||||
predicted_noise += self.compute_self_attention_guidance(
|
||||
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||
x=x,
|
||||
noise=unconditional_prediction,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
predicted_noise = self.unet(latents)
|
||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||
|
||||
return self.solver(x, predicted_noise=predicted_noise, step=step)
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||
return time_ids.repeat(2, 1)
|
||||
return time_ids.repeat(2 if self.classifier_free_guidance else 1, 1)
|
||||
|
||||
def set_unet_context(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue