(doc/fluxion/model_converter) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:03:38 +00:00 committed by Laureηt
parent 9590703f99
commit c31da03bad

View file

@ -46,8 +46,11 @@ class ModuleArgsDict(TypedDict):
class ConversionStage(Enum):
"""Represents the current stage of the conversion process.
- `INIT`: The conversion process has not started.
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
Attributes:
INIT: The conversion process has not started.
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
MODELS_OUTPUT_AGREE: The source and target models agree.
"""
INIT = auto()
@ -57,6 +60,34 @@ class ConversionStage(Enum):
class ModelConverter:
"""Converts a model's state_dict to match another model's state_dict.
Note: The conversion process consists of three stages
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
Example:
```py
source = ...
target = ...
converter = ModelConverter(
source_model=source,
target_model=target,
threshold=0.1,
verbose=False
)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="converted_model.pt")
```
"""
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
stage: ConversionStage = ConversionStage.INIT
_stored_mapping: dict[str, str] | None = None
@ -73,36 +104,20 @@ class ModelConverter:
skip_init_check: bool = False,
verbose: bool = True,
) -> None:
"""
Create a ModelConverter.
"""Initializes the ModelConverter.
- `source_model`: The model to convert from.
- `target_model`: The model to convert to.
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
- `threshold`: The threshold for comparing outputs between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
layers.
- `verbose`: Whether to print messages during the conversion process.
Args:
source_model: The model to convert from.
target_model: The model to convert to.
source_keys_to_skip: A list of keys to skip when tracing the source model.
target_keys_to_skip: A list of keys to skip when tracing the target model.
custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
threshold: The threshold for comparing outputs between the source and target models.
skip_output_check: Whether to skip comparing the outputs of the source and target models.
skip_init_check: Whether to skip checking that the source and target models have the same number of basic
layers.
verbose: Whether to print messages during the conversion process.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
### Example:
```
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="test.pt")
```
"""
self.source_model = source_model
self.target_model = target_model
@ -124,27 +139,17 @@ class ModelConverter:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
"""
Run the conversion process.
"""Run the conversion process.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- `True` if the conversion process is done and the models agree.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
Returns:
True if the conversion process is done and the models agree.
"""
if target_args is None:
target_args = source_args
@ -234,14 +239,16 @@ class ModelConverter:
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
"""Save the converted model to a SafeTensors file.
This method can only be called after the conversion process is done.
Warning:
This method can only be called after the conversion process is done.
- `path`: The path to save the converted model to.
- `metadata`: Metadata to save with the converted model.
- `half`: Whether to save the converted model as half precision.
Args:
path: The path to save the converted model to.
metadata: Metadata to save with the converted model.
half: Whether to save the converted model as half precision.
### Raises:
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
Raises:
ValueError: If the conversion process is not done yet. Run `converter` first.
"""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
@ -255,17 +262,17 @@ class ModelConverter:
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
) -> dict[str, str] | None:
"""
Find a mapping between the source and target models' state_dicts.
"""Find a mapping between the source and target models' state_dicts.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
Returns:
A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
"""
if target_args is None:
target_args = source_args
@ -301,15 +308,18 @@ class ModelConverter:
target_args: ModuleArgs | None = None,
threshold: float = 1e-5,
) -> bool:
"""
Compare the outputs of the source and target models.
"""Compare the outputs of the source and target models.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `threshold`: The threshold for comparing outputs between the source and target models.
Args:
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
threshold: The threshold for comparing outputs between the source and target models.
Returns:
True if the outputs of the source and target models agree.
"""
if target_args is None:
target_args = source_args
@ -519,16 +529,16 @@ class ModelConverter:
args: ModuleArgs,
keys_to_skip: list[str],
) -> dict[ModelTypeShape, list[str]]:
"""
Execute a forward pass and store the order of execution of specific sub-modules.
"""Execute a forward pass and store the order of execution of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
Args:
module: The module to trace.
args: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
### Returns:
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
Returns:
A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
@ -607,19 +617,19 @@ class ModelConverter:
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
"""
Execute a forward pass and store the output of specific sub-modules.
"""Execute a forward pass and store the output of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
Args:
module: The module to trace.
args: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
### Returns:
- A list of tuples containing the key of each sub-module and its output.
Returns:
A list of tuples containing the key of each sub-module and its output.
### Note:
- The output of each sub-module is cloned to avoid memory leaks.
Note:
The output of each sub-module is cloned to avoid memory leaks.
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: list[tuple[str, Tensor]] = []