refiners/tests/adapters/test_range_adapter.py
Pierre Chapuis 0f476ea18b make high-level adapters Adapters
This generalizes the Adapter abstraction to higher-level
constructs such as high-level LoRA (targeting e.g. the
SD UNet), ControlNet and Reference-Only Control.

Some adapters now work by adapting child models with
"sub-adapters" that they inject / eject when needed.
2023-08-31 10:57:18 +02:00

25 lines
778 B
Python

import torch
from refiners.adapters.adapter import Adapter
from refiners.adapters.range_adapter import RangeEncoder
from refiners.fluxion.layers import Chain, Linear
class DummyLinearAdapter(Chain, Adapter[Linear]):
def __init__(self, target: Linear):
with self.setup_adapter(target):
super().__init__(target)
def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-433
dtype = torch.float64
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
assert adapter.parent == chain.RangeEncoder
x = torch.tensor([42], device=test_device)
y = chain(x)
assert y.dtype == dtype