From cf20621894fe95a2aace6e24e8516325c3a6e370 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 1 Feb 2024 21:52:38 +0000 Subject: [PATCH] (doc/fluxion/conv) add/convert docstrings to mkdocstrings format --- src/refiners/fluxion/layers/conv.py | 77 ++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/src/refiners/fluxion/layers/conv.py b/src/refiners/fluxion/layers/conv.py index 5df0d4c..ed1b3ee 100644 --- a/src/refiners/fluxion/layers/conv.py +++ b/src/refiners/fluxion/layers/conv.py @@ -4,6 +4,33 @@ from refiners.fluxion.layers.module import WeightedModule class Conv2d(nn.Conv2d, WeightedModule): + """2D Convolutional layer. + + This layer wraps [`torch.nn.Conv2d`][torch.nn.Conv2d]. + + Receives: + (Real[Tensor, "batch in_channels in_height in_width"]): + + Returns: + (Real[Tensor, "batch out_channels out_height out_width"]): + + Example: + ```py + conv2d = fl.Conv2d( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ) + + tensor = torch.randn(2, 3, 128, 128) + output = conv2d(tensor) + + assert output.shape == (2, 32, 128, 128) + ``` + """ + def __init__( self, in_channels: int, @@ -19,22 +46,49 @@ class Conv2d(nn.Conv2d, WeightedModule): dtype: DType | None = None, ) -> None: super().__init__( # type: ignore - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - use_bias, - padding_mode, - device, - dtype, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=use_bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, ) self.use_bias = use_bias class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule): + """2D Transposed Convolutional layer. + + This layer wraps [`torch.nn.ConvTranspose2d`][torch.nn.ConvTranspose2d]. + + Receives: + (Real[Tensor, "batch in_channels in_height in_width"]): + + Returns: + (Real[Tensor, "batch out_channels out_height out_width"]): + + Example: + ```py + conv2d = fl.ConvTranspose2d( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ) + + tensor = torch.randn(2, 3, 128, 128) + output = conv2d(tensor) + + assert output.shape == (2, 32, 128, 128) + ``` + """ + def __init__( self, in_channels: int, @@ -64,3 +118,4 @@ class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule): device=device, dtype=dtype, ) + self.use_bias = use_bias