make basic layers an enum and work with subtyping

This commit is contained in:
limiteinductive 2023-08-15 16:06:49 +02:00 committed by Benjamin Trom
parent 9da00e6fcf
commit 17dc75421b

View file

@ -1,4 +1,5 @@
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
from PIL import Image
from numpy import array, float32
@ -7,7 +8,7 @@ from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING:
@ -78,35 +79,42 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
_save_file(tensors, path, metadata) # type: ignore
BASIC_LAYERS: list[str] = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"Linear",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"LayerNorm",
"GroupNorm",
"Embedding",
"MaxPool2d",
"AvgPool2d",
"AdaptiveAvgPool2d",
]
class BasicLayers(Enum):
Conv1d = nn.Conv1d
Conv2d = nn.Conv2d
Conv3d = nn.Conv3d
ConvTranspose1d = nn.ConvTranspose1d
ConvTranspose2d = nn.ConvTranspose2d
ConvTranspose3d = nn.ConvTranspose3d
Linear = nn.Linear
BatchNorm1d = nn.BatchNorm1d
BatchNorm2d = nn.BatchNorm2d
BatchNorm3d = nn.BatchNorm3d
LayerNorm = nn.LayerNorm
GroupNorm = nn.GroupNorm
Embedding = nn.Embedding
MaxPool2d = nn.MaxPool2d
AvgPool2d = nn.AvgPool2d
AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
ModelTypeShape = tuple[str, tuple[Size, ...]]
def is_basic_layer(module: "Module") -> bool:
return module.__class__.__name__ in BASIC_LAYERS
def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
"""Identify if the provided module matches any in the BasicLayers enum."""
for layer_type in BasicLayers:
if isinstance(module, layer_type.value):
return layer_type
return None
def get_module_signature(module: "Module") -> ModelTypeShape:
"""Return a tuple representing the module's type and parameter shapes."""
layer_type = infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (module.__class__.__name__, tuple(param_shapes))
return (str(object=layer_type), tuple(param_shapes))
def forward_order_of_execution(
@ -114,20 +122,25 @@ def forward_order_of_execution(
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> dict[ModelTypeShape, list[str]]:
"""
Determine the execution order of sub-modules during a forward pass.
Optionally skips specific modules using `key_skipper`.
"""
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: "Module", *_: Any):
layer_signature = get_module_signature(layer)
def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
layer_signature = get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name):
if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook)
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
with no_grad():
@ -143,7 +156,8 @@ def print_side_by_side(
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
):
) -> None:
"""Print module keys side by side, useful for debugging shape mismatches."""
print(f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
@ -155,6 +169,7 @@ def print_side_by_side(
def verify_shape_match(
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
"""Check if the sub-modules in source and target have matching shapes."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False
@ -164,7 +179,7 @@ def verify_shape_match(
if len(source_keys) != len(target_keys):
shape_missmatched = True
print_side_by_side(model_type_shape, source_keys, target_keys)
print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched
@ -177,13 +192,22 @@ def create_state_dict_mapping(
source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None,
) -> dict[str, str] | None:
"""
Create a mapping between state_dict keys of the source and target models.
This facilitates the transfer of weights when architectures have slight differences.
"""
if target_args is None:
target_args = source_args
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper)
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper)
source_order = forward_order_of_execution(
module=source_model, example_args=source_args, key_skipper=source_key_skipper
)
target_order = forward_order_of_execution(
module=target_model, example_args=target_args, key_skipper=target_key_skipper
)
if not verify_shape_match(source_order, target_order):
if not verify_shape_match(source_order=source_order, target_order=target_order):
return None
mapping: dict[str, str] = {}
@ -198,9 +222,10 @@ def create_state_dict_mapping(
def convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(".", 1)
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
@ -213,18 +238,19 @@ def forward_store_outputs(
example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None,
) -> list[tuple[str, Tensor]]:
"""Execute a forward pass and store outputs of specific sub-modules."""
key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {}
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor):
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name):
if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook)
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
with no_grad():
@ -245,11 +271,16 @@ def compare_models(
target_key_skipper: Callable[[str], bool] | None = None,
threshold: float = 1e-5,
) -> bool:
"""
Compare the outputs of two models given the same inputs.
Flag if any difference exceeds the given threshold.
"""
if target_args is None:
target_args = source_args
source_order = forward_store_outputs(source_model, source_args, source_key_skipper)
target_order = forward_store_outputs(target_model, target_args, target_key_skipper)
source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
prev_source_key, prev_target_key = None, None
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):