mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
(doc/fluxion/basic) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
a7c048f5fb
commit
77b97b3c8e
|
@ -1,11 +1,26 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Size, Tensor, device as Device, dtype as DType, randn
|
from torch import Size, Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
from refiners.fluxion.layers.module import Module, WeightedModule
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
|
"""Identity operator layer.
|
||||||
|
|
||||||
|
This layer simply returns the input tensor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
identity = fl.Identity()
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 10)
|
||||||
|
output = identity(tensor)
|
||||||
|
|
||||||
|
assert torch.allclose(tensor, output)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -23,6 +38,25 @@ class View(Module):
|
||||||
|
|
||||||
|
|
||||||
class GetArg(Module):
|
class GetArg(Module):
|
||||||
|
"""GetArg operation layer.
|
||||||
|
|
||||||
|
This layer returns the nth tensor of the input arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
get_arg = fl.GetArg(1)
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
torch.randn(10, 10),
|
||||||
|
torch.randn(20, 20),
|
||||||
|
torch.randn(30, 30),
|
||||||
|
)
|
||||||
|
output = get_arg(inputs)
|
||||||
|
|
||||||
|
assert torch.allclose(tensor[1], output)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, index: int) -> None:
|
def __init__(self, index: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
@ -32,28 +66,86 @@ class GetArg(Module):
|
||||||
|
|
||||||
|
|
||||||
class Flatten(Module):
|
class Flatten(Module):
|
||||||
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
"""Flatten operation layer.
|
||||||
|
|
||||||
|
This layer flattens the input tensor between the given dimensions.
|
||||||
|
See also [`torch.flatten`][torch.flatten].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
flatten = fl.Flatten(start_dim=1)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 10, 10)
|
||||||
|
output = flatten(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (10, 100)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start_dim: int = 0,
|
||||||
|
end_dim: int = -1,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.start_dim = start_dim
|
self.start_dim = start_dim
|
||||||
self.end_dim = end_dim
|
self.end_dim = end_dim
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.flatten(self.start_dim, self.end_dim)
|
return torch.flatten(
|
||||||
|
input=x,
|
||||||
|
start_dim=self.start_dim,
|
||||||
|
end_dim=self.end_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Unflatten(Module):
|
class Unflatten(Module):
|
||||||
|
"""Unflatten operation layer.
|
||||||
|
|
||||||
|
This layer unflattens the input tensor at the given dimension with the given sizes.
|
||||||
|
See also [`torch.unflatten`][torch.unflatten].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
unflatten = fl.Unflatten(dim=1)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 100)
|
||||||
|
output = unflatten(tensor, sizes=(10, 10))
|
||||||
|
|
||||||
|
assert output_unflatten.shape == (10, 10, 10)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
||||||
return x.unflatten(self.dim, sizes) # type: ignore
|
return torch.unflatten(
|
||||||
|
input=x,
|
||||||
|
dim=self.dim,
|
||||||
|
sizes=sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Reshape(Module):
|
class Reshape(Module):
|
||||||
"""
|
"""Reshape operation layer.
|
||||||
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
|
|
||||||
dimension is preserved.
|
This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape).
|
||||||
|
See also [torch.reshape][torch.reshape].
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
The first dimension (batch dimension) is forcefully preserved.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
reshape = fl.Reshape(5, 2)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 10, 1)
|
||||||
|
output = reshape(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 5, 2)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *shape: int) -> None:
|
def __init__(self, *shape: int) -> None:
|
||||||
|
@ -61,30 +153,95 @@ class Reshape(Module):
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.reshape(x.shape[0], *self.shape)
|
return torch.reshape(
|
||||||
|
input=x,
|
||||||
|
shape=(x.shape[0], *self.shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Transpose(Module):
|
class Transpose(Module):
|
||||||
|
"""Transpose operation layer.
|
||||||
|
|
||||||
|
This layer transposes the input tensor between the two given dimensions.
|
||||||
|
See also [`torch.transpose`][torch.transpose].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
transpose = fl.Transpose(dim0=1, dim1=2)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 20, 30)
|
||||||
|
output = transpose(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (10, 30, 20)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dim0: int, dim1: int) -> None:
|
def __init__(self, dim0: int, dim1: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim0 = dim0
|
self.dim0 = dim0
|
||||||
self.dim1 = dim1
|
self.dim1 = dim1
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.transpose(self.dim0, self.dim1)
|
return torch.transpose(
|
||||||
|
input=x,
|
||||||
|
dim0=self.dim0,
|
||||||
|
dim1=self.dim1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Permute(Module):
|
class Permute(Module):
|
||||||
|
"""Permute operation layer.
|
||||||
|
|
||||||
|
This layer permutes the input tensor according to the given dimensions.
|
||||||
|
See also [`torch.permute`][torch.permute].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
permute = fl.Permute(2, 0, 1)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 20, 30)
|
||||||
|
output = permute(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (30, 10, 20)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *dims: int) -> None:
|
def __init__(self, *dims: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.permute(*self.dims)
|
return torch.permute(
|
||||||
|
input=x,
|
||||||
|
dims=self.dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Slicing(Module):
|
class Slicing(Module):
|
||||||
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
|
"""Slicing operation layer.
|
||||||
|
|
||||||
|
This layer slices the input tensor at the given dimension between the given start and end indices.
|
||||||
|
See also [`torch.index_select`][torch.index_select].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
slicing = fl.Slicing(dim=1, start=50)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 100)
|
||||||
|
output = slicing(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (10, 50)
|
||||||
|
assert torch.allclose(output, tensor[:, 50:])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 0,
|
||||||
|
start: int = 0,
|
||||||
|
end: int | None = None,
|
||||||
|
step: int = 1,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.start = start
|
self.start = start
|
||||||
|
@ -93,55 +250,162 @@ class Slicing(Module):
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
dim_size = x.shape[self.dim]
|
dim_size = x.shape[self.dim]
|
||||||
|
|
||||||
|
# compute start index
|
||||||
start = self.start if self.start >= 0 else dim_size + self.start
|
start = self.start if self.start >= 0 else dim_size + self.start
|
||||||
|
start = max(min(start, dim_size), 0)
|
||||||
|
|
||||||
|
# compute end index
|
||||||
end = self.end or dim_size
|
end = self.end or dim_size
|
||||||
end = end if end >= 0 else dim_size + end
|
end = end if end >= 0 else dim_size + end
|
||||||
start = max(min(start, dim_size), 0)
|
|
||||||
end = max(min(end, dim_size), 0)
|
end = max(min(end, dim_size), 0)
|
||||||
if start >= end:
|
|
||||||
return self.get_empty_slice(x)
|
|
||||||
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
|
|
||||||
return x.index_select(self.dim, indices)
|
|
||||||
|
|
||||||
def get_empty_slice(self, x: Tensor) -> Tensor:
|
if start >= end:
|
||||||
"""
|
return self._get_empty_slice(x)
|
||||||
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
|
|
||||||
"""
|
# compute indices
|
||||||
|
indices = torch.arange(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
step=self.step,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.index_select(
|
||||||
|
input=x,
|
||||||
|
dim=self.dim,
|
||||||
|
index=indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_empty_slice(self, x: Tensor) -> Tensor:
|
||||||
|
"""Get an empty slice of the same shape as the input tensor (to mimic PyTorch's slicing behavior)."""
|
||||||
|
|
||||||
shape = list(x.shape)
|
shape = list(x.shape)
|
||||||
shape[self.dim] = 0
|
shape[self.dim] = 0
|
||||||
return torch.empty(*shape, device=x.device)
|
return torch.empty(*shape, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
class Squeeze(Module):
|
class Squeeze(Module):
|
||||||
|
"""Squeeze operation layer.
|
||||||
|
|
||||||
|
This layer squeezes the input tensor at the given dimension.
|
||||||
|
See also [`torch.squeeze`][torch.squeeze].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
squeeze = fl.Squeeze(dim=1)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 1, 10)
|
||||||
|
output = squeeze(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (10, 10)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.squeeze(self.dim)
|
return torch.squeeze(
|
||||||
|
input=x,
|
||||||
|
dim=self.dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Unsqueeze(Module):
|
class Unsqueeze(Module):
|
||||||
|
"""Unsqueeze operation layer.
|
||||||
|
|
||||||
|
This layer unsqueezes the input tensor at the given dimension.
|
||||||
|
See also [`torch.unsqueeze`][torch.unsqueeze].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
unsqueeze = fl.Unsqueeze(dim=1)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 10)
|
||||||
|
output = unsqueeze(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (10, 1, 10)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.unsqueeze(self.dim)
|
return torch.unsqueeze(
|
||||||
|
input=x,
|
||||||
|
dim=self.dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sin(Module):
|
class Sin(Module):
|
||||||
|
"""Sine operator layer.
|
||||||
|
|
||||||
|
This layer applies the sine function to the input tensor.
|
||||||
|
See also [`torch.sin`][torch.sin].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
sin = fl.Sin()
|
||||||
|
|
||||||
|
tensor = torch.tensor([0, torch.pi])
|
||||||
|
output = sin(tensor)
|
||||||
|
|
||||||
|
expected_output = torch.tensor([0.0, 0.0])
|
||||||
|
assert torch.allclose(output, expected_output, atol=1e-6)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return torch.sin(input=x)
|
return torch.sin(input=x)
|
||||||
|
|
||||||
|
|
||||||
class Cos(Module):
|
class Cos(Module):
|
||||||
|
"""Cosine operator layer.
|
||||||
|
|
||||||
|
This layer applies the cosine function to the input tensor.
|
||||||
|
See also [`torch.cos`][torch.cos].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
cos = fl.Cos()
|
||||||
|
|
||||||
|
tensor = torch.tensor([0, torch.pi])
|
||||||
|
output = cos(tensor)
|
||||||
|
|
||||||
|
expected_output = torch.tensor([1.0, -1.0])
|
||||||
|
assert torch.allclose(output, expected_output, atol=1e-6)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return torch.cos(input=x)
|
return torch.cos(input=x)
|
||||||
|
|
||||||
|
|
||||||
class Multiply(Module):
|
class Multiply(Module):
|
||||||
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
|
"""Multiply operator layer.
|
||||||
|
|
||||||
|
This layer scales and shifts the input tensor by the given scale and bias.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
multiply = fl.Multiply(scale=2, bias=1)
|
||||||
|
|
||||||
|
tensor = torch.ones(1)
|
||||||
|
output = multiply(tensor)
|
||||||
|
|
||||||
|
assert torch.allclose(output, torch.tensor([3.0]))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scale: float = 1.0,
|
||||||
|
bias: float = 0.0,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
@ -151,14 +415,40 @@ class Multiply(Module):
|
||||||
|
|
||||||
|
|
||||||
class Parameter(WeightedModule):
|
class Parameter(WeightedModule):
|
||||||
"""
|
"""Parameter layer.
|
||||||
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
|
||||||
|
This layer simple wraps a PyTorch [`Parameter`][torch.nn.parameter.Parameter].
|
||||||
|
When called, it simply returns the [`Parameter`][torch.nn.parameter.Parameter] Tensor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight (torch.nn.parameter.Parameter): The parameter Tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims: int,
|
||||||
|
requires_grad: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
|
self.weight = TorchParameter(
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
data=torch.randn(
|
||||||
|
*dims,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return self.weight.expand(x.shape[0], *self.dims)
|
return self.weight.expand(x.shape[0], *self.dims)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_grad(self) -> bool:
|
||||||
|
return self.weight.requires_grad
|
||||||
|
|
||||||
|
@requires_grad.setter
|
||||||
|
def requires_grad(self, value: bool) -> None:
|
||||||
|
self.weight.requires_grad = value
|
||||||
|
|
Loading…
Reference in a new issue