(doc/fluxion/sampling) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 21:49:02 +00:00 committed by Laureηt
parent be75f68893
commit 08349d97d7

View file

@ -10,7 +10,32 @@ from refiners.fluxion.layers.module import Module
from refiners.fluxion.utils import interpolate
class Interpolate(Module):
"""Interpolate layer.
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
"""
def __init__(self) -> None:
super().__init__()
def forward(
self,
x: Tensor,
shape: Size,
) -> Tensor:
return interpolate(x, shape)
class Downsample(Chain):
"""Downsample layer.
This layer downsamples the input by the given scale factor.
Raises:
RuntimeError: If the context sampling is not set or if the context does not contain a list.
"""
def __init__(
self,
channels: int,
@ -20,16 +45,22 @@ class Downsample(Chain):
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Downsamples the input by the given scale factor.
"""Initializes the Downsample layer.
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
sampling is not set or if the context does not contain a list.
Args:
channels: The number of input and output channels.
scale_factor: The factor by which to downsample the input.
padding: The amount of zero-padding added to both sides of the input.
register_shape: If True, registers the input shape in the context.
device: The device to use for the convolutional layer.
dtype: The dtype to use for the convolutional layer.
"""
self.channels = channels
self.in_channels = channels
self.out_channels = channels
self.scale_factor = scale_factor
self.padding = padding
super().__init__(
Conv2d(
in_channels=channels,
@ -41,25 +72,41 @@ class Downsample(Chain):
dtype=dtype,
),
)
if padding == 0:
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
self.insert(
index=0,
module=Lambda(func=zero_pad),
)
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
if register_shape:
self.insert(
index=0,
module=SetContext(
context="sampling",
key="shapes",
callback=self.register_shape,
),
)
def register_shape(
self,
shapes: list[Size],
x: Tensor,
) -> None:
shapes.append(x.shape[2:])
class Interpolate(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, shape: Size) -> Tensor:
return interpolate(x, shape)
class Upsample(Chain):
"""Upsample layer.
This layer upsamples the input by the given scale factor.
Raises:
RuntimeError: If the context sampling is not set or if the context is empty.
"""
def __init__(
self,
channels: int,
@ -67,10 +114,14 @@ class Upsample(Chain):
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Upsamples the input by the given scale factor.
"""Initializes the Upsample layer.
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
Args:
channels: The number of input and output channels.
upsample_factor: The factor by which to upsample the input.
If None, the input shape is taken from the context.
device: The device to use for the convolutional layer.
dtype: The dtype to use for the convolutional layer.
"""
self.channels = channels
self.upsample_factor = upsample_factor