mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/model_converter) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
9590703f99
commit
c31da03bad
|
@ -46,8 +46,11 @@ class ModuleArgsDict(TypedDict):
|
||||||
class ConversionStage(Enum):
|
class ConversionStage(Enum):
|
||||||
"""Represents the current stage of the conversion process.
|
"""Represents the current stage of the conversion process.
|
||||||
|
|
||||||
- `INIT`: The conversion process has not started.
|
Attributes:
|
||||||
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
|
INIT: The conversion process has not started.
|
||||||
|
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
|
||||||
|
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
|
||||||
|
MODELS_OUTPUT_AGREE: The source and target models agree.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INIT = auto()
|
INIT = auto()
|
||||||
|
@ -57,6 +60,34 @@ class ConversionStage(Enum):
|
||||||
|
|
||||||
|
|
||||||
class ModelConverter:
|
class ModelConverter:
|
||||||
|
"""Converts a model's state_dict to match another model's state_dict.
|
||||||
|
|
||||||
|
Note: The conversion process consists of three stages
|
||||||
|
1. Verify that the source and target models have the same number of basic layers.
|
||||||
|
2. Find matching shapes and layers between the source and target models.
|
||||||
|
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||||
|
4. Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
The conversion process can be run multiple times, and will resume from the last stage.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
source = ...
|
||||||
|
target = ...
|
||||||
|
|
||||||
|
converter = ModelConverter(
|
||||||
|
source_model=source,
|
||||||
|
target_model=target,
|
||||||
|
threshold=0.1,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
is_converted = converter(args)
|
||||||
|
if is_converted:
|
||||||
|
converter.save_to_safetensors(path="converted_model.pt")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
||||||
stage: ConversionStage = ConversionStage.INIT
|
stage: ConversionStage = ConversionStage.INIT
|
||||||
_stored_mapping: dict[str, str] | None = None
|
_stored_mapping: dict[str, str] | None = None
|
||||||
|
@ -73,36 +104,20 @@ class ModelConverter:
|
||||||
skip_init_check: bool = False,
|
skip_init_check: bool = False,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Initializes the ModelConverter.
|
||||||
Create a ModelConverter.
|
|
||||||
|
|
||||||
- `source_model`: The model to convert from.
|
Args:
|
||||||
- `target_model`: The model to convert to.
|
source_model: The model to convert from.
|
||||||
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
|
target_model: The model to convert to.
|
||||||
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
|
source_keys_to_skip: A list of keys to skip when tracing the source model.
|
||||||
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
|
target_keys_to_skip: A list of keys to skip when tracing the target model.
|
||||||
- `threshold`: The threshold for comparing outputs between the source and target models.
|
custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
|
||||||
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
|
threshold: The threshold for comparing outputs between the source and target models.
|
||||||
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
|
skip_output_check: Whether to skip comparing the outputs of the source and target models.
|
||||||
layers.
|
skip_init_check: Whether to skip checking that the source and target models have the same number of basic
|
||||||
- `verbose`: Whether to print messages during the conversion process.
|
layers.
|
||||||
|
verbose: Whether to print messages during the conversion process.
|
||||||
|
|
||||||
The conversion process consists of three stages:
|
|
||||||
|
|
||||||
1. Verify that the source and target models have the same number of basic layers.
|
|
||||||
2. Find matching shapes and layers between the source and target models.
|
|
||||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
|
||||||
4. Compare the outputs of the source and target models.
|
|
||||||
|
|
||||||
The conversion process can be run multiple times, and will resume from the last stage.
|
|
||||||
|
|
||||||
### Example:
|
|
||||||
```
|
|
||||||
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
|
|
||||||
is_converted = converter(args)
|
|
||||||
if is_converted:
|
|
||||||
converter.save_to_safetensors(path="test.pt")
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
self.source_model = source_model
|
self.source_model = source_model
|
||||||
self.target_model = target_model
|
self.target_model = target_model
|
||||||
|
@ -124,27 +139,17 @@ class ModelConverter:
|
||||||
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
||||||
|
|
||||||
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||||
"""
|
"""Run the conversion process.
|
||||||
Run the conversion process.
|
|
||||||
|
|
||||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
Args:
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
is not provided, these arguments will also be passed to the target model.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
is not provided, these arguments will also be passed to the target model.
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
### Returns:
|
Returns:
|
||||||
|
True if the conversion process is done and the models agree.
|
||||||
- `True` if the conversion process is done and the models agree.
|
|
||||||
|
|
||||||
The conversion process consists of three stages:
|
|
||||||
|
|
||||||
1. Verify that the source and target models have the same number of basic layers.
|
|
||||||
2. Find matching shapes and layers between the source and target models.
|
|
||||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
|
||||||
4. Compare the outputs of the source and target models.
|
|
||||||
|
|
||||||
The conversion process can be run multiple times, and will resume from the last stage.
|
|
||||||
"""
|
"""
|
||||||
if target_args is None:
|
if target_args is None:
|
||||||
target_args = source_args
|
target_args = source_args
|
||||||
|
@ -234,14 +239,16 @@ class ModelConverter:
|
||||||
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
||||||
"""Save the converted model to a SafeTensors file.
|
"""Save the converted model to a SafeTensors file.
|
||||||
|
|
||||||
This method can only be called after the conversion process is done.
|
Warning:
|
||||||
|
This method can only be called after the conversion process is done.
|
||||||
|
|
||||||
- `path`: The path to save the converted model to.
|
Args:
|
||||||
- `metadata`: Metadata to save with the converted model.
|
path: The path to save the converted model to.
|
||||||
- `half`: Whether to save the converted model as half precision.
|
metadata: Metadata to save with the converted model.
|
||||||
|
half: Whether to save the converted model as half precision.
|
||||||
|
|
||||||
### Raises:
|
Raises:
|
||||||
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
|
ValueError: If the conversion process is not done yet. Run `converter` first.
|
||||||
"""
|
"""
|
||||||
if not self:
|
if not self:
|
||||||
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
@ -255,17 +262,17 @@ class ModelConverter:
|
||||||
source_args: ModuleArgs,
|
source_args: ModuleArgs,
|
||||||
target_args: ModuleArgs | None = None,
|
target_args: ModuleArgs | None = None,
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
"""
|
"""Find a mapping between the source and target models' state_dicts.
|
||||||
Find a mapping between the source and target models' state_dicts.
|
|
||||||
|
|
||||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
Args:
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
is not provided, these arguments will also be passed to the target model.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
is not provided, these arguments will also be passed to the target model.
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
### Returns:
|
Returns:
|
||||||
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||||
"""
|
"""
|
||||||
if target_args is None:
|
if target_args is None:
|
||||||
target_args = source_args
|
target_args = source_args
|
||||||
|
@ -301,15 +308,18 @@ class ModelConverter:
|
||||||
target_args: ModuleArgs | None = None,
|
target_args: ModuleArgs | None = None,
|
||||||
threshold: float = 1e-5,
|
threshold: float = 1e-5,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Compare the outputs of the source and target models.
|
||||||
Compare the outputs of the source and target models.
|
|
||||||
|
|
||||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
Args:
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
is not provided, these arguments will also be passed to the target model.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
is not provided, these arguments will also be passed to the target model.
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
- `threshold`: The threshold for comparing outputs between the source and target models.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
threshold: The threshold for comparing outputs between the source and target models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the outputs of the source and target models agree.
|
||||||
"""
|
"""
|
||||||
if target_args is None:
|
if target_args is None:
|
||||||
target_args = source_args
|
target_args = source_args
|
||||||
|
@ -519,16 +529,16 @@ class ModelConverter:
|
||||||
args: ModuleArgs,
|
args: ModuleArgs,
|
||||||
keys_to_skip: list[str],
|
keys_to_skip: list[str],
|
||||||
) -> dict[ModelTypeShape, list[str]]:
|
) -> dict[ModelTypeShape, list[str]]:
|
||||||
"""
|
"""Execute a forward pass and store the order of execution of specific sub-modules.
|
||||||
Execute a forward pass and store the order of execution of specific sub-modules.
|
|
||||||
|
|
||||||
- `module`: The module to trace.
|
Args:
|
||||||
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
module: The module to trace.
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
keys_to_skip: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
### Returns:
|
Returns:
|
||||||
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||||
"""
|
"""
|
||||||
submodule_to_key: dict[nn.Module, str] = {}
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||||
|
@ -607,19 +617,19 @@ class ModelConverter:
|
||||||
def _collect_layers_outputs(
|
def _collect_layers_outputs(
|
||||||
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||||
) -> list[tuple[str, Tensor]]:
|
) -> list[tuple[str, Tensor]]:
|
||||||
"""
|
"""Execute a forward pass and store the output of specific sub-modules.
|
||||||
Execute a forward pass and store the output of specific sub-modules.
|
|
||||||
|
|
||||||
- `module`: The module to trace.
|
Args:
|
||||||
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
module: The module to trace.
|
||||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
keys_to_skip: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
### Returns:
|
Returns:
|
||||||
- A list of tuples containing the key of each sub-module and its output.
|
A list of tuples containing the key of each sub-module and its output.
|
||||||
|
|
||||||
### Note:
|
Note:
|
||||||
- The output of each sub-module is cloned to avoid memory leaks.
|
The output of each sub-module is cloned to avoid memory leaks.
|
||||||
"""
|
"""
|
||||||
submodule_to_key: dict[nn.Module, str] = {}
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
execution_order: list[tuple[str, Tensor]] = []
|
execution_order: list[tuple[str, Tensor]] = []
|
||||||
|
|
Loading…
Reference in a new issue