mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
infer device and dtype in LoraAdapter
This commit is contained in:
parent
8c7298f8cc
commit
71ddb55a8e
|
@ -183,7 +183,7 @@ from refiners.adapters.lora import LoraAdapter
|
||||||
|
|
||||||
for layer in vit.layers(fl.Attention):
|
for layer in vit.layers(fl.Attention):
|
||||||
for linear, parent in layer.walk(fl.Linear):
|
for linear, parent in layer.walk(fl.Linear):
|
||||||
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
|
adapter = LoraAdapter(target=linear, rank=64)
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
|
|
||||||
# ... and load existing weights if the LoRAs are pretrained ...
|
# ... and load existing weights if the LoRAs are pretrained ...
|
||||||
|
|
|
@ -33,8 +33,6 @@ class LoraConfig(BaseModel):
|
||||||
adapter = LoraAdapter(
|
adapter = LoraAdapter(
|
||||||
target=linear,
|
target=linear,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
device=module.device,
|
|
||||||
dtype=module.dtype,
|
|
||||||
)
|
)
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
for linear in adapter.Lora.layers(fl.Linear):
|
for linear in adapter.Lora.layers(fl.Linear):
|
||||||
|
|
|
@ -49,8 +49,6 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||||
target: fl.Linear,
|
target: fl.Linear,
|
||||||
rank: int = 16,
|
rank: int = 16,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.in_features = target.in_features
|
self.in_features = target.in_features
|
||||||
self.out_features = target.out_features
|
self.out_features = target.out_features
|
||||||
|
@ -63,8 +61,8 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||||
in_features=target.in_features,
|
in_features=target.in_features,
|
||||||
out_features=target.out_features,
|
out_features=target.out_features,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
device=device,
|
device=target.device,
|
||||||
dtype=dtype,
|
dtype=target.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.Lora.set_scale(scale=scale)
|
self.Lora.set_scale(scale=scale)
|
||||||
|
|
|
@ -49,8 +49,6 @@ def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale
|
||||||
target=linear,
|
target=linear,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
device=module.device,
|
|
||||||
dtype=module.dtype,
|
|
||||||
)
|
)
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue