(doc/fluxion/basic) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:08:15 +00:00 committed by Laureηt
parent a7c048f5fb
commit 77b97b3c8e

View file

@ -1,11 +1,26 @@
import torch
from torch import Size, Tensor, device as Device, dtype as DType, randn
from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from refiners.fluxion.layers.module import Module, WeightedModule
class Identity(Module):
"""Identity operator layer.
This layer simply returns the input tensor.
Example:
```py
identity = fl.Identity()
tensor = torch.randn(10, 10)
output = identity(tensor)
assert torch.allclose(tensor, output)
```
"""
def __init__(self) -> None:
super().__init__()
@ -23,6 +38,25 @@ class View(Module):
class GetArg(Module):
"""GetArg operation layer.
This layer returns the nth tensor of the input arguments.
Example:
```py
get_arg = fl.GetArg(1)
inputs = (
torch.randn(10, 10),
torch.randn(20, 20),
torch.randn(30, 30),
)
output = get_arg(inputs)
assert torch.allclose(tensor[1], output)
```
"""
def __init__(self, index: int) -> None:
super().__init__()
self.index = index
@ -32,28 +66,86 @@ class GetArg(Module):
class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
"""Flatten operation layer.
This layer flattens the input tensor between the given dimensions.
See also [`torch.flatten`][torch.flatten].
Example:
```py
flatten = fl.Flatten(start_dim=1)
tensor = torch.randn(10, 10, 10)
output = flatten(tensor)
assert output.shape == (10, 100)
```
"""
def __init__(
self,
start_dim: int = 0,
end_dim: int = -1,
) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim)
return torch.flatten(
input=x,
start_dim=self.start_dim,
end_dim=self.end_dim,
)
class Unflatten(Module):
"""Unflatten operation layer.
This layer unflattens the input tensor at the given dimension with the given sizes.
See also [`torch.unflatten`][torch.unflatten].
Example:
```py
unflatten = fl.Unflatten(dim=1)
tensor = torch.randn(10, 100)
output = unflatten(tensor, sizes=(10, 10))
assert output_unflatten.shape == (10, 10, 10)
```
"""
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore
return torch.unflatten(
input=x,
dim=self.dim,
sizes=sizes,
)
class Reshape(Module):
"""
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
dimension is preserved.
"""Reshape operation layer.
This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape).
See also [torch.reshape][torch.reshape].
Warning:
The first dimension (batch dimension) is forcefully preserved.
Example:
```py
reshape = fl.Reshape(5, 2)
tensor = torch.randn(2, 10, 1)
output = reshape(tensor)
assert output.shape == (2, 5, 2)
```
"""
def __init__(self, *shape: int) -> None:
@ -61,30 +153,95 @@ class Reshape(Module):
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape)
return torch.reshape(
input=x,
shape=(x.shape[0], *self.shape),
)
class Transpose(Module):
"""Transpose operation layer.
This layer transposes the input tensor between the two given dimensions.
See also [`torch.transpose`][torch.transpose].
Example:
```py
transpose = fl.Transpose(dim0=1, dim1=2)
tensor = torch.randn(10, 20, 30)
output = transpose(tensor)
assert output.shape == (10, 30, 20)
```
"""
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1)
return torch.transpose(
input=x,
dim0=self.dim0,
dim1=self.dim1,
)
class Permute(Module):
"""Permute operation layer.
This layer permutes the input tensor according to the given dimensions.
See also [`torch.permute`][torch.permute].
Example:
```py
permute = fl.Permute(2, 0, 1)
tensor = torch.randn(10, 20, 30)
output = permute(tensor)
assert output.shape == (30, 10, 20)
```
"""
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims
def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims)
return torch.permute(
input=x,
dims=self.dims,
)
class Slicing(Module):
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
"""Slicing operation layer.
This layer slices the input tensor at the given dimension between the given start and end indices.
See also [`torch.index_select`][torch.index_select].
Example:
```py
slicing = fl.Slicing(dim=1, start=50)
tensor = torch.randn(10, 100)
output = slicing(tensor)
assert output.shape == (10, 50)
assert torch.allclose(output, tensor[:, 50:])
```
"""
def __init__(
self,
dim: int = 0,
start: int = 0,
end: int | None = None,
step: int = 1,
) -> None:
super().__init__()
self.dim = dim
self.start = start
@ -93,55 +250,162 @@ class Slicing(Module):
def forward(self, x: Tensor) -> Tensor:
dim_size = x.shape[self.dim]
# compute start index
start = self.start if self.start >= 0 else dim_size + self.start
start = max(min(start, dim_size), 0)
# compute end index
end = self.end or dim_size
end = end if end >= 0 else dim_size + end
start = max(min(start, dim_size), 0)
end = max(min(end, dim_size), 0)
if start >= end:
return self.get_empty_slice(x)
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
return x.index_select(self.dim, indices)
def get_empty_slice(self, x: Tensor) -> Tensor:
"""
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
"""
if start >= end:
return self._get_empty_slice(x)
# compute indices
indices = torch.arange(
start=start,
end=end,
step=self.step,
device=x.device,
)
return torch.index_select(
input=x,
dim=self.dim,
index=indices,
)
def _get_empty_slice(self, x: Tensor) -> Tensor:
"""Get an empty slice of the same shape as the input tensor (to mimic PyTorch's slicing behavior)."""
shape = list(x.shape)
shape[self.dim] = 0
return torch.empty(*shape, device=x.device)
class Squeeze(Module):
"""Squeeze operation layer.
This layer squeezes the input tensor at the given dimension.
See also [`torch.squeeze`][torch.squeeze].
Example:
```py
squeeze = fl.Squeeze(dim=1)
tensor = torch.randn(10, 1, 10)
output = squeeze(tensor)
assert output.shape == (10, 10)
```
"""
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim)
return torch.squeeze(
input=x,
dim=self.dim,
)
class Unsqueeze(Module):
"""Unsqueeze operation layer.
This layer unsqueezes the input tensor at the given dimension.
See also [`torch.unsqueeze`][torch.unsqueeze].
Example:
```py
unsqueeze = fl.Unsqueeze(dim=1)
tensor = torch.randn(10, 10)
output = unsqueeze(tensor)
assert output.shape == (10, 1, 10)
```
"""
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim)
return torch.unsqueeze(
input=x,
dim=self.dim,
)
class Sin(Module):
"""Sine operator layer.
This layer applies the sine function to the input tensor.
See also [`torch.sin`][torch.sin].
Example:
```py
sin = fl.Sin()
tensor = torch.tensor([0, torch.pi])
output = sin(tensor)
expected_output = torch.tensor([0.0, 0.0])
assert torch.allclose(output, expected_output, atol=1e-6)
```
"""
def forward(self, x: Tensor) -> Tensor:
return torch.sin(input=x)
class Cos(Module):
"""Cosine operator layer.
This layer applies the cosine function to the input tensor.
See also [`torch.cos`][torch.cos].
Example:
```py
cos = fl.Cos()
tensor = torch.tensor([0, torch.pi])
output = cos(tensor)
expected_output = torch.tensor([1.0, -1.0])
assert torch.allclose(output, expected_output, atol=1e-6)
```
"""
def forward(self, x: Tensor) -> Tensor:
return torch.cos(input=x)
class Multiply(Module):
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
"""Multiply operator layer.
This layer scales and shifts the input tensor by the given scale and bias.
Example:
```py
multiply = fl.Multiply(scale=2, bias=1)
tensor = torch.ones(1)
output = multiply(tensor)
assert torch.allclose(output, torch.tensor([3.0]))
```
"""
def __init__(
self,
scale: float = 1.0,
bias: float = 0.0,
) -> None:
super().__init__()
self.scale = scale
self.bias = bias
@ -151,14 +415,40 @@ class Multiply(Module):
class Parameter(WeightedModule):
"""
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
"""Parameter layer.
This layer simple wraps a PyTorch [`Parameter`][torch.nn.parameter.Parameter].
When called, it simply returns the [`Parameter`][torch.nn.parameter.Parameter] Tensor.
Attributes:
weight (torch.nn.parameter.Parameter): The parameter Tensor.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
def __init__(
self,
*dims: int,
requires_grad: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.dims = dims
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
self.weight = TorchParameter(
requires_grad=requires_grad,
data=torch.randn(
*dims,
device=device,
dtype=dtype,
),
)
def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims)
@property
def requires_grad(self) -> bool:
return self.weight.requires_grad
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
self.weight.requires_grad = value