remove Buffer

This commit is contained in:
Pierre Chapuis 2024-01-29 16:33:37 +01:00 committed by Cédric Deltheil
parent e6be1394ff
commit a1ad317b00
2 changed files with 0 additions and 26 deletions

View file

@ -1,7 +1,6 @@
from refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU from refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
from refiners.fluxion.layers.basics import ( from refiners.fluxion.layers.basics import (
Buffer,
Cos, Cos,
Flatten, Flatten,
GetArg, GetArg,
@ -75,7 +74,6 @@ __all__ = [
"Cos", "Cos",
"Multiply", "Multiply",
"Matmul", "Matmul",
"Buffer",
"Lambda", "Lambda",
"Return", "Return",
"Sum", "Sum",

View file

@ -162,27 +162,3 @@ class Parameter(WeightedModule):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims) return self.weight.expand(x.shape[0], *self.dims)
class Buffer(WeightedModule):
"""
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
Buffers are not trainable.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property
def device(self) -> Device:
return self.buffer.device
@property
def dtype(self) -> DType:
return self.buffer.dtype
def forward(self, _: Tensor) -> Tensor:
return self.buffer