mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add maxpool to refiners layer
This commit is contained in:
parent
f666bc82f5
commit
2d4c4774f4
|
@ -44,6 +44,7 @@ from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
|
||||||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||||
from refiners.fluxion.layers.embedding import Embedding
|
from refiners.fluxion.layers.embedding import Embedding
|
||||||
from refiners.fluxion.layers.converter import Converter
|
from refiners.fluxion.layers.converter import Converter
|
||||||
|
from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding",
|
"Embedding",
|
||||||
|
@ -104,4 +105,6 @@ __all__ = [
|
||||||
"ReflectionPad2d",
|
"ReflectionPad2d",
|
||||||
"PixelUnshuffle",
|
"PixelUnshuffle",
|
||||||
"Converter",
|
"Converter",
|
||||||
|
"MaxPool1d",
|
||||||
|
"MaxPool2d",
|
||||||
]
|
]
|
||||||
|
|
42
src/refiners/fluxion/layers/maxpool.py
Normal file
42
src/refiners/fluxion/layers/maxpool.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from torch import nn
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool1d(nn.MaxPool1d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int | None = None,
|
||||||
|
padding: int = 0,
|
||||||
|
dilation: int = 1,
|
||||||
|
return_indices: bool = False,
|
||||||
|
ceil_mode: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
return_indices=return_indices,
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2d(nn.MaxPool2d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
stride: int | tuple[int, int] | None = None,
|
||||||
|
padding: int | tuple[int, int] = (0, 0),
|
||||||
|
dilation: int | tuple[int, int] = (1, 1),
|
||||||
|
return_indices: bool = False,
|
||||||
|
ceil_mode: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding, # type: ignore
|
||||||
|
dilation=dilation,
|
||||||
|
return_indices=return_indices,
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
|
)
|
Loading…
Reference in a new issue