make the first diffusion step a first class property of LDM & Schedulers

This commit is contained in:
Pierre Chapuis 2024-01-19 10:55:04 +01:00
parent 42b7749630
commit 8a36c8c279
9 changed files with 68 additions and 77 deletions

View file

@ -117,10 +117,9 @@ t2i_adapter = SDXLT2IAdapter(
# Tune parameters
seed = 9752
first_step = 1
ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
with no_grad():
@ -136,11 +135,11 @@ with no_grad():
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(seed=seed)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(
device=sdxl.device, dtype=sdxl.dtype
)
for step in sdxl.steps[first_step:]:
for step in sdxl.steps:
x = sdxl(
x,
step=step,

View file

@ -32,21 +32,21 @@ class LatentDiffusionModel(fl.Module, ABC):
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int) -> None:
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps,
num_inference_steps=num_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_step,
).to(device=device, dtype=dtype)
def init_latents(
self,
size: tuple[int, int],
init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None,
) -> Tensor:
height, width = size
@ -59,11 +59,15 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
return self.scheduler.add_noise(
x=encoded_image,
noise=noise,
step=self.scheduler.first_inference_step,
)
@property
def steps(self) -> list[int]:
return self.scheduler.steps
return self.scheduler.inference_steps
@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:

View file

@ -11,6 +11,7 @@ class DDIM(Scheduler):
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
@ -20,6 +21,7 @@ class DDIM(Scheduler):
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
@ -35,6 +37,8 @@ class DDIM(Scheduler):
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = (
self.timesteps[step],
(

View file

@ -15,6 +15,7 @@ class DDPM(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0,
device: Device | str = "cpu",
) -> None:
super().__init__(
@ -22,6 +23,7 @@ class DDPM(Scheduler):
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step,
device=device,
)

View file

@ -24,6 +24,7 @@ class DPMSolver(Scheduler):
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
@ -33,6 +34,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
@ -100,12 +102,14 @@ class DPMSolver(Scheduler):
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs).
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step)

View file

@ -13,6 +13,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
@ -24,6 +25,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
@ -64,6 +66,8 @@ class EulerScheduler(Scheduler):
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

View file

@ -33,6 +33,7 @@ class Scheduler(ABC):
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType = float32,
):
@ -43,6 +44,7 @@ class Scheduler(ABC):
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
self.first_inference_step = first_inference_step
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
@ -68,9 +70,13 @@ class Scheduler(ABC):
...
@property
def steps(self) -> list[int]:
def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.

View file

@ -594,13 +594,12 @@ def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -624,13 +623,12 @@ def test_diffusion_std_random_init_euler(
sd15 = sd15_euler
euler_scheduler = sd15_euler.scheduler
assert isinstance(euler_scheduler, EulerScheduler)
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -678,14 +676,13 @@ def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_float16
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -707,13 +704,12 @@ def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
):
sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
sd15.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(2)
@ -738,19 +734,17 @@ def test_diffusion_std_init_image(
expected_image_std_init_image: Image.Image,
):
sd15 = sd15_std
n_steps = 35
first_step = 5
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(35, first_step=5)
manual_seed(2)
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
x = sd15.init_latents((512, 512), cutecat_init)
for step in sd15.steps[first_step:]:
for step in sd15.steps:
x = sd15(
x,
step=step,
@ -786,13 +780,12 @@ def test_diffusion_inpainting(
test_device: torch.device,
):
sd15 = sd15_inpainting
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
@ -820,14 +813,13 @@ def test_diffusion_inpainting_float16(
test_device: torch.device,
):
sd15 = sd15_inpainting_float16
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2)
@ -853,7 +845,6 @@ def test_diffusion_controlnet(
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
@ -865,7 +856,7 @@ def test_diffusion_controlnet(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -897,7 +888,6 @@ def test_diffusion_controlnet_structural_copy(
):
sd15_base = sd15_std
sd15 = sd15_base.structural_copy()
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
@ -909,7 +899,7 @@ def test_diffusion_controlnet_structural_copy(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -940,7 +930,6 @@ def test_diffusion_controlnet_float16(
test_device: torch.device,
):
sd15 = sd15_std_float16
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
@ -952,7 +941,7 @@ def test_diffusion_controlnet_float16(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -985,7 +974,6 @@ def test_diffusion_controlnet_stack(
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
@ -1002,7 +990,7 @@ def test_diffusion_controlnet_stack(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
depth_controlnet = SD1ControlnetAdapter(
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
@ -1038,14 +1026,13 @@ def test_diffusion_lora(
test_device: torch.device,
) -> None:
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights = lora_data_pokemon
prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
SDLoraManager(sd15).load(lora_weights, scale=1)
@ -1074,7 +1061,6 @@ def test_diffusion_sdxl_lora(
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123
guidance_scale = 7.5
lora_scale = 1.4
@ -1088,7 +1074,7 @@ def test_diffusion_sdxl_lora(
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
sdxl.set_inference_steps(40)
manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
@ -1155,14 +1141,13 @@ def test_diffusion_inpainting_refonly(
test_device: torch.device,
):
sd15 = sd15_inpainting
n_steps = 30
prompt = "" # unconditional
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
@ -1203,12 +1188,10 @@ def test_diffusion_textual_inversion_random_init(
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
conceptExtender.inject()
n_steps = 30
prompt = "a cute cat on a <gta5-artwork>"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -1235,7 +1218,6 @@ def test_diffusion_ip_adapter(
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
# See tencent-ailab/IP-Adapter best practices section:
#
@ -1254,7 +1236,7 @@ def test_diffusion_ip_adapter(
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1281,7 +1263,6 @@ def test_diffusion_sdxl_ip_adapter(
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1298,7 +1279,7 @@ def test_diffusion_sdxl_ip_adapter(
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
@ -1332,7 +1313,6 @@ def test_diffusion_ip_adapter_controlnet(
test_device: torch.device,
):
sd15 = sd15_ddim.to(dtype=torch.float16)
n_steps = 50
input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
@ -1360,7 +1340,7 @@ def test_diffusion_ip_adapter_controlnet(
dtype=torch.float16,
)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1388,7 +1368,6 @@ def test_diffusion_ip_adapter_plus(
test_device: torch.device,
):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1403,7 +1382,7 @@ def test_diffusion_ip_adapter_plus(
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1430,7 +1409,6 @@ def test_diffusion_sdxl_ip_adapter_plus(
test_device: torch.device,
):
sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1448,7 +1426,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
sdxl.set_inference_steps(30)
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
@ -1474,7 +1452,6 @@ def test_sdxl_random_init(
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -1484,7 +1461,7 @@ def test_sdxl_random_init(
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
sdxl.set_inference_steps(30)
manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device)
@ -1509,7 +1486,6 @@ def test_sdxl_random_init_sag(
) -> None:
sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -1519,7 +1495,7 @@ def test_sdxl_random_init_sag(
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
sdxl.set_inference_steps(30)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(seed=2)
@ -1577,7 +1553,6 @@ def test_t2i_adapter_depth(
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_data_depth
@ -1589,7 +1564,7 @@ def test_t2i_adapter_depth(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
@ -1618,7 +1593,6 @@ def test_t2i_adapter_xl_canny(
test_device: torch.device,
):
sdxl = sdxl_ddim
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny
@ -1635,7 +1609,7 @@ def test_t2i_adapter_xl_canny(
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)
sdxl.set_inference_steps(30)
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
t2i_adapter.set_scale(0.8)
@ -1667,14 +1641,13 @@ def test_restart(
test_device: torch.device,
):
sd15 = sd15_ddim
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(30)
restart = Restart(ldm=sd15)
manual_seed(2)
@ -1706,23 +1679,21 @@ def test_freeu(
expected_freeu: Image.Image,
):
sd15 = sd15_std
n_steps = 50
first_step = 1
prompt = "best quality, high quality cute cat"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inference_steps(50, first_step=1)
SDFreeUAdapter(
sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2]
).inject()
manual_seed(9752)
x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype)
x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype)
for step in sd15.steps[first_step:]:
for step in sd15.steps:
x = sd15(
x,
step=step,
@ -1770,17 +1741,14 @@ def test_hello_world(
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
first_step = 1
ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(9752)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
device=sdxl.device, dtype=sdxl.dtype
)
for step in sdxl.steps[first_step:]:
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype)
for step in sdxl.steps:
x = sdxl(
x,
step=step,

View file

@ -133,5 +133,5 @@ def test_scheduler_device(test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device)
x = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.steps[0])
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
assert noised.device == test_device