mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix model comparison with custom layers
This commit is contained in:
parent
7651daa01f
commit
88efa117bf
|
@ -50,7 +50,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
|||
)
|
||||
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
|
||||
|
||||
broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||
broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.0.attentions.0.proj_in",
|
||||
|
@ -89,7 +89,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
|||
assert target_order[broken_k] == expected_target_order
|
||||
source_order[broken_k] = fixed_source_order
|
||||
|
||||
broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
||||
broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.1.attentions.0.proj_in",
|
||||
|
@ -125,7 +125,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
|||
assert target_order[broken_k] == expected_target_order
|
||||
source_order[broken_k] = fixed_source_order
|
||||
|
||||
broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
||||
broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
||||
|
||||
expected_source_order = [
|
||||
"down_blocks.2.attentions.0.proj_in",
|
||||
|
|
|
@ -28,7 +28,7 @@ TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
|||
]
|
||||
|
||||
|
||||
ModelTypeShape = tuple[str, tuple[torch.Size, ...]]
|
||||
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
|
||||
|
||||
|
||||
class ModuleArgsDict(TypedDict):
|
||||
|
@ -69,6 +69,7 @@ class ModelConverter:
|
|||
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
|
||||
threshold: float = 1e-5,
|
||||
skip_output_check: bool = False,
|
||||
skip_init_check: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -81,6 +82,8 @@ class ModelConverter:
|
|||
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
|
||||
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
|
||||
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
|
||||
layers.
|
||||
- `verbose`: Whether to print messages during the conversion process.
|
||||
|
||||
The conversion process consists of three stages:
|
||||
|
@ -107,8 +110,18 @@ class ModelConverter:
|
|||
self.custom_layer_mapping = custom_layer_mapping or {}
|
||||
self.threshold = threshold
|
||||
self.skip_output_check = skip_output_check
|
||||
self.skip_init_check = skip_init_check
|
||||
self.verbose = verbose
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
|
||||
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
||||
|
||||
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||
"""
|
||||
Run the conversion process.
|
||||
|
@ -137,38 +150,72 @@ class ModelConverter:
|
|||
|
||||
match self.stage:
|
||||
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
|
||||
self._increment_stage()
|
||||
return True
|
||||
|
||||
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage():
|
||||
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
|
||||
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
|
||||
source_args=source_args, target_args=target_args
|
||||
):
|
||||
self._increment_stage()
|
||||
return True
|
||||
|
||||
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
|
||||
source_args=source_args, target_args=target_args
|
||||
):
|
||||
self.stage = (
|
||||
ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||
if not self.skip_output_check
|
||||
else ConversionStage.MODELS_OUTPUT_AGREE
|
||||
)
|
||||
self._increment_stage()
|
||||
return self.run(source_args=source_args, target_args=target_args)
|
||||
|
||||
case ConversionStage.INIT if self._run_init_stage():
|
||||
self.stage = ConversionStage.BASIC_LAYERS_MATCH
|
||||
self._increment_stage()
|
||||
return self.run(source_args=source_args, target_args=target_args)
|
||||
|
||||
case _:
|
||||
self._log(message=f"Conversion failed at stage {self.stage.value}")
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
|
||||
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
|
||||
def _increment_stage(self) -> None:
|
||||
"""Increment the stage of the conversion process."""
|
||||
match self.stage:
|
||||
case ConversionStage.INIT:
|
||||
self.stage = ConversionStage.BASIC_LAYERS_MATCH
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
|
||||
" layers..."
|
||||
)
|
||||
)
|
||||
case ConversionStage.BASIC_LAYERS_MATCH:
|
||||
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
|
||||
" models..."
|
||||
)
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
||||
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
|
||||
if self.skip_output_check:
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
|
||||
" `skip_output_check` to `False`"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
|
||||
" converted model using `save_to_safetensors`"
|
||||
)
|
||||
)
|
||||
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||
self._log(
|
||||
message=(
|
||||
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
|
||||
" the converted model using `save_to_safetensors`"
|
||||
)
|
||||
)
|
||||
|
||||
def get_state_dict(self) -> dict[str, Tensor]:
|
||||
"""Get the converted state_dict."""
|
||||
|
@ -233,9 +280,16 @@ class ModelConverter:
|
|||
return None
|
||||
|
||||
mapping: dict[str, str] = {}
|
||||
for model_type_shape in source_order:
|
||||
source_keys = source_order[model_type_shape]
|
||||
target_keys = target_order[model_type_shape]
|
||||
for source_type_shape in source_order:
|
||||
source_keys = source_order[source_type_shape]
|
||||
target_type_shape = source_type_shape
|
||||
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
|
||||
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
|
||||
if source_custom_type == source_type_shape[0]:
|
||||
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||
break
|
||||
|
||||
target_keys = target_order[target_type_shape]
|
||||
mapping.update(zip(target_keys, source_keys))
|
||||
|
||||
return mapping
|
||||
|
@ -266,9 +320,9 @@ class ModelConverter:
|
|||
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||
)
|
||||
|
||||
prev_source_key, prev_target_key = None, None
|
||||
diff, prev_source_key, prev_target_key = None, None, None
|
||||
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
|
||||
diff = norm(source_output - target_output).item()
|
||||
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
|
||||
if diff > threshold:
|
||||
self._log(
|
||||
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
||||
|
@ -277,10 +331,21 @@ class ModelConverter:
|
|||
return False
|
||||
prev_source_key, prev_target_key = source_key, target_key
|
||||
|
||||
self._log(message=f"Models agree. Difference in norm: {diff}")
|
||||
|
||||
return True
|
||||
|
||||
def _run_init_stage(self) -> bool:
|
||||
"""Run the init stage of the conversion process."""
|
||||
if self.skip_init_check:
|
||||
self._log(
|
||||
message=(
|
||||
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
|
||||
" `False`"
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
is_count_correct = self._verify_basic_layers_count()
|
||||
is_not_missing_layers = self._verify_missing_basic_layers()
|
||||
|
||||
|
@ -288,16 +353,12 @@ class ModelConverter:
|
|||
|
||||
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||
"""Run the basic layers match stage of the conversion process."""
|
||||
self._log(message="Finding matching shapes and layers...")
|
||||
|
||||
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
|
||||
self._stored_mapping = mapping
|
||||
if mapping is None:
|
||||
self._log(message="Models do not have matching shapes.")
|
||||
return False
|
||||
|
||||
self._log(message="Found matching shapes and layers. Converting state_dict...")
|
||||
|
||||
source_state_dict = self.source_model.state_dict()
|
||||
target_state_dict = self.target_model.state_dict()
|
||||
converted_state_dict = self._convert_state_dict(
|
||||
|
@ -309,17 +370,22 @@ class ModelConverter:
|
|||
|
||||
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||
"""Run the shape and layers match stage of the conversion process."""
|
||||
if self.skip_output_check:
|
||||
self._log(
|
||||
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
|
||||
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
|
||||
return True
|
||||
else:
|
||||
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
|
||||
return False
|
||||
|
||||
def _run_models_output_agree_stage(self) -> bool:
|
||||
"""Run the models output agree stage of the conversion process."""
|
||||
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log(message=f"An error occurred while comparing the models: {e}")
|
||||
return False
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Print a message if `verbose` is `True`."""
|
||||
|
@ -354,16 +420,19 @@ class ModelConverter:
|
|||
|
||||
return positional_args, keyword_args
|
||||
|
||||
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
|
||||
"""Check if a module type is a subclass of a torch basic layer."""
|
||||
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
|
||||
|
||||
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
|
||||
"""Infer the type of a basic layer."""
|
||||
for layer_type in TORCH_BASIC_LAYERS:
|
||||
layer_types = (
|
||||
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
|
||||
)
|
||||
for layer_type in layer_types:
|
||||
if isinstance(module, layer_type):
|
||||
return layer_type
|
||||
|
||||
for source_type in self.custom_layer_mapping.keys():
|
||||
if isinstance(module, source_type):
|
||||
return source_type
|
||||
|
||||
return None
|
||||
|
||||
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
|
||||
|
@ -371,7 +440,7 @@ class ModelConverter:
|
|||
layer_type = self._infer_basic_layer_type(module=module)
|
||||
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||
param_shapes = [p.shape for p in module.parameters()]
|
||||
return (str(object=layer_type), tuple(param_shapes))
|
||||
return (layer_type, tuple(param_shapes))
|
||||
|
||||
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
|
||||
"""Count the number of basic layers in a module."""
|
||||
|
@ -388,10 +457,20 @@ class ModelConverter:
|
|||
source_layers = self._count_basic_layers(module=self.source_model)
|
||||
target_layers = self._count_basic_layers(module=self.target_model)
|
||||
|
||||
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
|
||||
|
||||
diff: dict[type[nn.Module], tuple[int, int]] = {}
|
||||
for layer_type in set(source_layers.keys()) | set(target_layers.keys()):
|
||||
source_count = source_layers.get(layer_type, 0)
|
||||
target_count = target_layers.get(layer_type, 0)
|
||||
for layer_type, source_count in source_layers.items():
|
||||
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
|
||||
target_count = target_layers.get(target_type, 0)
|
||||
|
||||
if source_count != target_count:
|
||||
diff[layer_type] = (source_count, target_count)
|
||||
|
||||
for layer_type, target_count in target_layers.items():
|
||||
source_type = reverse_mapping.get(layer_type, layer_type)
|
||||
source_count = source_layers.get(source_type, 0)
|
||||
|
||||
if source_count != target_count:
|
||||
diff[layer_type] = (source_count, target_count)
|
||||
|
||||
|
@ -399,7 +478,7 @@ class ModelConverter:
|
|||
message = "Models do not have the same number of basic layers:\n"
|
||||
for layer_type, counts in diff.items():
|
||||
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
|
||||
self._log(message=message.rstrip())
|
||||
self._log(message=message.strip())
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -424,8 +503,8 @@ class ModelConverter:
|
|||
if missing_source_layers or missing_target_layers:
|
||||
self._log(
|
||||
message=(
|
||||
"Models might have missing basic layers. You can either pass them into keys to skip or set"
|
||||
f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}"
|
||||
"Models might have missing basic layers. If you want to skip this check, set"
|
||||
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
@ -478,17 +557,36 @@ class ModelConverter:
|
|||
) -> bool:
|
||||
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
|
||||
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||
shape_missmatched = False
|
||||
|
||||
for model_type_shape in model_type_shapes:
|
||||
default_type_shapes = [
|
||||
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
|
||||
]
|
||||
|
||||
shape_mismatched = False
|
||||
|
||||
for model_type_shape in default_type_shapes:
|
||||
source_keys = source_order.get(model_type_shape, [])
|
||||
target_keys = target_order.get(model_type_shape, [])
|
||||
|
||||
if len(source_keys) != len(target_keys):
|
||||
shape_missmatched = True
|
||||
shape_mismatched = True
|
||||
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||
|
||||
return not shape_missmatched
|
||||
for source_custom_type in self.custom_layer_mapping.keys():
|
||||
# iterate over all type_shapes that have the same type as source_custom_type
|
||||
for source_type_shape in [
|
||||
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
|
||||
]:
|
||||
source_keys = source_order.get(source_type_shape, [])
|
||||
target_custom_type = self.custom_layer_mapping[source_custom_type]
|
||||
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||
target_keys = target_order.get(target_type_shape, [])
|
||||
|
||||
if len(source_keys) != len(target_keys):
|
||||
shape_mismatched = True
|
||||
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||
|
||||
return not shape_mismatched
|
||||
|
||||
@staticmethod
|
||||
def _convert_state_dict(
|
||||
|
|
133
tests/fluxion/test_model_converter.py
Normal file
133
tests/fluxion/test_model_converter.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
# pyright: reportPrivateUsage=false
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
from refiners.fluxion.model_converter import ModelConverter, ConversionStage
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class CustomBasicLayer1(fl.Module):
|
||||
def __init__(self, in_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x @ self.weight.t()
|
||||
|
||||
|
||||
class CustomBasicLayer2(fl.Module):
|
||||
def __init__(self, in_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x @ self.weight.t()
|
||||
|
||||
|
||||
# Source Model
|
||||
class SourceModel(fl.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = fl.Linear(in_features=10, out_features=2)
|
||||
self.activation = fl.ReLU()
|
||||
self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)])
|
||||
self.flatten = fl.Flatten()
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.activation(x)
|
||||
for layer in self.custom_layers:
|
||||
x = layer(x)
|
||||
x = self.flatten(x)
|
||||
x = self.dropout(x)
|
||||
x = x.view(1, 1, -1)
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
# Target Model (Purposely obfuscated but functionally equivalent)
|
||||
class TargetModel(fl.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = fl.ReLU()
|
||||
self.drop = nn.Dropout(0.5)
|
||||
self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)])
|
||||
self.flattenIt = fl.Flatten()
|
||||
self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||||
self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||||
self.lin = fl.Linear(in_features=10, out_features=2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.lin(x)
|
||||
x = self.relu(x)
|
||||
for layer in self.layers1:
|
||||
x = layer(x)
|
||||
x = self.flattenIt(x)
|
||||
x = self.drop(x)
|
||||
x = x.view(1, 1, -1)
|
||||
x = self.convolution(x)
|
||||
x = self.max_pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def source_model() -> SourceModel:
|
||||
manual_seed(seed=2)
|
||||
return SourceModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_model() -> TargetModel:
|
||||
manual_seed(seed=2)
|
||||
return TargetModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter:
|
||||
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2}
|
||||
return ModelConverter(
|
||||
source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_tensor() -> Tensor:
|
||||
return torch.randn(1, 10)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def source_args(random_tensor: Tensor) -> tuple[Tensor]:
|
||||
return (random_tensor,)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_args(random_tensor: Tensor) -> tuple[Tensor]:
|
||||
return (random_tensor,)
|
||||
|
||||
|
||||
def test_converter_stages(
|
||||
model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]
|
||||
) -> None:
|
||||
assert model_converter.stage == ConversionStage.INIT
|
||||
assert model_converter._run_init_stage()
|
||||
model_converter._increment_stage()
|
||||
|
||||
assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH
|
||||
assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args)
|
||||
model_converter._increment_stage()
|
||||
|
||||
assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||
assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args)
|
||||
model_converter._increment_stage()
|
||||
|
||||
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
||||
|
||||
|
||||
def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None:
|
||||
assert model_converter.run(source_args=source_args, target_args=target_args)
|
||||
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from refiners.fluxion.utils import manual_seed
|
||||
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -31,18 +31,8 @@ def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sdxl_unet_weights_std(test_weights_path: Path) -> Path:
|
||||
unet_weights_std = test_weights_path / "sdxl-unet.safetensors"
|
||||
if not unet_weights_std.is_file():
|
||||
warn(message=f"could not find weights at {unet_weights_std}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return unet_weights_std
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
|
||||
def refiners_sdxl_unet() -> SDXLUNet:
|
||||
unet = SDXLUNet(in_channels=4)
|
||||
unet.load_from_safetensors(tensors_path=sdxl_unet_weights_std)
|
||||
return unet
|
||||
|
||||
|
||||
|
@ -57,19 +47,27 @@ def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> No
|
|||
clip_text_embeddings = torch.randn(1, 77, 2048)
|
||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||
|
||||
target.set_timestep(timestep=timestep)
|
||||
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||
target_args = (x,)
|
||||
source_args = {
|
||||
"positional": (x, timestep, clip_text_embeddings),
|
||||
"keyword": {"added_cond_kwargs": added_cond_kwargs},
|
||||
}
|
||||
|
||||
converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2)
|
||||
old_forward = target.forward
|
||||
|
||||
def forward_with_context(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
target.set_timestep(timestep=timestep)
|
||||
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||
return old_forward(self, *args, **kwargs)
|
||||
|
||||
target.forward = forward_with_context
|
||||
|
||||
converter = ModelConverter(source_model=source, target_model=target, verbose=True, threshold=1e-2)
|
||||
|
||||
assert converter.run(
|
||||
source_args=source_args,
|
||||
target_args=target_args,
|
||||
)
|
||||
assert converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
|
||||
|
|
Loading…
Reference in a new issue