infer device and dtype in LoraAdapter

This commit is contained in:
Cédric Deltheil 2023-08-22 11:42:31 +02:00 committed by Cédric Deltheil
parent 8c7298f8cc
commit 71ddb55a8e
4 changed files with 3 additions and 9 deletions

View file

@ -183,7 +183,7 @@ from refiners.adapters.lora import LoraAdapter
for layer in vit.layers(fl.Attention):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
adapter = LoraAdapter(target=linear, rank=64)
adapter.inject(parent)
# ... and load existing weights if the LoRAs are pretrained ...

View file

@ -33,8 +33,6 @@ class LoraConfig(BaseModel):
adapter = LoraAdapter(
target=linear,
rank=self.rank,
device=module.device,
dtype=module.dtype,
)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):

View file

@ -49,8 +49,6 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
target: fl.Linear,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
@ -63,8 +61,8 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=device,
dtype=dtype,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)

View file

@ -49,8 +49,6 @@ def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale
target=linear,
rank=rank,
scale=scale,
device=module.device,
dtype=module.dtype,
)
adapter.inject(parent)