diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index dddb809..1269b1a 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -40,6 +40,7 @@ from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d from refiners.fluxion.layers.linear import Linear, MultiLinear from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule from refiners.fluxion.layers.padding import ReflectionPad2d +from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.converter import Converter @@ -101,5 +102,6 @@ __all__ = [ "ContextModule", "Interpolate", "ReflectionPad2d", + "PixelUnshuffle", "Converter", ] diff --git a/src/refiners/fluxion/layers/pixelshuffle.py b/src/refiners/fluxion/layers/pixelshuffle.py new file mode 100644 index 0000000..003dafc --- /dev/null +++ b/src/refiners/fluxion/layers/pixelshuffle.py @@ -0,0 +1,7 @@ +from refiners.fluxion.layers.module import Module +from torch.nn import PixelUnshuffle as _PixelUnshuffle + + +class PixelUnshuffle(_PixelUnshuffle, Module): + def __init__(self, downscale_factor: int): + _PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)