add helper for multiple image prompts

This commit is contained in:
Cédric Deltheil 2024-01-29 17:06:56 +00:00 committed by Cédric Deltheil
parent fd01ba910e
commit ca5c5a7ca5

View file

@ -1,9 +1,9 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
@ -413,7 +413,53 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
@overload
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
...
@overload
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor:
...
@overload
def compute_clip_image_embedding(
self, image_prompt: list[Image.Image], weights: list[float] | None = None
) -> Tensor:
...
def compute_clip_image_embedding(
self,
image_prompt: Tensor | Image.Image | list[Image.Image],
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
batch_size = image_prompt.shape[0]
if weights is not None:
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
if any(weight != 1.0 for weight in weights):
conditional_embedding *= (
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
.unsqueeze(-1)
.unsqueeze(-1)
)
if batch_size > 1 and concat_batches:
# Create a longer image tokens sequence when a batch of images is given
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
return cat((negative_embedding, conditional_embedding))
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
@ -423,7 +469,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return cat((negative_embedding, conditional_embedding))
return negative_embedding, conditional_embedding
def preprocess_image(
self,