add converter layer + tests

This commit is contained in:
limiteinductive 2023-08-21 11:30:42 +02:00 committed by Benjamin Trom
parent 4526d58cd5
commit 108fa8f26a
3 changed files with 108 additions and 0 deletions

View file

@ -34,6 +34,7 @@ from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule
from refiners.fluxion.layers.padding import ReflectionPad2d from refiners.fluxion.layers.padding import ReflectionPad2d
from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate
from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.embedding import Embedding
from refiners.fluxion.layers.converter import Converter
__all__ = [ __all__ = [
"Embedding", "Embedding",
@ -84,4 +85,5 @@ __all__ = [
"ContextModule", "ContextModule",
"Interpolate", "Interpolate",
"ReflectionPad2d", "ReflectionPad2d",
"Converter",
] ]

View file

@ -0,0 +1,44 @@
from refiners.fluxion.layers.module import ContextModule
from torch import Tensor
class Converter(ContextModule):
"""
A Converter class that adjusts tensor properties based on a parent module's settings.
This class inherits from `ContextModule` and provides functionality to adjust
the device and dtype of input tensor(s) to match the parent module's attributes.
Attributes:
set_device (bool): If True, matches the device of the input tensor(s) to the parent's device.
set_dtype (bool): If True, matches the dtype of the input tensor(s) to the parent's dtype.
Note:
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
"""
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
super().__init__()
self.set_device = set_device
self.set_dtype = set_dtype
def forward(self, *inputs: Tensor) -> tuple[Tensor]:
parent = self.ensure_parent
converted_tensors: list[Tensor] = []
for x in inputs:
if self.set_device:
device = parent.device
assert device is not None, "parent has no device"
x = x.to(device=device)
if self.set_dtype:
dtype = parent.dtype
assert dtype is not None, "parent has no dtype"
x = x.to(dtype=dtype)
converted_tensors.append(x)
return tuple(converted_tensors)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(set_device={self.set_device}, set_dtype={self.set_dtype})"

View file

@ -0,0 +1,62 @@
import torch
import pytest
import refiners.fluxion.layers as fl
from refiners.fluxion.layers.chain import Distribute
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_converter_device_single_tensor() -> None:
chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=False),
fl.Linear(in_features=1, out_features=1, device="cuda:0"),
)
tensor = torch.randn(1, 1)
converted_tensor = chain(tensor)
assert converted_tensor.device == torch.device(device="cuda:0")
def test_converter_dtype_single_tensor() -> None:
chain = fl.Chain(
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=1, out_features=1, dtype=torch.float64),
)
tensor = torch.randn(1, 1).to(dtype=torch.float32)
converted_tensor = chain(tensor)
assert converted_tensor.dtype == torch.float64
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_converter_multiple_tensors() -> None:
chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=True),
Distribute(
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64),
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64),
),
)
tensor1 = torch.randn(1, 1)
tensor2 = torch.randn(1, 1)
converted_tensor1, converted_tensor2 = chain(tensor1, tensor2)
assert converted_tensor1.device == torch.device(device="cuda:0")
assert converted_tensor1.dtype == torch.float64
assert converted_tensor2.device == torch.device(device="cuda:0")
assert converted_tensor2.dtype == torch.float64
def test_converter_no_parent_device_or_dtype() -> None:
chain = fl.Chain(
fl.Lambda(func=(lambda x: x)),
fl.Converter(set_device=True, set_dtype=False),
)
tensor = torch.randn(1, 1)
with pytest.raises(expected_exception=ValueError):
chain(tensor)