diff --git a/src/refiners/foundationals/clip/common.py b/src/refiners/foundationals/clip/common.py new file mode 100644 index 0000000..399e87d --- /dev/null +++ b/src/refiners/foundationals/clip/common.py @@ -0,0 +1,51 @@ +from torch import Tensor, arange, device as Device, dtype as DType +import refiners.fluxion.layers as fl + + +class PositionalEncoder(fl.Chain): + structural_attrs = ["max_sequence_length", "embedding_dim"] + + def __init__( + self, + max_sequence_length: int, + embedding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.max_sequence_length = max_sequence_length + self.embedding_dim = embedding_dim + super().__init__( + fl.Lambda(func=self.get_position_ids), + fl.Embedding( + num_embeddings=max_sequence_length, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + ) + + @property + def position_ids(self) -> Tensor: + return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) + + def get_position_ids(self, x: Tensor) -> Tensor: + return self.position_ids[:, : x.shape[1]] + + +class FeedForward(fl.Chain): + structural_attrs = ["embedding_dim", "feedforward_dim"] + + def __init__( + self, + embedding_dim: int, + feedforward_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.feedforward_dim = feedforward_dim + super().__init__( + fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype), + fl.GeLU(), + fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype), + ) diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py new file mode 100644 index 0000000..c54d901 --- /dev/null +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -0,0 +1,182 @@ +from torch import device as Device, dtype as DType +import refiners.fluxion.layers as fl +from refiners.foundationals.clip.common import PositionalEncoder, FeedForward + + +class ClassEncoder(fl.Chain): + structural_attrs = ["embedding_dim"] + + def __init__( + self, + embedding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + super().__init__( + fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)), + fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)), + ) + + +class PatchEncoder(fl.Chain): + structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"] + + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int = 16, + use_bias: bool = True, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.use_bias = use_bias + super().__init__( + fl.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=(self.patch_size, self.patch_size), + stride=(self.patch_size, self.patch_size), + use_bias=self.use_bias, + device=device, + dtype=dtype, + ), + fl.Permute(0, 2, 3, 1), + ) + + +class TransformerLayer(fl.Chain): + structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"] + + def __init__( + self, + embedding_dim: int = 768, + feedforward_dim: int = 3072, + num_attention_heads: int = 12, + layer_norm_eps: float = 1e-5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.feedforward_dim = feedforward_dim + self.num_attention_heads = num_attention_heads + self.layer_norm_eps = layer_norm_eps + super().__init__( + fl.Residual( + fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), + fl.SelfAttention( + embedding_dim=embedding_dim, num_heads=num_attention_heads, device=device, dtype=dtype + ), + ), + fl.Residual( + fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), + FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype), + ), + ) + + +class ViTEmbeddings(fl.Chain): + structural_attrs = ["image_size", "embedding_dim", "patch_size"] + + def __init__( + self, + image_size: int = 224, + embedding_dim: int = 768, + patch_size: int = 32, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Concatenate( + ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype), + fl.Chain( + PatchEncoder( + in_channels=3, + out_channels=embedding_dim, + patch_size=patch_size, + use_bias=False, + device=device, + dtype=dtype, + ), + fl.Reshape((image_size // patch_size) ** 2, embedding_dim), + ), + dim=1, + ), + fl.Residual( + PositionalEncoder( + max_sequence_length=(image_size // patch_size) ** 2 + 1, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + ), + ) + + +class CLIPImageEncoder(fl.Chain): + structural_attrs = [ + "image_size", + "embedding_dim", + "patch_size", + "num_layers", + "num_attention_heads", + "feedforward_dim", + ] + + def __init__( + self, + image_size: int = 224, + embedding_dim: int = 768, + output_dim: int = 512, + patch_size: int = 32, + num_layers: int = 12, + num_attention_heads: int = 12, + feedforward_dim: int = 3072, + layer_norm_eps: float = 1e-5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.image_size = image_size + self.embedding_dim = embedding_dim + self.patch_size = patch_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.feedforward_dim = feedforward_dim + super().__init__( + ViTEmbeddings( + image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype + ), + fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), + fl.Chain( + TransformerLayer( + embedding_dim=embedding_dim, + feedforward_dim=feedforward_dim, + num_attention_heads=num_attention_heads, + layer_norm_eps=layer_norm_eps, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ), + fl.Lambda(func=lambda x: x[:, 0, :]), + fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), + fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype), + ) + + +class CLIPImageEncoderH(CLIPImageEncoder): + def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: + super().__init__( + embedding_dim=1280, + output_dim=1024, + patch_size=14, + num_layers=32, + num_attention_heads=16, + feedforward_dim=5120, + device=device, + dtype=dtype, + ) diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index ef1e455..638e343 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -1,5 +1,6 @@ -from torch import Tensor, arange, device as Device, dtype as DType +from torch import device as Device, dtype as DType import refiners.fluxion.layers as fl +from refiners.foundationals.clip.common import PositionalEncoder, FeedForward from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -23,55 +24,6 @@ class TokenEncoder(fl.Embedding): ) -class PositionalEncoder(fl.Chain): - structural_attrs = ["max_sequence_length", "embedding_dim"] - - def __init__( - self, - max_sequence_length: int, - embedding_dim: int, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - self.max_sequence_length = max_sequence_length - self.embedding_dim = embedding_dim - super().__init__( - fl.Lambda(func=self.get_position_ids), - fl.Embedding( - num_embeddings=max_sequence_length, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - ), - ) - - @property - def position_ids(self) -> Tensor: - return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) - - def get_position_ids(self, x: Tensor) -> Tensor: - return self.position_ids[:, : x.shape[1]] - - -class FeedForward(fl.Chain): - structural_attrs = ["embedding_dim", "feedforward_dim"] - - def __init__( - self, - embedding_dim: int, - feedforward_dim: int, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - self.embedding_dim = embedding_dim - self.feedforward_dim = feedforward_dim - super().__init__( - fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype), - fl.GeLU(), - fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype), - ) - - class TransformerLayer(fl.Chain): structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]