From 17dc75421b7b59f1d1e86230a06e3367ae4c92ed Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Tue, 15 Aug 2023 16:06:49 +0200 Subject: [PATCH] make basic layers an enum and work with subtyping --- src/refiners/fluxion/utils.py | 105 ++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 37 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 8b9b198..75061ce 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,4 +1,5 @@ from collections import defaultdict +from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar from PIL import Image from numpy import array, float32 @@ -7,7 +8,7 @@ from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore from torch import norm as _norm, manual_seed as _manual_seed # type: ignore from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore -from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType +from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn from torch.utils.hooks import RemovableHandle if TYPE_CHECKING: @@ -78,35 +79,42 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: _save_file(tensors, path, metadata) # type: ignore -BASIC_LAYERS: list[str] = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", - "Linear", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "LayerNorm", - "GroupNorm", - "Embedding", - "MaxPool2d", - "AvgPool2d", - "AdaptiveAvgPool2d", -] +class BasicLayers(Enum): + Conv1d = nn.Conv1d + Conv2d = nn.Conv2d + Conv3d = nn.Conv3d + ConvTranspose1d = nn.ConvTranspose1d + ConvTranspose2d = nn.ConvTranspose2d + ConvTranspose3d = nn.ConvTranspose3d + Linear = nn.Linear + BatchNorm1d = nn.BatchNorm1d + BatchNorm2d = nn.BatchNorm2d + BatchNorm3d = nn.BatchNorm3d + LayerNorm = nn.LayerNorm + GroupNorm = nn.GroupNorm + Embedding = nn.Embedding + MaxPool2d = nn.MaxPool2d + AvgPool2d = nn.AvgPool2d + AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d + ModelTypeShape = tuple[str, tuple[Size, ...]] -def is_basic_layer(module: "Module") -> bool: - return module.__class__.__name__ in BASIC_LAYERS +def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None: + """Identify if the provided module matches any in the BasicLayers enum.""" + for layer_type in BasicLayers: + if isinstance(module, layer_type.value): + return layer_type + return None def get_module_signature(module: "Module") -> ModelTypeShape: + """Return a tuple representing the module's type and parameter shapes.""" + layer_type = infer_basic_layer_type(module=module) + assert layer_type is not None, f"Module {module} is not a basic layer" param_shapes = [p.shape for p in module.parameters()] - return (module.__class__.__name__, tuple(param_shapes)) + return (str(object=layer_type), tuple(param_shapes)) def forward_order_of_execution( @@ -114,20 +122,25 @@ def forward_order_of_execution( example_args: tuple[Any, ...], key_skipper: Callable[[str], bool] | None = None, ) -> dict[ModelTypeShape, list[str]]: + """ + Determine the execution order of sub-modules during a forward pass. + + Optionally skips specific modules using `key_skipper`. + """ key_skipper = key_skipper or (lambda _: False) submodule_to_key: dict["Module", str] = {} execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) - def collect_execution_order_hook(layer: "Module", *_: Any): - layer_signature = get_module_signature(layer) + def collect_execution_order_hook(layer: "Module", *_: Any) -> None: + layer_signature = get_module_signature(module=layer) execution_order[layer_signature].append(submodule_to_key[layer]) hooks: list[RemovableHandle] = [] for name, submodule in module.named_modules(): - if is_basic_layer(submodule) and not key_skipper(name): + if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name): submodule_to_key[submodule] = name - hook = submodule.register_forward_hook(collect_execution_order_hook) + hook = submodule.register_forward_hook(hook=collect_execution_order_hook) hooks.append(hook) with no_grad(): @@ -143,7 +156,8 @@ def print_side_by_side( shape: ModelTypeShape, source_keys: list[str], target_keys: list[str], -): +) -> None: + """Print module keys side by side, useful for debugging shape mismatches.""" print(f"{shape}") max_len = max(len(source_keys), len(target_keys)) for i in range(max_len): @@ -155,6 +169,7 @@ def print_side_by_side( def verify_shape_match( source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]] ) -> bool: + """Check if the sub-modules in source and target have matching shapes.""" model_type_shapes = set(source_order.keys()) | set(target_order.keys()) shape_missmatched = False @@ -164,7 +179,7 @@ def verify_shape_match( if len(source_keys) != len(target_keys): shape_missmatched = True - print_side_by_side(model_type_shape, source_keys, target_keys) + print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys) return not shape_missmatched @@ -177,13 +192,22 @@ def create_state_dict_mapping( source_key_skipper: Callable[[str], bool] | None = None, target_key_skipper: Callable[[str], bool] | None = None, ) -> dict[str, str] | None: + """ + Create a mapping between state_dict keys of the source and target models. + + This facilitates the transfer of weights when architectures have slight differences. + """ if target_args is None: target_args = source_args - source_order = forward_order_of_execution(source_model, source_args, source_key_skipper) - target_order = forward_order_of_execution(target_model, target_args, target_key_skipper) + source_order = forward_order_of_execution( + module=source_model, example_args=source_args, key_skipper=source_key_skipper + ) + target_order = forward_order_of_execution( + module=target_model, example_args=target_args, key_skipper=target_key_skipper + ) - if not verify_shape_match(source_order, target_order): + if not verify_shape_match(source_order=source_order, target_order=target_order): return None mapping: dict[str, str] = {} @@ -198,9 +222,10 @@ def create_state_dict_mapping( def convert_state_dict( source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str] ) -> dict[str, Tensor]: + """Convert source state_dict based on the provided mapping to match target state_dict structure.""" converted_state_dict: dict[str, Tensor] = {} for target_key in target_state_dict: - target_prefix, suffix = target_key.rsplit(".", 1) + target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1) source_prefix = state_dict_mapping[target_prefix] source_key = ".".join([source_prefix, suffix]) converted_state_dict[target_key] = source_state_dict[source_key] @@ -213,18 +238,19 @@ def forward_store_outputs( example_args: tuple[Any, ...], key_skipper: Callable[[str], bool] | None = None, ) -> list[tuple[str, Tensor]]: + """Execute a forward pass and store outputs of specific sub-modules.""" key_skipper = key_skipper or (lambda _: False) submodule_to_key: dict["Module", str] = {} execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list - def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor): + def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None: execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output hooks: list[RemovableHandle] = [] for name, submodule in module.named_modules(): - if is_basic_layer(submodule) and not key_skipper(name): + if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name): submodule_to_key[submodule] = name - hook = submodule.register_forward_hook(collect_execution_order_hook) + hook = submodule.register_forward_hook(hook=collect_execution_order_hook) hooks.append(hook) with no_grad(): @@ -245,11 +271,16 @@ def compare_models( target_key_skipper: Callable[[str], bool] | None = None, threshold: float = 1e-5, ) -> bool: + """ + Compare the outputs of two models given the same inputs. + + Flag if any difference exceeds the given threshold. + """ if target_args is None: target_args = source_args - source_order = forward_store_outputs(source_model, source_args, source_key_skipper) - target_order = forward_store_outputs(target_model, target_args, target_key_skipper) + source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper) + target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper) prev_source_key, prev_target_key = None, None for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):