From 71ddb55a8e2e615b21dd4209b10d33bee82d22d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Tue, 22 Aug 2023 11:42:31 +0200 Subject: [PATCH] infer device and dtype in LoraAdapter --- README.md | 2 +- scripts/training/finetune-ldm-lora.py | 2 -- src/refiners/adapters/lora.py | 6 ++---- src/refiners/foundationals/latent_diffusion/lora.py | 2 -- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0e2b59e..7fb75c2 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ from refiners.adapters.lora import LoraAdapter for layer in vit.layers(fl.Attention): for linear, parent in layer.walk(fl.Linear): - adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype) + adapter = LoraAdapter(target=linear, rank=64) adapter.inject(parent) # ... and load existing weights if the LoRAs are pretrained ... diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index d910d51..abffc46 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -33,8 +33,6 @@ class LoraConfig(BaseModel): adapter = LoraAdapter( target=linear, rank=self.rank, - device=module.device, - dtype=module.dtype, ) adapter.inject(parent) for linear in adapter.Lora.layers(fl.Linear): diff --git a/src/refiners/adapters/lora.py b/src/refiners/adapters/lora.py index 21a9e98..d7f167d 100644 --- a/src/refiners/adapters/lora.py +++ b/src/refiners/adapters/lora.py @@ -49,8 +49,6 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]): target: fl.Linear, rank: int = 16, scale: float = 1.0, - device: Device | str | None = None, - dtype: DType | None = None, ) -> None: self.in_features = target.in_features self.out_features = target.out_features @@ -63,8 +61,8 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]): in_features=target.in_features, out_features=target.out_features, rank=rank, - device=device, - dtype=dtype, + device=target.device, + dtype=target.dtype, ), ) self.Lora.set_scale(scale=scale) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 9a90b96..ac29631 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -49,8 +49,6 @@ def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale target=linear, rank=rank, scale=scale, - device=module.device, - dtype=module.dtype, ) adapter.inject(parent)