diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 697debe..55ca1d6 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -19,6 +19,7 @@ class LatentDiffusionModel(fl.Module, ABC): lda: LatentDiffusionAutoencoder, clip_text_encoder: fl.Chain, solver: Solver, + classifier_free_guidance: bool = True, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: @@ -29,6 +30,7 @@ class LatentDiffusionModel(fl.Module, ABC): self.lda = lda.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.solver = solver.to(device=self.device, dtype=self.dtype) + self.classifier_free_guidance = classifier_free_guidance def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) @@ -80,24 +82,33 @@ class LatentDiffusionModel(fl.Module, ABC): def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor ) -> Tensor: + if self.classifier_free_guidance: + assert clip_text_embedding.shape[0] % 2 == 0, f"invalid batch size: {clip_text_embedding.shape[0]}" + timestep = self.solver.timesteps[step].unsqueeze(dim=0) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) - latents = torch.cat(tensors=(x, x)) # for classifier-free guidance + latents = torch.cat(tensors=(x, x)) if self.classifier_free_guidance else x # scale latents for solvers that need it latents = self.solver.scale_model_input(latents, step=step) - unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) - # classifier-free guidance - predicted_noise = unconditional_prediction + condition_scale * ( - conditional_prediction - unconditional_prediction - ) - x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting - - if self.has_self_attention_guidance(): - predicted_noise += self.compute_self_attention_guidance( - x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs + if self.classifier_free_guidance: + unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) + predicted_noise = unconditional_prediction + condition_scale * ( + conditional_prediction - unconditional_prediction ) + x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting + if self.has_self_attention_guidance(): + predicted_noise += self.compute_self_attention_guidance( + x=x, + noise=unconditional_prediction, + step=step, + clip_text_embedding=clip_text_embedding, + **kwargs, + ) + else: + predicted_noise = self.unet(latents) + x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting return self.solver(x, predicted_noise=predicted_noise, step=step) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 7899365..31a747d 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -93,7 +93,7 @@ class StableDiffusion_XL(LatentDiffusionModel): # [original_height, original_width, crop_top, crop_left, target_height, target_width] # See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device) - return time_ids.repeat(2, 1) + return time_ids.repeat(2 if self.classifier_free_guidance else 1, 1) def set_unet_context( self,