mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/sampling) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
be75f68893
commit
08349d97d7
|
@ -10,7 +10,32 @@ from refiners.fluxion.layers.module import Module
|
||||||
from refiners.fluxion.utils import interpolate
|
from refiners.fluxion.utils import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(Module):
|
||||||
|
"""Interpolate layer.
|
||||||
|
|
||||||
|
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
shape: Size,
|
||||||
|
) -> Tensor:
|
||||||
|
return interpolate(x, shape)
|
||||||
|
|
||||||
|
|
||||||
class Downsample(Chain):
|
class Downsample(Chain):
|
||||||
|
"""Downsample layer.
|
||||||
|
|
||||||
|
This layer downsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the context sampling is not set or if the context does not contain a list.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
@ -20,16 +45,22 @@ class Downsample(Chain):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
):
|
):
|
||||||
"""Downsamples the input by the given scale factor.
|
"""Initializes the Downsample layer.
|
||||||
|
|
||||||
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
|
Args:
|
||||||
sampling is not set or if the context does not contain a list.
|
channels: The number of input and output channels.
|
||||||
|
scale_factor: The factor by which to downsample the input.
|
||||||
|
padding: The amount of zero-padding added to both sides of the input.
|
||||||
|
register_shape: If True, registers the input shape in the context.
|
||||||
|
device: The device to use for the convolutional layer.
|
||||||
|
dtype: The dtype to use for the convolutional layer.
|
||||||
"""
|
"""
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.in_channels = channels
|
self.in_channels = channels
|
||||||
self.out_channels = channels
|
self.out_channels = channels
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Conv2d(
|
Conv2d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
|
@ -41,25 +72,41 @@ class Downsample(Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if padding == 0:
|
if padding == 0:
|
||||||
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
||||||
self.insert(0, Lambda(zero_pad))
|
self.insert(
|
||||||
if register_shape:
|
index=0,
|
||||||
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
module=Lambda(func=zero_pad),
|
||||||
|
)
|
||||||
|
|
||||||
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
if register_shape:
|
||||||
|
self.insert(
|
||||||
|
index=0,
|
||||||
|
module=SetContext(
|
||||||
|
context="sampling",
|
||||||
|
key="shapes",
|
||||||
|
callback=self.register_shape,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_shape(
|
||||||
|
self,
|
||||||
|
shapes: list[Size],
|
||||||
|
x: Tensor,
|
||||||
|
) -> None:
|
||||||
shapes.append(x.shape[2:])
|
shapes.append(x.shape[2:])
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(Module):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, shape: Size) -> Tensor:
|
|
||||||
return interpolate(x, shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(Chain):
|
class Upsample(Chain):
|
||||||
|
"""Upsample layer.
|
||||||
|
|
||||||
|
This layer upsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the context sampling is not set or if the context is empty.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
@ -67,10 +114,14 @@ class Upsample(Chain):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
):
|
):
|
||||||
"""Upsamples the input by the given scale factor.
|
"""Initializes the Upsample layer.
|
||||||
|
|
||||||
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
|
Args:
|
||||||
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
|
channels: The number of input and output channels.
|
||||||
|
upsample_factor: The factor by which to upsample the input.
|
||||||
|
If None, the input shape is taken from the context.
|
||||||
|
device: The device to use for the convolutional layer.
|
||||||
|
dtype: The dtype to use for the convolutional layer.
|
||||||
"""
|
"""
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
|
|
Loading…
Reference in a new issue