make basic adapters a part of Fluxion

This commit is contained in:
Pierre Chapuis 2023-09-01 16:50:41 +02:00
parent 31785f2059
commit d389d11a06
15 changed files with 14 additions and 14 deletions

View file

@ -179,7 +179,7 @@ The `Adapter` API lets you **easily patch models** by injecting parameters in ta
E.g. to inject LoRA layers in all attention's linear layers: E.g. to inject LoRA layers in all attention's linear layers:
```python ```python
from refiners.adapters.lora import SingleLoraAdapter from refiners.fluxion.adapters.lora import SingleLoraAdapter
for layer in vit.layers(fl.Attention): for layer in vit.layers(fl.Attention):
for linear, parent in layer.walk(fl.Linear): for linear, parent in layer.walk(fl.Linear):

View file

@ -12,7 +12,7 @@ from diffusers import DiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import save_to_safetensors
from refiners.adapters.lora import Lora, LoraAdapter from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets

View file

@ -1,7 +1,7 @@
from typing import Iterable, Generic, TypeVar, Any from typing import Iterable, Generic, TypeVar, Any
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter

View file

@ -1,4 +1,4 @@
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl

View file

@ -7,8 +7,8 @@ from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.adapters.lora import SingleLoraAdapter, LoraAdapter from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d

View file

@ -2,7 +2,7 @@ import math
from torch import Tensor, arange, float32, exp, sin, cat, cos, device as Device, dtype as DType from torch import Tensor, arange, float32, exp, sin, cat, cos, device as Device, dtype as DType
from jaxtyping import Float, Int from jaxtyping import Float, Int
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl

View file

@ -9,7 +9,7 @@ from refiners.fluxion.layers import (
Identity, Identity,
Parallel, Parallel,
) )
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from torch import Tensor from torch import Tensor

View file

@ -7,7 +7,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualBlock, ResidualBlock,
TimestepEncoder, TimestepEncoder,
) )
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d
from typing import cast, Iterable from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType

View file

@ -1,6 +1,6 @@
from typing import cast from typing import cast
from torch import device as Device, dtype as DType, Tensor, cat from torch import device as Device, dtype as DType, Tensor, cat
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL

View file

@ -5,7 +5,7 @@ from torch.nn import Dropout as TorchDropout
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.training_utils.callback import Callback from refiners.training_utils.callback import Callback
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig

View file

@ -1,5 +1,5 @@
import pytest import pytest
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.layers import Chain, Linear from refiners.fluxion.layers import Chain, Linear

View file

@ -1,4 +1,4 @@
from refiners.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter from refiners.fluxion.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter
from torch import randn, allclose from torch import randn, allclose
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl

View file

@ -1,5 +1,5 @@
import torch import torch
from refiners.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
from refiners.fluxion.layers import Chain, Linear from refiners.fluxion.layers import Chain, Linear