unet: get rid of clip_embedding attribute for SD1

It is implicitly defined by the underlying cross-attention layer. This
also makes it consistent with SDXL.
This commit is contained in:
Cédric Deltheil 2023-08-31 17:22:57 +02:00 committed by Cédric Deltheil
parent 134ee7b754
commit b933fabf31
11 changed files with 13 additions and 22 deletions

View file

@ -21,7 +21,7 @@ class Args(argparse.Namespace):
@torch.no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet = SD1UNet(in_channels=4)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = unet.Controlnet

View file

@ -44,7 +44,7 @@ def process(args: Args) -> None:
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
refiners_model = SD1UNet(in_channels=4)
target = LoraTarget.CrossAttention
metadata = {"unet_targets": "CrossAttentionBlock2d"}
rank = diffusers_state_dict[

View file

@ -22,9 +22,7 @@ def setup_converter(args: Args) -> ModelConverter:
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
target = (
SDXLUNet(in_channels=source_in_channels)
if source_has_time_ids
else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim)
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
)
x = torch.randn(1, source_in_channels, 32, 32)

View file

@ -62,7 +62,7 @@ def main() -> None:
for meta_key, meta_value in metadata.items():
match meta_key:
case "unet_targets":
model = SD1UNet(in_channels=4, clip_embedding_dim=768)
model = SD1UNet(in_channels=4)
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
key_prefix = "unet."
lora_prefix = "lora_unet_"

View file

@ -24,7 +24,7 @@ class StableDiffusion_1(LatentDiffusionModel):
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
unet = unet or SD1UNet(in_channels=4)
lda = lda or LatentDiffusionAutoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)

View file

@ -243,17 +243,10 @@ class ResidualConcatenator(fl.Chain):
class SD1UNet(fl.Chain):
structural_attrs = ["in_channels", "clip_embedding_dim"]
structural_attrs = ["in_channels"]
def __init__(
self,
in_channels: int,
clip_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
self.clip_embedding_dim = clip_embedding_dim
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),

View file

@ -128,7 +128,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> SD1UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
return SD1UNet(in_channels=4, device=self.device).to(device=self.device)
@cached_property
def text_encoder(self) -> CLIPTextEncoderL:

View file

@ -224,7 +224,7 @@ def sd15_inpainting(
warn("not running on CPU, skipping")
pytest.skip()
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
@ -242,7 +242,7 @@ def sd15_inpainting_float16(
warn("not running on CPU, skipping")
pytest.skip()
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))

View file

@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
@pytest.fixture(scope="module", params=[True, False])
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
with_parent: bool = request.param
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9)
if with_parent:
fl.Chain(unet)
yield unet

View file

@ -14,7 +14,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
@torch.no_grad()
def test_sai_inject_eject() -> None:
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9)
sai = ReferenceOnlyControlAdapter(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))

View file

@ -9,7 +9,7 @@ def test_unet_context_flush():
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad():