refiners/tests/fluxion/test_model_converter.py

135 lines
4.5 KiB
Python
Raw Permalink Normal View History

# pyright: reportPrivateUsage=false
import pytest
import torch
from torch import Tensor, nn
import refiners.fluxion.layers as fl
2024-10-09 09:28:34 +00:00
from refiners.conversion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed
class CustomBasicLayer1(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
class CustomBasicLayer2(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
# Source Model
class SourceModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = fl.Linear(in_features=10, out_features=2)
self.activation = fl.ReLU()
self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)])
self.flatten = fl.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
def forward(self, x: Tensor) -> Tensor:
x = self.linear1(x)
x = self.activation(x)
for layer in self.custom_layers:
x = layer(x)
x = self.flatten(x)
x = self.dropout(x)
x = x.view(1, 1, -1)
x = self.conv(x)
x = self.pool(x)
return x
# Target Model (Purposely obfuscated but functionally equivalent)
class TargetModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.relu = fl.ReLU()
self.drop = nn.Dropout(0.5)
self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)])
self.flattenIt = fl.Flatten()
self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.lin = fl.Linear(in_features=10, out_features=2)
def forward(self, x: Tensor) -> Tensor:
x = self.lin(x)
x = self.relu(x)
for layer in self.layers1:
x = layer(x)
x = self.flattenIt(x)
x = self.drop(x)
x = x.view(1, 1, -1)
x = self.convolution(x)
x = self.max_pool(x)
return x
@pytest.fixture
def source_model() -> SourceModel:
manual_seed(seed=2)
return SourceModel()
@pytest.fixture
def target_model() -> TargetModel:
manual_seed(seed=2)
return TargetModel()
@pytest.fixture
def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter:
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2}
return ModelConverter(
source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True
)
@pytest.fixture
def random_tensor() -> Tensor:
return torch.randn(1, 10)
@pytest.fixture
def source_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
@pytest.fixture
def target_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
def test_converter_stages(
model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]
) -> None:
assert model_converter.stage == ConversionStage.INIT
assert model_converter._run_init_stage()
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH
assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH
assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None:
assert model_converter.run(source_args=source_args, target_args=target_args)
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE