mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
add helper for multiple image prompts
This commit is contained in:
parent
fd01ba910e
commit
ca5c5a7ca5
|
@ -1,9 +1,9 @@
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
||||||
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
|
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -413,7 +413,53 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||||
|
|
||||||
|
@overload
|
||||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def compute_clip_image_embedding(
|
||||||
|
self, image_prompt: list[Image.Image], weights: list[float] | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
def compute_clip_image_embedding(
|
||||||
|
self,
|
||||||
|
image_prompt: Tensor | Image.Image | list[Image.Image],
|
||||||
|
weights: list[float] | None = None,
|
||||||
|
concat_batches: bool = True,
|
||||||
|
) -> Tensor:
|
||||||
|
if isinstance(image_prompt, Image.Image):
|
||||||
|
image_prompt = self.preprocess_image(image_prompt)
|
||||||
|
elif isinstance(image_prompt, list):
|
||||||
|
assert all(isinstance(image, Image.Image) for image in image_prompt)
|
||||||
|
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
|
||||||
|
|
||||||
|
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
||||||
|
|
||||||
|
batch_size = image_prompt.shape[0]
|
||||||
|
if weights is not None:
|
||||||
|
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
|
||||||
|
if any(weight != 1.0 for weight in weights):
|
||||||
|
conditional_embedding *= (
|
||||||
|
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_size > 1 and concat_batches:
|
||||||
|
# Create a longer image tokens sequence when a batch of images is given
|
||||||
|
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
|
||||||
|
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
|
||||||
|
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
|
||||||
|
|
||||||
|
return cat((negative_embedding, conditional_embedding))
|
||||||
|
|
||||||
|
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||||
clip_embedding = image_encoder(image_prompt)
|
clip_embedding = image_encoder(image_prompt)
|
||||||
conditional_embedding = self.image_proj(clip_embedding)
|
conditional_embedding = self.image_proj(clip_embedding)
|
||||||
|
@ -423,7 +469,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||||
clip_embedding = image_encoder(zeros_like(image_prompt))
|
clip_embedding = image_encoder(zeros_like(image_prompt))
|
||||||
negative_embedding = self.image_proj(clip_embedding)
|
negative_embedding = self.image_proj(clip_embedding)
|
||||||
return cat((negative_embedding, conditional_embedding))
|
return negative_embedding, conditional_embedding
|
||||||
|
|
||||||
def preprocess_image(
|
def preprocess_image(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue