mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add pixel unshuffle to fluxion's layers
This commit is contained in:
parent
251277a0a8
commit
4352e78483
|
@ -40,6 +40,7 @@ from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
|||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||
from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
|
||||
from refiners.fluxion.layers.padding import ReflectionPad2d
|
||||
from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
|
||||
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
|
||||
from refiners.fluxion.layers.embedding import Embedding
|
||||
from refiners.fluxion.layers.converter import Converter
|
||||
|
@ -101,5 +102,6 @@ __all__ = [
|
|||
"ContextModule",
|
||||
"Interpolate",
|
||||
"ReflectionPad2d",
|
||||
"PixelUnshuffle",
|
||||
"Converter",
|
||||
]
|
||||
|
|
7
src/refiners/fluxion/layers/pixelshuffle.py
Normal file
7
src/refiners/fluxion/layers/pixelshuffle.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from refiners.fluxion.layers.module import Module
|
||||
from torch.nn import PixelUnshuffle as _PixelUnshuffle
|
||||
|
||||
|
||||
class PixelUnshuffle(_PixelUnshuffle, Module):
|
||||
def __init__(self, downscale_factor: int):
|
||||
_PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)
|
Loading…
Reference in a new issue