mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/model_converter) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
9590703f99
commit
c31da03bad
|
@ -46,8 +46,11 @@ class ModuleArgsDict(TypedDict):
|
|||
class ConversionStage(Enum):
|
||||
"""Represents the current stage of the conversion process.
|
||||
|
||||
- `INIT`: The conversion process has not started.
|
||||
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
|
||||
Attributes:
|
||||
INIT: The conversion process has not started.
|
||||
BASIC_LAYERS_MATCH: The source and target models have the same number of basic layers.
|
||||
SHAPE_AND_LAYERS_MATCH: The shape of both models agree.
|
||||
MODELS_OUTPUT_AGREE: The source and target models agree.
|
||||
"""
|
||||
|
||||
INIT = auto()
|
||||
|
@ -57,6 +60,34 @@ class ConversionStage(Enum):
|
|||
|
||||
|
||||
class ModelConverter:
|
||||
"""Converts a model's state_dict to match another model's state_dict.
|
||||
|
||||
Note: The conversion process consists of three stages
|
||||
1. Verify that the source and target models have the same number of basic layers.
|
||||
2. Find matching shapes and layers between the source and target models.
|
||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||
4. Compare the outputs of the source and target models.
|
||||
|
||||
The conversion process can be run multiple times, and will resume from the last stage.
|
||||
|
||||
Example:
|
||||
```py
|
||||
source = ...
|
||||
target = ...
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=source,
|
||||
target_model=target,
|
||||
threshold=0.1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
is_converted = converter(args)
|
||||
if is_converted:
|
||||
converter.save_to_safetensors(path="converted_model.pt")
|
||||
```
|
||||
"""
|
||||
|
||||
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
||||
stage: ConversionStage = ConversionStage.INIT
|
||||
_stored_mapping: dict[str, str] | None = None
|
||||
|
@ -73,36 +104,20 @@ class ModelConverter:
|
|||
skip_init_check: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Create a ModelConverter.
|
||||
"""Initializes the ModelConverter.
|
||||
|
||||
- `source_model`: The model to convert from.
|
||||
- `target_model`: The model to convert to.
|
||||
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
|
||||
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
|
||||
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
|
||||
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
|
||||
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
|
||||
layers.
|
||||
- `verbose`: Whether to print messages during the conversion process.
|
||||
Args:
|
||||
source_model: The model to convert from.
|
||||
target_model: The model to convert to.
|
||||
source_keys_to_skip: A list of keys to skip when tracing the source model.
|
||||
target_keys_to_skip: A list of keys to skip when tracing the target model.
|
||||
custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
|
||||
threshold: The threshold for comparing outputs between the source and target models.
|
||||
skip_output_check: Whether to skip comparing the outputs of the source and target models.
|
||||
skip_init_check: Whether to skip checking that the source and target models have the same number of basic
|
||||
layers.
|
||||
verbose: Whether to print messages during the conversion process.
|
||||
|
||||
The conversion process consists of three stages:
|
||||
|
||||
1. Verify that the source and target models have the same number of basic layers.
|
||||
2. Find matching shapes and layers between the source and target models.
|
||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||
4. Compare the outputs of the source and target models.
|
||||
|
||||
The conversion process can be run multiple times, and will resume from the last stage.
|
||||
|
||||
### Example:
|
||||
```
|
||||
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
|
||||
is_converted = converter(args)
|
||||
if is_converted:
|
||||
converter.save_to_safetensors(path="test.pt")
|
||||
```
|
||||
"""
|
||||
self.source_model = source_model
|
||||
self.target_model = target_model
|
||||
|
@ -124,27 +139,17 @@ class ModelConverter:
|
|||
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
||||
|
||||
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||
"""
|
||||
Run the conversion process.
|
||||
"""Run the conversion process.
|
||||
|
||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
|
||||
### Returns:
|
||||
|
||||
- `True` if the conversion process is done and the models agree.
|
||||
|
||||
The conversion process consists of three stages:
|
||||
|
||||
1. Verify that the source and target models have the same number of basic layers.
|
||||
2. Find matching shapes and layers between the source and target models.
|
||||
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||
4. Compare the outputs of the source and target models.
|
||||
|
||||
The conversion process can be run multiple times, and will resume from the last stage.
|
||||
Returns:
|
||||
True if the conversion process is done and the models agree.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
@ -234,14 +239,16 @@ class ModelConverter:
|
|||
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
||||
"""Save the converted model to a SafeTensors file.
|
||||
|
||||
This method can only be called after the conversion process is done.
|
||||
Warning:
|
||||
This method can only be called after the conversion process is done.
|
||||
|
||||
- `path`: The path to save the converted model to.
|
||||
- `metadata`: Metadata to save with the converted model.
|
||||
- `half`: Whether to save the converted model as half precision.
|
||||
Args:
|
||||
path: The path to save the converted model to.
|
||||
metadata: Metadata to save with the converted model.
|
||||
half: Whether to save the converted model as half precision.
|
||||
|
||||
### Raises:
|
||||
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
|
||||
Raises:
|
||||
ValueError: If the conversion process is not done yet. Run `converter` first.
|
||||
"""
|
||||
if not self:
|
||||
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||
|
@ -255,17 +262,17 @@ class ModelConverter:
|
|||
source_args: ModuleArgs,
|
||||
target_args: ModuleArgs | None = None,
|
||||
) -> dict[str, str] | None:
|
||||
"""
|
||||
Find a mapping between the source and target models' state_dicts.
|
||||
"""Find a mapping between the source and target models' state_dicts.
|
||||
|
||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
|
||||
### Returns:
|
||||
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||
Returns:
|
||||
A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
@ -301,15 +308,18 @@ class ModelConverter:
|
|||
target_args: ModuleArgs | None = None,
|
||||
threshold: float = 1e-5,
|
||||
) -> bool:
|
||||
"""
|
||||
Compare the outputs of the source and target models.
|
||||
"""Compare the outputs of the source and target models.
|
||||
|
||||
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||
Args:
|
||||
source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||
is not provided, these arguments will also be passed to the target model.
|
||||
target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
threshold: The threshold for comparing outputs between the source and target models.
|
||||
|
||||
Returns:
|
||||
True if the outputs of the source and target models agree.
|
||||
"""
|
||||
if target_args is None:
|
||||
target_args = source_args
|
||||
|
@ -519,16 +529,16 @@ class ModelConverter:
|
|||
args: ModuleArgs,
|
||||
keys_to_skip: list[str],
|
||||
) -> dict[ModelTypeShape, list[str]]:
|
||||
"""
|
||||
Execute a forward pass and store the order of execution of specific sub-modules.
|
||||
"""Execute a forward pass and store the order of execution of specific sub-modules.
|
||||
|
||||
- `module`: The module to trace.
|
||||
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||
Args:
|
||||
module: The module to trace.
|
||||
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
keys_to_skip: A list of keys to skip when tracing the module.
|
||||
|
||||
### Returns:
|
||||
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||
Returns:
|
||||
A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||
"""
|
||||
submodule_to_key: dict[nn.Module, str] = {}
|
||||
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||
|
@ -607,19 +617,19 @@ class ModelConverter:
|
|||
def _collect_layers_outputs(
|
||||
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||
) -> list[tuple[str, Tensor]]:
|
||||
"""
|
||||
Execute a forward pass and store the output of specific sub-modules.
|
||||
"""Execute a forward pass and store the output of specific sub-modules.
|
||||
|
||||
- `module`: The module to trace.
|
||||
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||
Args:
|
||||
module: The module to trace.
|
||||
args: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||
keys_to_skip: A list of keys to skip when tracing the module.
|
||||
|
||||
### Returns:
|
||||
- A list of tuples containing the key of each sub-module and its output.
|
||||
Returns:
|
||||
A list of tuples containing the key of each sub-module and its output.
|
||||
|
||||
### Note:
|
||||
- The output of each sub-module is cloned to avoid memory leaks.
|
||||
Note:
|
||||
The output of each sub-module is cloned to avoid memory leaks.
|
||||
"""
|
||||
submodule_to_key: dict[nn.Module, str] = {}
|
||||
execution_order: list[tuple[str, Tensor]] = []
|
||||
|
|
Loading…
Reference in a new issue