diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index faba547..92d9b0a 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -50,7 +50,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]: ) target_order = converter._trace_module_execution_order(module=controlnet, args=(x,), keys_to_skip=[]) - broken_k = (str(object=nn.Conv2d), (torch.Size([320, 320, 1, 1]), torch.Size([320]))) + broken_k = (nn.Conv2d, (torch.Size([320, 320, 1, 1]), torch.Size([320]))) expected_source_order = [ "down_blocks.0.attentions.0.proj_in", @@ -89,7 +89,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - broken_k = (str(object=nn.Conv2d), (torch.Size([640, 640, 1, 1]), torch.Size([640]))) + broken_k = (nn.Conv2d, (torch.Size([640, 640, 1, 1]), torch.Size([640]))) expected_source_order = [ "down_blocks.1.attentions.0.proj_in", @@ -125,7 +125,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - broken_k = (str(object=nn.Conv2d), (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) + broken_k = (nn.Conv2d, (torch.Size([1280, 1280, 1, 1]), torch.Size([1280]))) expected_source_order = [ "down_blocks.2.attentions.0.proj_in", diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py index b5c6a57..4c827f4 100644 --- a/src/refiners/fluxion/model_converter.py +++ b/src/refiners/fluxion/model_converter.py @@ -28,7 +28,7 @@ TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ ] -ModelTypeShape = tuple[str, tuple[torch.Size, ...]] +ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]] class ModuleArgsDict(TypedDict): @@ -69,6 +69,7 @@ class ModelConverter: custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None, threshold: float = 1e-5, skip_output_check: bool = False, + skip_init_check: bool = False, verbose: bool = True, ) -> None: """ @@ -81,6 +82,8 @@ class ModelConverter: - `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models. - `threshold`: The threshold for comparing outputs between the source and target models. - `skip_output_check`: Whether to skip comparing the outputs of the source and target models. + - `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic + layers. - `verbose`: Whether to print messages during the conversion process. The conversion process consists of three stages: @@ -107,8 +110,18 @@ class ModelConverter: self.custom_layer_mapping = custom_layer_mapping or {} self.threshold = threshold self.skip_output_check = skip_output_check + self.skip_init_check = skip_init_check self.verbose = verbose + def __repr__(self) -> str: + return ( + f"ModelConverter(source_model={self.source_model.__class__.__name__}," + f" target_model={self.target_model.__class__.__name__}, stage={self.stage})" + ) + + def __bool__(self) -> bool: + return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3 + def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool: """ Run the conversion process. @@ -137,38 +150,72 @@ class ModelConverter: match self.stage: case ConversionStage.MODELS_OUTPUT_AGREE: - self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") + self._increment_stage() return True - case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_models_output_agree_stage(): - self.stage = ConversionStage.MODELS_OUTPUT_AGREE + case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage( + source_args=source_args, target_args=target_args + ): + self._increment_stage() return True case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage( source_args=source_args, target_args=target_args ): - self.stage = ( - ConversionStage.SHAPE_AND_LAYERS_MATCH - if not self.skip_output_check - else ConversionStage.MODELS_OUTPUT_AGREE - ) + self._increment_stage() return self.run(source_args=source_args, target_args=target_args) case ConversionStage.INIT if self._run_init_stage(): - self.stage = ConversionStage.BASIC_LAYERS_MATCH + self._increment_stage() return self.run(source_args=source_args, target_args=target_args) case _: + self._log(message=f"Conversion failed at stage {self.stage.value}") return False - def __repr__(self) -> str: - return ( - f"ModelConverter(source_model={self.source_model.__class__.__name__}," - f" target_model={self.target_model.__class__.__name__}, stage={self.stage})" - ) + def _increment_stage(self) -> None: + """Increment the stage of the conversion process.""" + match self.stage: + case ConversionStage.INIT: + self.stage = ConversionStage.BASIC_LAYERS_MATCH + self._log( + message=( + "Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and" + " layers..." + ) + ) + case ConversionStage.BASIC_LAYERS_MATCH: + self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH + self._log( + message=( + "Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing" + " models..." + ) + ) - def __bool__(self) -> bool: - return self.stage == ConversionStage.MODELS_OUTPUT_AGREE + case ConversionStage.SHAPE_AND_LAYERS_MATCH: + if self.skip_output_check: + self._log( + message=( + "Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set" + " `skip_output_check` to `False`" + ) + ) + else: + self.stage = ConversionStage.MODELS_OUTPUT_AGREE + self._log( + message=( + "Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the" + " converted model using `save_to_safetensors`" + ) + ) + case ConversionStage.MODELS_OUTPUT_AGREE: + self._log( + message=( + "Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export" + " the converted model using `save_to_safetensors`" + ) + ) def get_state_dict(self) -> dict[str, Tensor]: """Get the converted state_dict.""" @@ -233,9 +280,16 @@ class ModelConverter: return None mapping: dict[str, str] = {} - for model_type_shape in source_order: - source_keys = source_order[model_type_shape] - target_keys = target_order[model_type_shape] + for source_type_shape in source_order: + source_keys = source_order[source_type_shape] + target_type_shape = source_type_shape + if not self._is_torch_basic_layer(module_type=source_type_shape[0]): + for source_custom_type, target_custom_type in self.custom_layer_mapping.items(): + if source_custom_type == source_type_shape[0]: + target_type_shape = (target_custom_type, source_type_shape[1]) + break + + target_keys = target_order[target_type_shape] mapping.update(zip(target_keys, source_keys)) return mapping @@ -266,9 +320,9 @@ class ModelConverter: module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip ) - prev_source_key, prev_target_key = None, None + diff, prev_source_key, prev_target_key = None, None, None for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs): - diff = norm(source_output - target_output).item() + diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item() if diff > threshold: self._log( f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and" @@ -277,10 +331,21 @@ class ModelConverter: return False prev_source_key, prev_target_key = source_key, target_key + self._log(message=f"Models agree. Difference in norm: {diff}") + return True def _run_init_stage(self) -> bool: """Run the init stage of the conversion process.""" + if self.skip_init_check: + self._log( + message=( + "Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to" + " `False`" + ) + ) + return True + is_count_correct = self._verify_basic_layers_count() is_not_missing_layers = self._verify_missing_basic_layers() @@ -288,16 +353,12 @@ class ModelConverter: def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: """Run the basic layers match stage of the conversion process.""" - self._log(message="Finding matching shapes and layers...") - mapping = self.map_state_dicts(source_args=source_args, target_args=target_args) self._stored_mapping = mapping if mapping is None: self._log(message="Models do not have matching shapes.") return False - self._log(message="Found matching shapes and layers. Converting state_dict...") - source_state_dict = self.source_model.state_dict() target_state_dict = self.target_model.state_dict() converted_state_dict = self._convert_state_dict( @@ -309,17 +370,22 @@ class ModelConverter: def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool: """Run the shape and layers match stage of the conversion process.""" - if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold): - self._log(message="Models agree. You can export the converted model using `save_to_safetensors`") + if self.skip_output_check: + self._log( + message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`" + ) return True - else: - self._log(message="Models do not agree. Try to increase the threshold or modify the models.") - return False - def _run_models_output_agree_stage(self) -> bool: - """Run the models output agree stage of the conversion process.""" - self._log(message="Conversion is done: you can export the converted model using `save_to_safetensors`") - return True + try: + if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold): + self._log(message="Models agree. You can export the converted model using `save_to_safetensors`") + return True + else: + self._log(message="Models do not agree. Try to increase the threshold or modify the models.") + return False + except Exception as e: + self._log(message=f"An error occurred while comparing the models: {e}") + return False def _log(self, message: str) -> None: """Print a message if `verbose` is `True`.""" @@ -354,16 +420,19 @@ class ModelConverter: return positional_args, keyword_args + def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool: + """Check if a module type is a subclass of a torch basic layer.""" + return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS) + def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None: """Infer the type of a basic layer.""" - for layer_type in TORCH_BASIC_LAYERS: + layer_types = ( + set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS) + ) + for layer_type in layer_types: if isinstance(module, layer_type): return layer_type - for source_type in self.custom_layer_mapping.keys(): - if isinstance(module, source_type): - return source_type - return None def get_module_signature(self, module: nn.Module) -> ModelTypeShape: @@ -371,7 +440,7 @@ class ModelConverter: layer_type = self._infer_basic_layer_type(module=module) assert layer_type is not None, f"Module {module} is not a basic layer" param_shapes = [p.shape for p in module.parameters()] - return (str(object=layer_type), tuple(param_shapes)) + return (layer_type, tuple(param_shapes)) def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]: """Count the number of basic layers in a module.""" @@ -388,10 +457,20 @@ class ModelConverter: source_layers = self._count_basic_layers(module=self.source_model) target_layers = self._count_basic_layers(module=self.target_model) + reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()} + diff: dict[type[nn.Module], tuple[int, int]] = {} - for layer_type in set(source_layers.keys()) | set(target_layers.keys()): - source_count = source_layers.get(layer_type, 0) - target_count = target_layers.get(layer_type, 0) + for layer_type, source_count in source_layers.items(): + target_type = self.custom_layer_mapping.get(layer_type, layer_type) + target_count = target_layers.get(target_type, 0) + + if source_count != target_count: + diff[layer_type] = (source_count, target_count) + + for layer_type, target_count in target_layers.items(): + source_type = reverse_mapping.get(layer_type, layer_type) + source_count = source_layers.get(source_type, 0) + if source_count != target_count: diff[layer_type] = (source_count, target_count) @@ -399,7 +478,7 @@ class ModelConverter: message = "Models do not have the same number of basic layers:\n" for layer_type, counts in diff.items(): message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n" - self._log(message=message.rstrip()) + self._log(message=message.strip()) return False return True @@ -424,8 +503,8 @@ class ModelConverter: if missing_source_layers or missing_target_layers: self._log( message=( - "Models might have missing basic layers. You can either pass them into keys to skip or set" - f" `check_missing_basic_layer` to `False`: {missing_source_layers}, {missing_target_layers}" + "Models might have missing basic layers. If you want to skip this check, set" + f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}" ) ) return False @@ -478,17 +557,36 @@ class ModelConverter: ) -> bool: """Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned.""" model_type_shapes = set(source_order.keys()) | set(target_order.keys()) - shape_missmatched = False - for model_type_shape in model_type_shapes: + default_type_shapes = [ + type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0]) + ] + + shape_mismatched = False + + for model_type_shape in default_type_shapes: source_keys = source_order.get(model_type_shape, []) target_keys = target_order.get(model_type_shape, []) if len(source_keys) != len(target_keys): - shape_missmatched = True + shape_mismatched = True self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) - return not shape_missmatched + for source_custom_type in self.custom_layer_mapping.keys(): + # iterate over all type_shapes that have the same type as source_custom_type + for source_type_shape in [ + type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type + ]: + source_keys = source_order.get(source_type_shape, []) + target_custom_type = self.custom_layer_mapping[source_custom_type] + target_type_shape = (target_custom_type, source_type_shape[1]) + target_keys = target_order.get(target_type_shape, []) + + if len(source_keys) != len(target_keys): + shape_mismatched = True + self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys) + + return not shape_mismatched @staticmethod def _convert_state_dict( diff --git a/tests/fluxion/test_model_converter.py b/tests/fluxion/test_model_converter.py new file mode 100644 index 0000000..2e5936c --- /dev/null +++ b/tests/fluxion/test_model_converter.py @@ -0,0 +1,133 @@ +# pyright: reportPrivateUsage=false +import pytest +import torch +from torch import nn, Tensor +from refiners.fluxion.utils import manual_seed +from refiners.fluxion.model_converter import ModelConverter, ConversionStage +import refiners.fluxion.layers as fl + + +class CustomBasicLayer1(fl.Module): + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.weight = nn.Parameter(data=torch.randn(out_features, in_features)) + + def forward(self, x: Tensor) -> Tensor: + return x @ self.weight.t() + + +class CustomBasicLayer2(fl.Module): + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.weight = nn.Parameter(data=torch.randn(out_features, in_features)) + + def forward(self, x: Tensor) -> Tensor: + return x @ self.weight.t() + + +# Source Model +class SourceModel(fl.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = fl.Linear(in_features=10, out_features=2) + self.activation = fl.ReLU() + self.custom_layers = nn.ModuleList(modules=[CustomBasicLayer1(in_features=2, out_features=2) for _ in range(3)]) + self.flatten = fl.Flatten() + self.dropout = nn.Dropout(p=0.5) + self.conv = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool1d(kernel_size=2, stride=2) + + def forward(self, x: Tensor) -> Tensor: + x = self.linear1(x) + x = self.activation(x) + for layer in self.custom_layers: + x = layer(x) + x = self.flatten(x) + x = self.dropout(x) + x = x.view(1, 1, -1) + x = self.conv(x) + x = self.pool(x) + return x + + +# Target Model (Purposely obfuscated but functionally equivalent) +class TargetModel(fl.Module): + def __init__(self) -> None: + super().__init__() + self.relu = fl.ReLU() + self.drop = nn.Dropout(0.5) + self.layers1 = nn.ModuleList(modules=[CustomBasicLayer2(in_features=2, out_features=2) for _ in range(3)]) + self.flattenIt = fl.Flatten() + self.max_pool = nn.MaxPool1d(kernel_size=2, stride=2) + self.convolution = nn.Conv1d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1) + self.lin = fl.Linear(in_features=10, out_features=2) + + def forward(self, x: Tensor) -> Tensor: + x = self.lin(x) + x = self.relu(x) + for layer in self.layers1: + x = layer(x) + x = self.flattenIt(x) + x = self.drop(x) + x = x.view(1, 1, -1) + x = self.convolution(x) + x = self.max_pool(x) + return x + + +@pytest.fixture +def source_model() -> SourceModel: + manual_seed(seed=2) + return SourceModel() + + +@pytest.fixture +def target_model() -> TargetModel: + manual_seed(seed=2) + return TargetModel() + + +@pytest.fixture +def model_converter(source_model: SourceModel, target_model: TargetModel) -> ModelConverter: + custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] = {CustomBasicLayer1: CustomBasicLayer2} + return ModelConverter( + source_model=source_model, target_model=target_model, custom_layer_mapping=custom_layer_mapping, verbose=True + ) + + +@pytest.fixture +def random_tensor() -> Tensor: + return torch.randn(1, 10) + + +@pytest.fixture +def source_args(random_tensor: Tensor) -> tuple[Tensor]: + return (random_tensor,) + + +@pytest.fixture +def target_args(random_tensor: Tensor) -> tuple[Tensor]: + return (random_tensor,) + + +def test_converter_stages( + model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor] +) -> None: + assert model_converter.stage == ConversionStage.INIT + assert model_converter._run_init_stage() + model_converter._increment_stage() + + assert model_converter.stage == ConversionStage.BASIC_LAYERS_MATCH + assert model_converter._run_basic_layers_match_stage(source_args=source_args, target_args=target_args) + model_converter._increment_stage() + + assert model_converter.stage == ConversionStage.SHAPE_AND_LAYERS_MATCH + assert model_converter._run_shape_and_layers_match_stage(source_args=source_args, target_args=target_args) + model_converter._increment_stage() + + assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE + + +def test_run(model_converter: ModelConverter, source_args: tuple[Tensor], target_args: tuple[Tensor]) -> None: + assert model_converter.run(source_args=source_args, target_args=target_args) + assert model_converter.stage == ConversionStage.MODELS_OUTPUT_AGREE diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index 92af508..0cf9993 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -6,7 +6,7 @@ import torch from refiners.fluxion.utils import manual_seed from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet -from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.model_converter import ConversionStage, ModelConverter @pytest.fixture(scope="module") @@ -31,18 +31,8 @@ def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any: @pytest.fixture(scope="module") -def sdxl_unet_weights_std(test_weights_path: Path) -> Path: - unet_weights_std = test_weights_path / "sdxl-unet.safetensors" - if not unet_weights_std.is_file(): - warn(message=f"could not find weights at {unet_weights_std}, skipping") - pytest.skip(allow_module_level=True) - return unet_weights_std - - -@pytest.fixture(scope="module") -def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet: +def refiners_sdxl_unet() -> SDXLUNet: unet = SDXLUNet(in_channels=4) - unet.load_from_safetensors(tensors_path=sdxl_unet_weights_std) return unet @@ -57,19 +47,27 @@ def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> No clip_text_embeddings = torch.randn(1, 77, 2048) added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} - target.set_timestep(timestep=timestep) - target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) - target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) - target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) target_args = (x,) source_args = { "positional": (x, timestep, clip_text_embeddings), "keyword": {"added_cond_kwargs": added_cond_kwargs}, } - converter = ModelConverter(source_model=source, target_model=target, verbose=False, threshold=1e-2) + old_forward = target.forward + + def forward_with_context(self: Any, *args: Any, **kwargs: Any) -> Any: + target.set_timestep(timestep=timestep) + target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) + target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) + target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) + return old_forward(self, *args, **kwargs) + + target.forward = forward_with_context + + converter = ModelConverter(source_model=source, target_model=target, verbose=True, threshold=1e-2) assert converter.run( source_args=source_args, target_args=target_args, ) + assert converter.stage == ConversionStage.MODELS_OUTPUT_AGREE