make the first diffusion step a first class property of LDM & Schedulers

This commit is contained in:
Pierre Chapuis 2024-01-19 10:55:04 +01:00
parent 42b7749630
commit 8a36c8c279
9 changed files with 68 additions and 77 deletions

View file

@ -117,10 +117,9 @@ t2i_adapter = SDXLT2IAdapter(
# Tune parameters # Tune parameters
seed = 9752 seed = 9752
first_step = 1
ip_adapter.set_scale(0.85) ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8) t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50) sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
with no_grad(): with no_grad():
@ -136,11 +135,11 @@ with no_grad():
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
manual_seed(seed=seed) manual_seed(seed=seed)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to( x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(
device=sdxl.device, dtype=sdxl.dtype device=sdxl.device, dtype=sdxl.dtype
) )
for step in sdxl.steps[first_step:]: for step in sdxl.steps:
x = sdxl( x = sdxl(
x, x,
step=step, step=step,

View file

@ -32,21 +32,21 @@ class LatentDiffusionModel(fl.Module, ABC):
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__( self.scheduler = self.scheduler.__class__(
num_inference_steps, num_inference_steps=num_steps,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_step,
).to(device=device, dtype=dtype) ).to(device=device, dtype=dtype)
def init_latents( def init_latents(
self, self,
size: tuple[int, int], size: tuple[int, int],
init_image: Image.Image | None = None, init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None, noise: Tensor | None = None,
) -> Tensor: ) -> Tensor:
height, width = size height, width = size
@ -59,11 +59,15 @@ class LatentDiffusionModel(fl.Module, ABC):
if init_image is None: if init_image is None:
return noise return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step]) return self.scheduler.add_noise(
x=encoded_image,
noise=noise,
step=self.scheduler.first_inference_step,
)
@property @property
def steps(self) -> list[int]: def steps(self) -> list[int]:
return self.scheduler.steps return self.scheduler.inference_steps
@abstractmethod @abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:

View file

@ -11,6 +11,7 @@ class DDIM(Scheduler):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
@ -20,6 +21,7 @@ class DDIM(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -35,6 +37,8 @@ class DDIM(Scheduler):
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = ( timestep, previous_timestep = (
self.timesteps[step], self.timesteps[step],
( (

View file

@ -15,6 +15,7 @@ class DDPM(Scheduler):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
) -> None: ) -> None:
super().__init__( super().__init__(
@ -22,6 +23,7 @@ class DDPM(Scheduler):
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step,
device=device, device=device,
) )

View file

@ -24,6 +24,7 @@ class DPMSolver(Scheduler):
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False, last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
@ -33,6 +34,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -100,12 +102,14 @@ class DPMSolver(Scheduler):
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs). (ODEs).
""" """
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1): if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step) return self.multistep_dpm_solver_second_order_update(x=x, step=step)

View file

@ -13,6 +13,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
@ -24,6 +25,7 @@ class EulerScheduler(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -64,6 +66,8 @@ class EulerScheduler(Scheduler):
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
s_noise: float = 1.0, s_noise: float = 1.0,
) -> Tensor: ) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
sigma = self.sigmas[step] sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

View file

@ -33,6 +33,7 @@ class Scheduler(ABC):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
): ):
@ -43,6 +44,7 @@ class Scheduler(ABC):
self.initial_diffusion_rate = initial_diffusion_rate self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule self.noise_schedule = noise_schedule
self.first_inference_step = first_inference_step
self.scale_factors = self.sample_noise_schedule() self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
@ -68,9 +70,13 @@ class Scheduler(ABC):
... ...
@property @property
def steps(self) -> list[int]: def all_steps(self) -> list[int]:
return list(range(self.num_inference_steps)) return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :]
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
""" """
For compatibility with schedulers that need to scale the input according to the current timestep. For compatibility with schedulers that need to scale the input according to the current timestep.

View file

@ -594,13 +594,12 @@ def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
@ -624,13 +623,12 @@ def test_diffusion_std_random_init_euler(
sd15 = sd15_euler sd15 = sd15_euler
euler_scheduler = sd15_euler.scheduler euler_scheduler = sd15_euler.scheduler
assert isinstance(euler_scheduler, EulerScheduler) assert isinstance(euler_scheduler, EulerScheduler)
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
@ -678,14 +676,13 @@ def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
): ):
sd15 = sd15_std_float16 sd15 = sd15_std_float16
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16 assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -707,13 +704,12 @@ def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
sd15.set_self_attention_guidance(enable=True, scale=0.75) sd15.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(2) manual_seed(2)
@ -738,19 +734,17 @@ def test_diffusion_std_init_image(
expected_image_std_init_image: Image.Image, expected_image_std_init_image: Image.Image,
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 35
first_step = 5
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(35, first_step=5)
manual_seed(2) manual_seed(2)
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step) x = sd15.init_latents((512, 512), cutecat_init)
for step in sd15.steps[first_step:]: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
step=step, step=step,
@ -786,13 +780,12 @@ def test_diffusion_inpainting(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_inpainting sd15 = sd15_inpainting
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2) manual_seed(2)
@ -820,14 +813,13 @@ def test_diffusion_inpainting_float16(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_inpainting_float16 sd15 = sd15_inpainting_float16
n_steps = 30
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16 assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
manual_seed(2) manual_seed(2)
@ -853,7 +845,6 @@ def test_diffusion_controlnet(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
@ -865,7 +856,7 @@ def test_diffusion_controlnet(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter( controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -897,7 +888,6 @@ def test_diffusion_controlnet_structural_copy(
): ):
sd15_base = sd15_std sd15_base = sd15_std
sd15 = sd15_base.structural_copy() sd15 = sd15_base.structural_copy()
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
@ -909,7 +899,7 @@ def test_diffusion_controlnet_structural_copy(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter( controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -940,7 +930,6 @@ def test_diffusion_controlnet_float16(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_std_float16 sd15 = sd15_std_float16
n_steps = 30
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
@ -952,7 +941,7 @@ def test_diffusion_controlnet_float16(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
controlnet = SD1ControlnetAdapter( controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
@ -985,7 +974,6 @@ def test_diffusion_controlnet_stack(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny _, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
@ -1002,7 +990,7 @@ def test_diffusion_controlnet_stack(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
depth_controlnet = SD1ControlnetAdapter( depth_controlnet = SD1ControlnetAdapter(
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path) sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
@ -1038,14 +1026,13 @@ def test_diffusion_lora(
test_device: torch.device, test_device: torch.device,
) -> None: ) -> None:
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
expected_image, lora_weights = lora_data_pokemon expected_image, lora_weights = lora_data_pokemon
prompt = "a cute cat" prompt = "a cute cat"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
SDLoraManager(sd15).load(lora_weights, scale=1) SDLoraManager(sd15).load(lora_weights, scale=1)
@ -1074,7 +1061,6 @@ def test_diffusion_sdxl_lora(
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
# except that we are using DDIM instead of sde-dpmsolver++ # except that we are using DDIM instead of sde-dpmsolver++
n_steps = 40
seed = 12341234123 seed = 12341234123
guidance_scale = 7.5 guidance_scale = 7.5
lora_scale = 1.4 lora_scale = 1.4
@ -1088,7 +1074,7 @@ def test_diffusion_sdxl_lora(
) )
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_inference_steps(40)
manual_seed(seed=seed) manual_seed(seed=seed)
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
@ -1155,14 +1141,13 @@ def test_diffusion_inpainting_refonly(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_inpainting sd15 = sd15_inpainting
n_steps = 30
prompt = "" # unconditional prompt = "" # unconditional
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
guide = sd15.lda.encode_image(scene_image_inpainting_refonly) guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
@ -1203,12 +1188,10 @@ def test_diffusion_textual_inversion_random_init(
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion) conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
conceptExtender.inject() conceptExtender.inject()
n_steps = 30
prompt = "a cute cat on a <gta5-artwork>" prompt = "a cute cat on a <gta5-artwork>"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
@ -1235,7 +1218,6 @@ def test_diffusion_ip_adapter(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
# See tencent-ailab/IP-Adapter best practices section: # See tencent-ailab/IP-Adapter best practices section:
# #
@ -1254,7 +1236,7 @@ def test_diffusion_ip_adapter(
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding) ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1281,7 +1263,6 @@ def test_diffusion_sdxl_ip_adapter(
test_device: torch.device, test_device: torch.device,
): ):
sdxl = sdxl_ddim.to(dtype=torch.float16) sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1298,7 +1279,7 @@ def test_diffusion_sdxl_ip_adapter(
ip_adapter.set_clip_image_embedding(clip_image_embedding) ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
@ -1332,7 +1313,6 @@ def test_diffusion_ip_adapter_controlnet(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_ddim.to(dtype=torch.float16) sd15 = sd15_ddim.to(dtype=torch.float16)
n_steps = 50
input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
@ -1360,7 +1340,7 @@ def test_diffusion_ip_adapter_controlnet(
dtype=torch.float16, dtype=torch.float16,
) )
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(50)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1388,7 +1368,6 @@ def test_diffusion_ip_adapter_plus(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
n_steps = 50
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1403,7 +1382,7 @@ def test_diffusion_ip_adapter_plus(
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding) ip_adapter.set_clip_image_embedding(clip_image_embedding)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
manual_seed(42) # seed=42 is used in the official IP-Adapter demo manual_seed(42) # seed=42 is used in the official IP-Adapter demo
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
@ -1430,7 +1409,6 @@ def test_diffusion_sdxl_ip_adapter_plus(
test_device: torch.device, test_device: torch.device,
): ):
sdxl = sdxl_ddim.to(dtype=torch.float16) sdxl = sdxl_ddim.to(dtype=torch.float16)
n_steps = 30
prompt = "best quality, high quality" prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
@ -1448,7 +1426,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
ip_adapter.set_clip_image_embedding(clip_image_embedding) ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_inference_steps(30)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
@ -1474,7 +1452,6 @@ def test_sdxl_random_init(
) -> None: ) -> None:
sdxl = sdxl_ddim sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init expected_image = expected_sdxl_ddim_random_init
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -1484,7 +1461,7 @@ def test_sdxl_random_init(
) )
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps) sdxl.set_inference_steps(30)
manual_seed(seed=2) manual_seed(seed=2)
x = torch.randn(1, 4, 128, 128, device=test_device) x = torch.randn(1, 4, 128, 128, device=test_device)
@ -1509,7 +1486,6 @@ def test_sdxl_random_init_sag(
) -> None: ) -> None:
sdxl = sdxl_ddim sdxl = sdxl_ddim
expected_image = expected_sdxl_ddim_random_init_sag expected_image = expected_sdxl_ddim_random_init_sag
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
@ -1519,7 +1495,7 @@ def test_sdxl_random_init_sag(
) )
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(num_inference_steps=n_steps) sdxl.set_inference_steps(30)
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(seed=2) manual_seed(seed=2)
@ -1577,7 +1553,6 @@ def test_t2i_adapter_depth(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_data_depth name, condition_image, expected_image, weights_path = t2i_adapter_data_depth
@ -1589,7 +1564,7 @@ def test_t2i_adapter_depth(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject() t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
@ -1618,7 +1593,6 @@ def test_t2i_adapter_xl_canny(
test_device: torch.device, test_device: torch.device,
): ):
sdxl = sdxl_ddim sdxl = sdxl_ddim
n_steps = 30
name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny
@ -1635,7 +1609,7 @@ def test_t2i_adapter_xl_canny(
) )
time_ids = sdxl.default_time_ids time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps) sdxl.set_inference_steps(30)
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
t2i_adapter.set_scale(0.8) t2i_adapter.set_scale(0.8)
@ -1667,14 +1641,13 @@ def test_restart(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_ddim sd15 = sd15_ddim
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(30)
restart = Restart(ldm=sd15) restart = Restart(ldm=sd15)
manual_seed(2) manual_seed(2)
@ -1706,23 +1679,21 @@ def test_freeu(
expected_freeu: Image.Image, expected_freeu: Image.Image,
): ):
sd15 = sd15_std sd15 = sd15_std
n_steps = 50
first_step = 1
prompt = "best quality, high quality cute cat" prompt = "best quality, high quality cute cat"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_inference_steps(50, first_step=1)
SDFreeUAdapter( SDFreeUAdapter(
sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2] sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2]
).inject() ).inject()
manual_seed(9752) manual_seed(9752)
x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype) x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype)
for step in sd15.steps[first_step:]: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
step=step, step=step,
@ -1770,17 +1741,14 @@ def test_hello_world(
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
first_step = 1
ip_adapter.set_scale(0.85) ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8) t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50) sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
manual_seed(9752) manual_seed(9752)
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to( x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype)
device=sdxl.device, dtype=sdxl.dtype for step in sdxl.steps:
)
for step in sdxl.steps[first_step:]:
x = sdxl( x = sdxl(
x, x,
step=step, step=step,

View file

@ -133,5 +133,5 @@ def test_scheduler_device(test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device) scheduler = DDIM(num_inference_steps=30, device=test_device)
x = randn(1, 4, 32, 32, device=test_device) x = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device) noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.steps[0]) noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
assert noised.device == test_device assert noised.device == test_device