fix model comparison with custom layers

This commit is contained in:
limiteinductive 2023-08-29 14:38:58 +02:00 committed by Benjamin Trom
parent 7651daa01f
commit 88efa117bf
4 changed files with 300 additions and 71 deletions

View file

@ -50,7 +50,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
) )
target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[]) target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[])
broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320]))) broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320])))
expected_source_order = [ expected_source_order = [
"down_blocks.0.attentions.0.proj_in", "down_blocks.0.attentions.0.proj_in",
@ -89,7 +89,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order source_order[broken_k] = fixed_source_order
broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640]))) broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640])))
expected_source_order = [ expected_source_order = [
"down_blocks.1.attentions.0.proj_in", "down_blocks.1.attentions.0.proj_in",
@ -125,7 +125,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order source_order[broken_k] = fixed_source_order
broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
expected_source_order = [ expected_source_order = [
"down_blocks.2.attentions.0.proj_in", "down_blocks.2.attentions.0.proj_in",

View file

@ -28,7 +28,7 @@ TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
] ]
ModelTypeShape = tuple[str, tuple[torch.Size, ...]] ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
class ModuleArgsDict(TypedDict): class ModuleArgsDict(TypedDict):
@ -69,6 +69,7 @@ class ModelConverter:
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None, custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
threshold: float = 1e-5, threshold: float = 1e-5,
skip_output_check: bool = False, skip_output_check: bool = False,
skip_init_check: bool = False,
verbose: bool = True, verbose: bool = True,
) -> None: ) -> None:
""" """
@ -81,6 +82,8 @@ class ModelConverter:
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models. - `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
- `threshold`: The threshold for comparing outputs between the source and target models. - `threshold`: The threshold for comparing outputs between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models. - `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
layers.
- `verbose`: Whether to print messages during the conversion process. - `verbose`: Whether to print messages during the conversion process.
The conversion process consists of three stages: The conversion process consists of three stages:
@ -107,8 +110,18 @@ class ModelConverter:
self.custom_layer_mapping = custom_layer_mapping or {} self.custom_layer_mapping = custom_layer_mapping or {}
self.threshold = threshold self.threshold = threshold
self.skip_output_check = skip_output_check self.skip_output_check = skip_output_check
self.skip_init_check = skip_init_check
self.verbose = verbose self.verbose = verbose
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def __bool__(self) -> bool:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool: def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
""" """
Run the conversion process. Run the conversion process.
@ -137,38 +150,72 @@ class ModelConverter:
match self.stage: match self.stage:
case ConversionStage.MODELS_OUTPUT_AGREE: case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") self._increment_stage()
return True return True
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage(): case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
self.stage = ConversionStage.MODELS_OUTPUT_AGREE source_args=source_args, target_args=target_args
):
self._increment_stage()
return True return True
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage( case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
source_args=source_args, target_args=target_args source_args=source_args, target_args=target_args
): ):
self.stage = ( self._increment_stage()
ConversionStage.SHAPE_AND_LAYERS_MATCH
if not self.skip_output_check
else ConversionStage.MODELS_OUTPUT_AGREE
)
return self.run(source_args=source_args, target_args=target_args) return self.run(source_args=source_args, target_args=target_args)
case ConversionStage.INIT if self._run_init_stage(): case ConversionStage.INIT if self._run_init_stage():
self.stage = ConversionStage.BASIC_LAYERS_MATCH self._increment_stage()
return self.run(source_args=source_args, target_args=target_args) return self.run(source_args=source_args, target_args=target_args)
case _: case _:
self._log(message=f"Conversion failed at stage {self.stage.value}")
return False return False
def __repr__(self) -> str: def _increment_stage(self) -> None:
return ( """Increment the stage of the conversion process."""
f"ModelConverter(source_model={self.source_model.__class__.__name__}," match self.stage:
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})" case ConversionStage.INIT:
) self.stage = ConversionStage.BASIC_LAYERS_MATCH
self._log(
message=(
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
" layers..."
)
)
case ConversionStage.BASIC_LAYERS_MATCH:
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
self._log(
message=(
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
" models..."
)
)
def __bool__(self) -> bool: case ConversionStage.SHAPE_AND_LAYERS_MATCH:
return self.stage == ConversionStage.MODELS_OUTPUT_AGREE if self.skip_output_check:
self._log(
message=(
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
" `skip_output_check` to `False`"
)
)
else:
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
self._log(
message=(
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
" converted model using `save_to_safetensors`"
)
)
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(
message=(
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
" the converted model using `save_to_safetensors`"
)
)
def get_state_dict(self) -> dict[str, Tensor]: def get_state_dict(self) -> dict[str, Tensor]:
"""Get the converted state_dict.""" """Get the converted state_dict."""
@ -233,9 +280,16 @@ class ModelConverter:
return None return None
mapping: dict[str, str] = {} mapping: dict[str, str] = {}
for model_type_shape in source_order: for source_type_shape in source_order:
source_keys = source_order[model_type_shape] source_keys = source_order[source_type_shape]
target_keys = target_order[model_type_shape] target_type_shape = source_type_shape
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
if source_custom_type == source_type_shape[0]:
target_type_shape = (target_custom_type, source_type_shape[1])
break
target_keys = target_order[target_type_shape]
mapping.update(zip(target_keys, source_keys)) mapping.update(zip(target_keys, source_keys))
return mapping return mapping
@ -266,9 +320,9 @@ class ModelConverter:
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
) )
prev_source_key, prev_target_key = None, None diff, prev_source_key, prev_target_key = None, None, None
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs): for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
diff = norm(source_output - target_output).item() diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
if diff > threshold: if diff > threshold:
self._log( self._log(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and" f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
@ -277,10 +331,21 @@ class ModelConverter:
return False return False
prev_source_key, prev_target_key = source_key, target_key prev_source_key, prev_target_key = source_key, target_key
self._log(message=f"Models agree. Difference in norm: {diff}")
return True return True
def _run_init_stage(self) -> bool: def _run_init_stage(self) -> bool:
"""Run the init stage of the conversion process.""" """Run the init stage of the conversion process."""
if self.skip_init_check:
self._log(
message=(
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
" `False`"
)
)
return True
is_count_correct = self._verify_basic_layers_count() is_count_correct = self._verify_basic_layers_count()
is_not_missing_layers = self._verify_missing_basic_layers() is_not_missing_layers = self._verify_missing_basic_layers()
@ -288,16 +353,12 @@ class ModelConverter:
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the basic layers match stage of the conversion process.""" """Run the basic layers match stage of the conversion process."""
self._log(message="Finding matching shapes and layers...")
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args) mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
self._stored_mapping = mapping self._stored_mapping = mapping
if mapping is None: if mapping is None:
self._log(message="Models do not have matching shapes.") self._log(message="Models do not have matching shapes.")
return False return False
self._log(message="Found matching shapes and layers. Converting state_dict...")
source_state_dict = self.source_model.state_dict() source_state_dict = self.source_model.state_dict()
target_state_dict = self.target_model.state_dict() target_state_dict = self.target_model.state_dict()
converted_state_dict = self._convert_state_dict( converted_state_dict = self._convert_state_dict(
@ -309,17 +370,22 @@ class ModelConverter:
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the shape and layers match stage of the conversion process.""" """Run the shape and layers match stage of the conversion process."""
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold): if self.skip_output_check:
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`") self._log(
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
)
return True return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
def _run_models_output_agree_stage(self) -> bool: try:
"""Run the models output agree stage of the conversion process.""" if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
return True return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
except Exception as e:
self._log(message=f"An error occurred while comparing the models: {e}")
return False
def _log(self, message: str) -> None: def _log(self, message: str) -> None:
"""Print a message if `verbose` is `True`.""" """Print a message if `verbose` is `True`."""
@ -354,16 +420,19 @@ class ModelConverter:
return positional_args, keyword_args return positional_args, keyword_args
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
"""Check if a module type is a subclass of a torch basic layer."""
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None: def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
"""Infer the type of a basic layer.""" """Infer the type of a basic layer."""
for layer_type in TORCH_BASIC_LAYERS: layer_types = (
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
)
for layer_type in layer_types:
if isinstance(module, layer_type): if isinstance(module, layer_type):
return layer_type return layer_type
for source_type in self.custom_layer_mapping.keys():
if isinstance(module, source_type):
return source_type
return None return None
def get_module_signature(self, module: nn.Module) -> ModelTypeShape: def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
@ -371,7 +440,7 @@ class ModelConverter:
layer_type = self._infer_basic_layer_type(module=module) layer_type = self._infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer" assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()] param_shapes = [p.shape for p in module.parameters()]
return (str(object=layer_type), tuple(param_shapes)) return (layer_type, tuple(param_shapes))
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]: def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
"""Count the number of basic layers in a module.""" """Count the number of basic layers in a module."""
@ -388,10 +457,20 @@ class ModelConverter:
source_layers = self._count_basic_layers(module=self.source_model) source_layers = self._count_basic_layers(module=self.source_model)
target_layers = self._count_basic_layers(module=self.target_model) target_layers = self._count_basic_layers(module=self.target_model)
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
diff: dict[type[nn.Module], tuple[int, int]] = {} diff: dict[type[nn.Module], tuple[int, int]] = {}
for layer_type in set(source_layers.keys()) | set(target_layers.keys()): for layer_type, source_count in source_layers.items():
source_count = source_layers.get(layer_type, 0) target_type = self.custom_layer_mapping.get(layer_type, layer_type)
target_count = target_layers.get(layer_type, 0) target_count = target_layers.get(target_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
for layer_type, target_count in target_layers.items():
source_type = reverse_mapping.get(layer_type, layer_type)
source_count = source_layers.get(source_type, 0)
if source_count != target_count: if source_count != target_count:
diff[layer_type] = (source_count, target_count) diff[layer_type] = (source_count, target_count)
@ -399,7 +478,7 @@ class ModelConverter:
message = "Models do not have the same number of basic layers:\n" message = "Models do not have the same number of basic layers:\n"
for layer_type, counts in diff.items(): for layer_type, counts in diff.items():
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n" message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
self._log(message=message.rstrip()) self._log(message=message.strip())
return False return False
return True return True
@ -424,8 +503,8 @@ class ModelConverter:
if missing_source_layers or missing_target_layers: if missing_source_layers or missing_target_layers:
self._log( self._log(
message=( message=(
"Models might have missing basic layers. You can either pass them into keys to skip or set" "Models might have missing basic layers. If you want to skip this check, set"
f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}" f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
) )
) )
return False return False
@ -478,17 +557,36 @@ class ModelConverter:
) -> bool: ) -> bool:
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned.""" """Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys()) model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
for model_type_shape in model_type_shapes: default_type_shapes = [
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
]
shape_mismatched = False
for model_type_shape in default_type_shapes:
source_keys = source_order.get(model_type_shape, []) source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, []) target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys): if len(source_keys) != len(target_keys):
shape_missmatched = True shape_mismatched = True
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched for source_custom_type in self.custom_layer_mapping.keys():
# iterate over all type_shapes that have the same type as source_custom_type
for source_type_shape in [
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
]:
source_keys = source_order.get(source_type_shape, [])
target_custom_type = self.custom_layer_mapping[source_custom_type]
target_type_shape = (target_custom_type, source_type_shape[1])
target_keys = target_order.get(target_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_mismatched
@staticmethod @staticmethod
def _convert_state_dict( def _convert_state_dict(

View file

@ -0,0 +1,133 @@
# pyright: reportPrivateUsage=false
import pytest
import torch
from torch import nn, Tensor
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.model_converter import ModelConverter, ConversionStage
import refiners.fluxion.layers as fl
class CustomBasicLayer1(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
class CustomBasicLayer2(fl.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.weight = nn.Parameter(data=torch.randn(out_features, in_features))
def forward(self, x: Tensor) -> Tensor:
return x @ self.weight.t()
# Source Model
class SourceModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = fl.Linear(in_features=10, out_features=2)
self.activation = fl.ReLU()
self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)])
self.flatten = fl.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
def forward(self, x: Tensor) -> Tensor:
x = self.linear1(x)
x = self.activation(x)
for layer in self.custom_layers:
x = layer(x)
x = self.flatten(x)
x = self.dropout(x)
x = x.view(1, 1, -1)
x = self.conv(x)
x = self.pool(x)
return x
# Target Model (Purposely obfuscated but functionally equivalent)
class TargetModel(fl.Module):
def __init__(self) -> None:
super().__init__()
self.relu = fl.ReLU()
self.drop = nn.Dropout(0.5)
self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)])
self.flattenIt = fl.Flatten()
self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
self.lin = fl.Linear(in_features=10, out_features=2)
def forward(self, x: Tensor) -> Tensor:
x = self.lin(x)
x = self.relu(x)
for layer in self.layers1:
x = layer(x)
x = self.flattenIt(x)
x = self.drop(x)
x = x.view(1, 1, -1)
x = self.convolution(x)
x = self.max_pool(x)
return x
@pytest.fixture
def source_model() -> SourceModel:
manual_seed(seed=2)
return SourceModel()
@pytest.fixture
def target_model() -> TargetModel:
manual_seed(seed=2)
return TargetModel()
@pytest.fixture
def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter:
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2}
return ModelConverter(
source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True
)
@pytest.fixture
def random_tensor() -> Tensor:
return torch.randn(1, 10)
@pytest.fixture
def source_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
@pytest.fixture
def target_args(random_tensor: Tensor) -> tuple[Tensor]:
return (random_tensor,)
def test_converter_stages(
model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]
) -> None:
assert model_converter.stage == ConversionStage.INIT
assert model_converter._run_init_stage()
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH
assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH
assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args)
model_converter._increment_stage()
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE
def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None:
assert model_converter.run(source_args=source_args, target_args=target_args)
assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE

View file

@ -6,7 +6,7 @@ import torch
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ConversionStage, ModelConverter
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -31,18 +31,8 @@ def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sdxl_unet_weights_std(test_weights_path: Path) -> Path: def refiners_sdxl_unet() -> SDXLUNet:
unet_weights_std = test_weights_path / "sdxl-unet.safetensors"
if not unet_weights_std.is_file():
warn(message=f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
unet = SDXLUNet(in_channels=4) unet = SDXLUNet(in_channels=4)
unet.load_from_safetensors(tensors_path=sdxl_unet_weights_std)
return unet return unet
@ -57,19 +47,27 @@ def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> No
clip_text_embeddings = torch.randn(1, 77, 2048) clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,) target_args = (x,)
source_args = { source_args = {
"positional": (x, timestep, clip_text_embeddings), "positional": (x, timestep, clip_text_embeddings),
"keyword": {"added_cond_kwargs": added_cond_kwargs}, "keyword": {"added_cond_kwargs": added_cond_kwargs},
} }
converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2) old_forward = target.forward
def forward_with_context(self: Any, *args: Any, **kwargs: Any) -> Any:
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
return old_forward(self, *args, **kwargs)
target.forward = forward_with_context
converter = ModelConverter(source_model=source, target_model=target, verbose=True, threshold=1e-2)
assert converter.run( assert converter.run(
source_args=source_args, source_args=source_args,
target_args=target_args, target_args=target_args,
) )
assert converter.stage == ConversionStage.MODELS_OUTPUT_AGREE