diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index 69d1412..135f82a 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -10,7 +10,32 @@ from refiners.fluxion.layers.module import Module from refiners.fluxion.utils import interpolate +class Interpolate(Module): + """Interpolate layer. + + This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate]. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: Tensor, + shape: Size, + ) -> Tensor: + return interpolate(x, shape) + + class Downsample(Chain): + """Downsample layer. + + This layer downsamples the input by the given scale factor. + + Raises: + RuntimeError: If the context sampling is not set or if the context does not contain a list. + """ + def __init__( self, channels: int, @@ -20,16 +45,22 @@ class Downsample(Chain): device: Device | str | None = None, dtype: DType | None = None, ): - """Downsamples the input by the given scale factor. + """Initializes the Downsample layer. - If register_shape is True, the input shape is registered in the context. It will throw an error if the context - sampling is not set or if the context does not contain a list. + Args: + channels: The number of input and output channels. + scale_factor: The factor by which to downsample the input. + padding: The amount of zero-padding added to both sides of the input. + register_shape: If True, registers the input shape in the context. + device: The device to use for the convolutional layer. + dtype: The dtype to use for the convolutional layer. """ self.channels = channels self.in_channels = channels self.out_channels = channels self.scale_factor = scale_factor self.padding = padding + super().__init__( Conv2d( in_channels=channels, @@ -41,25 +72,41 @@ class Downsample(Chain): dtype=dtype, ), ) + if padding == 0: zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1)) - self.insert(0, Lambda(zero_pad)) - if register_shape: - self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape)) + self.insert( + index=0, + module=Lambda(func=zero_pad), + ) - def register_shape(self, shapes: list[Size], x: Tensor) -> None: + if register_shape: + self.insert( + index=0, + module=SetContext( + context="sampling", + key="shapes", + callback=self.register_shape, + ), + ) + + def register_shape( + self, + shapes: list[Size], + x: Tensor, + ) -> None: shapes.append(x.shape[2:]) -class Interpolate(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x: Tensor, shape: Size) -> Tensor: - return interpolate(x, shape) - - class Upsample(Chain): + """Upsample layer. + + This layer upsamples the input by the given scale factor. + + Raises: + RuntimeError: If the context sampling is not set or if the context is empty. + """ + def __init__( self, channels: int, @@ -67,10 +114,14 @@ class Upsample(Chain): device: Device | str | None = None, dtype: DType | None = None, ): - """Upsamples the input by the given scale factor. + """Initializes the Upsample layer. - If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context - sampling is not set or if the context is empty (then you should use the dynamic version of Downsample). + Args: + channels: The number of input and output channels. + upsample_factor: The factor by which to upsample the input. + If None, the input shape is taken from the context. + device: The device to use for the convolutional layer. + dtype: The dtype to use for the convolutional layer. """ self.channels = channels self.upsample_factor = upsample_factor