mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
unet: get rid of clip_embedding attribute for SD1
It is implicitly defined by the underlying cross-attention layer. This also makes it consistent with SDXL.
This commit is contained in:
parent
134ee7b754
commit
b933fabf31
|
@ -21,7 +21,7 @@ class Args(argparse.Namespace):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=4)
|
||||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||||
controlnet = unet.Controlnet
|
controlnet = unet.Controlnet
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ def process(args: Args) -> None:
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
||||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||||
|
|
||||||
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
refiners_model = SD1UNet(in_channels=4)
|
||||||
target = LoraTarget.CrossAttention
|
target = LoraTarget.CrossAttention
|
||||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||||
rank = diffusers_state_dict[
|
rank = diffusers_state_dict[
|
||||||
|
|
|
@ -22,9 +22,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||||
target = (
|
target = (
|
||||||
SDXLUNet(in_channels=source_in_channels)
|
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
|
||||||
if source_has_time_ids
|
|
||||||
else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(1, source_in_channels, 32, 32)
|
x = torch.randn(1, source_in_channels, 32, 32)
|
||||||
|
|
|
@ -62,7 +62,7 @@ def main() -> None:
|
||||||
for meta_key, meta_value in metadata.items():
|
for meta_key, meta_value in metadata.items():
|
||||||
match meta_key:
|
match meta_key:
|
||||||
case "unet_targets":
|
case "unet_targets":
|
||||||
model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
model = SD1UNet(in_channels=4)
|
||||||
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
||||||
key_prefix = "unet."
|
key_prefix = "unet."
|
||||||
lora_prefix = "lora_unet_"
|
lora_prefix = "lora_unet_"
|
||||||
|
|
|
@ -24,7 +24,7 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
|
unet = unet or SD1UNet(in_channels=4)
|
||||||
lda = lda or LatentDiffusionAutoencoder()
|
lda = lda or LatentDiffusionAutoencoder()
|
||||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||||
|
|
|
@ -243,17 +243,10 @@ class ResidualConcatenator(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SD1UNet(fl.Chain):
|
class SD1UNet(fl.Chain):
|
||||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
clip_embedding_dim: int,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
):
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.clip_embedding_dim = clip_embedding_dim
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(device=device, dtype=dtype),
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||||
|
|
|
@ -128,7 +128,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
@cached_property
|
@cached_property
|
||||||
def unet(self) -> SD1UNet:
|
def unet(self) -> SD1UNet:
|
||||||
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
||||||
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
return SD1UNet(in_channels=4, device=self.device).to(device=self.device)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def text_encoder(self) -> CLIPTextEncoderL:
|
def text_encoder(self) -> CLIPTextEncoderL:
|
||||||
|
|
|
@ -224,7 +224,7 @@ def sd15_inpainting(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
@ -242,7 +242,7 @@ def sd15_inpainting_float16(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
|
|
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
|
||||||
@pytest.fixture(scope="module", params=[True, False])
|
@pytest.fixture(scope="module", params=[True, False])
|
||||||
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
||||||
with_parent: bool = request.param
|
with_parent: bool = request.param
|
||||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9)
|
||||||
if with_parent:
|
if with_parent:
|
||||||
fl.Chain(unet)
|
fl.Chain(unet)
|
||||||
yield unet
|
yield unet
|
||||||
|
|
|
@ -14,7 +14,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_sai_inject_eject() -> None:
|
def test_sai_inject_eject() -> None:
|
||||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9)
|
||||||
sai = ReferenceOnlyControlAdapter(unet)
|
sai = ReferenceOnlyControlAdapter(unet)
|
||||||
|
|
||||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||||
|
|
|
@ -9,7 +9,7 @@ def test_unet_context_flush():
|
||||||
timestep = torch.randint(0, 999, size=(1, 1))
|
timestep = torch.randint(0, 999, size=(1, 1))
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
|
||||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=4)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in a new issue