mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make basic layers an enum and work with subtyping
This commit is contained in:
parent
9da00e6fcf
commit
17dc75421b
|
@ -1,4 +1,5 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Literal, TypeVar
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from numpy import array, float32
|
from numpy import array, float32
|
||||||
|
@ -7,7 +8,7 @@ from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
from torch import norm as _norm, manual_seed as _manual_seed # type: ignore
|
||||||
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
from torch.nn.functional import pad as _pad, interpolate as _interpolate # type: ignore
|
||||||
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType
|
from torch import Size, Tensor, tensor, no_grad, device as Device, dtype as DType, nn
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -78,35 +79,42 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
|
||||||
_save_file(tensors, path, metadata) # type: ignore
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
BASIC_LAYERS: list[str] = [
|
class BasicLayers(Enum):
|
||||||
"Conv1d",
|
Conv1d = nn.Conv1d
|
||||||
"Conv2d",
|
Conv2d = nn.Conv2d
|
||||||
"Conv3d",
|
Conv3d = nn.Conv3d
|
||||||
"ConvTranspose1d",
|
ConvTranspose1d = nn.ConvTranspose1d
|
||||||
"ConvTranspose2d",
|
ConvTranspose2d = nn.ConvTranspose2d
|
||||||
"ConvTranspose3d",
|
ConvTranspose3d = nn.ConvTranspose3d
|
||||||
"Linear",
|
Linear = nn.Linear
|
||||||
"BatchNorm1d",
|
BatchNorm1d = nn.BatchNorm1d
|
||||||
"BatchNorm2d",
|
BatchNorm2d = nn.BatchNorm2d
|
||||||
"BatchNorm3d",
|
BatchNorm3d = nn.BatchNorm3d
|
||||||
"LayerNorm",
|
LayerNorm = nn.LayerNorm
|
||||||
"GroupNorm",
|
GroupNorm = nn.GroupNorm
|
||||||
"Embedding",
|
Embedding = nn.Embedding
|
||||||
"MaxPool2d",
|
MaxPool2d = nn.MaxPool2d
|
||||||
"AvgPool2d",
|
AvgPool2d = nn.AvgPool2d
|
||||||
"AdaptiveAvgPool2d",
|
AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d
|
||||||
]
|
|
||||||
|
|
||||||
ModelTypeShape = tuple[str, tuple[Size, ...]]
|
ModelTypeShape = tuple[str, tuple[Size, ...]]
|
||||||
|
|
||||||
|
|
||||||
def is_basic_layer(module: "Module") -> bool:
|
def infer_basic_layer_type(module: nn.Module) -> BasicLayers | None:
|
||||||
return module.__class__.__name__ in BASIC_LAYERS
|
"""Identify if the provided module matches any in the BasicLayers enum."""
|
||||||
|
for layer_type in BasicLayers:
|
||||||
|
if isinstance(module, layer_type.value):
|
||||||
|
return layer_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_module_signature(module: "Module") -> ModelTypeShape:
|
def get_module_signature(module: "Module") -> ModelTypeShape:
|
||||||
|
"""Return a tuple representing the module's type and parameter shapes."""
|
||||||
|
layer_type = infer_basic_layer_type(module=module)
|
||||||
|
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||||
param_shapes = [p.shape for p in module.parameters()]
|
param_shapes = [p.shape for p in module.parameters()]
|
||||||
return (module.__class__.__name__, tuple(param_shapes))
|
return (str(object=layer_type), tuple(param_shapes))
|
||||||
|
|
||||||
|
|
||||||
def forward_order_of_execution(
|
def forward_order_of_execution(
|
||||||
|
@ -114,20 +122,25 @@ def forward_order_of_execution(
|
||||||
example_args: tuple[Any, ...],
|
example_args: tuple[Any, ...],
|
||||||
key_skipper: Callable[[str], bool] | None = None,
|
key_skipper: Callable[[str], bool] | None = None,
|
||||||
) -> dict[ModelTypeShape, list[str]]:
|
) -> dict[ModelTypeShape, list[str]]:
|
||||||
|
"""
|
||||||
|
Determine the execution order of sub-modules during a forward pass.
|
||||||
|
|
||||||
|
Optionally skips specific modules using `key_skipper`.
|
||||||
|
"""
|
||||||
key_skipper = key_skipper or (lambda _: False)
|
key_skipper = key_skipper or (lambda _: False)
|
||||||
|
|
||||||
submodule_to_key: dict["Module", str] = {}
|
submodule_to_key: dict["Module", str] = {}
|
||||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
def collect_execution_order_hook(layer: "Module", *_: Any):
|
def collect_execution_order_hook(layer: "Module", *_: Any) -> None:
|
||||||
layer_signature = get_module_signature(layer)
|
layer_signature = get_module_signature(module=layer)
|
||||||
execution_order[layer_signature].append(submodule_to_key[layer])
|
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||||
|
|
||||||
hooks: list[RemovableHandle] = []
|
hooks: list[RemovableHandle] = []
|
||||||
for name, submodule in module.named_modules():
|
for name, submodule in module.named_modules():
|
||||||
if is_basic_layer(submodule) and not key_skipper(name):
|
if (infer_basic_layer_type(module=submodule) is not None) and not key_skipper(name):
|
||||||
submodule_to_key[submodule] = name
|
submodule_to_key[submodule] = name
|
||||||
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
hooks.append(hook)
|
hooks.append(hook)
|
||||||
|
|
||||||
with no_grad():
|
with no_grad():
|
||||||
|
@ -143,7 +156,8 @@ def print_side_by_side(
|
||||||
shape: ModelTypeShape,
|
shape: ModelTypeShape,
|
||||||
source_keys: list[str],
|
source_keys: list[str],
|
||||||
target_keys: list[str],
|
target_keys: list[str],
|
||||||
):
|
) -> None:
|
||||||
|
"""Print module keys side by side, useful for debugging shape mismatches."""
|
||||||
print(f"{shape}")
|
print(f"{shape}")
|
||||||
max_len = max(len(source_keys), len(target_keys))
|
max_len = max(len(source_keys), len(target_keys))
|
||||||
for i in range(max_len):
|
for i in range(max_len):
|
||||||
|
@ -155,6 +169,7 @@ def print_side_by_side(
|
||||||
def verify_shape_match(
|
def verify_shape_match(
|
||||||
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""Check if the sub-modules in source and target have matching shapes."""
|
||||||
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||||
shape_missmatched = False
|
shape_missmatched = False
|
||||||
|
|
||||||
|
@ -164,7 +179,7 @@ def verify_shape_match(
|
||||||
|
|
||||||
if len(source_keys) != len(target_keys):
|
if len(source_keys) != len(target_keys):
|
||||||
shape_missmatched = True
|
shape_missmatched = True
|
||||||
print_side_by_side(model_type_shape, source_keys, target_keys)
|
print_side_by_side(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||||
|
|
||||||
return not shape_missmatched
|
return not shape_missmatched
|
||||||
|
|
||||||
|
@ -177,13 +192,22 @@ def create_state_dict_mapping(
|
||||||
source_key_skipper: Callable[[str], bool] | None = None,
|
source_key_skipper: Callable[[str], bool] | None = None,
|
||||||
target_key_skipper: Callable[[str], bool] | None = None,
|
target_key_skipper: Callable[[str], bool] | None = None,
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
"""
|
||||||
|
Create a mapping between state_dict keys of the source and target models.
|
||||||
|
|
||||||
|
This facilitates the transfer of weights when architectures have slight differences.
|
||||||
|
"""
|
||||||
if target_args is None:
|
if target_args is None:
|
||||||
target_args = source_args
|
target_args = source_args
|
||||||
|
|
||||||
source_order = forward_order_of_execution(source_model, source_args, source_key_skipper)
|
source_order = forward_order_of_execution(
|
||||||
target_order = forward_order_of_execution(target_model, target_args, target_key_skipper)
|
module=source_model, example_args=source_args, key_skipper=source_key_skipper
|
||||||
|
)
|
||||||
|
target_order = forward_order_of_execution(
|
||||||
|
module=target_model, example_args=target_args, key_skipper=target_key_skipper
|
||||||
|
)
|
||||||
|
|
||||||
if not verify_shape_match(source_order, target_order):
|
if not verify_shape_match(source_order=source_order, target_order=target_order):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mapping: dict[str, str] = {}
|
mapping: dict[str, str] = {}
|
||||||
|
@ -198,9 +222,10 @@ def create_state_dict_mapping(
|
||||||
def convert_state_dict(
|
def convert_state_dict(
|
||||||
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
|
"""Convert source state_dict based on the provided mapping to match target state_dict structure."""
|
||||||
converted_state_dict: dict[str, Tensor] = {}
|
converted_state_dict: dict[str, Tensor] = {}
|
||||||
for target_key in target_state_dict:
|
for target_key in target_state_dict:
|
||||||
target_prefix, suffix = target_key.rsplit(".", 1)
|
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
||||||
source_prefix = state_dict_mapping[target_prefix]
|
source_prefix = state_dict_mapping[target_prefix]
|
||||||
source_key = ".".join([source_prefix, suffix])
|
source_key = ".".join([source_prefix, suffix])
|
||||||
converted_state_dict[target_key] = source_state_dict[source_key]
|
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||||
|
@ -213,18 +238,19 @@ def forward_store_outputs(
|
||||||
example_args: tuple[Any, ...],
|
example_args: tuple[Any, ...],
|
||||||
key_skipper: Callable[[str], bool] | None = None,
|
key_skipper: Callable[[str], bool] | None = None,
|
||||||
) -> list[tuple[str, Tensor]]:
|
) -> list[tuple[str, Tensor]]:
|
||||||
|
"""Execute a forward pass and store outputs of specific sub-modules."""
|
||||||
key_skipper = key_skipper or (lambda _: False)
|
key_skipper = key_skipper or (lambda _: False)
|
||||||
submodule_to_key: dict["Module", str] = {}
|
submodule_to_key: dict["Module", str] = {}
|
||||||
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
|
execution_order: list[tuple[str, Tensor]] = [] # Store outputs in a list
|
||||||
|
|
||||||
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor):
|
def collect_execution_order_hook(layer: "Module", _: Any, output: Tensor) -> None:
|
||||||
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
|
execution_order.append((submodule_to_key[layer], output.clone())) # Store a copy of the output
|
||||||
|
|
||||||
hooks: list[RemovableHandle] = []
|
hooks: list[RemovableHandle] = []
|
||||||
for name, submodule in module.named_modules():
|
for name, submodule in module.named_modules():
|
||||||
if is_basic_layer(submodule) and not key_skipper(name):
|
if (infer_basic_layer_type(module=module) is not None) and not key_skipper(name):
|
||||||
submodule_to_key[submodule] = name
|
submodule_to_key[submodule] = name
|
||||||
hook = submodule.register_forward_hook(collect_execution_order_hook)
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
hooks.append(hook)
|
hooks.append(hook)
|
||||||
|
|
||||||
with no_grad():
|
with no_grad():
|
||||||
|
@ -245,11 +271,16 @@ def compare_models(
|
||||||
target_key_skipper: Callable[[str], bool] | None = None,
|
target_key_skipper: Callable[[str], bool] | None = None,
|
||||||
threshold: float = 1e-5,
|
threshold: float = 1e-5,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Compare the outputs of two models given the same inputs.
|
||||||
|
|
||||||
|
Flag if any difference exceeds the given threshold.
|
||||||
|
"""
|
||||||
if target_args is None:
|
if target_args is None:
|
||||||
target_args = source_args
|
target_args = source_args
|
||||||
|
|
||||||
source_order = forward_store_outputs(source_model, source_args, source_key_skipper)
|
source_order = forward_store_outputs(module=source_model, example_args=source_args, key_skipper=source_key_skipper)
|
||||||
target_order = forward_store_outputs(target_model, target_args, target_key_skipper)
|
target_order = forward_store_outputs(module=target_model, example_args=target_args, key_skipper=target_key_skipper)
|
||||||
|
|
||||||
prev_source_key, prev_target_key = None, None
|
prev_source_key, prev_target_key = None, None
|
||||||
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
|
for (source_key, source_output), (target_key, target_output) in zip(source_order, target_order):
|
||||||
|
|
Loading…
Reference in a new issue