mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
unet: get rid of clip_embedding attribute for SD1
It is implicitly defined by the underlying cross-attention layer. This also makes it consistent with SDXL.
This commit is contained in:
parent
134ee7b754
commit
b933fabf31
|
@ -21,7 +21,7 @@ class Args(argparse.Namespace):
|
|||
@torch.no_grad()
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=4)
|
||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||
controlnet = unet.Controlnet
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ def process(args: Args) -> None:
|
|||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||
|
||||
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
refiners_model = SD1UNet(in_channels=4)
|
||||
target = LoraTarget.CrossAttention
|
||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||
rank = diffusers_state_dict[
|
||||
|
|
|
@ -22,9 +22,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||
target = (
|
||||
SDXLUNet(in_channels=source_in_channels)
|
||||
if source_has_time_ids
|
||||
else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim)
|
||||
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
|
||||
)
|
||||
|
||||
x = torch.randn(1, source_in_channels, 32, 32)
|
||||
|
|
|
@ -62,7 +62,7 @@ def main() -> None:
|
|||
for meta_key, meta_value in metadata.items():
|
||||
match meta_key:
|
||||
case "unet_targets":
|
||||
model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
model = SD1UNet(in_channels=4)
|
||||
create_mapping = partial(get_unet_mapping, source_path=args.sd15)
|
||||
key_prefix = "unet."
|
||||
lora_prefix = "lora_unet_"
|
||||
|
|
|
@ -24,7 +24,7 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet = unet or SD1UNet(in_channels=4)
|
||||
lda = lda or LatentDiffusionAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||
|
|
|
@ -243,17 +243,10 @@ class ResidualConcatenator(fl.Chain):
|
|||
|
||||
|
||||
class SD1UNet(fl.Chain):
|
||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
clip_embedding_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.in_channels = in_channels
|
||||
self.clip_embedding_dim = clip_embedding_dim
|
||||
super().__init__(
|
||||
TimestepEncoder(device=device, dtype=dtype),
|
||||
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||
|
|
|
@ -128,7 +128,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
@cached_property
|
||||
def unet(self) -> SD1UNet:
|
||||
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
||||
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
||||
return SD1UNet(in_channels=4, device=self.device).to(device=self.device)
|
||||
|
||||
@cached_property
|
||||
def text_encoder(self) -> CLIPTextEncoderL:
|
||||
|
|
|
@ -224,7 +224,7 @@ def sd15_inpainting(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
|
@ -242,7 +242,7 @@ def sd15_inpainting_float16(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
|
|
|
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor
|
|||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
||||
with_parent: bool = request.param
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9)
|
||||
if with_parent:
|
||||
fl.Chain(unet)
|
||||
yield unet
|
||||
|
|
|
@ -14,7 +14,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
|||
|
||||
@torch.no_grad()
|
||||
def test_sai_inject_eject() -> None:
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9)
|
||||
sai = ReferenceOnlyControlAdapter(unet)
|
||||
|
||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||
|
|
|
@ -9,7 +9,7 @@ def test_unet_context_flush():
|
|||
timestep = torch.randint(0, 999, size=(1, 1))
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=4)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in a new issue