mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
infer device and dtype in LoraAdapter
This commit is contained in:
parent
8c7298f8cc
commit
71ddb55a8e
|
@ -183,7 +183,7 @@ from refiners.adapters.lora import LoraAdapter
|
|||
|
||||
for layer in vit.layers(fl.Attention):
|
||||
for linear, parent in layer.walk(fl.Linear):
|
||||
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
|
||||
adapter = LoraAdapter(target=linear, rank=64)
|
||||
adapter.inject(parent)
|
||||
|
||||
# ... and load existing weights if the LoRAs are pretrained ...
|
||||
|
|
|
@ -33,8 +33,6 @@ class LoraConfig(BaseModel):
|
|||
adapter = LoraAdapter(
|
||||
target=linear,
|
||||
rank=self.rank,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
adapter.inject(parent)
|
||||
for linear in adapter.Lora.layers(fl.Linear):
|
||||
|
|
|
@ -49,8 +49,6 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
|||
target: fl.Linear,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_features = target.in_features
|
||||
self.out_features = target.out_features
|
||||
|
@ -63,8 +61,8 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
|||
in_features=target.in_features,
|
||||
out_features=target.out_features,
|
||||
rank=rank,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
),
|
||||
)
|
||||
self.Lora.set_scale(scale=scale)
|
||||
|
|
|
@ -49,8 +49,6 @@ def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale
|
|||
target=linear,
|
||||
rank=rank,
|
||||
scale=scale,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
adapter.inject(parent)
|
||||
|
||||
|
|
Loading…
Reference in a new issue