Skip to content

Layers

Activation

Activation()

Bases: Module, ABC

Base class for activation layers.

Activation layers are layers that apply a (non-linear) function to their input.

Receives:

Name Type Description
x Tensor

Returns:

Type Description
Tensor
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:
    super().__init__()

Attention

Attention(
    embedding_dim: int,
    num_heads: int = 1,
    key_embedding_dim: int | None = None,
    value_embedding_dim: int | None = None,
    inner_dim: int | None = None,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Multi-Head Attention layer.

See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details

This layer simply chains

Receives:

Name Type Description
Query Float[Tensor, 'batch sequence_length embedding_dim']
Key Float[Tensor, 'batch sequence_length embedding_dim']
Value Float[Tensor, 'batch sequence_length embedding_dim']

Returns:

Type Description
Float[Tensor, 'batch sequence_length embedding_dim']
Example
attention = fl.Attention(num_heads=8, embedding_dim=128)

tensor = torch.randn(2, 10, 128)
output = attention(tensor, tensor, tensor)

assert output.shape == (2, 10, 128)

Parameters:

Name Type Description Default
embedding_dim int

The embedding dimension of the input and output tensors.

required
num_heads int

The number of heads to use.

1
key_embedding_dim int | None

The embedding dimension of the key tensor.

None
value_embedding_dim int | None

The embedding dimension of the value tensor.

None
inner_dim int | None

The inner dimension of the linear layers.

None
use_bias bool

Whether to use bias in the linear layers.

True
is_causal bool

Whether to use causal attention.

False
is_optimized bool

Whether to use optimized attention.

True
device device | str | None

The device to use.

None
dtype dtype | None

The dtype to use.

None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(
    self,
    embedding_dim: int,
    num_heads: int = 1,
    key_embedding_dim: int | None = None,
    value_embedding_dim: int | None = None,
    inner_dim: int | None = None,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the Attention layer.

    Args:
        embedding_dim: The embedding dimension of the input and output tensors.
        num_heads: The number of heads to use.
        key_embedding_dim: The embedding dimension of the key tensor.
        value_embedding_dim: The embedding dimension of the value tensor.
        inner_dim: The inner dimension of the linear layers.
        use_bias: Whether to use bias in the linear layers.
        is_causal: Whether to use causal attention.
        is_optimized: Whether to use optimized attention.
        device: The device to use.
        dtype: The dtype to use.
    """
    assert (
        embedding_dim % num_heads == 0
    ), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    self.heads_dim = embedding_dim // num_heads
    self.key_embedding_dim = key_embedding_dim or embedding_dim
    self.value_embedding_dim = value_embedding_dim or embedding_dim
    self.inner_dim = inner_dim or embedding_dim
    self.use_bias = use_bias
    self.is_causal = is_causal
    self.is_optimized = is_optimized

    super().__init__(
        Distribute(
            Linear(  # Query projection
                in_features=self.embedding_dim,
                out_features=self.inner_dim,
                bias=self.use_bias,
                device=device,
                dtype=dtype,
            ),
            Linear(  # Key projection
                in_features=self.key_embedding_dim,
                out_features=self.inner_dim,
                bias=self.use_bias,
                device=device,
                dtype=dtype,
            ),
            Linear(  # Value projection
                in_features=self.value_embedding_dim,
                out_features=self.inner_dim,
                bias=self.use_bias,
                device=device,
                dtype=dtype,
            ),
        ),
        ScaledDotProductAttention(
            num_heads=num_heads,
            is_causal=is_causal,
            is_optimized=is_optimized,
        ),
        Linear(  # Output projection
            in_features=self.inner_dim,
            out_features=self.embedding_dim,
            bias=True,
            device=device,
            dtype=dtype,
        ),
    )

Breakpoint

Breakpoint(vscode: bool = True)

Bases: ContextModule

Breakpoint layer.

This layer pauses the execution when encountered, and opens a debugger.

Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, vscode: bool = True):
    super().__init__()
    self.vscode = vscode

Chain

Chain(*args: Module | Iterable[Module])

Bases: ContextModule

Chain layer.

This layer is the main building block of Fluxion. It is used to compose other layers in a sequential manner. Similarly to torch.nn.Sequential, it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer. However, it also provides additional methods to manipulate its sub-layers and their context.

Example
chain = fl.Chain(
    fl.Linear(32, 64),
    fl.ReLU(),
    fl.Linear(64, 128),
)

tensor = torch.randn(2, 32)
output = chain(tensor)

assert output.shape == (2, 128)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

device property

device: device | None

The PyTorch device of the Chain's parameters.

dtype property

dtype: dtype | None

The PyTorch dtype of the Chain's parameters.

provider property

provider: ContextProvider

The ContextProvider of the Chain.

append

append(module: Module) -> None

Append a new module to the chain.

Parameters:

Name Type Description Default
module Module

The module to append.

required
Source code in src/refiners/fluxion/layers/chain.py
def append(self, module: Module) -> None:
    """Append a new module to the chain.

    Args:
        module: The module to append.
    """
    self.insert(-1, module)

ensure_find

ensure_find(layer_type: type[T]) -> T

Walk the Chain's sub-module tree and return the first layer of the given type.

Parameters:

Name Type Description Default
layer_type type[T]

The type of layer to find.

required

Returns:

Type Description
T

The first module of the given layer_type.

Raises:

Type Description
AssertionError

If the module doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def ensure_find(self, layer_type: type[T]) -> T:
    """Walk the Chain's sub-module tree and return the first layer of the given type.

    Args:
        layer_type: The type of layer to find.

    Returns:
        The first module of the given layer_type.

    Raises:
        AssertionError: If the module doesn't exist.
    """
    r = self.find(layer_type)
    assert r is not None, f"could not find {layer_type} in {self}"
    return r

ensure_find_parent

ensure_find_parent(module: Module) -> Chain

Walk the Chain's sub-module tree and return the parent of the given module.

Parameters:

Name Type Description Default
module Module

The module whose parent to find.

required

Returns:

Type Description
Chain

The parent of the given module.

Raises:

Type Description
AssertionError

If the module doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def ensure_find_parent(self, module: Module) -> "Chain":
    """Walk the Chain's sub-module tree and return the parent of the given module.

    Args:
        module: The module whose parent to find.

    Returns:
        The parent of the given module.

    Raises:
        AssertionError: If the module doesn't exist.
    """
    r = self.find_parent(module)
    assert r is not None, f"could not find {module} in {self}"
    return r

find

find(layer_type: type[T]) -> T | None

Walk the Chain's sub-module tree and return the first layer of the given type.

Parameters:

Name Type Description Default
layer_type type[T]

The type of layer to find.

required

Returns:

Type Description
T | None

The first module of the given layer_type, or None if it doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def find(self, layer_type: type[T]) -> T | None:
    """Walk the Chain's sub-module tree and return the first layer of the given type.

    Args:
        layer_type: The type of layer to find.

    Returns:
        The first module of the given layer_type, or None if it doesn't exist.
    """
    return next(self.layers(layer_type=layer_type), None)

find_parent

find_parent(module: Module) -> Chain | None

Walk the Chain's sub-module tree and return the parent of the given module.

Parameters:

Name Type Description Default
module Module

The module whose parent to find.

required

Returns:

Type Description
Chain | None

The parent of the given module, or None if it doesn't exist.

Source code in src/refiners/fluxion/layers/chain.py
def find_parent(self, module: Module) -> "Chain | None":
    """Walk the Chain's sub-module tree and return the parent of the given module.

    Args:
        module: The module whose parent to find.

    Returns:
        The parent of the given module, or None if it doesn't exist.
    """
    if module in self:  # avoid DFS-crawling the whole tree
        return self
    for _, parent in self.walk(lambda m, _: m == module):
        return parent
    return None

init_context

init_context() -> Contexts

Initialize the context provider with some default values.

This method is called when the Chain is created, and when it is reset. This method may be overridden by subclasses to provide default values for the context provider.

Source code in src/refiners/fluxion/layers/chain.py
def init_context(self) -> Contexts:
    """Initialize the context provider with some default values.

    This method is called when the Chain is created, and when it is reset.
    This method may be overridden by subclasses to provide default values for the context provider.
    """
    return {}

insert

insert(index: int, module: Module) -> None

Insert a new module in the chain.

Parameters:

Name Type Description Default
index int

The index at which to insert the module.

required
module Module

The module to insert.

required

Raises:

Type Description
IndexError

If the index is out of range.

Source code in src/refiners/fluxion/layers/chain.py
def insert(self, index: int, module: Module) -> None:
    """Insert a new module in the chain.

    Args:
        index: The index at which to insert the module.
        module: The module to insert.

    Raises:
        IndexError: If the index is out of range.
    """
    if index < 0:
        index = max(0, len(self._modules) + index + 1)
    modules = list(self)
    modules.insert(index, module)
    self._regenerate_keys(modules)
    if isinstance(module, ContextModule):
        module._set_parent(self)
    self._register_provider()

insert_after_type

insert_after_type(
    module_type: type[Module], new_module: Module
) -> None

Insert a new module in the chain, right after the first module of the given type.

Parameters:

Name Type Description Default
module_type type[Module]

The type of module to insert after.

required
new_module Module

The module to insert.

required

Raises:

Type Description
ValueError

If no module of the given type exists in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
    """Insert a new module in the chain, right after the first module of the given type.

    Args:
        module_type: The type of module to insert after.
        new_module: The module to insert.

    Raises:
        ValueError: If no module of the given type exists in the chain.
    """
    for i, module in enumerate(self):
        if isinstance(module, module_type):
            self.insert(i + 1, new_module)
            return
    raise ValueError(f"No module of type {module_type.__name__} found in the chain.")

insert_before_type

insert_before_type(
    module_type: type[Module], new_module: Module
) -> None

Insert a new module in the chain, right before the first module of the given type.

Parameters:

Name Type Description Default
module_type type[Module]

The type of module to insert before.

required
new_module Module

The module to insert.

required

Raises:

Type Description
ValueError

If no module of the given type exists in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
    """Insert a new module in the chain, right before the first module of the given type.

    Args:
        module_type: The type of module to insert before.
        new_module: The module to insert.

    Raises:
        ValueError: If no module of the given type exists in the chain.
    """
    for i, module in enumerate(self):
        if isinstance(module, module_type):
            self.insert(i, new_module)
            return
    raise ValueError(f"No module of type {module_type.__name__} found in the chain.")

layer

layer(
    key: str | int | Sequence[str | int],
    layer_type: type[T] = Module,
) -> T

Access a layer of the Chain given its type.

Example
# same as my_chain["Linear_2"], asserts it is a Linear
my_chain.layer("Linear_2", fl.Linear)


# same as my_chain[3], asserts it is a Linear
my_chain.layer(3, fl.Linear)

# probably won't work
my_chain.layer("Conv2d", fl.Linear)


# same as my_chain["foo"][42]["bar"],
# assuming bar is a MyType and all parents are Chains
my_chain.layer(("foo", 42, "bar"), fl.MyType)

Parameters:

Name Type Description Default
key str | int | Sequence[str | int]

The key or path of the layer.

required
layer_type type[T]

The type of the layer.

Module

Yields:

Type Description
T

The layer.

Raises:

Type Description
AssertionError

If the layer doesn't exist or the type is invalid.

Source code in src/refiners/fluxion/layers/chain.py
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:
    """Access a layer of the Chain given its type.

    Example:
        ```py
        # same as my_chain["Linear_2"], asserts it is a Linear
        my_chain.layer("Linear_2", fl.Linear)


        # same as my_chain[3], asserts it is a Linear
        my_chain.layer(3, fl.Linear)

        # probably won't work
        my_chain.layer("Conv2d", fl.Linear)


        # same as my_chain["foo"][42]["bar"],
        # assuming bar is a MyType and all parents are Chains
        my_chain.layer(("foo", 42, "bar"), fl.MyType)
        ```

    Args:
        key: The key or path of the layer.
        layer_type: The type of the layer.

    Yields:
        The layer.

    Raises:
        AssertionError: If the layer doesn't exist or the type is invalid.
    """
    if isinstance(key, (str, int)):
        r = self[key]
        assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}"
        return r
    if len(key) == 0:
        assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}"
        return self
    if len(key) == 1:
        return self.layer(key[0], layer_type)
    return self.layer(key[0], Chain).layer(key[1:], layer_type)

layers

layers(
    layer_type: type[T], recurse: bool = False
) -> Iterator[T]

Walk the Chain's sub-module tree and yield each layer of the given type.

Parameters:

Name Type Description Default
layer_type type[T]

The type of layer to yield.

required
recurse bool

Whether to recurse into sub-Chains.

False

Yields:

Type Description
T

Each module of the given layer_type.

Source code in src/refiners/fluxion/layers/chain.py
def layers(
    self,
    layer_type: type[T],
    recurse: bool = False,
) -> Iterator[T]:
    """Walk the Chain's sub-module tree and yield each layer of the given type.

    Args:
        layer_type: The type of layer to yield.
        recurse: Whether to recurse into sub-Chains.

    Yields:
        Each module of the given layer_type.
    """
    for module, _ in self.walk(layer_type, recurse):
        yield module

pop

pop(index: int = -1) -> Module

Pop a module from the chain at the given index.

Parameters:

Name Type Description Default
index int

The index of the module to pop.

-1

Returns:

Type Description
Module

The popped module.

Raises:

Type Description
IndexError

If the index is out of range.

Source code in src/refiners/fluxion/layers/chain.py
def pop(self, index: int = -1) -> Module:
    """Pop a module from the chain at the given index.

    Args:
        index: The index of the module to pop.

    Returns:
        The popped module.

    Raises:
        IndexError: If the index is out of range.
    """
    modules = list(self)
    if index < 0:
        index = len(modules) + index
    if index < 0 or index >= len(modules):
        raise IndexError("Index out of range.")
    removed_module = modules.pop(index)
    if isinstance(removed_module, ContextModule):
        removed_module._set_parent(None)
    self._regenerate_keys(modules)
    return removed_module

remove

remove(module: Module) -> None

Remove a module from the chain.

Parameters:

Name Type Description Default
module Module

The module to remove.

required

Raises:

Type Description
ValueError

If the module is not in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def remove(self, module: Module) -> None:
    """Remove a module from the chain.

    Args:
        module: The module to remove.

    Raises:
        ValueError: If the module is not in the chain.
    """
    modules = list(self)
    try:
        modules.remove(module)
    except ValueError:
        raise ValueError(f"{module} is not in {self}")
    self._regenerate_keys(modules)
    if isinstance(module, ContextModule):
        module._set_parent(None)

replace

replace(
    old_module: Module,
    new_module: Module,
    old_module_parent: Chain | None = None,
) -> None

Replace a module in the chain with a new module.

Parameters:

Name Type Description Default
old_module Module

The module to replace.

required
new_module Module

The module to replace with.

required
old_module_parent Chain | None

The parent of the old module. If None, the old module is orphanized.

None

Raises:

Type Description
ValueError

If the module is not in the chain.

Source code in src/refiners/fluxion/layers/chain.py
def replace(
    self,
    old_module: Module,
    new_module: Module,
    old_module_parent: "Chain | None" = None,
) -> None:
    """Replace a module in the chain with a new module.

    Args:
        old_module: The module to replace.
        new_module: The module to replace with.
        old_module_parent: The parent of the old module.
            If None, the old module is orphanized.

    Raises:
        ValueError: If the module is not in the chain.
    """
    modules = list(self)
    try:
        modules[modules.index(old_module)] = new_module
    except ValueError:
        raise ValueError(f"{old_module} is not in {self}")
    self._regenerate_keys(modules)
    if isinstance(new_module, ContextModule):
        new_module._set_parent(self)
    if isinstance(old_module, ContextModule):
        old_module._set_parent(old_module_parent)
    self._register_provider()

set_context

set_context(context: str, value: Any) -> None

Set a value in the context provider.

Parameters:

Name Type Description Default
context str

The context to update.

required
value Any

The value to set.

required
Source code in src/refiners/fluxion/layers/chain.py
def set_context(self, context: str, value: Any) -> None:
    """Set a value in the context provider.

    Args:
        context: The context to update.
        value: The value to set.
    """
    self._provider.set_context(context, value)
    self._register_provider()

structural_copy

structural_copy() -> TChain

Copy the structure of the Chain tree.

This method returns a recursive copy of the Chain tree where all inner nodes (instances of Chain and its subclasses) are duplicated and all leaves (regular Modules) are not.

Such copies can be adapted without disrupting the base model, but do not require extra GPU memory since the weights are in the leaves and hence not copied.

Source code in src/refiners/fluxion/layers/chain.py
def structural_copy(self: TChain) -> TChain:
    """Copy the structure of the Chain tree.

    This method returns a recursive copy of the Chain tree where all inner nodes
    (instances of Chain and its subclasses) are duplicated and all leaves
    (regular Modules) are not.

    Such copies can be adapted without disrupting the base model, but do not
    require extra GPU memory since the weights are in the leaves and hence not copied.
    """
    if hasattr(self, "_pre_structural_copy"):
        assert callable(self._pre_structural_copy)
        self._pre_structural_copy()

    modules = [structural_copy(m) for m in self]
    clone = super().structural_copy()
    clone._provider = ContextProvider.create(clone.init_context())

    for module in modules:
        clone.append(module=module)

    if hasattr(clone, "_post_structural_copy"):
        assert callable(clone._post_structural_copy)
        clone._post_structural_copy(self)

    return clone

walk

walk(
    predicate: (
        Callable[[Module, Chain], bool] | None
    ) = None,
    recurse: bool = False,
) -> Iterator[tuple[Module, Chain]]
walk(
    predicate: type[T], recurse: bool = False
) -> Iterator[tuple[T, Chain]]
walk(
    predicate: (
        type[T] | Callable[[Module, Chain], bool] | None
    ) = None,
    recurse: bool = False,
) -> (
    Iterator[tuple[T, Chain]]
    | Iterator[tuple[Module, Chain]]
)

Walk the Chain's sub-module tree and yield each module that matches the predicate.

Parameters:

Name Type Description Default
predicate type[T] | Callable[[Module, Chain], bool] | None

The predicate to match.

None
recurse bool

Whether to recurse into sub-Chains.

False

Yields:

Type Description
Iterator[tuple[T, Chain]] | Iterator[tuple[Module, Chain]]

Each module that matches the predicate.

Source code in src/refiners/fluxion/layers/chain.py
def walk(
    self,
    predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None,
    recurse: bool = False,
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
    """Walk the Chain's sub-module tree and yield each module that matches the predicate.

    Args:
        predicate: The predicate to match.
        recurse: Whether to recurse into sub-Chains.

    Yields:
        Each module that matches the predicate.
    """

    if get_origin(predicate) is not None:
        raise ValueError(f"subscripted generics cannot be used as predicates")

    if isinstance(predicate, type):
        # if the predicate is a Module type
        # build a predicate function that matches the type
        return self._walk(
            predicate=lambda m, _: isinstance(m, predicate),
            recurse=recurse,
        )
    else:
        return self._walk(
            predicate=predicate,
            recurse=recurse,
        )

Concatenate

Concatenate(*modules: Module, dim: int = 0)

Bases: Chain

Concatenation layer.

This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.

Example
concatenate = fl.Concatenate(
    fl.Linear(32, 128),
    fl.Linear(32, 128),
    dim=1,
)

tensor = torch.randn(2, 32)
output = concatenate(tensor)

assert output.shape == (2, 256)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *modules: Module, dim: int = 0) -> None:
    super().__init__(*modules)
    self.dim = dim

ContextModule

ContextModule(*args: Any, **kwargs: Any)

Bases: Module

A module containing a ContextProvider.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)
    self._parent = []

ensure_parent property

ensure_parent: Chain

Return the module's parent, or raise an error if module is an orphan.

parent property

parent: Chain | None

Return the module's parent, or None if module is an orphan.

provider property

provider: ContextProvider

Return the module's context provider.

get_parents

get_parents() -> list[Chain]

Recursively retrieve the module's parents.

Source code in src/refiners/fluxion/layers/module.py
def get_parents(self) -> "list[Chain]":
    """Recursively retrieve the module's parents."""
    return self._parent + self._parent[0].get_parents() if self._parent else []

get_path

get_path(
    parent: Chain | None = None, top: Module | None = None
) -> str

Get the path of the module in the chain.

Parameters:

Name Type Description Default
parent Chain | None

The parent of the module in the chain.

None
top Module | None

The top module of the chain. If None, the path will be relative to the root of the chain.

None
Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
    """Get the path of the module in the chain.

    Args:
        parent: The parent of the module in the chain.
        top: The top module of the chain.
            If None, the path will be relative to the root of the chain.
    """

    return super().get_path(parent=parent or self.parent, top=top)

use_context

use_context(context_name: str) -> Context

Retrieve the context object from the module's context provider.

Source code in src/refiners/fluxion/layers/module.py
def use_context(self, context_name: str) -> Context:
    """Retrieve the context object from the module's context provider."""
    context = self.provider.get_context(context_name)
    assert context is not None, f"Context {context_name} not found."
    return context

Conv2d

Conv2d(
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] = (1, 1),
    padding: int | tuple[int, int] | str = (0, 0),
    groups: int = 1,
    use_bias: bool = True,
    dilation: int | tuple[int, int] = (1, 1),
    padding_mode: str = "zeros",
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Conv2d, WeightedModule

2D Convolutional layer.

This layer wraps torch.nn.Conv2d.

Receives:

Type Description
Real[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description
Real[Tensor, 'batch out_channels out_height out_width']
Example
conv2d = fl.Conv2d(
    in_channels=3,
    out_channels=32,
    kernel_size=3,
    stride=1,
    padding=1,
)

tensor = torch.randn(2, 3, 128, 128)
output = conv2d(tensor)

assert output.shape == (2, 32, 128, 128)
Source code in src/refiners/fluxion/layers/conv.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] = (1, 1),
    padding: int | tuple[int, int] | str = (0, 0),
    groups: int = 1,
    use_bias: bool = True,
    dilation: int | tuple[int, int] = (1, 1),
    padding_mode: str = "zeros",
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(  # type: ignore
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
        bias=use_bias,
        padding_mode=padding_mode,
        device=device,
        dtype=dtype,
    )
    self.use_bias = use_bias

ConvTranspose2d

ConvTranspose2d(
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 0,
    output_padding: int | tuple[int, int] = 0,
    groups: int = 1,
    use_bias: bool = True,
    dilation: int | tuple[int, int] = 1,
    padding_mode: str = "zeros",
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: ConvTranspose2d, WeightedModule

2D Transposed Convolutional layer.

This layer wraps torch.nn.ConvTranspose2d.

Receives:

Type Description
Real[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description
Real[Tensor, 'batch out_channels out_height out_width']
Example
conv2d = fl.ConvTranspose2d(
    in_channels=3,
    out_channels=32,
    kernel_size=3,
    stride=1,
    padding=1,
)

tensor = torch.randn(2, 3, 128, 128)
output = conv2d(tensor)

assert output.shape == (2, 32, 128, 128)
Source code in src/refiners/fluxion/layers/conv.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] = 1,
    padding: int | tuple[int, int] = 0,
    output_padding: int | tuple[int, int] = 0,
    groups: int = 1,
    use_bias: bool = True,
    dilation: int | tuple[int, int] = 1,
    padding_mode: str = "zeros",
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(  # type: ignore
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
        bias=use_bias,
        padding_mode=padding_mode,
        device=device,
        dtype=dtype,
    )
    self.use_bias = use_bias

Converter

Converter(set_device: bool = True, set_dtype: bool = True)

Bases: ContextModule

A Converter class that adjusts tensor properties based on a parent module's settings.

This class inherits from ContextModule and provides functionality to adjust the device and dtype of input tensor(s) to match the parent module's attributes.

Note

Ensure the parent module has device and dtype attributes if set_device or set_dtype are set to True.

Parameters:

Name Type Description Default
set_device bool

If True, matches the device of the input tensor(s) to the parent's device.

True
set_dtype bool

If True, matches the dtype of the input tensor(s) to the parent's dtype.

True
Source code in src/refiners/fluxion/layers/converter.py
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
    """Initializes the Converter layer.

    Args:
        set_device: If True, matches the device of the input tensor(s) to the parent's device.
        set_dtype: If True, matches the dtype of the input tensor(s) to the parent's dtype.
    """
    super().__init__()
    self.set_device = set_device
    self.set_dtype = set_dtype

Cos

Cos(*args: Any, **kwargs: Any)

Bases: Module

Cosine operator layer.

This layer applies the cosine function to the input tensor. See also torch.cos.

Example
cos = fl.Cos()

tensor = torch.tensor([0, torch.pi])
output = cos(tensor)

expected_output = torch.tensor([1.0, -1.0])
assert torch.allclose(output, expected_output, atol=1e-6)
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

Distribute

Distribute(*args: Module | Iterable[Module])

Bases: Chain

Distribute layer.

This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.

Example
distribute = fl.Distribute(
    fl.Linear(32, 128),
    fl.Linear(64, 256),
)

tensor1 = torch.randn(2, 32)
tensor2 = torch.randn(4, 64)
outputs = distribute(tensor1, tensor2)

assert len(outputs) == 2
assert outputs[0].shape == (2, 128)
assert outputs[1].shape == (4, 256)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

Downsample

Downsample(
    channels: int,
    scale_factor: int,
    padding: int = 0,
    register_shape: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Downsample layer.

This layer downsamples the input by the given scale factor.

Raises:

Type Description
RuntimeError

If the context sampling is not set or if the context does not contain a list.

Parameters:

Name Type Description Default
channels int

The number of input and output channels.

required
scale_factor int

The factor by which to downsample the input.

required
padding int

The amount of zero-padding added to both sides of the input.

0
register_shape bool

If True, registers the input shape in the context.

True
device device | str | None

The device to use for the convolutional layer.

None
dtype dtype | None

The dtype to use for the convolutional layer.

None
Source code in src/refiners/fluxion/layers/sampling.py
def __init__(
    self,
    channels: int,
    scale_factor: int,
    padding: int = 0,
    register_shape: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
):
    """Initializes the Downsample layer.

    Args:
        channels: The number of input and output channels.
        scale_factor: The factor by which to downsample the input.
        padding: The amount of zero-padding added to both sides of the input.
        register_shape: If True, registers the input shape in the context.
        device: The device to use for the convolutional layer.
        dtype: The dtype to use for the convolutional layer.
    """
    self.channels = channels
    self.in_channels = channels
    self.out_channels = channels
    self.scale_factor = scale_factor
    self.padding = padding

    super().__init__(
        Conv2d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=3,
            stride=scale_factor,
            padding=padding,
            device=device,
            dtype=dtype,
        ),
    )

    if padding == 0:
        zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
        self.insert(
            index=0,
            module=Lambda(func=zero_pad),
        )

    if register_shape:
        self.insert(
            index=0,
            module=SetContext(
                context="sampling",
                key="shapes",
                callback=self.register_shape,
            ),
        )

Embedding

Embedding(
    num_embeddings: int,
    embedding_dim: int,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Embedding, WeightedModule

Embedding layer.

This layer wraps torch.nn.Embedding.

Receives:

Type Description
Int[Tensor, 'batch length']

Returns:

Type Description
Float[Tensor, 'batch length embedding_dim']
Example
embedding = fl.Embedding(
    num_embeddings=10,
    embedding_dim=128
)

tensor = torch.randint(0, 10, (2, 10))
output = embedding(tensor)

assert output.shape == (2, 10, 128)

Parameters:

Name Type Description Default
num_embeddings int

The number of embeddings.

required
embedding_dim int

The dimension of the embeddings.

required
device device | str | None

The device to use for the embedding layer.

None
dtype dtype | None

The dtype to use for the embedding layer.

None
Source code in src/refiners/fluxion/layers/embedding.py
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    device: Device | str | None = None,
    dtype: DType | None = None,
):
    """Initializes the Embedding layer.

    Args:
        num_embeddings: The number of embeddings.
        embedding_dim: The dimension of the embeddings.
        device: The device to use for the embedding layer.
        dtype: The dtype to use for the embedding layer.
    """
    _Embedding.__init__(  # type: ignore
        self,
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        device=device,
        dtype=dtype,
    )

Flatten

Flatten(start_dim: int = 0, end_dim: int = -1)

Bases: Module

Flatten operation layer.

This layer flattens the input tensor between the given dimensions. See also torch.flatten.

Example
flatten = fl.Flatten(start_dim=1)

tensor = torch.randn(10, 10, 10)
output = flatten(tensor)

assert output.shape == (10, 100)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(
    self,
    start_dim: int = 0,
    end_dim: int = -1,
) -> None:
    super().__init__()
    self.start_dim = start_dim
    self.end_dim = end_dim

GLU

GLU(activation: Activation)

Bases: Activation

Gated Linear Unit activation function.

See [arXiv:2002.05202] GLU Variants Improve Transformer for more details.

Example
glu = fl.GLU(fl.ReLU())
tensor = torch.tensor([[1.0, 0.0, -1.0, 1.0]])
output = glu(tensor)
assert torch.allclose(output, torch.tensor([0.0, 0.0]))
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self, activation: Activation) -> None:
    super().__init__()
    self.activation = activation

GeLU

GeLU(approximation: GeLUApproximation = NONE)

Bases: Activation

Gaussian Error Linear Unit activation function.

This activation can be quite expensive to compute, a few approximations are available, see GeLUApproximation.

See [arXiv:1606.08415] Gaussian Error Linear Units for more details.

Example
gelu = fl.GeLU()

tensor = torch.tensor([[-1.0, 0.0, 1.0]])
output = gelu(tensor)
Source code in src/refiners/fluxion/layers/activations.py
def __init__(
    self,
    approximation: GeLUApproximation = GeLUApproximation.NONE,
) -> None:
    super().__init__()
    self.approximation = approximation

GeLUApproximation

Bases: Enum

Approximation methods for the Gaussian Error Linear Unit activation function.

Attributes:

Name Type Description
NONE

No approximation, use the original formula.

TANH

Use the tanh approximation.

SIGMOID

Use the sigmoid approximation.

GetArg

GetArg(index: int)

Bases: Module

GetArg operation layer.

This layer returns the nth tensor of the input arguments.

Example
get_arg = fl.GetArg(1)

inputs = (
    torch.randn(10, 10),
    torch.randn(20, 20),
    torch.randn(30, 30),
)
output = get_arg(*inputs)

assert id(inputs[1]) == id(output)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, index: int) -> None:
    super().__init__()
    self.index = index

GroupNorm

GroupNorm(
    channels: int,
    num_groups: int,
    eps: float = 1e-05,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: GroupNorm, WeightedModule

Group Normalization layer.

This layer wraps torch.nn.GroupNorm.

Receives:

Type Description
Float[Tensor, 'batch channels *normalized_shape']

Returns:

Type Description
Float[Tensor, 'batch channels *normalized_shape']
Example
groupnorm = fl.GroupNorm(channels=128, num_groups=8)

tensor = torch.randn(2, 128, 8)
output = groupnorm(tensor)

assert output.shape == (2, 128, 8)
Source code in src/refiners/fluxion/layers/norm.py
def __init__(
    self,
    channels: int,
    num_groups: int,
    eps: float = 1e-5,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(  # type: ignore
        num_groups=num_groups,
        num_channels=channels,
        eps=eps,
        affine=True,  # otherwise not a WeightedModule
        device=device,
        dtype=dtype,
    )
    self.channels = channels
    self.num_groups = num_groups
    self.eps = eps

Identity

Identity()

Bases: Module

Identity operator layer.

This layer simply returns the input tensor.

Example
identity = fl.Identity()

tensor = torch.randn(10, 10)
output = identity(tensor)

assert torch.equal(tensor, output)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self) -> None:
    super().__init__()

InstanceNorm2d

InstanceNorm2d(
    num_features: int,
    eps: float = 1e-05,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: InstanceNorm2d, Module

Instance Normalization layer.

This layer wraps torch.nn.InstanceNorm2d.

Receives:

Type Description
Float[Tensor, 'batch channels height width']

Returns:

Type Description
Float[Tensor, 'batch channels height width']
Source code in src/refiners/fluxion/layers/norm.py
def __init__(
    self,
    num_features: int,
    eps: float = 1e-05,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(  # type: ignore
        num_features=num_features,
        eps=eps,
        device=device,
        dtype=dtype,
    )

Interpolate

Interpolate(mode: str = 'nearest', antialias: bool = False)

Bases: Module

Interpolate layer.

This layer wraps torch.nn.functional.interpolate.

Source code in src/refiners/fluxion/layers/sampling.py
def __init__(
    self,
    mode: str = "nearest",
    antialias: bool = False,
) -> None:
    super().__init__()
    self.mode = mode
    self.antialias = antialias

Lambda

Lambda(func: Callable[..., Any])

Bases: Module

Lambda layer.

This layer wraps a Callable.

When called, it will
  • Execute the Callable with the given arguments
  • Return the output of the Callable)
Example
lambda_layer = fl.Lambda(lambda x: x + 1)

tensor = torch.tensor([1, 2, 3])
output = lambda_layer(tensor)

expected_output = torch.tensor([2, 3, 4])
assert torch.allclose(output, expected_output)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, func: Callable[..., Any]) -> None:
    super().__init__()
    self.func = func

LayerNorm

LayerNorm(
    normalized_shape: int | list[int],
    eps: float = 1e-05,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: LayerNorm, WeightedModule

Layer Normalization layer.

This layer wraps torch.nn.LayerNorm.

Receives:

Type Description
Float[Tensor, batch * normalized_shape]

Returns:

Type Description
Float[Tensor, batch * normalized_shape]
Example
layernorm = fl.LayerNorm(normalized_shape=128)

tensor = torch.randn(2, 128)
output = layernorm(tensor)

assert output.shape == (2, 128)
Source code in src/refiners/fluxion/layers/norm.py
def __init__(
    self,
    normalized_shape: int | list[int],
    eps: float = 0.00001,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__(  # type: ignore
        normalized_shape=normalized_shape,
        eps=eps,
        elementwise_affine=True,  # otherwise not a WeightedModule
        device=device,
        dtype=dtype,
    )

LayerNorm2d

LayerNorm2d(
    channels: int,
    eps: float = 1e-06,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: WeightedModule

2D Layer Normalization layer.

This layer applies Layer Normalization along the 2nd dimension of a 4D tensor.

Receives:

Type Description
Float[Tensor, 'batch channels height width']

Returns:

Type Description
Float[Tensor, 'batch channels height width']
Source code in src/refiners/fluxion/layers/norm.py
def __init__(
    self,
    channels: int,
    eps: float = 1e-6,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__()
    self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
    self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
    self.eps = eps

Linear

Linear(
    in_features: int,
    out_features: int,
    bias: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Linear, WeightedModule

Linear layer.

This layer wraps torch.nn.Linear.

Receives:

Name Type Description
Input Float[Tensor, 'batch in_features']

Returns:

Name Type Description
Output Float[Tensor, 'batch out_features']
Example
linear = fl.Linear(in_features=32, out_features=128)

tensor = torch.randn(2, 32)
output = linear(tensor)

assert output.shape == (2, 128)

Parameters:

Name Type Description Default
in_features int

The number of input features.

required
out_features int

The number of output features.

required
bias bool

If True, adds a learnable bias to the output.

True
device device | str | None

The device to use for the linear layer.

None
dtype dtype | None

The dtype to use for the linear layer.

None
Source code in src/refiners/fluxion/layers/linear.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initializes the Linear layer.

    Args:
        in_features: The number of input features.
        out_features: The number of output features.
        bias: If True, adds a learnable bias to the output.
        device: The device to use for the linear layer.
        dtype: The dtype to use for the linear layer.
    """
    self.in_features = in_features
    self.out_features = out_features
    super().__init__(  # type: ignore
        in_features=in_features,
        out_features=out_features,
        bias=bias,
        device=device,
        dtype=dtype,
    )

Matmul

Matmul(input: Module, other: Module)

Bases: Chain

Matrix multiplication layer.

This layer returns the matrix multiplication of the outputs of its two sub-modules.

Example
matmul = fl.Matmul(
    fl.Identity(),
    fl.Multiply(scale=2),
)

tensor = torch.randn(10, 10)
output = matmul(tensor)

expected_output = tensor @ (2 * tensor)
assert torch.allclose(output, expected_output)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, input: Module, other: Module) -> None:
    super().__init__(
        input,
        other,
    )

MaxPool1d

MaxPool1d(
    kernel_size: int,
    stride: int | None = None,
    padding: int = 0,
    dilation: int = 1,
    return_indices: bool = False,
    ceil_mode: bool = False,
)

Bases: MaxPool1d, Module

MaxPool1d layer.

This layer wraps torch.nn.MaxPool1d.

Receives:

Type Description
Float[Tensor, 'batch channels in_length']

Returns:

Type Description
Float[Tensor, 'batch channels out_length']

Parameters:

Name Type Description Default
kernel_size int

The size of the sliding window.

required
stride int | None

The stride of the sliding window.

None
padding int

The amount of zero-padding added to both sides of the input.

0
dilation int

The spacing between kernel elements.

1
return_indices bool

If True, returns the max indices along with the outputs.

False
ceil_mode bool

If True, uses ceil instead of floor to compute the output shape.

False
Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(
    self,
    kernel_size: int,
    stride: int | None = None,
    padding: int = 0,
    dilation: int = 1,
    return_indices: bool = False,
    ceil_mode: bool = False,
) -> None:
    """Initializes the MaxPool1d layer.

    Args:
        kernel_size: The size of the sliding window.
        stride: The stride of the sliding window.
        padding: The amount of zero-padding added to both sides of the input.
        dilation: The spacing between kernel elements.
        return_indices: If True, returns the max indices along with the outputs.
        ceil_mode: If True, uses ceil instead of floor to compute the output shape.
    """
    super().__init__(
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        return_indices=return_indices,
        ceil_mode=ceil_mode,
    )

MaxPool2d

MaxPool2d(
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] | None = None,
    padding: int | tuple[int, int] = (0, 0),
    dilation: int | tuple[int, int] = (1, 1),
    return_indices: bool = False,
    ceil_mode: bool = False,
)

Bases: MaxPool2d, Module

MaxPool2d layer.

This layer wraps torch.nn.MaxPool2d.

Receives:

Type Description
Float[Tensor, 'batch channels in_height in_width']

Returns:

Type Description
Float[Tensor, 'batch channels out_height out_width']

Parameters:

Name Type Description Default
kernel_size int | tuple[int, int]

The size of the sliding window.

required
stride int | tuple[int, int] | None

The stride of the sliding window.

None
padding int | tuple[int, int]

The amount of zero-padding added to both sides of the input.

(0, 0)
dilation int | tuple[int, int]

The spacing between kernel elements.

(1, 1)
return_indices bool

If True, returns the max indices along with the outputs.

False
ceil_mode bool

If True, uses ceil instead of floor to compute the output shape.

False
Source code in src/refiners/fluxion/layers/maxpool.py
def __init__(
    self,
    kernel_size: int | tuple[int, int],
    stride: int | tuple[int, int] | None = None,
    padding: int | tuple[int, int] = (0, 0),
    dilation: int | tuple[int, int] = (1, 1),
    return_indices: bool = False,
    ceil_mode: bool = False,
) -> None:
    """Initializes the MaxPool2d layer.

    Args:
        kernel_size: The size of the sliding window.
        stride: The stride of the sliding window.
        padding: The amount of zero-padding added to both sides of the input.
        dilation: The spacing between kernel elements.
        return_indices: If True, returns the max indices along with the outputs.
        ceil_mode: If True, uses ceil instead of floor to compute the output shape.
    """
    super().__init__(
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        return_indices=return_indices,
        ceil_mode=ceil_mode,
    )

Module

Module(*args: Any, **kwargs: Any)

Bases: Module

A wrapper around torch.nn.Module.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

basic_attributes

basic_attributes(
    init_attrs_only: bool = False,
) -> dict[str, BasicType | Sequence[BasicType]]

Return a dictionary of basic attributes of the module.

Parameters:

Name Type Description Default
init_attrs_only bool

Whether to only return attributes that are passed to the module's constructor.

False
Source code in src/refiners/fluxion/layers/module.py
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
    """Return a dictionary of basic attributes of the module.

    Args:
        init_attrs_only: Whether to only return attributes that are passed to the module's constructor.
    """
    sig = signature(obj=self.__init__)
    init_params = set(sig.parameters.keys()) - {"self"}
    default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}

    def is_basic_attribute(key: str, value: Any) -> bool:
        if key.startswith("_"):
            return False

        if isinstance(value, BasicType):
            return True

        if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
            return True

        return False

    return {
        key: value
        for key, value in self.__dict__.items()
        if is_basic_attribute(key=key, value=value)
        and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
    }

get_path

get_path(
    parent: Chain | None = None, top: Module | None = None
) -> str

Get the path of the module in the chain.

Parameters:

Name Type Description Default
parent Chain | None

The parent of the module in the chain.

None
top Module | None

The top module of the chain. If None, the path will be relative to the root of the chain.

None
Source code in src/refiners/fluxion/layers/module.py
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
    """Get the path of the module in the chain.

    Args:
        parent: The parent of the module in the chain.
        top: The top module of the chain.
            If None, the path will be relative to the root of the chain.
    """
    if (parent is None) or (self == top):
        return self.__class__.__name__
    for k, m in parent._modules.items():  # type: ignore
        if m is self:
            return parent.get_path(parent=parent.parent, top=top) + "." + k
    raise ValueError(f"{self} not found in {parent}")

load_from_safetensors

load_from_safetensors(
    tensors_path: str | Path, strict: bool = True
) -> T

Load the module's state from a SafeTensors file.

Parameters:

Name Type Description Default
tensors_path str | Path

The path to the SafeTensors file.

required
strict bool

Whether to raise an error if the SafeTensors's content doesn't map perfectly to the module's state.

True

Returns:

Type Description
T

The module, with its state loaded from the SafeTensors file.

Source code in src/refiners/fluxion/layers/module.py
def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T:
    """Load the module's state from a SafeTensors file.

    Args:
        tensors_path: The path to the SafeTensors file.
        strict: Whether to raise an error if the SafeTensors's
            content doesn't map perfectly to the module's state.

    Returns:
        The module, with its state loaded from the SafeTensors file.
    """
    state_dict = load_from_safetensors(tensors_path)
    self.load_state_dict(state_dict, strict=strict)
    return self

named_modules

named_modules(
    *args: Any, **kwargs: Any
) -> Generator[tuple[str, Module], None, None]

Get all the sub-modules of the module.

Returns:

Type Description
None

An iterator over all the sub-modules of the module.

Source code in src/refiners/fluxion/layers/module.py
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]":  # type: ignore
    """Get all the sub-modules of the module.

    Returns:
        An iterator over all the sub-modules of the module.
    """
    return super().named_modules(*args)  # type: ignore

pretty_print

pretty_print(depth: int = -1) -> None

Print the module in a tree-like format.

Parameters:

Name Type Description Default
depth int

The maximum depth of the tree to print. If negative, the whole tree is printed.

-1
Source code in src/refiners/fluxion/layers/module.py
def pretty_print(self, depth: int = -1) -> None:
    """Print the module in a tree-like format.

    Args:
        depth: The maximum depth of the tree to print.
            If negative, the whole tree is printed.
    """
    tree = ModuleTree(module=self)
    print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth))  # type: ignore[reportPrivateUsage]

to

to(
    device: device | str | None = None,
    dtype: dtype | None = None,
) -> T

Move the module to the given device and cast its parameters to the given dtype.

Parameters:

Name Type Description Default
device device | str | None

The device to move the module to.

None
dtype dtype | None

The dtype to cast the module's parameters to.

None

Returns:

Type Description
T

The module, moved to the given device and cast to the given dtype.

Source code in src/refiners/fluxion/layers/module.py
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T:  # type: ignore
    """Move the module to the given device and cast its parameters to the given dtype.

    Args:
        device: The device to move the module to.
        dtype: The dtype to cast the module's parameters to.

    Returns:
        The module, moved to the given device and cast to the given dtype.
    """
    return super().to(device=device, dtype=dtype)  # type: ignore

MultiLinear

MultiLinear(
    input_dim: int,
    output_dim: int,
    inner_dim: int,
    num_layers: int,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Multi-layer linear network.

This layer wraps multiple torch.nn.Linear layers, with an Activation layer in between.

Receives:

Name Type Description
Input Float[Tensor, 'batch input_dim']

Returns:

Name Type Description
Output Float[Tensor, 'batch output_dim']
Example
linear = fl.MultiLinear(
    input_dim=32,
    output_dim=128,
    inner_dim=64,
    num_layers=3,
)

tensor = torch.randn(2, 32)
output = linear(tensor)

assert output.shape == (2, 128)

Parameters:

Name Type Description Default
input_dim int

The input dimension of the first linear layer.

required
output_dim int

The output dimension of the last linear layer.

required
inner_dim int

The output dimension of the inner linear layers.

required
num_layers int

The number of linear layers.

required
device device | str | None

The device to use for the linear layers.

None
dtype dtype | None

The dtype to use for the linear layers.

None
Source code in src/refiners/fluxion/layers/linear.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    inner_dim: int,
    num_layers: int,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initializes the MultiLinear layer.

    Args:
        input_dim: The input dimension of the first linear layer.
        output_dim: The output dimension of the last linear layer.
        inner_dim: The output dimension of the inner linear layers.
        num_layers: The number of linear layers.
        device: The device to use for the linear layers.
        dtype: The dtype to use for the linear layers.
    """
    layers: list[Module] = []
    for i in range(num_layers - 1):
        layers.append(
            Linear(
                in_features=input_dim if i == 0 else inner_dim,
                out_features=inner_dim,
                device=device,
                dtype=dtype,
            )
        )
        layers.append(
            ReLU(),
        )
    layers.append(
        Linear(
            in_features=inner_dim,
            out_features=output_dim,
            device=device,
            dtype=dtype,
        )
    )

    super().__init__(layers)

Multiply

Multiply(scale: float = 1.0, bias: float = 0.0)

Bases: Module

Multiply operator layer.

This layer scales and shifts the input tensor by the given scale and bias.

Example
multiply = fl.Multiply(scale=2, bias=1)

tensor = torch.ones(1)
output = multiply(tensor)

assert torch.allclose(output, torch.tensor([3.0]))
Source code in src/refiners/fluxion/layers/basics.py
def __init__(
    self,
    scale: float = 1.0,
    bias: float = 0.0,
) -> None:
    super().__init__()
    self.scale = scale
    self.bias = bias

Parallel

Parallel(*args: Module | Iterable[Module])

Bases: Chain

Parallel layer.

This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.

Example
parallel = fl.Parallel(
    fl.Linear(32, 64),
    fl.Identity(),
    fl.Linear(32, 128),
)

tensor = torch.randn(2, 32)
outputs = parallel(tensor)

assert len(outputs) == 3
assert outputs[0].shape == (2, 64)
assert torch.allclose(outputs[1], tensor)
assert outputs[2].shape == (2, 128)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

Parameter

Parameter(
    *dims: int,
    requires_grad: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None
)

Bases: WeightedModule

Parameter layer.

This layer simple wraps a PyTorch Parameter. When called, it simply returns the Parameter Tensor.

Attributes:

Name Type Description
weight Parameter

The parameter Tensor.

Source code in src/refiners/fluxion/layers/basics.py
def __init__(
    self,
    *dims: int,
    requires_grad: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    super().__init__()
    self.dims = dims
    self.weight = TorchParameter(
        requires_grad=requires_grad,
        data=torch.randn(
            *dims,
            device=device,
            dtype=dtype,
        ),
    )

Passthrough

Passthrough(*args: Module | Iterable[Module])

Bases: Chain

Passthrough layer.

This layer call its sub-modules sequentially, and returns its original inputs, like an Identity layer.

Example
passthrough = fl.Passthrough(
    fl.Linear(32, 128),
    fl.ReLU(),
    fl.Linear(128, 128),
)

tensor = torch.randn(2, 32)
output = passthrough(tensor)

assert torch.allclose(output, tensor)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

Permute

Permute(*dims: int)

Bases: Module

Permute operation layer.

This layer permutes the input tensor according to the given dimensions. See also torch.permute.

Example
permute = fl.Permute(2, 0, 1)

tensor = torch.randn(10, 20, 30)
output = permute(tensor)

assert output.shape == (30, 10, 20)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *dims: int) -> None:
    super().__init__()
    self.dims = dims

PixelUnshuffle

PixelUnshuffle(downscale_factor: int)

Bases: PixelUnshuffle, Module

Pixel Unshuffle layer.

This layer wraps torch.nn.PixelUnshuffle.

Receives:

Type Description
Float[Tensor, 'batch in_channels in_height in_width']

Returns:

Type Description
Float[Tensor, 'batch out_channels out_height out_width']
Source code in src/refiners/fluxion/layers/pixelshuffle.py
def __init__(self, downscale_factor: int):
    _PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)

ReLU

ReLU()

Bases: Activation

Rectified Linear Unit activation function.

See Rectified Linear Units Improve Restricted Boltzmann Machines and Cognitron: A self-organizing multilayered neural network

Example
relu = fl.ReLU()

tensor = torch.tensor([[-1.0, 0.0, 1.0]])
output = relu(tensor)

expected_output = torch.tensor([[0.0, 0.0, 1.0]])
assert torch.equal(output, expected_output)
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:
    super().__init__()

ReflectionPad2d

ReflectionPad2d(padding: int)

Bases: ReflectionPad2d, Module

Reflection padding layer.

This layer wraps torch.nn.ReflectionPad2d.

Receives:

Type Description
Float[Tensor, 'batch channels in_height in_width']

Returns:

Type Description
Float[Tensor, 'batch channels out_height out_width']
Source code in src/refiners/fluxion/layers/padding.py
def __init__(self, padding: int) -> None:
    super().__init__(padding=padding)

Reshape

Reshape(*shape: int)

Bases: Module

Reshape operation layer.

This layer reshapes the input tensor to a specific shape (which must be compatible with the original shape). See also torch.reshape.

Warning

The first dimension (batch dimension) is forcefully preserved.

Example
reshape = fl.Reshape(5, 2)

tensor = torch.randn(2, 10, 1)
output = reshape(tensor)

assert output.shape == (2, 5, 2)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, *shape: int) -> None:
    super().__init__()
    self.shape = shape

Residual

Residual(*args: Module | Iterable[Module])

Bases: Chain

Residual layer.

This layer calls its sub-modules sequentially, and adds the original input to the output.

Example
residual = fl.Residual(
    fl.Multiply(scale=10),
)

tensor = torch.ones(2, 32)
output = residual(tensor)

assert output.shape == (2, 32)
assert torch.allclose(output, 10 * tensor + tensor)
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

Return

Return(*args: Any, **kwargs: Any)

Bases: Module

Return layer.

This layer stops the execution of a Chain when encountered.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

ScaledDotProductAttention

ScaledDotProductAttention(
    num_heads: int = 1,
    is_causal: bool = False,
    is_optimized: bool = True,
    slice_size: int | None = None,
)

Bases: Module

Scaled Dot Product Attention.

See [arXiv:1706.03762] Attention Is All You Need (Figure 2) for more details

Note

This layer simply wraps scaled_dot_product_attention inside an fl.Module.

Receives:

Name Type Description
Query Float[Tensor, 'batch num_queries embedding_dim']
Key Float[Tensor, 'batch num_keys embedding_dim']
Value Float[Tensor, 'batch num_values embedding_dim']

Returns:

Type Description
Float[Tensor, 'batch num_queries embedding_dim']
Example
attention = fl.ScaledDotProductAttention(num_heads=8)

query = torch.randn(2, 10, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)
output = attention(query, key, value)

assert output.shape == (2, 10, 128)

Parameters:

Name Type Description Default
num_heads int

The number of heads to use.

1
is_causal bool

Whether to use causal attention.

False
is_optimized bool

Whether to use optimized attention.

True
slice_size int | None

The slice size to use for the optimized attention.

None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(
    self,
    num_heads: int = 1,
    is_causal: bool = False,
    is_optimized: bool = True,
    slice_size: int | None = None,
) -> None:
    """Initialize the Scaled Dot Product Attention layer.

    Args:
        num_heads: The number of heads to use.
        is_causal: Whether to use causal attention.
        is_optimized: Whether to use optimized attention.
        slice_size: The slice size to use for the optimized attention.
    """
    super().__init__()
    self.num_heads = num_heads
    self.is_causal = is_causal
    self.is_optimized = is_optimized
    self.slice_size = slice_size
    self.dot_product = (
        scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized
    )

SelfAttention

SelfAttention(
    embedding_dim: int,
    inner_dim: int | None = None,
    num_heads: int = 1,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Attention

Multi-Head Self-Attention layer.

This layer simply chains
  • a Parallel layer, which duplicates the input Tensor (for each Linear layer in the Attention layer)
  • an Attention layer

Receives:

Type Description
Float[Tensor, 'batch sequence_length embedding_dim']

Returns:

Type Description
Float[Tensor, 'batch sequence_length embedding_dim']
Example
self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)

tensor = torch.randn(2, 10, 128)
output = self_attention(tensor)

assert output.shape == (2, 10, 128)

Parameters:

Name Type Description Default
embedding_dim int

The embedding dimension of the input and output tensors.

required
inner_dim int | None

The inner dimension of the linear layers.

None
num_heads int

The number of heads to use.

1
use_bias bool

Whether to use bias in the linear layers.

True
is_causal bool

Whether to use causal attention.

False
is_optimized bool

Whether to use optimized attention.

True
device device | str | None

The device to use.

None
dtype dtype | None

The dtype to use.

None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(
    self,
    embedding_dim: int,
    inner_dim: int | None = None,
    num_heads: int = 1,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the Self-Attention layer.

    Args:
        embedding_dim: The embedding dimension of the input and output tensors.
        inner_dim: The inner dimension of the linear layers.
        num_heads: The number of heads to use.
        use_bias: Whether to use bias in the linear layers.
        is_causal: Whether to use causal attention.
        is_optimized: Whether to use optimized attention.
        device: The device to use.
        dtype: The dtype to use.
    """
    super().__init__(
        embedding_dim=embedding_dim,
        inner_dim=inner_dim,
        num_heads=num_heads,
        use_bias=use_bias,
        is_causal=is_causal,
        is_optimized=is_optimized,
        device=device,
        dtype=dtype,
    )
    self.insert(
        index=0,
        module=Parallel(
            Identity(),  # Query projection's input
            Identity(),  # Key projection's input
            Identity(),  # Value projection's input
        ),
    )

SelfAttention2d

SelfAttention2d(
    channels: int,
    num_heads: int = 1,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: SelfAttention

Multi-Head 2D Self-Attention layer.

This Module simply chains
  • a Lambda layer, which transforms the input Tensor into a sequence
  • a SelfAttention layer
  • a Lambda layer, which transforms the output sequence into a 2D Tensor

Receives:

Type Description
Float[Tensor, 'batch channels height width']

Returns:

Type Description
Float[Tensor, 'batch channels height width']
Example
self_attention = fl.SelfAttention2d(channels=128, num_heads=8)

tensor = torch.randn(2, 128, 64, 64)
output = self_attention(tensor)

assert output.shape == (2, 128, 64, 64)

Parameters:

Name Type Description Default
channels int

The number of channels of the input and output tensors.

required
num_heads int

The number of heads to use.

1
use_bias bool

Whether to use bias in the linear layers.

True
is_causal bool

Whether to use causal attention.

False
is_optimized bool

Whether to use optimized attention.

True
device device | str | None

The device to use.

None
dtype dtype | None

The dtype to use.

None
Source code in src/refiners/fluxion/layers/attentions.py
def __init__(
    self,
    channels: int,
    num_heads: int = 1,
    use_bias: bool = True,
    is_causal: bool = False,
    is_optimized: bool = True,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the 2D Self-Attention layer.

    Args:
        channels: The number of channels of the input and output tensors.
        num_heads: The number of heads to use.
        use_bias: Whether to use bias in the linear layers.
        is_causal: Whether to use causal attention.
        is_optimized: Whether to use optimized attention.
        device: The device to use.
        dtype: The dtype to use.
    """
    assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
    self.channels = channels

    super().__init__(
        embedding_dim=channels,
        num_heads=num_heads,
        use_bias=use_bias,
        is_causal=is_causal,
        is_optimized=is_optimized,
        device=device,
        dtype=dtype,
    )

    self.insert(0, Lambda(self._tensor_2d_to_sequence))
    self.append(Lambda(self._sequence_to_tensor_2d))

SetContext

SetContext(
    context: str,
    key: str,
    callback: Callable[[Any, Any], Any] | None = None,
)

Bases: ContextModule

SetContext layer.

This layer writes to the ContextProvider of its parent Chain.

When called (without a callback), it will
  • Update the context with the given key and the input value
  • Return the input value
When called (with a callback), it will
  • Call the callback with the current value and the input value (the callback may update the context with a new value, or not)
  • Return the input value
Warning

The context needs to already exist in the ContextProvider

Source code in src/refiners/fluxion/layers/chain.py
def __init__(
    self,
    context: str,
    key: str,
    callback: Callable[[Any, Any], Any] | None = None,
) -> None:
    super().__init__()
    self.context = context
    self.key = key
    self.callback = callback

SiLU

SiLU()

Bases: Activation

Sigmoid Linear Unit activation function.

See [arXiv:1702.03118] Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning for more details.

Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:
    super().__init__()

Sigmoid

Sigmoid()

Bases: Activation

Sigmoid activation function.

Example
sigmoid = fl.Sigmoid()

tensor = torch.tensor([[-1.0, 0.0, 1.0]])
output = sigmoid(tensor)
Source code in src/refiners/fluxion/layers/activations.py
def __init__(self) -> None:
    super().__init__()

Sin

Sin(*args: Any, **kwargs: Any)

Bases: Module

Sine operator layer.

This layer applies the sine function to the input tensor. See also torch.sin.

Example
sin = fl.Sin()

tensor = torch.tensor([0, torch.pi])
output = sin(tensor)

expected_output = torch.tensor([0.0, 0.0])
assert torch.allclose(output, expected_output, atol=1e-6)
Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

Slicing

Slicing(
    dim: int = 0,
    start: int = 0,
    end: int | None = None,
    step: int = 1,
)

Bases: Module

Slicing operation layer.

This layer slices the input tensor at the given dimension between the given start and end indices. See also torch.index_select.

Example
slicing = fl.Slicing(dim=1, start=50)

tensor = torch.randn(10, 100)
output = slicing(tensor)

assert output.shape == (10, 50)
assert torch.allclose(output, tensor[:, 50:])
Source code in src/refiners/fluxion/layers/basics.py
def __init__(
    self,
    dim: int = 0,
    start: int = 0,
    end: int | None = None,
    step: int = 1,
) -> None:
    super().__init__()
    self.dim = dim
    self.start = start
    self.end = end
    self.step = step

Squeeze

Squeeze(dim: int)

Bases: Module

Squeeze operation layer.

This layer squeezes the input tensor at the given dimension. See also torch.squeeze.

Example
squeeze = fl.Squeeze(dim=1)

tensor = torch.randn(10, 1, 10)
output = squeeze(tensor)

assert output.shape == (10, 10)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:
    super().__init__()
    self.dim = dim

Sum

Sum(*args: Module | Iterable[Module])

Bases: Chain

Summation layer.

This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.

Example
summation = fl.Sum(
    fl.Multiply(scale=2, bias=1),
    fl.Multiply(scale=3, bias=0),
)

tensor = torch.ones(1)
output = summation(tensor)

assert torch.allclose(output, torch.tensor([6.0]))
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, *args: Module | Iterable[Module]) -> None:
    super().__init__()
    self._provider = ContextProvider()
    modules = cast(
        tuple[Module],
        (
            tuple(args[0])
            if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
            else tuple(args)
        ),
    )

    for module in modules:
        # Violating this would mean a ContextModule ends up in two chains,
        # with a single one correctly set as its parent.
        assert (
            (not isinstance(module, ContextModule))
            or (not module._can_refresh_parent)
            or (module.parent is None)
            or (module.parent == self)
        ), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"

    self._regenerate_keys(modules)
    self._reset_context()

    for module in self:
        if isinstance(module, ContextModule) and module.parent != self:
            module._set_parent(self)

Transpose

Transpose(dim0: int, dim1: int)

Bases: Module

Transpose operation layer.

This layer transposes the input tensor between the two given dimensions. See also torch.transpose.

Example
transpose = fl.Transpose(dim0=1, dim1=2)

tensor = torch.randn(10, 20, 30)
output = transpose(tensor)

assert output.shape == (10, 30, 20)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim0: int, dim1: int) -> None:
    super().__init__()
    self.dim0 = dim0
    self.dim1 = dim1

Unflatten

Unflatten(dim: int)

Bases: Module

Unflatten operation layer.

This layer unflattens the input tensor at the given dimension with the given sizes. See also torch.unflatten.

Example
unflatten = fl.Unflatten(dim=1)

tensor = torch.randn(10, 100)
output = unflatten(tensor, sizes=(10, 10))

assert output_unflatten.shape == (10, 10, 10)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:
    super().__init__()
    self.dim = dim

Unsqueeze

Unsqueeze(dim: int)

Bases: Module

Unsqueeze operation layer.

This layer unsqueezes the input tensor at the given dimension. See also torch.unsqueeze.

Example
unsqueeze = fl.Unsqueeze(dim=1)

tensor = torch.randn(10, 10)
output = unsqueeze(tensor)

assert output.shape == (10, 1, 10)
Source code in src/refiners/fluxion/layers/basics.py
def __init__(self, dim: int) -> None:
    super().__init__()
    self.dim = dim

Upsample

Upsample(
    channels: int,
    upsample_factor: int | None = None,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Chain

Upsample layer.

This layer upsamples the input by the given scale factor.

Raises:

Type Description
RuntimeError

If the context sampling is not set or if the context is empty.

Parameters:

Name Type Description Default
channels int

The number of input and output channels.

required
upsample_factor int | None

The factor by which to upsample the input. If None, the input shape is taken from the context.

None
device device | str | None

The device to use for the convolutional layer.

None
dtype dtype | None

The dtype to use for the convolutional layer.

None
Source code in src/refiners/fluxion/layers/sampling.py
def __init__(
    self,
    channels: int,
    upsample_factor: int | None = None,
    device: Device | str | None = None,
    dtype: DType | None = None,
):
    """Initializes the Upsample layer.

    Args:
        channels: The number of input and output channels.
        upsample_factor: The factor by which to upsample the input.
            If None, the input shape is taken from the context.
        device: The device to use for the convolutional layer.
        dtype: The dtype to use for the convolutional layer.
    """
    self.channels = channels
    self.upsample_factor = upsample_factor
    super().__init__(
        Parallel(
            Identity(),
            (
                Lambda(self._get_static_shape)
                if upsample_factor is not None
                else UseContext(context="sampling", key="shapes").compose(lambda x: x.pop())
            ),
        ),
        Interpolate(),
        Conv2d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=3,
            padding=1,
            device=device,
            dtype=dtype,
        ),
    )

UseContext

UseContext(context: str, key: str)

Bases: ContextModule

UseContext layer.

This layer reads from the ContextProvider of its parent Chain.

When called, it will
  • Retrieve a value from the context using the given key
  • Transform the value with the given function (optional)
  • Return the value
Source code in src/refiners/fluxion/layers/chain.py
def __init__(self, context: str, key: str) -> None:
    super().__init__()
    self.context = context
    self.key = key
    self.func: Callable[[Any], Any] = lambda x: x

WeightedModule

WeightedModule(*args: Any, **kwargs: Any)

Bases: Module

A module with a weight (Tensor) attribute.

Source code in src/refiners/fluxion/layers/module.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, *kwargs)  # type: ignore[reportUnknownMemberType]

device property

device: device

Return the device of the module's weight.

dtype property

dtype: dtype

Return the dtype of the module's weight.