remove Buffer

This commit is contained in:
Pierre Chapuis 2024-01-29 16:33:37 +01:00 committed by Cédric Deltheil
parent e6be1394ff
commit a1ad317b00
2 changed files with 0 additions and 26 deletions

View file

@ -1,7 +1,6 @@
from refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
from refiners.fluxion.layers.basics import (
Buffer,
Cos,
Flatten,
GetArg,
@ -75,7 +74,6 @@ __all__ = [
"Cos",
"Multiply",
"Matmul",
"Buffer",
"Lambda",
"Return",
"Sum",

View file

@ -162,27 +162,3 @@ class Parameter(WeightedModule):
def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims)
class Buffer(WeightedModule):
"""
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
Buffers are not trainable.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property
def device(self) -> Device:
return self.buffer.device
@property
def dtype(self) -> DType:
return self.buffer.dtype
def forward(self, _: Tensor) -> Tensor:
return self.buffer