(doc/fluxion/basic) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:08:15 +00:00 committed by Laureηt
parent a7c048f5fb
commit 77b97b3c8e

View file

@ -1,11 +1,26 @@
import torch import torch
from torch import Size, Tensor, device as Device, dtype as DType, randn from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
from refiners.fluxion.layers.module import Module, WeightedModule from refiners.fluxion.layers.module import Module, WeightedModule
class Identity(Module): class Identity(Module):
"""Identity operator layer.
This layer simply returns the input tensor.
Example:
```py
identity = fl.Identity()
tensor = torch.randn(10, 10)
output = identity(tensor)
assert torch.allclose(tensor, output)
```
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -23,6 +38,25 @@ class View(Module):
class GetArg(Module): class GetArg(Module):
"""GetArg operation layer.
This layer returns the nth tensor of the input arguments.
Example:
```py
get_arg = fl.GetArg(1)
inputs = (
torch.randn(10, 10),
torch.randn(20, 20),
torch.randn(30, 30),
)
output = get_arg(inputs)
assert torch.allclose(tensor[1], output)
```
"""
def __init__(self, index: int) -> None: def __init__(self, index: int) -> None:
super().__init__() super().__init__()
self.index = index self.index = index
@ -32,28 +66,86 @@ class GetArg(Module):
class Flatten(Module): class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: """Flatten operation layer.
This layer flattens the input tensor between the given dimensions.
See also [`torch.flatten`][torch.flatten].
Example:
```py
flatten = fl.Flatten(start_dim=1)
tensor = torch.randn(10, 10, 10)
output = flatten(tensor)
assert output.shape == (10, 100)
```
"""
def __init__(
self,
start_dim: int = 0,
end_dim: int = -1,
) -> None:
super().__init__() super().__init__()
self.start_dim = start_dim self.start_dim = start_dim
self.end_dim = end_dim self.end_dim = end_dim
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim) return torch.flatten(
input=x,
start_dim=self.start_dim,
end_dim=self.end_dim,
)
class Unflatten(Module): class Unflatten(Module):
"""Unflatten operation layer.
This layer unflattens the input tensor at the given dimension with the given sizes.
See also [`torch.unflatten`][torch.unflatten].
Example:
```py
unflatten = fl.Unflatten(dim=1)
tensor = torch.randn(10, 100)
output = unflatten(tensor, sizes=(10, 10))
assert output_unflatten.shape == (10, 10, 10)
```
"""
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
def forward(self, x: Tensor, sizes: Size) -> Tensor: def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore return torch.unflatten(
input=x,
dim=self.dim,
sizes=sizes,
)
class Reshape(Module): class Reshape(Module):
""" """Reshape operation layer.
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
dimension is preserved. This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape).
See also [torch.reshape][torch.reshape].
Warning:
The first dimension (batch dimension) is forcefully preserved.
Example:
```py
reshape = fl.Reshape(5, 2)
tensor = torch.randn(2, 10, 1)
output = reshape(tensor)
assert output.shape == (2, 5, 2)
```
""" """
def __init__(self, *shape: int) -> None: def __init__(self, *shape: int) -> None:
@ -61,30 +153,95 @@ class Reshape(Module):
self.shape = shape self.shape = shape
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape) return torch.reshape(
input=x,
shape=(x.shape[0], *self.shape),
)
class Transpose(Module): class Transpose(Module):
"""Transpose operation layer.
This layer transposes the input tensor between the two given dimensions.
See also [`torch.transpose`][torch.transpose].
Example:
```py
transpose = fl.Transpose(dim0=1, dim1=2)
tensor = torch.randn(10, 20, 30)
output = transpose(tensor)
assert output.shape == (10, 30, 20)
```
"""
def __init__(self, dim0: int, dim1: int) -> None: def __init__(self, dim0: int, dim1: int) -> None:
super().__init__() super().__init__()
self.dim0 = dim0 self.dim0 = dim0
self.dim1 = dim1 self.dim1 = dim1
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1) return torch.transpose(
input=x,
dim0=self.dim0,
dim1=self.dim1,
)
class Permute(Module): class Permute(Module):
"""Permute operation layer.
This layer permutes the input tensor according to the given dimensions.
See also [`torch.permute`][torch.permute].
Example:
```py
permute = fl.Permute(2, 0, 1)
tensor = torch.randn(10, 20, 30)
output = permute(tensor)
assert output.shape == (30, 10, 20)
```
"""
def __init__(self, *dims: int) -> None: def __init__(self, *dims: int) -> None:
super().__init__() super().__init__()
self.dims = dims self.dims = dims
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims) return torch.permute(
input=x,
dims=self.dims,
)
class Slicing(Module): class Slicing(Module):
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None: """Slicing operation layer.
This layer slices the input tensor at the given dimension between the given start and end indices.
See also [`torch.index_select`][torch.index_select].
Example:
```py
slicing = fl.Slicing(dim=1, start=50)
tensor = torch.randn(10, 100)
output = slicing(tensor)
assert output.shape == (10, 50)
assert torch.allclose(output, tensor[:, 50:])
```
"""
def __init__(
self,
dim: int = 0,
start: int = 0,
end: int | None = None,
step: int = 1,
) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.start = start self.start = start
@ -93,55 +250,162 @@ class Slicing(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
dim_size = x.shape[self.dim] dim_size = x.shape[self.dim]
# compute start index
start = self.start if self.start >= 0 else dim_size + self.start start = self.start if self.start >= 0 else dim_size + self.start
start = max(min(start, dim_size), 0)
# compute end index
end = self.end or dim_size end = self.end or dim_size
end = end if end >= 0 else dim_size + end end = end if end >= 0 else dim_size + end
start = max(min(start, dim_size), 0)
end = max(min(end, dim_size), 0) end = max(min(end, dim_size), 0)
if start >= end:
return self.get_empty_slice(x)
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
return x.index_select(self.dim, indices)
def get_empty_slice(self, x: Tensor) -> Tensor: if start >= end:
""" return self._get_empty_slice(x)
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
""" # compute indices
indices = torch.arange(
start=start,
end=end,
step=self.step,
device=x.device,
)
return torch.index_select(
input=x,
dim=self.dim,
index=indices,
)
def _get_empty_slice(self, x: Tensor) -> Tensor:
"""Get an empty slice of the same shape as the input tensor (to mimic PyTorch's slicing behavior)."""
shape = list(x.shape) shape = list(x.shape)
shape[self.dim] = 0 shape[self.dim] = 0
return torch.empty(*shape, device=x.device) return torch.empty(*shape, device=x.device)
class Squeeze(Module): class Squeeze(Module):
"""Squeeze operation layer.
This layer squeezes the input tensor at the given dimension.
See also [`torch.squeeze`][torch.squeeze].
Example:
```py
squeeze = fl.Squeeze(dim=1)
tensor = torch.randn(10, 1, 10)
output = squeeze(tensor)
assert output.shape == (10, 10)
```
"""
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim) return torch.squeeze(
input=x,
dim=self.dim,
)
class Unsqueeze(Module): class Unsqueeze(Module):
"""Unsqueeze operation layer.
This layer unsqueezes the input tensor at the given dimension.
See also [`torch.unsqueeze`][torch.unsqueeze].
Example:
```py
unsqueeze = fl.Unsqueeze(dim=1)
tensor = torch.randn(10, 10)
output = unsqueeze(tensor)
assert output.shape == (10, 1, 10)
```
"""
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim) return torch.unsqueeze(
input=x,
dim=self.dim,
)
class Sin(Module): class Sin(Module):
"""Sine operator layer.
This layer applies the sine function to the input tensor.
See also [`torch.sin`][torch.sin].
Example:
```py
sin = fl.Sin()
tensor = torch.tensor([0, torch.pi])
output = sin(tensor)
expected_output = torch.tensor([0.0, 0.0])
assert torch.allclose(output, expected_output, atol=1e-6)
```
"""
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return torch.sin(input=x) return torch.sin(input=x)
class Cos(Module): class Cos(Module):
"""Cosine operator layer.
This layer applies the cosine function to the input tensor.
See also [`torch.cos`][torch.cos].
Example:
```py
cos = fl.Cos()
tensor = torch.tensor([0, torch.pi])
output = cos(tensor)
expected_output = torch.tensor([1.0, -1.0])
assert torch.allclose(output, expected_output, atol=1e-6)
```
"""
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return torch.cos(input=x) return torch.cos(input=x)
class Multiply(Module): class Multiply(Module):
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None: """Multiply operator layer.
This layer scales and shifts the input tensor by the given scale and bias.
Example:
```py
multiply = fl.Multiply(scale=2, bias=1)
tensor = torch.ones(1)
output = multiply(tensor)
assert torch.allclose(output, torch.tensor([3.0]))
```
"""
def __init__(
self,
scale: float = 1.0,
bias: float = 0.0,
) -> None:
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.bias = bias self.bias = bias
@ -151,14 +415,40 @@ class Multiply(Module):
class Parameter(WeightedModule): class Parameter(WeightedModule):
""" """Parameter layer.
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
This layer simple wraps a PyTorch [`Parameter`][torch.nn.parameter.Parameter].
When called, it simply returns the [`Parameter`][torch.nn.parameter.Parameter] Tensor.
Attributes:
weight (torch.nn.parameter.Parameter): The parameter Tensor.
""" """
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(
self,
*dims: int,
requires_grad: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype)) self.weight = TorchParameter(
requires_grad=requires_grad,
data=torch.randn(
*dims,
device=device,
dtype=dtype,
),
)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims) return self.weight.expand(x.shape[0], *self.dims)
@property
def requires_grad(self) -> bool:
return self.weight.requires_grad
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
self.weight.requires_grad = value