mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
add pixel unshuffle to fluxion's layers
This commit is contained in:
parent
251277a0a8
commit
4352e78483
|
@ -40,6 +40,7 @@ from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||||
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
||||||
from refiners.fluxion.layers.padding import ReflectionPad2d
|
from refiners.fluxion.layers.padding import ReflectionPad2d
|
||||||
|
from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
|
||||||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||||
from refiners.fluxion.layers.embedding import Embedding
|
from refiners.fluxion.layers.embedding import Embedding
|
||||||
from refiners.fluxion.layers.converter import Converter
|
from refiners.fluxion.layers.converter import Converter
|
||||||
|
@ -101,5 +102,6 @@ __all__ = [
|
||||||
"ContextModule",
|
"ContextModule",
|
||||||
"Interpolate",
|
"Interpolate",
|
||||||
"ReflectionPad2d",
|
"ReflectionPad2d",
|
||||||
|
"PixelUnshuffle",
|
||||||
"Converter",
|
"Converter",
|
||||||
]
|
]
|
||||||
|
|
7
src/refiners/fluxion/layers/pixelshuffle.py
Normal file
7
src/refiners/fluxion/layers/pixelshuffle.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
from torch.nn import PixelUnshuffle as _PixelUnshuffle
|
||||||
|
|
||||||
|
|
||||||
|
class PixelUnshuffle(_PixelUnshuffle, Module):
|
||||||
|
def __init__(self, downscale_factor: int):
|
||||||
|
_PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)
|
Loading…
Reference in a new issue