mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
support disabling CFG in LatentDiffusionModel
This commit is contained in:
parent
446967859d
commit
4a619e84f0
|
@ -19,6 +19,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
lda: LatentDiffusionAutoencoder,
|
lda: LatentDiffusionAutoencoder,
|
||||||
clip_text_encoder: fl.Chain,
|
clip_text_encoder: fl.Chain,
|
||||||
solver: Solver,
|
solver: Solver,
|
||||||
|
classifier_free_guidance: bool = True,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -29,6 +30,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
self.solver = solver.to(device=self.device, dtype=self.dtype)
|
self.solver = solver.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.classifier_free_guidance = classifier_free_guidance
|
||||||
|
|
||||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||||
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||||
|
@ -80,24 +82,33 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
if self.classifier_free_guidance:
|
||||||
|
assert clip_text_embedding.shape[0] % 2 == 0, f"invalid batch size: {clip_text_embedding.shape[0]}"
|
||||||
|
|
||||||
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||||
|
|
||||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
latents = torch.cat(tensors=(x, x)) if self.classifier_free_guidance else x
|
||||||
# scale latents for solvers that need it
|
# scale latents for solvers that need it
|
||||||
latents = self.solver.scale_model_input(latents, step=step)
|
latents = self.solver.scale_model_input(latents, step=step)
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
|
||||||
|
|
||||||
# classifier-free guidance
|
if self.classifier_free_guidance:
|
||||||
predicted_noise = unconditional_prediction + condition_scale * (
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
conditional_prediction - unconditional_prediction
|
predicted_noise = unconditional_prediction + condition_scale * (
|
||||||
)
|
conditional_prediction - unconditional_prediction
|
||||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
|
||||||
|
|
||||||
if self.has_self_attention_guidance():
|
|
||||||
predicted_noise += self.compute_self_attention_guidance(
|
|
||||||
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
|
||||||
)
|
)
|
||||||
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
if self.has_self_attention_guidance():
|
||||||
|
predicted_noise += self.compute_self_attention_guidance(
|
||||||
|
x=x,
|
||||||
|
noise=unconditional_prediction,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
predicted_noise = self.unet(latents)
|
||||||
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
|
||||||
return self.solver(x, predicted_noise=predicted_noise, step=step)
|
return self.solver(x, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||||
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||||
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||||
return time_ids.repeat(2, 1)
|
return time_ids.repeat(2 if self.classifier_free_guidance else 1, 1)
|
||||||
|
|
||||||
def set_unet_context(
|
def set_unet_context(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue