diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 3b8e595..70526df 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -1,11 +1,26 @@ import torch -from torch import Size, Tensor, device as Device, dtype as DType, randn +from torch import Size, Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter from refiners.fluxion.layers.module import Module, WeightedModule class Identity(Module): + """Identity operator layer. + + This layer simply returns the input tensor. + + Example: + ```py + identity = fl.Identity() + + tensor = torch.randn(10, 10) + output = identity(tensor) + + assert torch.allclose(tensor, output) + ``` + """ + def __init__(self) -> None: super().__init__() @@ -23,6 +38,25 @@ class View(Module): class GetArg(Module): + """GetArg operation layer. + + This layer returns the nth tensor of the input arguments. + + Example: + ```py + get_arg = fl.GetArg(1) + + inputs = ( + torch.randn(10, 10), + torch.randn(20, 20), + torch.randn(30, 30), + ) + output = get_arg(inputs) + + assert torch.allclose(tensor[1], output) + ``` + """ + def __init__(self, index: int) -> None: super().__init__() self.index = index @@ -32,28 +66,86 @@ class GetArg(Module): class Flatten(Module): - def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: + """Flatten operation layer. + + This layer flattens the input tensor between the given dimensions. + See also [`torch.flatten`][torch.flatten]. + + Example: + ```py + flatten = fl.Flatten(start_dim=1) + + tensor = torch.randn(10, 10, 10) + output = flatten(tensor) + + assert output.shape == (10, 100) + ``` + """ + + def __init__( + self, + start_dim: int = 0, + end_dim: int = -1, + ) -> None: super().__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, x: Tensor) -> Tensor: - return x.flatten(self.start_dim, self.end_dim) + return torch.flatten( + input=x, + start_dim=self.start_dim, + end_dim=self.end_dim, + ) class Unflatten(Module): + """Unflatten operation layer. + + This layer unflattens the input tensor at the given dimension with the given sizes. + See also [`torch.unflatten`][torch.unflatten]. + + Example: + ```py + unflatten = fl.Unflatten(dim=1) + + tensor = torch.randn(10, 100) + output = unflatten(tensor, sizes=(10, 10)) + + assert output_unflatten.shape == (10, 10, 10) + ``` + """ + def __init__(self, dim: int) -> None: super().__init__() self.dim = dim def forward(self, x: Tensor, sizes: Size) -> Tensor: - return x.unflatten(self.dim, sizes) # type: ignore + return torch.unflatten( + input=x, + dim=self.dim, + sizes=sizes, + ) class Reshape(Module): - """ - Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch - dimension is preserved. + """Reshape operation layer. + + This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape). + See also [torch.reshape][torch.reshape]. + + Warning: + The first dimension (batch dimension) is forcefully preserved. + + Example: + ```py + reshape = fl.Reshape(5, 2) + + tensor = torch.randn(2, 10, 1) + output = reshape(tensor) + + assert output.shape == (2, 5, 2) + ``` """ def __init__(self, *shape: int) -> None: @@ -61,30 +153,95 @@ class Reshape(Module): self.shape = shape def forward(self, x: Tensor) -> Tensor: - return x.reshape(x.shape[0], *self.shape) + return torch.reshape( + input=x, + shape=(x.shape[0], *self.shape), + ) class Transpose(Module): + """Transpose operation layer. + + This layer transposes the input tensor between the two given dimensions. + See also [`torch.transpose`][torch.transpose]. + + Example: + ```py + transpose = fl.Transpose(dim0=1, dim1=2) + + tensor = torch.randn(10, 20, 30) + output = transpose(tensor) + + assert output.shape == (10, 30, 20) + ``` + """ + def __init__(self, dim0: int, dim1: int) -> None: super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: Tensor) -> Tensor: - return x.transpose(self.dim0, self.dim1) + return torch.transpose( + input=x, + dim0=self.dim0, + dim1=self.dim1, + ) class Permute(Module): + """Permute operation layer. + + This layer permutes the input tensor according to the given dimensions. + See also [`torch.permute`][torch.permute]. + + Example: + ```py + permute = fl.Permute(2, 0, 1) + + tensor = torch.randn(10, 20, 30) + output = permute(tensor) + + assert output.shape == (30, 10, 20) + ``` + """ + def __init__(self, *dims: int) -> None: super().__init__() self.dims = dims def forward(self, x: Tensor) -> Tensor: - return x.permute(*self.dims) + return torch.permute( + input=x, + dims=self.dims, + ) class Slicing(Module): - def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None: + """Slicing operation layer. + + This layer slices the input tensor at the given dimension between the given start and end indices. + See also [`torch.index_select`][torch.index_select]. + + Example: + ```py + slicing = fl.Slicing(dim=1, start=50) + + tensor = torch.randn(10, 100) + output = slicing(tensor) + + assert output.shape == (10, 50) + assert torch.allclose(output, tensor[:, 50:]) + ``` + """ + + def __init__( + self, + dim: int = 0, + start: int = 0, + end: int | None = None, + step: int = 1, + ) -> None: super().__init__() self.dim = dim self.start = start @@ -93,55 +250,162 @@ class Slicing(Module): def forward(self, x: Tensor) -> Tensor: dim_size = x.shape[self.dim] + + # compute start index start = self.start if self.start >= 0 else dim_size + self.start + start = max(min(start, dim_size), 0) + + # compute end index end = self.end or dim_size end = end if end >= 0 else dim_size + end - start = max(min(start, dim_size), 0) end = max(min(end, dim_size), 0) - if start >= end: - return self.get_empty_slice(x) - indices = torch.arange(start=start, end=end, step=self.step, device=x.device) - return x.index_select(self.dim, indices) - def get_empty_slice(self, x: Tensor) -> Tensor: - """ - Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior. - """ + if start >= end: + return self._get_empty_slice(x) + + # compute indices + indices = torch.arange( + start=start, + end=end, + step=self.step, + device=x.device, + ) + + return torch.index_select( + input=x, + dim=self.dim, + index=indices, + ) + + def _get_empty_slice(self, x: Tensor) -> Tensor: + """Get an empty slice of the same shape as the input tensor (to mimic PyTorch's slicing behavior).""" + shape = list(x.shape) shape[self.dim] = 0 return torch.empty(*shape, device=x.device) class Squeeze(Module): + """Squeeze operation layer. + + This layer squeezes the input tensor at the given dimension. + See also [`torch.squeeze`][torch.squeeze]. + + Example: + ```py + squeeze = fl.Squeeze(dim=1) + + tensor = torch.randn(10, 1, 10) + output = squeeze(tensor) + + assert output.shape == (10, 10) + ``` + """ + def __init__(self, dim: int) -> None: super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: - return x.squeeze(self.dim) + return torch.squeeze( + input=x, + dim=self.dim, + ) class Unsqueeze(Module): + """Unsqueeze operation layer. + + This layer unsqueezes the input tensor at the given dimension. + See also [`torch.unsqueeze`][torch.unsqueeze]. + + Example: + ```py + unsqueeze = fl.Unsqueeze(dim=1) + + tensor = torch.randn(10, 10) + output = unsqueeze(tensor) + + assert output.shape == (10, 1, 10) + ``` + """ + def __init__(self, dim: int) -> None: super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: - return x.unsqueeze(self.dim) + return torch.unsqueeze( + input=x, + dim=self.dim, + ) class Sin(Module): + """Sine operator layer. + + This layer applies the sine function to the input tensor. + See also [`torch.sin`][torch.sin]. + + Example: + ```py + sin = fl.Sin() + + tensor = torch.tensor([0, torch.pi]) + output = sin(tensor) + + expected_output = torch.tensor([0.0, 0.0]) + assert torch.allclose(output, expected_output, atol=1e-6) + ``` + """ + def forward(self, x: Tensor) -> Tensor: return torch.sin(input=x) class Cos(Module): + """Cosine operator layer. + + This layer applies the cosine function to the input tensor. + See also [`torch.cos`][torch.cos]. + + Example: + ```py + cos = fl.Cos() + + tensor = torch.tensor([0, torch.pi]) + output = cos(tensor) + + expected_output = torch.tensor([1.0, -1.0]) + assert torch.allclose(output, expected_output, atol=1e-6) + ``` + """ + def forward(self, x: Tensor) -> Tensor: return torch.cos(input=x) class Multiply(Module): - def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None: + """Multiply operator layer. + + This layer scales and shifts the input tensor by the given scale and bias. + + Example: + ```py + multiply = fl.Multiply(scale=2, bias=1) + + tensor = torch.ones(1) + output = multiply(tensor) + + assert torch.allclose(output, torch.tensor([3.0])) + ``` + """ + + def __init__( + self, + scale: float = 1.0, + bias: float = 0.0, + ) -> None: super().__init__() self.scale = scale self.bias = bias @@ -151,14 +415,40 @@ class Multiply(Module): class Parameter(WeightedModule): - """ - A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias. + """Parameter layer. + + This layer simple wraps a PyTorch [`Parameter`][torch.nn.parameter.Parameter]. + When called, it simply returns the [`Parameter`][torch.nn.parameter.Parameter] Tensor. + + Attributes: + weight (torch.nn.parameter.Parameter): The parameter Tensor. """ - def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: + def __init__( + self, + *dims: int, + requires_grad: bool = True, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: super().__init__() self.dims = dims - self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype)) + self.weight = TorchParameter( + requires_grad=requires_grad, + data=torch.randn( + *dims, + device=device, + dtype=dtype, + ), + ) def forward(self, x: Tensor) -> Tensor: return self.weight.expand(x.shape[0], *self.dims) + + @property + def requires_grad(self) -> bool: + return self.weight.requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + self.weight.requires_grad = value