add maxpool to refiners layer

This commit is contained in:
Benjamin Trom 2023-11-16 17:28:01 +01:00
parent f666bc82f5
commit 2d4c4774f4
2 changed files with 45 additions and 0 deletions

View file

@ -44,6 +44,7 @@ from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.embedding import Embedding
from refiners.fluxion.layers.converter import Converter from refiners.fluxion.layers.converter import Converter
from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
__all__ = [ __all__ = [
"Embedding", "Embedding",
@ -104,4 +105,6 @@ __all__ = [
"ReflectionPad2d", "ReflectionPad2d",
"PixelUnshuffle", "PixelUnshuffle",
"Converter", "Converter",
"MaxPool1d",
"MaxPool2d",
] ]

View file

@ -0,0 +1,42 @@
from torch import nn
from refiners.fluxion.layers.module import Module
class MaxPool1d(nn.MaxPool1d, Module):
def __init__(
self,
kernel_size: int,
stride: int | None = None,
padding: int = 0,
dilation: int = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)
class MaxPool2d(nn.MaxPool2d, Module):
def __init__(
self,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] | None = None,
padding: int | tuple[int, int] = (0, 0),
dilation: int | tuple[int, int] = (1, 1),
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding, # type: ignore
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)