(doc/fluxion/conv) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 21:52:38 +00:00 committed by Laureηt
parent beb6dfb1c4
commit cf20621894

View file

@ -4,6 +4,33 @@ from refiners.fluxion.layers.module import WeightedModule
class Conv2d(nn.Conv2d, WeightedModule): class Conv2d(nn.Conv2d, WeightedModule):
"""2D Convolutional layer.
This layer wraps [`torch.nn.Conv2d`][torch.nn.Conv2d].
Receives:
(Real[Tensor, "batch in_channels in_height in_width"]):
Returns:
(Real[Tensor, "batch out_channels out_height out_width"]):
Example:
```py
conv2d = fl.Conv2d(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
)
tensor = torch.randn(2, 3, 128, 128)
output = conv2d(tensor)
assert output.shape == (2, 32, 128, 128)
```
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -19,22 +46,49 @@ class Conv2d(nn.Conv2d, WeightedModule):
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
super().__init__( # type: ignore super().__init__( # type: ignore
in_channels, in_channels=in_channels,
out_channels, out_channels=out_channels,
kernel_size, kernel_size=kernel_size,
stride, stride=stride,
padding, padding=padding,
dilation, dilation=dilation,
groups, groups=groups,
use_bias, bias=use_bias,
padding_mode, padding_mode=padding_mode,
device, device=device,
dtype, dtype=dtype,
) )
self.use_bias = use_bias self.use_bias = use_bias
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule): class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
"""2D Transposed Convolutional layer.
This layer wraps [`torch.nn.ConvTranspose2d`][torch.nn.ConvTranspose2d].
Receives:
(Real[Tensor, "batch in_channels in_height in_width"]):
Returns:
(Real[Tensor, "batch out_channels out_height out_width"]):
Example:
```py
conv2d = fl.ConvTranspose2d(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
)
tensor = torch.randn(2, 3, 128, 128)
output = conv2d(tensor)
assert output.shape == (2, 32, 128, 128)
```
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -64,3 +118,4 @@ class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.use_bias = use_bias