From 4352e78483be912b8d3c270cee0d937d7ce2d829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Sun, 24 Sep 2023 14:43:38 +0200 Subject: [PATCH] add pixel unshuffle to fluxion's layers --- src/refiners/fluxion/layers/__init__.py | 2 ++ src/refiners/fluxion/layers/pixelshuffle.py | 7 +++++++ 2 files changed, 9 insertions(+) create mode 100644 src/refiners/fluxion/layers/pixelshuffle.py diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index dddb809..1269b1a 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -40,6 +40,7 @@ from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d from refiners.fluxion.layers.linear import Linear, MultiLinear from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule from refiners.fluxion.layers.padding import ReflectionPad2d +from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.converter import Converter @@ -101,5 +102,6 @@ __all__ = [ "ContextModule", "Interpolate", "ReflectionPad2d", + "PixelUnshuffle", "Converter", ] diff --git a/src/refiners/fluxion/layers/pixelshuffle.py b/src/refiners/fluxion/layers/pixelshuffle.py new file mode 100644 index 0000000..003dafc --- /dev/null +++ b/src/refiners/fluxion/layers/pixelshuffle.py @@ -0,0 +1,7 @@ +from refiners.fluxion.layers.module import Module +from torch.nn import PixelUnshuffle as _PixelUnshuffle + + +class PixelUnshuffle(_PixelUnshuffle, Module): + def __init__(self, downscale_factor: int): + _PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)