mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
foundationals: add clip image encoder
This commit is contained in:
parent
32c1cfdbb1
commit
d8004718c8
51
src/refiners/foundationals/clip/common.py
Normal file
51
src/refiners/foundationals/clip/common.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Chain):
|
||||||
|
structural_attrs = ["max_sequence_length", "embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_sequence_length: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(func=self.get_position_ids),
|
||||||
|
fl.Embedding(
|
||||||
|
num_embeddings=max_sequence_length,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position_ids(self) -> Tensor:
|
||||||
|
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
182
src/refiners/foundationals/clip/image_encoder.py
Normal file
182
src/refiners/foundationals/clip/image_encoder.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
||||||
|
|
||||||
|
|
||||||
|
class ClassEncoder(fl.Chain):
|
||||||
|
structural_attrs = ["embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)),
|
||||||
|
fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(fl.Chain):
|
||||||
|
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_size: int = 16,
|
||||||
|
use_bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.use_bias = use_bias
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=(self.patch_size, self.patch_size),
|
||||||
|
stride=(self.patch_size, self.patch_size),
|
||||||
|
use_bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Permute(0, 2, 3, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.SelfAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_attention_heads, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViTEmbeddings(fl.Chain):
|
||||||
|
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int = 224,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
patch_size: int = 32,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.Chain(
|
||||||
|
PatchEncoder(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=embedding_dim,
|
||||||
|
patch_size=patch_size,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Reshape((image_size // patch_size) ** 2, embedding_dim),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
PositionalEncoder(
|
||||||
|
max_sequence_length=(image_size // patch_size) ** 2 + 1,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoder(fl.Chain):
|
||||||
|
structural_attrs = [
|
||||||
|
"image_size",
|
||||||
|
"embedding_dim",
|
||||||
|
"patch_size",
|
||||||
|
"num_layers",
|
||||||
|
"num_attention_heads",
|
||||||
|
"feedforward_dim",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int = 224,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
output_dim: int = 512,
|
||||||
|
patch_size: int = 32,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.image_size = image_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
ViTEmbeddings(
|
||||||
|
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.Chain(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
fl.Lambda(func=lambda x: x[:, 0, :]),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoderH(CLIPImageEncoder):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1280,
|
||||||
|
output_dim=1024,
|
||||||
|
patch_size=14,
|
||||||
|
num_layers=32,
|
||||||
|
num_attention_heads=16,
|
||||||
|
feedforward_dim=5120,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
|
@ -1,5 +1,6 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,55 +24,6 @@ class TokenEncoder(fl.Embedding):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoder(fl.Chain):
|
|
||||||
structural_attrs = ["max_sequence_length", "embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_sequence_length: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.max_sequence_length = max_sequence_length
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
super().__init__(
|
|
||||||
fl.Lambda(func=self.get_position_ids),
|
|
||||||
fl.Embedding(
|
|
||||||
num_embeddings=max_sequence_length,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def position_ids(self) -> Tensor:
|
|
||||||
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
|
||||||
|
|
||||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
|
||||||
return self.position_ids[:, : x.shape[1]]
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(fl.Chain):
|
|
||||||
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int,
|
|
||||||
feedforward_dim: int,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
self.feedforward_dim = feedforward_dim
|
|
||||||
super().__init__(
|
|
||||||
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
|
||||||
fl.GeLU(),
|
|
||||||
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerLayer(fl.Chain):
|
class TransformerLayer(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue