From 8a36c8c279cf30f36e44c7d771eadbc0ac410a2f Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 19 Jan 2024 10:55:04 +0100 Subject: [PATCH] make the first diffusion step a first class property of LDM & Schedulers --- README.md | 7 +- .../foundationals/latent_diffusion/model.py | 14 ++- .../latent_diffusion/schedulers/ddim.py | 4 + .../latent_diffusion/schedulers/ddpm.py | 2 + .../latent_diffusion/schedulers/dpm_solver.py | 6 +- .../latent_diffusion/schedulers/euler.py | 4 + .../latent_diffusion/schedulers/scheduler.py | 8 +- tests/e2e/test_diffusion.py | 98 +++++++------------ .../latent_diffusion/test_schedulers.py | 2 +- 9 files changed, 68 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 5d0c1e8..98ed4fe 100644 --- a/README.md +++ b/README.md @@ -117,10 +117,9 @@ t2i_adapter = SDXLT2IAdapter( # Tune parameters seed = 9752 -first_step = 1 ip_adapter.set_scale(0.85) t2i_adapter.set_scale(0.8) -sdxl.set_num_inference_steps(50) +sdxl.set_inference_steps(50, first_step=1) sdxl.set_self_attention_guidance(enable=True, scale=0.75) with no_grad(): @@ -136,11 +135,11 @@ with no_grad(): t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) manual_seed(seed=seed) - x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to( + x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to( device=sdxl.device, dtype=sdxl.dtype ) - for step in sdxl.steps[first_step:]: + for step in sdxl.steps: x = sdxl( x, step=step, diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 22618d7..44ec1cc 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -32,21 +32,21 @@ class LatentDiffusionModel(fl.Module, ABC): self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) - def set_num_inference_steps(self, num_inference_steps: int) -> None: + def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate device, dtype = self.scheduler.device, self.scheduler.dtype self.scheduler = self.scheduler.__class__( - num_inference_steps, + num_inference_steps=num_steps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, + first_inference_step=first_step, ).to(device=device, dtype=dtype) def init_latents( self, size: tuple[int, int], init_image: Image.Image | None = None, - first_step: int = 0, noise: Tensor | None = None, ) -> Tensor: height, width = size @@ -59,11 +59,15 @@ class LatentDiffusionModel(fl.Module, ABC): if init_image is None: return noise encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) - return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step]) + return self.scheduler.add_noise( + x=encoded_image, + noise=noise, + step=self.scheduler.first_inference_step, + ) @property def steps(self) -> list[int]: - return self.scheduler.steps + return self.scheduler.inference_steps @abstractmethod def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 3ff1d83..125dada 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -11,6 +11,7 @@ class DDIM(Scheduler): initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + first_inference_step: int = 0, device: Device | str = "cpu", dtype: Dtype = float32, ) -> None: @@ -20,6 +21,7 @@ class DDIM(Scheduler): initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, + first_inference_step=first_inference_step, device=device, dtype=dtype, ) @@ -35,6 +37,8 @@ class DDIM(Scheduler): return timesteps.flip(0) def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" + timestep, previous_timestep = ( self.timesteps[step], ( diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 52ff5e9..e5d312e 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -15,6 +15,7 @@ class DDPM(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + first_inference_step: int = 0, device: Device | str = "cpu", ) -> None: super().__init__( @@ -22,6 +23,7 @@ class DDPM(Scheduler): num_train_timesteps=num_train_timesteps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, + first_inference_step=first_inference_step, device=device, ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index c2c4e16..0711fd9 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -24,6 +24,7 @@ class DPMSolver(Scheduler): final_diffusion_rate: float = 1.2e-2, last_step_first_order: bool = False, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + first_inference_step: int = 0, device: Device | str = "cpu", dtype: Dtype = float32, ): @@ -33,6 +34,7 @@ class DPMSolver(Scheduler): initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, + first_inference_step=first_inference_step, device=device, dtype=dtype, ) @@ -100,12 +102,14 @@ class DPMSolver(Scheduler): backward Euler update, which is a numerical method commonly used to solve ordinary differential equations (ODEs). """ + assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" + current_timestep = self.timesteps[step] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] estimated_denoised_data = (x - noise_ratio * noise) / scale_factor self.estimated_data.append(estimated_denoised_data) - if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1): + if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1): return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) return self.multistep_dpm_solver_second_order_update(x=x, step=step) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py index 2a69c8c..28c7c2b 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py @@ -13,6 +13,7 @@ class EulerScheduler(Scheduler): initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + first_inference_step: int = 0, device: Device | str = "cpu", dtype: Dtype = float32, ): @@ -24,6 +25,7 @@ class EulerScheduler(Scheduler): initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, + first_inference_step=first_inference_step, device=device, dtype=dtype, ) @@ -64,6 +66,8 @@ class EulerScheduler(Scheduler): s_tmax: float = float("inf"), s_noise: float = 1.0, ) -> Tensor: + assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" + sigma = self.sigmas[step] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index f64a4cc..37f9beb 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -33,6 +33,7 @@ class Scheduler(ABC): initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + first_inference_step: int = 0, device: Device | str = "cpu", dtype: DType = float32, ): @@ -43,6 +44,7 @@ class Scheduler(ABC): self.initial_diffusion_rate = initial_diffusion_rate self.final_diffusion_rate = final_diffusion_rate self.noise_schedule = noise_schedule + self.first_inference_step = first_inference_step self.scale_factors = self.sample_noise_schedule() self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) @@ -68,9 +70,13 @@ class Scheduler(ABC): ... @property - def steps(self) -> list[int]: + def all_steps(self) -> list[int]: return list(range(self.num_inference_steps)) + @property + def inference_steps(self) -> list[int]: + return self.all_steps[self.first_inference_step :] + def scale_model_input(self, x: Tensor, step: int) -> Tensor: """ For compatibility with schedulers that need to scale the input according to the current timestep. diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 5a69255..bb3ec4a 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -594,13 +594,12 @@ def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): sd15 = sd15_std - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -624,13 +623,12 @@ def test_diffusion_std_random_init_euler( sd15 = sd15_euler euler_scheduler = sd15_euler.scheduler assert isinstance(euler_scheduler, EulerScheduler) - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -678,14 +676,13 @@ def test_diffusion_std_random_init_float16( sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): sd15 = sd15_std_float16 - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) @@ -707,13 +704,12 @@ def test_diffusion_std_random_init_sag( sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device ): sd15 = sd15_std - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) sd15.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(2) @@ -738,19 +734,17 @@ def test_diffusion_std_init_image( expected_image_std_init_image: Image.Image, ): sd15 = sd15_std - n_steps = 35 - first_step = 5 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(35, first_step=5) manual_seed(2) - x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step) + x = sd15.init_latents((512, 512), cutecat_init) - for step in sd15.steps[first_step:]: + for step in sd15.steps: x = sd15( x, step=step, @@ -786,13 +780,12 @@ def test_diffusion_inpainting( test_device: torch.device, ): sd15 = sd15_inpainting - n_steps = 30 prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) manual_seed(2) @@ -820,14 +813,13 @@ def test_diffusion_inpainting_float16( test_device: torch.device, ): sd15 = sd15_inpainting_float16 - n_steps = 30 prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) manual_seed(2) @@ -853,7 +845,6 @@ def test_diffusion_controlnet( test_device: torch.device, ): sd15 = sd15_std - n_steps = 30 cn_name, condition_image, expected_image, cn_weights_path = controlnet_data @@ -865,7 +856,7 @@ def test_diffusion_controlnet( negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) @@ -897,7 +888,6 @@ def test_diffusion_controlnet_structural_copy( ): sd15_base = sd15_std sd15 = sd15_base.structural_copy() - n_steps = 30 cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny @@ -909,7 +899,7 @@ def test_diffusion_controlnet_structural_copy( negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) @@ -940,7 +930,6 @@ def test_diffusion_controlnet_float16( test_device: torch.device, ): sd15 = sd15_std_float16 - n_steps = 30 cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny @@ -952,7 +941,7 @@ def test_diffusion_controlnet_float16( negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) controlnet = SD1ControlnetAdapter( sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) @@ -985,7 +974,6 @@ def test_diffusion_controlnet_stack( test_device: torch.device, ): sd15 = sd15_std - n_steps = 30 _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth _, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny @@ -1002,7 +990,7 @@ def test_diffusion_controlnet_stack( negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) depth_controlnet = SD1ControlnetAdapter( sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path) @@ -1038,14 +1026,13 @@ def test_diffusion_lora( test_device: torch.device, ) -> None: sd15 = sd15_std - n_steps = 30 expected_image, lora_weights = lora_data_pokemon prompt = "a cute cat" clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) SDLoraManager(sd15).load(lora_weights, scale=1) @@ -1074,7 +1061,6 @@ def test_diffusion_sdxl_lora( # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA # except that we are using DDIM instead of sde-dpmsolver++ - n_steps = 40 seed = 12341234123 guidance_scale = 7.5 lora_scale = 1.4 @@ -1088,7 +1074,7 @@ def test_diffusion_sdxl_lora( ) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(n_steps) + sdxl.set_inference_steps(40) manual_seed(seed=seed) x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) @@ -1155,14 +1141,13 @@ def test_diffusion_inpainting_refonly( test_device: torch.device, ): sd15 = sd15_inpainting - n_steps = 30 prompt = "" # unconditional clip_text_embedding = sd15.compute_clip_text_embedding(prompt) refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) guide = sd15.lda.encode_image(scene_image_inpainting_refonly) @@ -1203,12 +1188,10 @@ def test_diffusion_textual_inversion_random_init( conceptExtender.add_concept("", text_embedding_textual_inversion) conceptExtender.inject() - n_steps = 30 - prompt = "a cute cat on a " clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -1235,7 +1218,6 @@ def test_diffusion_ip_adapter( test_device: torch.device, ): sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) - n_steps = 50 # See tencent-ailab/IP-Adapter best practices section: # @@ -1254,7 +1236,7 @@ def test_diffusion_ip_adapter( clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) @@ -1281,7 +1263,6 @@ def test_diffusion_sdxl_ip_adapter( test_device: torch.device, ): sdxl = sdxl_ddim.to(dtype=torch.float16) - n_steps = 30 prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" @@ -1298,7 +1279,7 @@ def test_diffusion_sdxl_ip_adapter( ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(n_steps) + sdxl.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) @@ -1332,7 +1313,6 @@ def test_diffusion_ip_adapter_controlnet( test_device: torch.device, ): sd15 = sd15_ddim.to(dtype=torch.float16) - n_steps = 50 input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth @@ -1360,7 +1340,7 @@ def test_diffusion_ip_adapter_controlnet( dtype=torch.float16, ) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(50) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) @@ -1388,7 +1368,6 @@ def test_diffusion_ip_adapter_plus( test_device: torch.device, ): sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16) - n_steps = 50 prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" @@ -1403,7 +1382,7 @@ def test_diffusion_ip_adapter_plus( clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) ip_adapter.set_clip_image_embedding(clip_image_embedding) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) manual_seed(42) # seed=42 is used in the official IP-Adapter demo x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) @@ -1430,7 +1409,6 @@ def test_diffusion_sdxl_ip_adapter_plus( test_device: torch.device, ): sdxl = sdxl_ddim.to(dtype=torch.float16) - n_steps = 30 prompt = "best quality, high quality" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" @@ -1448,7 +1426,7 @@ def test_diffusion_sdxl_ip_adapter_plus( ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(n_steps) + sdxl.set_inference_steps(30) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) @@ -1474,7 +1452,6 @@ def test_sdxl_random_init( ) -> None: sdxl = sdxl_ddim expected_image = expected_sdxl_ddim_random_init - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" @@ -1484,7 +1461,7 @@ def test_sdxl_random_init( ) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(num_inference_steps=n_steps) + sdxl.set_inference_steps(30) manual_seed(seed=2) x = torch.randn(1, 4, 128, 128, device=test_device) @@ -1509,7 +1486,6 @@ def test_sdxl_random_init_sag( ) -> None: sdxl = sdxl_ddim expected_image = expected_sdxl_ddim_random_init_sag - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" @@ -1519,7 +1495,7 @@ def test_sdxl_random_init_sag( ) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(num_inference_steps=n_steps) + sdxl.set_inference_steps(30) sdxl.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(seed=2) @@ -1577,7 +1553,6 @@ def test_t2i_adapter_depth( test_device: torch.device, ): sd15 = sd15_std - n_steps = 30 name, condition_image, expected_image, weights_path = t2i_adapter_data_depth @@ -1589,7 +1564,7 @@ def test_t2i_adapter_depth( negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject() @@ -1618,7 +1593,6 @@ def test_t2i_adapter_xl_canny( test_device: torch.device, ): sdxl = sdxl_ddim - n_steps = 30 name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny @@ -1635,7 +1609,7 @@ def test_t2i_adapter_xl_canny( ) time_ids = sdxl.default_time_ids - sdxl.set_num_inference_steps(n_steps) + sdxl.set_inference_steps(30) t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() t2i_adapter.set_scale(0.8) @@ -1667,14 +1641,13 @@ def test_restart( test_device: torch.device, ): sd15 = sd15_ddim - n_steps = 30 prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(30) restart = Restart(ldm=sd15) manual_seed(2) @@ -1706,23 +1679,21 @@ def test_freeu( expected_freeu: Image.Image, ): sd15 = sd15_std - n_steps = 50 - first_step = 1 prompt = "best quality, high quality cute cat" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - sd15.set_num_inference_steps(n_steps) + sd15.set_inference_steps(50, first_step=1) SDFreeUAdapter( sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2] ).inject() manual_seed(9752) - x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype) + x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype) - for step in sd15.steps[first_step:]: + for step in sd15.steps: x = sd15( x, step=step, @@ -1770,17 +1741,14 @@ def test_hello_world( condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) - first_step = 1 ip_adapter.set_scale(0.85) t2i_adapter.set_scale(0.8) - sdxl.set_num_inference_steps(50) + sdxl.set_inference_steps(50, first_step=1) sdxl.set_self_attention_guidance(enable=True, scale=0.75) manual_seed(9752) - x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to( - device=sdxl.device, dtype=sdxl.dtype - ) - for step in sdxl.steps[first_step:]: + x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype) + for step in sdxl.steps: x = sdxl( x, step=step, diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 91956ca..3737c26 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -133,5 +133,5 @@ def test_scheduler_device(test_device: Device): scheduler = DDIM(num_inference_steps=30, device=test_device) x = randn(1, 4, 32, 32, device=test_device) noise = randn(1, 4, 32, 32, device=test_device) - noised = scheduler.add_noise(x, noise, scheduler.steps[0]) + noised = scheduler.add_noise(x, noise, scheduler.first_inference_step) assert noised.device == test_device