foundationals: add clip image encoder

This commit is contained in:
Cédric Deltheil 2023-08-29 16:33:33 +02:00 committed by Cédric Deltheil
parent 32c1cfdbb1
commit d8004718c8
3 changed files with 235 additions and 50 deletions

View file

@ -0,0 +1,51 @@
from torch import Tensor, arange, device as Device, dtype as DType
import refiners.fluxion.layers as fl
class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
)
@property
def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)

View file

@ -0,0 +1,182 @@
from torch import device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
class ClassEncoder(fl.Chain):
structural_attrs = ["embedding_dim"]
def __init__(
self,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
super().__init__(
fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)),
fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)),
)
class PatchEncoder(fl.Chain):
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 16,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.use_bias = use_bias
super().__init__(
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
use_bias=self.use_bias,
device=device,
dtype=dtype,
),
fl.Permute(0, 2, 3, 1),
)
class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
def __init__(
self,
embedding_dim: int = 768,
feedforward_dim: int = 3072,
num_attention_heads: int = 12,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
self.num_attention_heads = num_attention_heads
self.layer_norm_eps = layer_norm_eps
super().__init__(
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.SelfAttention(
embedding_dim=embedding_dim, num_heads=num_attention_heads, device=device, dtype=dtype
),
),
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
),
)
class ViTEmbeddings(fl.Chain):
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
def __init__(
self,
image_size: int = 224,
embedding_dim: int = 768,
patch_size: int = 32,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Concatenate(
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
fl.Chain(
PatchEncoder(
in_channels=3,
out_channels=embedding_dim,
patch_size=patch_size,
use_bias=False,
device=device,
dtype=dtype,
),
fl.Reshape((image_size // patch_size) ** 2, embedding_dim),
),
dim=1,
),
fl.Residual(
PositionalEncoder(
max_sequence_length=(image_size // patch_size) ** 2 + 1,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
class CLIPImageEncoder(fl.Chain):
structural_attrs = [
"image_size",
"embedding_dim",
"patch_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
]
def __init__(
self,
image_size: int = 224,
embedding_dim: int = 768,
output_dim: int = 512,
patch_size: int = 32,
num_layers: int = 12,
num_attention_heads: int = 12,
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
super().__init__(
ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Chain(
TransformerLayer(
embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
num_attention_heads=num_attention_heads,
layer_norm_eps=layer_norm_eps,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.Lambda(func=lambda x: x[:, 0, :]),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
)
class CLIPImageEncoderH(CLIPImageEncoder):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1280,
output_dim=1024,
patch_size=14,
num_layers=32,
num_attention_heads=16,
feedforward_dim=5120,
device=device,
dtype=dtype,
)

View file

@ -1,5 +1,6 @@
from torch import Tensor, arange, device as Device, dtype as DType
from torch import device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -23,55 +24,6 @@ class TokenEncoder(fl.Embedding):
)
class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
)
@property
def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]