mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
0f476ea18b
This generalizes the Adapter abstraction to higher-level constructs such as high-level LoRA (targeting e.g. the SD UNet), ControlNet and Reference-Only Control. Some adapters now work by adapting child models with "sub-adapters" that they inject / eject when needed.
25 lines
778 B
Python
25 lines
778 B
Python
import torch
|
|
from refiners.adapters.adapter import Adapter
|
|
from refiners.adapters.range_adapter import RangeEncoder
|
|
from refiners.fluxion.layers import Chain, Linear
|
|
|
|
|
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
|
def __init__(self, target: Linear):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-433
|
|
dtype = torch.float64
|
|
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
|
|
|
adaptee = chain.RangeEncoder.Linear_1
|
|
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
|
|
|
|
assert adapter.parent == chain.RangeEncoder
|
|
|
|
x = torch.tensor([42], device=test_device)
|
|
y = chain(x)
|
|
assert y.dtype == dtype
|