mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/sampling) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
be75f68893
commit
08349d97d7
|
@ -10,7 +10,32 @@ from refiners.fluxion.layers.module import Module
|
|||
from refiners.fluxion.utils import interpolate
|
||||
|
||||
|
||||
class Interpolate(Module):
|
||||
"""Interpolate layer.
|
||||
|
||||
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
shape: Size,
|
||||
) -> Tensor:
|
||||
return interpolate(x, shape)
|
||||
|
||||
|
||||
class Downsample(Chain):
|
||||
"""Downsample layer.
|
||||
|
||||
This layer downsamples the input by the given scale factor.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the context sampling is not set or if the context does not contain a list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
@ -20,16 +45,22 @@ class Downsample(Chain):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
"""Downsamples the input by the given scale factor.
|
||||
"""Initializes the Downsample layer.
|
||||
|
||||
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
|
||||
sampling is not set or if the context does not contain a list.
|
||||
Args:
|
||||
channels: The number of input and output channels.
|
||||
scale_factor: The factor by which to downsample the input.
|
||||
padding: The amount of zero-padding added to both sides of the input.
|
||||
register_shape: If True, registers the input shape in the context.
|
||||
device: The device to use for the convolutional layer.
|
||||
dtype: The dtype to use for the convolutional layer.
|
||||
"""
|
||||
self.channels = channels
|
||||
self.in_channels = channels
|
||||
self.out_channels = channels
|
||||
self.scale_factor = scale_factor
|
||||
self.padding = padding
|
||||
|
||||
super().__init__(
|
||||
Conv2d(
|
||||
in_channels=channels,
|
||||
|
@ -41,25 +72,41 @@ class Downsample(Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
if padding == 0:
|
||||
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
||||
self.insert(0, Lambda(zero_pad))
|
||||
if register_shape:
|
||||
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
||||
self.insert(
|
||||
index=0,
|
||||
module=Lambda(func=zero_pad),
|
||||
)
|
||||
|
||||
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
||||
if register_shape:
|
||||
self.insert(
|
||||
index=0,
|
||||
module=SetContext(
|
||||
context="sampling",
|
||||
key="shapes",
|
||||
callback=self.register_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def register_shape(
|
||||
self,
|
||||
shapes: list[Size],
|
||||
x: Tensor,
|
||||
) -> None:
|
||||
shapes.append(x.shape[2:])
|
||||
|
||||
|
||||
class Interpolate(Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor, shape: Size) -> Tensor:
|
||||
return interpolate(x, shape)
|
||||
|
||||
|
||||
class Upsample(Chain):
|
||||
"""Upsample layer.
|
||||
|
||||
This layer upsamples the input by the given scale factor.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the context sampling is not set or if the context is empty.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
@ -67,10 +114,14 @@ class Upsample(Chain):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
"""Upsamples the input by the given scale factor.
|
||||
"""Initializes the Upsample layer.
|
||||
|
||||
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
|
||||
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
|
||||
Args:
|
||||
channels: The number of input and output channels.
|
||||
upsample_factor: The factor by which to upsample the input.
|
||||
If None, the input shape is taken from the context.
|
||||
device: The device to use for the convolutional layer.
|
||||
dtype: The dtype to use for the convolutional layer.
|
||||
"""
|
||||
self.channels = channels
|
||||
self.upsample_factor = upsample_factor
|
||||
|
|
Loading…
Reference in a new issue