mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
|
# pyright: reportPrivateUsage=false
|
||
|
import pytest
|
||
|
import torch
|
||
|
from torch import nn, Tensor
|
||
|
from refiners.fluxion.utils import manual_seed
|
||
|
from refiners.fluxion.model_converter import ModelConverter, ConversionStage
|
||
|
import refiners.fluxion.layers as fl
|
||
|
|
||
|
|
||
|
class CustomBasicLayer1(fl.Module):
|
||
|
def __init__(self, in_features: int, out_features: int) -> None:
|
||
|
super().__init__()
|
||
|
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
return x @ self.weight.t()
|
||
|
|
||
|
|
||
|
class CustomBasicLayer2(fl.Module):
|
||
|
def __init__(self, in_features: int, out_features: int) -> None:
|
||
|
super().__init__()
|
||
|
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
return x @ self.weight.t()
|
||
|
|
||
|
|
||
|
# Source Model
|
||
|
class SourceModel(fl.Module):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
self.linear1 = fl.Linear(in_features=10, out_features=2)
|
||
|
self.activation = fl.ReLU()
|
||
|
self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)])
|
||
|
self.flatten = fl.Flatten()
|
||
|
self.dropout = nn.Dropout(p=0.5)
|
||
|
self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||
|
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
x = self.linear1(x)
|
||
|
x = self.activation(x)
|
||
|
for layer in self.custom_layers:
|
||
|
x = layer(x)
|
||
|
x = self.flatten(x)
|
||
|
x = self.dropout(x)
|
||
|
x = x.view(1, 1, -1)
|
||
|
x = self.conv(x)
|
||
|
x = self.pool(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
# Target Model (Purposely obfuscated but functionally equivalent)
|
||
|
class TargetModel(fl.Module):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
self.relu = fl.ReLU()
|
||
|
self.drop = nn.Dropout(0.5)
|
||
|
self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)])
|
||
|
self.flattenIt = fl.Flatten()
|
||
|
self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||
|
self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||
|
self.lin = fl.Linear(in_features=10, out_features=2)
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
x = self.lin(x)
|
||
|
x = self.relu(x)
|
||
|
for layer in self.layers1:
|
||
|
x = layer(x)
|
||
|
x = self.flattenIt(x)
|
||
|
x = self.drop(x)
|
||
|
x = x.view(1, 1, -1)
|
||
|
x = self.convolution(x)
|
||
|
x = self.max_pool(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def source_model() -> SourceModel:
|
||
|
manual_seed(seed=2)
|
||
|
return SourceModel()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def target_model() -> TargetModel:
|
||
|
manual_seed(seed=2)
|
||
|
return TargetModel()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter:
|
||
|
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2}
|
||
|
return ModelConverter(
|
||
|
source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def random_tensor() -> Tensor:
|
||
|
return torch.randn(1, 10)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def source_args(random_tensor: Tensor) -> tuple[Tensor]:
|
||
|
return (random_tensor,)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def target_args(random_tensor: Tensor) -> tuple[Tensor]:
|
||
|
return (random_tensor,)
|
||
|
|
||
|
|
||
|
def test_converter_stages(
|
||
|
model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]
|
||
|
) -> None:
|
||
|
assert model_converter.stage == ConversionStage.INIT
|
||
|
assert model_converter._run_init_stage()
|
||
|
model_converter._increment_stage()
|
||
|
|
||
|
assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH
|
||
|
assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args)
|
||
|
model_converter._increment_stage()
|
||
|
|
||
|
assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||
|
assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args)
|
||
|
model_converter._increment_stage()
|
||
|
|
||
|
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
||
|
|
||
|
|
||
|
def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None:
|
||
|
assert model_converter.run(source_args=source_args, target_args=target_args)
|
||
|
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|