(doc/fluxion/model_converter) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:03:38 +00:00 committed by Laureηt
parent 9590703f99
commit c31da03bad

View file

@ -46,8 +46,11 @@ class ModuleArgsDict(TypedDict):
class ConversionStage(Enum): class ConversionStage(Enum):
"""Represents the current stage of the conversion process. """Represents the current stage of the conversion process.
- `INIT`: The conversion process has not started. Attributes:
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers. INIT: The conversion process has not started.
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
MODELS_OUTPUT_AGREE: The source and target models agree.
""" """
INIT = auto() INIT = auto()
@ -57,6 +60,34 @@ class ConversionStage(Enum):
class ModelConverter: class ModelConverter:
"""Converts a model's state_dict to match another model's state_dict.
Note: The conversion process consists of three stages
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
Example:
```py
source = ...
target = ...
converter = ModelConverter(
source_model=source,
target_model=target,
threshold=0.1,
verbose=False
)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="converted_model.pt")
```
"""
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
stage: ConversionStage = ConversionStage.INIT stage: ConversionStage = ConversionStage.INIT
_stored_mapping: dict[str, str] | None = None _stored_mapping: dict[str, str] | None = None
@ -73,36 +104,20 @@ class ModelConverter:
skip_init_check: bool = False, skip_init_check: bool = False,
verbose: bool = True, verbose: bool = True,
) -> None: ) -> None:
""" """Initializes the ModelConverter.
Create a ModelConverter.
- `source_model`: The model to convert from. Args:
- `target_model`: The model to convert to. source_model: The model to convert from.
- `source_keys_to_skip`: A list of keys to skip when tracing the source model. target_model: The model to convert to.
- `target_keys_to_skip`: A list of keys to skip when tracing the target model. source_keys_to_skip: A list of keys to skip when tracing the source model.
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models. target_keys_to_skip: A list of keys to skip when tracing the target model.
- `threshold`: The threshold for comparing outputs between the source and target models. custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models. threshold: The threshold for comparing outputs between the source and target models.
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic skip_output_check: Whether to skip comparing the outputs of the source and target models.
layers. skip_init_check: Whether to skip checking that the source and target models have the same number of basic
- `verbose`: Whether to print messages during the conversion process. layers.
verbose: Whether to print messages during the conversion process.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
### Example:
```
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="test.pt")
```
""" """
self.source_model = source_model self.source_model = source_model
self.target_model = target_model self.target_model = target_model
@ -124,27 +139,17 @@ class ModelConverter:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3 return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool: def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
""" """Run the conversion process.
Run the conversion process.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, Args:
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
is not provided, these arguments will also be passed to the target model. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, is not provided, these arguments will also be passed to the target model.
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns: Returns:
True if the conversion process is done and the models agree.
- `True` if the conversion process is done and the models agree.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
""" """
if target_args is None: if target_args is None:
target_args = source_args target_args = source_args
@ -234,14 +239,16 @@ class ModelConverter:
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None: def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
"""Save the converted model to a SafeTensors file. """Save the converted model to a SafeTensors file.
This method can only be called after the conversion process is done. Warning:
This method can only be called after the conversion process is done.
- `path`: The path to save the converted model to. Args:
- `metadata`: Metadata to save with the converted model. path: The path to save the converted model to.
- `half`: Whether to save the converted model as half precision. metadata: Metadata to save with the converted model.
half: Whether to save the converted model as half precision.
### Raises: Raises:
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first. ValueError: If the conversion process is not done yet. Run `converter` first.
""" """
if not self: if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.") raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
@ -255,17 +262,17 @@ class ModelConverter:
source_args: ModuleArgs, source_args: ModuleArgs,
target_args: ModuleArgs | None = None, target_args: ModuleArgs | None = None,
) -> dict[str, str] | None: ) -> dict[str, str] | None:
""" """Find a mapping between the source and target models' state_dicts.
Find a mapping between the source and target models' state_dicts.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, Args:
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
is not provided, these arguments will also be passed to the target model. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, is not provided, these arguments will also be passed to the target model.
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns: Returns:
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict. A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
""" """
if target_args is None: if target_args is None:
target_args = source_args target_args = source_args
@ -301,15 +308,18 @@ class ModelConverter:
target_args: ModuleArgs | None = None, target_args: ModuleArgs | None = None,
threshold: float = 1e-5, threshold: float = 1e-5,
) -> bool: ) -> bool:
""" """Compare the outputs of the source and target models.
Compare the outputs of the source and target models.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments, Args:
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args` source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
is not provided, these arguments will also be passed to the target model. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments, is not provided, these arguments will also be passed to the target model.
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
- `threshold`: The threshold for comparing outputs between the source and target models. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
threshold: The threshold for comparing outputs between the source and target models.
Returns:
True if the outputs of the source and target models agree.
""" """
if target_args is None: if target_args is None:
target_args = source_args target_args = source_args
@ -519,16 +529,16 @@ class ModelConverter:
args: ModuleArgs, args: ModuleArgs,
keys_to_skip: list[str], keys_to_skip: list[str],
) -> dict[ModelTypeShape, list[str]]: ) -> dict[ModelTypeShape, list[str]]:
""" """Execute a forward pass and store the order of execution of specific sub-modules.
Execute a forward pass and store the order of execution of specific sub-modules.
- `module`: The module to trace. Args:
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments, module: The module to trace.
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. args: The arguments to pass to the module it can be either a tuple of positional arguments,
- `keys_to_skip`: A list of keys to skip when tracing the module. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
### Returns: Returns:
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules` A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
""" """
submodule_to_key: dict[nn.Module, str] = {} submodule_to_key: dict[nn.Module, str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list) execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
@ -607,19 +617,19 @@ class ModelConverter:
def _collect_layers_outputs( def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]: ) -> list[tuple[str, Tensor]]:
""" """Execute a forward pass and store the output of specific sub-modules.
Execute a forward pass and store the output of specific sub-modules.
- `module`: The module to trace. Args:
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments, module: The module to trace.
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. args: The arguments to pass to the module it can be either a tuple of positional arguments,
- `keys_to_skip`: A list of keys to skip when tracing the module. a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
keys_to_skip: A list of keys to skip when tracing the module.
### Returns: Returns:
- A list of tuples containing the key of each sub-module and its output. A list of tuples containing the key of each sub-module and its output.
### Note: Note:
- The output of each sub-module is cloned to avoid memory leaks. The output of each sub-module is cloned to avoid memory leaks.
""" """
submodule_to_key: dict[nn.Module, str] = {} submodule_to_key: dict[nn.Module, str] = {}
execution_order: list[tuple[str, Tensor]] = [] execution_order: list[tuple[str, Tensor]] = []