add maxpool to refiners layer

This commit is contained in:
Benjamin Trom 2023-11-16 17:28:01 +01:00
parent f666bc82f5
commit 2d4c4774f4
2 changed files with 45 additions and 0 deletions

View file

@ -44,6 +44,7 @@ from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding
from refiners.fluxion.layers.converter import Converter
from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
__all__ = [
"Embedding",
@ -104,4 +105,6 @@ __all__ = [
"ReflectionPad2d",
"PixelUnshuffle",
"Converter",
"MaxPool1d",
"MaxPool2d",
]

View file

@ -0,0 +1,42 @@
from torch import nn
from refiners.fluxion.layers.module import Module
class MaxPool1d(nn.MaxPool1d, Module):
def __init__(
self,
kernel_size: int,
stride: int | None = None,
padding: int = 0,
dilation: int = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)
class MaxPool2d(nn.MaxPool2d, Module):
def __init__(
self,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] | None = None,
padding: int | tuple[int, int] = (0, 0),
dilation: int | tuple[int, int] = (1, 1),
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding, # type: ignore
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)