diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 3738698..2771255 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,9 +1,9 @@ import math -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from jaxtyping import Float from PIL import Image -from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like +from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter @@ -413,7 +413,53 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): def set_clip_image_embedding(self, image_embedding: Tensor) -> None: self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) + @overload def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: + ... + + @overload + def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: + ... + + @overload + def compute_clip_image_embedding( + self, image_prompt: list[Image.Image], weights: list[float] | None = None + ) -> Tensor: + ... + + def compute_clip_image_embedding( + self, + image_prompt: Tensor | Image.Image | list[Image.Image], + weights: list[float] | None = None, + concat_batches: bool = True, + ) -> Tensor: + if isinstance(image_prompt, Image.Image): + image_prompt = self.preprocess_image(image_prompt) + elif isinstance(image_prompt, list): + assert all(isinstance(image, Image.Image) for image in image_prompt) + image_prompt = cat([self.preprocess_image(image) for image in image_prompt]) + + negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt) + + batch_size = image_prompt.shape[0] + if weights is not None: + assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images" + if any(weight != 1.0 for weight in weights): + conditional_embedding *= ( + tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype) + .unsqueeze(-1) + .unsqueeze(-1) + ) + + if batch_size > 1 and concat_batches: + # Create a longer image tokens sequence when a batch of images is given + # See https://github.com/tencent-ailab/IP-Adapter/issues/99 + negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1) + conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1) + + return cat((negative_embedding, conditional_embedding)) + + def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]: image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder clip_embedding = image_encoder(image_prompt) conditional_embedding = self.image_proj(clip_embedding) @@ -423,7 +469,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): # See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352 clip_embedding = image_encoder(zeros_like(image_prompt)) negative_embedding = self.image_proj(clip_embedding) - return cat((negative_embedding, conditional_embedding)) + return negative_embedding, conditional_embedding def preprocess_image( self,