From 2d4c4774f4a0951de00dace26c86ea75e1ed6b61 Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 16 Nov 2023 17:28:01 +0100 Subject: [PATCH] add maxpool to refiners layer --- src/refiners/fluxion/layers/__init__.py | 3 ++ src/refiners/fluxion/layers/maxpool.py | 42 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 src/refiners/fluxion/layers/maxpool.py diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index 1269b1a..de38cc0 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -44,6 +44,7 @@ from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.converter import Converter +from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d __all__ = [ "Embedding", @@ -104,4 +105,6 @@ __all__ = [ "ReflectionPad2d", "PixelUnshuffle", "Converter", + "MaxPool1d", + "MaxPool2d", ] diff --git a/src/refiners/fluxion/layers/maxpool.py b/src/refiners/fluxion/layers/maxpool.py new file mode 100644 index 0000000..60ffd0a --- /dev/null +++ b/src/refiners/fluxion/layers/maxpool.py @@ -0,0 +1,42 @@ +from torch import nn +from refiners.fluxion.layers.module import Module + + +class MaxPool1d(nn.MaxPool1d, Module): + def __init__( + self, + kernel_size: int, + stride: int | None = None, + padding: int = 0, + dilation: int = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super().__init__( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + ) + + +class MaxPool2d(nn.MaxPool2d, Module): + def __init__( + self, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] | None = None, + padding: int | tuple[int, int] = (0, 0), + dilation: int | tuple[int, int] = (1, 1), + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super().__init__( + kernel_size=kernel_size, + stride=stride, + padding=padding, # type: ignore + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + )