diff --git a/src/refiners/fluxion/layers/converter.py b/src/refiners/fluxion/layers/converter.py index 8826a68..efcb31e 100644 --- a/src/refiners/fluxion/layers/converter.py +++ b/src/refiners/fluxion/layers/converter.py @@ -7,18 +7,21 @@ class Converter(ContextModule): """ A Converter class that adjusts tensor properties based on a parent module's settings. - This class inherits from `ContextModule` and provides functionality to adjust - the device and dtype of input tensor(s) to match the parent module's attributes. - - Attributes: - set_device (bool): If True, matches the device of the input tensor(s) to the parent's device. - set_dtype (bool): If True, matches the dtype of the input tensor(s) to the parent's dtype. + This class inherits from [`ContextModule`][refiners.fluxion.layers.ContextModule] + and provides functionality to adjust the device and dtype + of input tensor(s) to match the parent module's attributes. Note: Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True. """ def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None: + """Initializes the Converter layer. + + Args: + set_device: If True, matches the device of the input tensor(s) to the parent's device. + set_dtype: If True, matches the dtype of the input tensor(s) to the parent's dtype. + """ super().__init__() self.set_device = set_device self.set_dtype = set_dtype