fix model comparison with custom layers

This commit is contained in:
limiteinductive 2023-08-29 14:38:58 +02:00 committed by Benjamin Trom
parent 7651daa01f
commit 88efa117bf
4 changed files with 300 additions and 71 deletions

View file

@ -50,7 +50,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
)
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320])))
broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320])))
expected_source_order = [
"down_blocks.0.attentions.0.proj_in",
@ -89,7 +89,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640])))
broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640])))
expected_source_order = [
"down_blocks.1.attentions.0.proj_in",
@ -125,7 +125,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order
broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
expected_source_order = [
"down_blocks.2.attentions.0.proj_in",

View file

@ -28,7 +28,7 @@ TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
]
ModelTypeShape = tuple[str, tuple[torch.Size, ...]]
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
class ModuleArgsDict(TypedDict):
@ -69,6 +69,7 @@ class ModelConverter:
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
threshold: float = 1e-5,
skip_output_check: bool = False,
skip_init_check: bool = False,
verbose: bool = True,
) -> None:
"""
@ -81,6 +82,8 @@ class ModelConverter:
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
- `threshold`: The threshold for comparing outputs between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
layers.
- `verbose`: Whether to print messages during the conversion process.
The conversion process consists of three stages:
@ -107,8 +110,18 @@ class ModelConverter:
self.custom_layer_mapping = custom_layer_mapping or {}
self.threshold = threshold
self.skip_output_check = skip_output_check
self.skip_init_check = skip_init_check
self.verbose = verbose
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def __bool__(self) -> bool:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
"""
Run the conversion process.
@ -137,38 +150,72 @@ class ModelConverter:
match self.stage:
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
self._increment_stage()
return True
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage():
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
source_args=source_args, target_args=target_args
):
self._increment_stage()
return True
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
source_args=source_args, target_args=target_args
):
self.stage = (
ConversionStage.SHAPE_AND_LAYERS_MATCH
if not self.skip_output_check
else ConversionStage.MODELS_OUTPUT_AGREE
)
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case ConversionStage.INIT if self._run_init_stage():
self.stage = ConversionStage.BASIC_LAYERS_MATCH
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case _:
self._log(message=f"Conversion failed at stage {self.stage.value}")
return False
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def _increment_stage(self) -> None:
"""Increment the stage of the conversion process."""
match self.stage:
case ConversionStage.INIT:
self.stage = ConversionStage.BASIC_LAYERS_MATCH
self._log(
message=(
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
" layers..."
)
)
case ConversionStage.BASIC_LAYERS_MATCH:
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
self._log(
message=(
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
" models..."
)
)
def __bool__(self) -> bool:
return self.stage == ConversionStage.MODELS_OUTPUT_AGREE
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
if self.skip_output_check:
self._log(
message=(
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
" `skip_output_check` to `False`"
)
)
else:
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
self._log(
message=(
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
" converted model using `save_to_safetensors`"
)
)
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(
message=(
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
" the converted model using `save_to_safetensors`"
)
)
def get_state_dict(self) -> dict[str, Tensor]:
"""Get the converted state_dict."""
@ -233,9 +280,16 @@ class ModelConverter:
return None
mapping: dict[str, str] = {}
for model_type_shape in source_order:
source_keys = source_order[model_type_shape]
target_keys = target_order[model_type_shape]
for source_type_shape in source_order:
source_keys = source_order[source_type_shape]
target_type_shape = source_type_shape
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
if source_custom_type == source_type_shape[0]:
target_type_shape = (target_custom_type, source_type_shape[1])
break
target_keys = target_order[target_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
@ -266,9 +320,9 @@ class ModelConverter:
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
prev_source_key, prev_target_key = None, None
diff, prev_source_key, prev_target_key = None, None, None
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
diff = norm(source_output - target_output).item()
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
if diff > threshold:
self._log(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
@ -277,10 +331,21 @@ class ModelConverter:
return False
prev_source_key, prev_target_key = source_key, target_key
self._log(message=f"Models agree. Difference in norm: {diff}")
return True
def _run_init_stage(self) -> bool:
"""Run the init stage of the conversion process."""
if self.skip_init_check:
self._log(
message=(
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
" `False`"
)
)
return True
is_count_correct = self._verify_basic_layers_count()
is_not_missing_layers = self._verify_missing_basic_layers()
@ -288,16 +353,12 @@ class ModelConverter:
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the basic layers match stage of the conversion process."""
self._log(message="Finding matching shapes and layers...")
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
self._stored_mapping = mapping
if mapping is None:
self._log(message="Models do not have matching shapes.")
return False
self._log(message="Found matching shapes and layers. Converting state_dict...")
source_state_dict = self.source_model.state_dict()
target_state_dict = self.target_model.state_dict()
converted_state_dict = self._convert_state_dict(
@ -309,17 +370,22 @@ class ModelConverter:
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the shape and layers match stage of the conversion process."""
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
if self.skip_output_check:
self._log(
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
)
return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
def _run_models_output_agree_stage(self) -> bool:
"""Run the models output agree stage of the conversion process."""
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`")
return True
try:
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
except Exception as e:
self._log(message=f"An error occurred while comparing the models: {e}")
return False
def _log(self, message: str) -> None:
"""Print a message if `verbose` is `True`."""
@ -354,16 +420,19 @@ class ModelConverter:
return positional_args, keyword_args
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
"""Check if a module type is a subclass of a torch basic layer."""
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
"""Infer the type of a basic layer."""
for layer_type in TORCH_BASIC_LAYERS:
layer_types = (
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
)
for layer_type in layer_types:
if isinstance(module, layer_type):
return layer_type
for source_type in self.custom_layer_mapping.keys():
if isinstance(module, source_type):
return source_type
return None
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
@ -371,7 +440,7 @@ class ModelConverter:
layer_type = self._infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (str(object=layer_type), tuple(param_shapes))
return (layer_type, tuple(param_shapes))
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
"""Count the number of basic layers in a module."""
@ -388,10 +457,20 @@ class ModelConverter:
source_layers = self._count_basic_layers(module=self.source_model)
target_layers = self._count_basic_layers(module=self.target_model)
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
diff: dict[type[nn.Module], tuple[int, int]] = {}
for layer_type in set(source_layers.keys()) | set(target_layers.keys()):
source_count = source_layers.get(layer_type, 0)
target_count = target_layers.get(layer_type, 0)
for layer_type, source_count in source_layers.items():
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
target_count = target_layers.get(target_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
for layer_type, target_count in target_layers.items():
source_type = reverse_mapping.get(layer_type, layer_type)
source_count = source_layers.get(source_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
@ -399,7 +478,7 @@ class ModelConverter:
message = "Models do not have the same number of basic layers:\n"
for layer_type, counts in diff.items():
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
self._log(message=message.rstrip())
self._log(message=message.strip())
return False
return True
@ -424,8 +503,8 @@ class ModelConverter:
if missing_source_layers or missing_target_layers:
self._log(
message=(
"Models might have missing basic layers. You can either pass them into keys to skip or set"
f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}"
"Models might have missing basic layers. If you want to skip this check, set"
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
)
)
return False
@ -478,17 +557,36 @@ class ModelConverter:
) -> bool:
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
for model_type_shape in model_type_shapes:
default_type_shapes = [
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
]
shape_mismatched = False
for model_type_shape in default_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_missmatched = True
shape_mismatched = True
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched
for source_custom_type in self.custom_layer_mapping.keys():
# iterate over all type_shapes that have the same type as source_custom_type
for source_type_shape in [
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
]:
source_keys = source_order.get(source_type_shape, [])
target_custom_type = self.custom_layer_mapping[source_custom_type]
target_type_shape = (target_custom_type, source_type_shape[1])
target_keys = target_order.get(target_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_mismatched
@staticmethod
def _convert_state_dict(

View file

@ -0,0 +1,133 @@
# pyright: reportPrivateUsage=false
import pytest
import torch
from torch import nn, Tensor
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.model_converter import ModelConverter, ConversionStage
import refiners.fluxion.layers as fl
class CustomBasicLayer1(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
class CustomBasicLayer2(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
# Source Model
class SourceModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = fl.Linear(in_features=10, out_features=2)
self.activation = fl.ReLU()
self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)])
self.flatten = fl.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
def forward(self, x: Tensor) -> Tensor:
x = self.linear1(x)
x = self.activation(x)
for layer in self.custom_layers:
x = layer(x)
x = self.flatten(x)
x = self.dropout(x)
x = x.view(1, 1, -1)
x = self.conv(x)
x = self.pool(x)
return x
# Target Model (Purposely obfuscated but functionally equivalent)
class TargetModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.relu = fl.ReLU()
self.drop = nn.Dropout(0.5)
self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)])
self.flattenIt = fl.Flatten()
self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.lin = fl.Linear(in_features=10, out_features=2)
def forward(self, x: Tensor) -> Tensor:
x = self.lin(x)
x = self.relu(x)
for layer in self.layers1:
x = layer(x)
x = self.flattenIt(x)
x = self.drop(x)
x = x.view(1, 1, -1)
x = self.convolution(x)
x = self.max_pool(x)
return x
@pytest.fixture
def source_model() -> SourceModel:
manual_seed(seed=2)
return SourceModel()
@pytest.fixture
def target_model() -> TargetModel:
manual_seed(seed=2)
return TargetModel()
@pytest.fixture
def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter:
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2}
return ModelConverter(
source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True
)
@pytest.fixture
def random_tensor() -> Tensor:
return torch.randn(1, 10)
@pytest.fixture
def source_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
@pytest.fixture
def target_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
def test_converter_stages(
model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]
) -> None:
assert model_converter.stage == ConversionStage.INIT
assert model_converter._run_init_stage()
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH
assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH
assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None:
assert model_converter.run(source_args=source_args, target_args=target_args)
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE

View file

@ -6,7 +6,7 @@ import torch
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
@pytest.fixture(scope="module")
@ -31,18 +31,8 @@ def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
@pytest.fixture(scope="module")
def sdxl_unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "sdxl-unet.safetensors"
if not unet_weights_std.is_file():
warn(message=f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
def refiners_sdxl_unet() -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
unet.load_from_safetensors(tensors_path=sdxl_unet_weights_std)
return unet
@ -57,19 +47,27 @@ def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> No
clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
source_args = {
"positional": (x, timestep, clip_text_embeddings),
"keyword": {"added_cond_kwargs": added_cond_kwargs},
}
converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2)
old_forward = target.forward
def forward_with_context(self: Any, *args: Any, **kwargs: Any) -> Any:
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
return old_forward(self, *args, **kwargs)
target.forward = forward_with_context
converter = ModelConverter(source_model=source, target_model=target, verbose=True, threshold=1e-2)
assert converter.run(
source_args=source_args,
target_args=target_args,
)
assert converter.stage == ConversionStage.MODELS_OUTPUT_AGREE