(doc/fluxion/sampling) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 21:49:02 +00:00 committed by Laureηt
parent be75f68893
commit 08349d97d7

View file

@ -10,7 +10,32 @@ from refiners.fluxion.layers.module import Module
from refiners.fluxion.utils import interpolate from refiners.fluxion.utils import interpolate
class Interpolate(Module):
"""Interpolate layer.
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
"""
def __init__(self) -> None:
super().__init__()
def forward(
self,
x: Tensor,
shape: Size,
) -> Tensor:
return interpolate(x, shape)
class Downsample(Chain): class Downsample(Chain):
"""Downsample layer.
This layer downsamples the input by the given scale factor.
Raises:
RuntimeError: If the context sampling is not set or if the context does not contain a list.
"""
def __init__( def __init__(
self, self,
channels: int, channels: int,
@ -20,16 +45,22 @@ class Downsample(Chain):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
): ):
"""Downsamples the input by the given scale factor. """Initializes the Downsample layer.
If register_shape is True, the input shape is registered in the context. It will throw an error if the context Args:
sampling is not set or if the context does not contain a list. channels: The number of input and output channels.
scale_factor: The factor by which to downsample the input.
padding: The amount of zero-padding added to both sides of the input.
register_shape: If True, registers the input shape in the context.
device: The device to use for the convolutional layer.
dtype: The dtype to use for the convolutional layer.
""" """
self.channels = channels self.channels = channels
self.in_channels = channels self.in_channels = channels
self.out_channels = channels self.out_channels = channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.padding = padding self.padding = padding
super().__init__( super().__init__(
Conv2d( Conv2d(
in_channels=channels, in_channels=channels,
@ -41,25 +72,41 @@ class Downsample(Chain):
dtype=dtype, dtype=dtype,
), ),
) )
if padding == 0: if padding == 0:
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1)) zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad)) self.insert(
if register_shape: index=0,
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape)) module=Lambda(func=zero_pad),
)
def register_shape(self, shapes: list[Size], x: Tensor) -> None: if register_shape:
self.insert(
index=0,
module=SetContext(
context="sampling",
key="shapes",
callback=self.register_shape,
),
)
def register_shape(
self,
shapes: list[Size],
x: Tensor,
) -> None:
shapes.append(x.shape[2:]) shapes.append(x.shape[2:])
class Interpolate(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, shape: Size) -> Tensor:
return interpolate(x, shape)
class Upsample(Chain): class Upsample(Chain):
"""Upsample layer.
This layer upsamples the input by the given scale factor.
Raises:
RuntimeError: If the context sampling is not set or if the context is empty.
"""
def __init__( def __init__(
self, self,
channels: int, channels: int,
@ -67,10 +114,14 @@ class Upsample(Chain):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
): ):
"""Upsamples the input by the given scale factor. """Initializes the Upsample layer.
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context Args:
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample). channels: The number of input and output channels.
upsample_factor: The factor by which to upsample the input.
If None, the input shape is taken from the context.
device: The device to use for the convolutional layer.
dtype: The dtype to use for the convolutional layer.
""" """
self.channels = channels self.channels = channels
self.upsample_factor = upsample_factor self.upsample_factor = upsample_factor