mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
471ef91d1c
PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1.
27 lines
878 B
Python
27 lines
878 B
Python
import torch
|
|
|
|
from refiners.fluxion.adapters.adapter import Adapter
|
|
from refiners.fluxion.layers import Chain, Linear
|
|
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
|
|
|
|
|
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
|
def __init__(self, target: Linear):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-433
|
|
dtype = torch.float64
|
|
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
|
|
|
range_encoder = chain.layer("RangeEncoder", RangeEncoder)
|
|
adaptee = range_encoder.layer("Linear_1", Linear)
|
|
adapter = DummyLinearAdapter(adaptee).inject(range_encoder)
|
|
|
|
assert adapter.parent == chain.RangeEncoder
|
|
|
|
x = torch.tensor([42], device=test_device)
|
|
y = chain(x)
|
|
assert y.dtype == dtype
|