mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
add maxpool to refiners layer
This commit is contained in:
parent
f666bc82f5
commit
2d4c4774f4
|
@ -44,6 +44,7 @@ from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
|
|||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||
from refiners.fluxion.layers.embedding import Embedding
|
||||
from refiners.fluxion.layers.converter import Converter
|
||||
from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
|
||||
|
||||
__all__ = [
|
||||
"Embedding",
|
||||
|
@ -104,4 +105,6 @@ __all__ = [
|
|||
"ReflectionPad2d",
|
||||
"PixelUnshuffle",
|
||||
"Converter",
|
||||
"MaxPool1d",
|
||||
"MaxPool2d",
|
||||
]
|
||||
|
|
42
src/refiners/fluxion/layers/maxpool.py
Normal file
42
src/refiners/fluxion/layers/maxpool.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from torch import nn
|
||||
from refiners.fluxion.layers.module import Module
|
||||
|
||||
|
||||
class MaxPool1d(nn.MaxPool1d, Module):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int,
|
||||
stride: int | None = None,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
return_indices: bool = False,
|
||||
ceil_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
return_indices=return_indices,
|
||||
ceil_mode=ceil_mode,
|
||||
)
|
||||
|
||||
|
||||
class MaxPool2d(nn.MaxPool2d, Module):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int] | None = None,
|
||||
padding: int | tuple[int, int] = (0, 0),
|
||||
dilation: int | tuple[int, int] = (1, 1),
|
||||
return_indices: bool = False,
|
||||
ceil_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding, # type: ignore
|
||||
dilation=dilation,
|
||||
return_indices=return_indices,
|
||||
ceil_mode=ceil_mode,
|
||||
)
|
Loading…
Reference in a new issue