make basic layers an enum and work with subtyping

This commit is contained in:
limiteinductive 2023-08-15 16:06:49 +02:00 committed by Benjamin Trom
parent 9da00e6fcf
commit 17dc75421b

View file

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
from PIL import Image from PIL import Image
from numpy import array, float32 from numpy import array, float32
@ -7,7 +8,7 @@ from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore from safetensors.torch import save_file as _save_file # type: ignore
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING: if TYPE_CHECKING:
@ -78,35 +79,42 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
_save_file(tensors, path, metadata) # type: ignore _save_file(tensors, path, metadata) # type: ignore
BASIC_LAYERS: list[str] = [ class BasicLayers(Enum):
"Conv1d", Conv1d = nn.Conv1d
"Conv2d", Conv2d = nn.Conv2d
"Conv3d", Conv3d = nn.Conv3d
"ConvTranspose1d", ConvTranspose1d = nn.ConvTranspose1d
"ConvTranspose2d", ConvTranspose2d = nn.ConvTranspose2d
"ConvTranspose3d", ConvTranspose3d = nn.ConvTranspose3d
"Linear", Linear = nn.Linear
"BatchNorm1d", BatchNorm1d = nn.BatchNorm1d
"BatchNorm2d", BatchNorm2d = nn.BatchNorm2d
"BatchNorm3d", BatchNorm3d = nn.BatchNorm3d
"LayerNorm", LayerNorm = nn.LayerNorm
"GroupNorm", GroupNorm = nn.GroupNorm
"Embedding", Embedding = nn.Embedding
"MaxPool2d", MaxPool2d = nn.MaxPool2d
"AvgPool2d", AvgPool2d = nn.AvgPool2d
"AdaptiveAvgPool2d", AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
]
ModelTypeShape = tuple[str, tuple[Size, ...]] ModelTypeShape = tuple[str, tuple[Size, ...]]
def is_basic_layer(module: "Module") -> bool: def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
return module.__class__.__name__ in BASIC_LAYERS """Identify if the provided module matches any in the BasicLayers enum."""
for layer_type in BasicLayers:
if isinstance(module, layer_type.value):
return layer_type
return None
def get_module_signature(module: "Module") -> ModelTypeShape: def get_module_signature(module: "Module") -> ModelTypeShape:
"""Return a tuple representing the module's type and parameter shapes."""
layer_type = infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()] param_shapes = [p.shape for p in module.parameters()]
return (module.__class__.__name__, tuple(param_shapes)) return (str(object=layer_type), tuple(param_shapes))
def forward_order_of_execution( def forward_order_of_execution(
@ -114,20 +122,25 @@ def forward_order_of_execution(
example_args: tuple[Any, ...], example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None, key_skipper: Callable[[str], bool] | None = None,
) -> dict[ModelTypeShape, list[str]]: ) -> dict[ModelTypeShape, list[str]]:
"""
Determine the execution order of sub-modules during a forward pass.
Optionally skips specific modules using `key_skipper`.
"""
key_skipper = key_skipper or (lambda _: False) key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {} submodule_to_key: dict["Module", str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: "Module", *_: Any): def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
layer_signature = get_module_signature(layer) layer_signature = get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer]) execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = [] hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name): if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook) hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook) hooks.append(hook)
with no_grad(): with no_grad():
@ -143,7 +156,8 @@ def print_side_by_side(
shape: ModelTypeShape, shape: ModelTypeShape,
source_keys: list[str], source_keys: list[str],
target_keys: list[str], target_keys: list[str],
): ) -> None:
"""Print module keys side by side, useful for debugging shape mismatches."""
print(f"{shape}") print(f"{shape}")
max_len = max(len(source_keys), len(target_keys)) max_len = max(len(source_keys), len(target_keys))
for i in range(max_len): for i in range(max_len):
@ -155,6 +169,7 @@ def print_side_by_side(
def verify_shape_match( def verify_shape_match(
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]] source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool: ) -> bool:
"""Check if the sub-modules in source and target have matching shapes."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys()) model_type_shapes = set(source_order.keys()) | set(target_order.keys())
shape_missmatched = False shape_missmatched = False
@ -164,7 +179,7 @@ def verify_shape_match(
if len(source_keys) != len(target_keys): if len(source_keys) != len(target_keys):
shape_missmatched = True shape_missmatched = True
print_side_by_side(model_type_shape, source_keys, target_keys) print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_missmatched return not shape_missmatched
@ -177,13 +192,22 @@ def create_state_dict_mapping(
source_key_skipper: Callable[[str], bool] | None = None, source_key_skipper: Callable[[str], bool] | None = None,
target_key_skipper: Callable[[str], bool] | None = None, target_key_skipper: Callable[[str], bool] | None = None,
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""
Create a mapping between state_dict keys of the source and target models.
This facilitates the transfer of weights when architectures have slight differences.
"""
if target_args is None: if target_args is None:
target_args = source_args target_args = source_args
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper) source_order = forward_order_of_execution(
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper) module=source_model, example_args=source_args, key_skipper=source_key_skipper
)
target_order = forward_order_of_execution(
module=target_model, example_args=target_args, key_skipper=target_key_skipper
)
if not verify_shape_match(source_order, target_order): if not verify_shape_match(source_order=source_order, target_order=target_order):
return None return None
mapping: dict[str, str] = {} mapping: dict[str, str] = {}
@ -198,9 +222,10 @@ def create_state_dict_mapping(
def convert_state_dict( def convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str] source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]: ) -> dict[str, Tensor]:
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
converted_state_dict: dict[str, Tensor] = {} converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict: for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(".", 1) target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix] source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix]) source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key] converted_state_dict[target_key] = source_state_dict[source_key]
@ -213,18 +238,19 @@ def forward_store_outputs(
example_args: tuple[Any, ...], example_args: tuple[Any, ...],
key_skipper: Callable[[str], bool] | None = None, key_skipper: Callable[[str], bool] | None = None,
) -> list[tuple[str, Tensor]]: ) -> list[tuple[str, Tensor]]:
"""Execute a forward pass and store outputs of specific sub-modules."""
key_skipper = key_skipper or (lambda _: False) key_skipper = key_skipper or (lambda _: False)
submodule_to_key: dict["Module", str] = {} submodule_to_key: dict["Module", str] = {}
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor): def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
hooks: list[RemovableHandle] = [] hooks: list[RemovableHandle] = []
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
if is_basic_layer(submodule) and not key_skipper(name): if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
submodule_to_key[submodule] = name submodule_to_key[submodule] = name
hook = submodule.register_forward_hook(collect_execution_order_hook) hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook) hooks.append(hook)
with no_grad(): with no_grad():
@ -245,11 +271,16 @@ def compare_models(
target_key_skipper: Callable[[str], bool] | None = None, target_key_skipper: Callable[[str], bool] | None = None,
threshold: float = 1e-5, threshold: float = 1e-5,
) -> bool: ) -> bool:
"""
Compare the outputs of two models given the same inputs.
Flag if any difference exceeds the given threshold.
"""
if target_args is None: if target_args is None:
target_args = source_args target_args = source_args
source_order = forward_store_outputs(source_model, source_args, source_key_skipper) source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
target_order = forward_store_outputs(target_model, target_args, target_key_skipper) target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
prev_source_key, prev_target_key = None, None prev_source_key, prev_target_key = None, None
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order): for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):