mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
make basic layers an enum and work with subtyping
This commit is contained in:
parent
9da00e6fcf
commit
17dc75421b
|
@ -1,4 +1,5 @@
|
|||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
|
||||
from PIL import Image
|
||||
from numpy import array, float32
|
||||
|
@ -7,7 +8,7 @@ from safetensors import safe_open as _safe_open # type: ignore
|
|||
from safetensors.torch import save_file as _save_file # type: ignore
|
||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType
|
||||
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -78,35 +79,42 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
|
|||
_save_file(tensors, path, metadata) # type: ignore
|
||||
|
||||
|
||||
BASIC_LAYERS: list[str] = [
|
||||
"Conv1d",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"ConvTranspose1d",
|
||||
"ConvTranspose2d",
|
||||
"ConvTranspose3d",
|
||||
"Linear",
|
||||
"BatchNorm1d",
|
||||
"BatchNorm2d",
|
||||
"BatchNorm3d",
|
||||
"LayerNorm",
|
||||
"GroupNorm",
|
||||
"Embedding",
|
||||
"MaxPool2d",
|
||||
"AvgPool2d",
|
||||
"AdaptiveAvgPool2d",
|
||||
]
|
||||
class BasicLayers(Enum):
|
||||
Conv1d = nn.Conv1d
|
||||
Conv2d = nn.Conv2d
|
||||
Conv3d = nn.Conv3d
|
||||
ConvTranspose1d = nn.ConvTranspose1d
|
||||
ConvTranspose2d = nn.ConvTranspose2d
|
||||
ConvTranspose3d = nn.ConvTranspose3d
|
||||
Linear = nn.Linear
|
||||
BatchNorm1d = nn.BatchNorm1d
|
||||
BatchNorm2d = nn.BatchNorm2d
|
||||
BatchNorm3d = nn.BatchNorm3d
|
||||
LayerNorm = nn.LayerNorm
|
||||
GroupNorm = nn.GroupNorm
|
||||
Embedding = nn.Embedding
|
||||
MaxPool2d = nn.MaxPool2d
|
||||
AvgPool2d = nn.AvgPool2d
|
||||
AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
|
||||
|
||||
|
||||
ModelTypeShape = tuple[str, tuple[Size, ...]]
|
||||
|
||||
|
||||
def is_basic_layer(module: "Module") -> bool:
|
||||
return module.__class__.__name__ in BASIC_LAYERS
|
||||
def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
|
||||
"""Identify if the provided module matches any in the BasicLayers enum."""
|
||||
for layer_type in BasicLayers:
|
||||
if isinstance(module, layer_type.value):
|
||||
return layer_type
|
||||
return None
|
||||
|
||||
|
||||
def get_module_signature(module: "Module") -> ModelTypeShape:
|
||||
"""Return a tuple representing the module's type and parameter shapes."""
|
||||
layer_type = infer_basic_layer_type(module=module)
|
||||
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||
param_shapes = [p.shape for p in module.parameters()]
|
||||
return (module.__class__.__name__, tuple(param_shapes))
|
||||
return (str(object=layer_type), tuple(param_shapes))
|
||||
|
||||
|
||||
def forward_order_of_execution(
|
||||
|
@ -114,20 +122,25 @@ def forward_order_of_execution(
|
|||
example_args: tuple[Any, ...],
|
||||
key_skipper: Callable[[str], bool] | None = None,
|
||||
) -> dict[ModelTypeShape, list[str]]:
|
||||
"""
|
||||
Determine the execution order of sub-modules during a forward pass.
|
||||
|
||||
Optionally skips specific modules using `key_skipper`.
|
||||
"""
|
||||
key_skipper = key_skipper or (lambda _: False)
|
||||
|
||||
submodule_to_key: dict["Module", str] = {}
|
||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||
|
||||
def collect_execution_order_hook(layer: "Module", *_: Any):
|
||||
layer_signature = get_module_signature(layer)
|
||||
def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
|
||||
layer_signature = get_module_signature(module=layer)
|
||||
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||
|
||||
hooks: list[RemovableHandle] = []
|
||||
for name, submodule in module.named_modules():
|
||||
if is_basic_layer(submodule) and not key_skipper(name):
|
||||
if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
|
||||
submodule_to_key[submodule] = name
|
||||
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||
hooks.append(hook)
|
||||
|
||||
with no_grad():
|
||||
|
@ -143,7 +156,8 @@ def print_side_by_side(
|
|||
shape: ModelTypeShape,
|
||||
source_keys: list[str],
|
||||
target_keys: list[str],
|
||||
):
|
||||
) -> None:
|
||||
"""Print module keys side by side, useful for debugging shape mismatches."""
|
||||
print(f"{shape}")
|
||||
max_len = max(len(source_keys), len(target_keys))
|
||||
for i in range(max_len):
|
||||
|
@ -155,6 +169,7 @@ def print_side_by_side(
|
|||
def verify_shape_match(
|
||||
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||
) -> bool:
|
||||
"""Check if the sub-modules in source and target have matching shapes."""
|
||||
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||
shape_missmatched = False
|
||||
|
||||
|
@ -164,7 +179,7 @@ def verify_shape_match(
|
|||
|
||||
if len(source_keys) != len(target_keys):
|
||||
shape_missmatched = True
|
||||
print_side_by_side(model_type_shape, source_keys, target_keys)
|
||||
print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||
|
||||
return not shape_missmatched
|
||||
|
||||
|
@ -177,13 +192,22 @@ def create_state_dict_mapping(
|
|||
source_key_skipper: Callable[[str], bool] | None = None,
|
||||
target_key_skipper: Callable[[str], bool] | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Create a mapping between state_dict keys of the source and target models.
|
||||
|
||||
This facilitates the transfer of weights when architectures have slight differences.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
||||
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper)
|
||||
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper)
|
||||
source_order = forward_order_of_execution(
|
||||
module=source_model, example_args=source_args, key_skipper=source_key_skipper
|
||||
)
|
||||
target_order = forward_order_of_execution(
|
||||
module=target_model, example_args=target_args, key_skipper=target_key_skipper
|
||||
)
|
||||
|
||||
if not verify_shape_match(source_order, target_order):
|
||||
if not verify_shape_match(source_order=source_order, target_order=target_order):
|
||||
return None
|
||||
|
||||
mapping: dict[str, str] = {}
|
||||
|
@ -198,9 +222,10 @@ def create_state_dict_mapping(
|
|||
def convert_state_dict(
|
||||
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
|
||||
converted_state_dict: dict[str, Tensor] = {}
|
||||
for target_key in target_state_dict:
|
||||
target_prefix, suffix = target_key.rsplit(".", 1)
|
||||
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
||||
source_prefix = state_dict_mapping[target_prefix]
|
||||
source_key = ".".join([source_prefix, suffix])
|
||||
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||
|
@ -213,18 +238,19 @@ def forward_store_outputs(
|
|||
example_args: tuple[Any, ...],
|
||||
key_skipper: Callable[[str], bool] | None = None,
|
||||
) -> list[tuple[str, Tensor]]:
|
||||
"""Execute a forward pass and store outputs of specific sub-modules."""
|
||||
key_skipper = key_skipper or (lambda _: False)
|
||||
submodule_to_key: dict["Module", str] = {}
|
||||
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
|
||||
|
||||
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor):
|
||||
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
|
||||
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
|
||||
|
||||
hooks: list[RemovableHandle] = []
|
||||
for name, submodule in module.named_modules():
|
||||
if is_basic_layer(submodule) and not key_skipper(name):
|
||||
if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
|
||||
submodule_to_key[submodule] = name
|
||||
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
||||
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||
hooks.append(hook)
|
||||
|
||||
with no_grad():
|
||||
|
@ -245,11 +271,16 @@ def compare_models(
|
|||
target_key_skipper: Callable[[str], bool] | None = None,
|
||||
threshold: float = 1e-5,
|
||||
) -> bool:
|
||||
"""
|
||||
Compare the outputs of two models given the same inputs.
|
||||
|
||||
Flag if any difference exceeds the given threshold.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
||||
source_order = forward_store_outputs(source_model, source_args, source_key_skipper)
|
||||
target_order = forward_store_outputs(target_model, target_args, target_key_skipper)
|
||||
source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
|
||||
target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
|
||||
|
||||
prev_source_key, prev_target_key = None, None
|
||||
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
|
||||
|
|
Loading…
Reference in a new issue