mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
make the first diffusion step a first class property of LDM & Schedulers
This commit is contained in:
parent
42b7749630
commit
8a36c8c279
|
@ -117,10 +117,9 @@ t2i_adapter = SDXLT2IAdapter(
|
|||
|
||||
# Tune parameters
|
||||
seed = 9752
|
||||
first_step = 1
|
||||
ip_adapter.set_scale(0.85)
|
||||
t2i_adapter.set_scale(0.8)
|
||||
sdxl.set_num_inference_steps(50)
|
||||
sdxl.set_inference_steps(50, first_step=1)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
with no_grad():
|
||||
|
@ -136,11 +135,11 @@ with no_grad():
|
|||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
manual_seed(seed=seed)
|
||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
|
||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(
|
||||
device=sdxl.device, dtype=sdxl.dtype
|
||||
)
|
||||
|
||||
for step in sdxl.steps[first_step:]:
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
|
|
|
@ -32,21 +32,21 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def set_num_inference_steps(self, num_inference_steps: int) -> None:
|
||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||
self.scheduler = self.scheduler.__class__(
|
||||
num_inference_steps,
|
||||
num_inference_steps=num_steps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
first_inference_step=first_step,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
size: tuple[int, int],
|
||||
init_image: Image.Image | None = None,
|
||||
first_step: int = 0,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
height, width = size
|
||||
|
@ -59,11 +59,15 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
||||
return self.scheduler.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
step=self.scheduler.first_inference_step,
|
||||
)
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
return self.scheduler.steps
|
||||
return self.scheduler.inference_steps
|
||||
|
||||
@abstractmethod
|
||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||
|
|
|
@ -11,6 +11,7 @@ class DDIM(Scheduler):
|
|||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
|
@ -20,6 +21,7 @@ class DDIM(Scheduler):
|
|||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -35,6 +37,8 @@ class DDIM(Scheduler):
|
|||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
timestep, previous_timestep = (
|
||||
self.timesteps[step],
|
||||
(
|
||||
|
|
|
@ -15,6 +15,7 @@ class DDPM(Scheduler):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
@ -22,6 +23,7 @@ class DDPM(Scheduler):
|
|||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ class DPMSolver(Scheduler):
|
|||
final_diffusion_rate: float = 1.2e-2,
|
||||
last_step_first_order: bool = False,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
|
@ -33,6 +34,7 @@ class DPMSolver(Scheduler):
|
|||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -100,12 +102,14 @@ class DPMSolver(Scheduler):
|
|||
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
||||
(ODEs).
|
||||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
current_timestep = self.timesteps[step]
|
||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||
self.estimated_data.append(estimated_denoised_data)
|
||||
|
||||
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
|
||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
|
|
|
@ -13,6 +13,7 @@ class EulerScheduler(Scheduler):
|
|||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
|
@ -24,6 +25,7 @@ class EulerScheduler(Scheduler):
|
|||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -64,6 +66,8 @@ class EulerScheduler(Scheduler):
|
|||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
) -> Tensor:
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
sigma = self.sigmas[step]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||
|
|
|
@ -33,6 +33,7 @@ class Scheduler(ABC):
|
|||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
|
@ -43,6 +44,7 @@ class Scheduler(ABC):
|
|||
self.initial_diffusion_rate = initial_diffusion_rate
|
||||
self.final_diffusion_rate = final_diffusion_rate
|
||||
self.noise_schedule = noise_schedule
|
||||
self.first_inference_step = first_inference_step
|
||||
self.scale_factors = self.sample_noise_schedule()
|
||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
|
@ -68,9 +70,13 @@ class Scheduler(ABC):
|
|||
...
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
def all_steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
@property
|
||||
def inference_steps(self) -> list[int]:
|
||||
return self.all_steps[self.first_inference_step :]
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||
|
|
|
@ -594,13 +594,12 @@ def test_diffusion_std_random_init(
|
|||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
@ -624,13 +623,12 @@ def test_diffusion_std_random_init_euler(
|
|||
sd15 = sd15_euler
|
||||
euler_scheduler = sd15_euler.scheduler
|
||||
assert isinstance(euler_scheduler, EulerScheduler)
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
@ -678,14 +676,13 @@ def test_diffusion_std_random_init_float16(
|
|||
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_std_float16
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
assert clip_text_embedding.dtype == torch.float16
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
@ -707,13 +704,12 @@ def test_diffusion_std_random_init_sag(
|
|||
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
sd15.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -738,19 +734,17 @@ def test_diffusion_std_init_image(
|
|||
expected_image_std_init_image: Image.Image,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 35
|
||||
first_step = 5
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(35, first_step=5)
|
||||
|
||||
manual_seed(2)
|
||||
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
|
||||
x = sd15.init_latents((512, 512), cutecat_init)
|
||||
|
||||
for step in sd15.steps[first_step:]:
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
|
@ -786,13 +780,12 @@ def test_diffusion_inpainting(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_inpainting
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -820,14 +813,13 @@ def test_diffusion_inpainting_float16(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_inpainting_float16
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
assert clip_text_embedding.dtype == torch.float16
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -853,7 +845,6 @@ def test_diffusion_controlnet(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data
|
||||
|
||||
|
@ -865,7 +856,7 @@ def test_diffusion_controlnet(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
|
@ -897,7 +888,6 @@ def test_diffusion_controlnet_structural_copy(
|
|||
):
|
||||
sd15_base = sd15_std
|
||||
sd15 = sd15_base.structural_copy()
|
||||
n_steps = 30
|
||||
|
||||
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
|
||||
|
||||
|
@ -909,7 +899,7 @@ def test_diffusion_controlnet_structural_copy(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
|
@ -940,7 +930,6 @@ def test_diffusion_controlnet_float16(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std_float16
|
||||
n_steps = 30
|
||||
|
||||
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_canny
|
||||
|
||||
|
@ -952,7 +941,7 @@ def test_diffusion_controlnet_float16(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
|
@ -985,7 +974,6 @@ def test_diffusion_controlnet_stack(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
|
||||
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
|
||||
|
@ -1002,7 +990,7 @@ def test_diffusion_controlnet_stack(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
depth_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
|
||||
|
@ -1038,14 +1026,13 @@ def test_diffusion_lora(
|
|||
test_device: torch.device,
|
||||
) -> None:
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
expected_image, lora_weights = lora_data_pokemon
|
||||
|
||||
prompt = "a cute cat"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
SDLoraManager(sd15).load(lora_weights, scale=1)
|
||||
|
||||
|
@ -1074,7 +1061,6 @@ def test_diffusion_sdxl_lora(
|
|||
|
||||
# parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA
|
||||
# except that we are using DDIM instead of sde-dpmsolver++
|
||||
n_steps = 40
|
||||
seed = 12341234123
|
||||
guidance_scale = 7.5
|
||||
lora_scale = 1.4
|
||||
|
@ -1088,7 +1074,7 @@ def test_diffusion_sdxl_lora(
|
|||
)
|
||||
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
sdxl.set_inference_steps(40)
|
||||
|
||||
manual_seed(seed=seed)
|
||||
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
@ -1155,14 +1141,13 @@ def test_diffusion_inpainting_refonly(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_inpainting
|
||||
n_steps = 30
|
||||
|
||||
prompt = "" # unconditional
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||
|
||||
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
|
||||
|
@ -1203,12 +1188,10 @@ def test_diffusion_textual_inversion_random_init(
|
|||
conceptExtender.add_concept("<gta5-artwork>", text_embedding_textual_inversion)
|
||||
conceptExtender.inject()
|
||||
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat on a <gta5-artwork>"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
@ -1235,7 +1218,6 @@ def test_diffusion_ip_adapter(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
|
||||
n_steps = 50
|
||||
|
||||
# See tencent-ailab/IP-Adapter best practices section:
|
||||
#
|
||||
|
@ -1254,7 +1236,7 @@ def test_diffusion_ip_adapter(
|
|||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
@ -1281,7 +1263,6 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sdxl = sdxl_ddim.to(dtype=torch.float16)
|
||||
n_steps = 30
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
@ -1298,7 +1279,7 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
sdxl.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||
|
@ -1332,7 +1313,6 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim.to(dtype=torch.float16)
|
||||
n_steps = 50
|
||||
input_image, _ = lora_data_pokemon # use the Pokemon LoRA output as input
|
||||
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
|
||||
|
||||
|
@ -1360,7 +1340,7 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(50)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
@ -1388,7 +1368,6 @@ def test_diffusion_ip_adapter_plus(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim_lda_ft_mse.to(dtype=torch.float16)
|
||||
n_steps = 50
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
@ -1403,7 +1382,7 @@ def test_diffusion_ip_adapter_plus(
|
|||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))
|
||||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(42) # seed=42 is used in the official IP-Adapter demo
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
@ -1430,7 +1409,6 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sdxl = sdxl_ddim.to(dtype=torch.float16)
|
||||
n_steps = 30
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
@ -1448,7 +1426,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
ip_adapter.set_clip_image_embedding(clip_image_embedding)
|
||||
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
sdxl.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||
|
@ -1474,7 +1452,6 @@ def test_sdxl_random_init(
|
|||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
expected_image = expected_sdxl_ddim_random_init
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
@ -1484,7 +1461,7 @@ def test_sdxl_random_init(
|
|||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
|
||||
sdxl.set_inference_steps(30)
|
||||
|
||||
manual_seed(seed=2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device)
|
||||
|
@ -1509,7 +1486,6 @@ def test_sdxl_random_init_sag(
|
|||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
expected_image = expected_sdxl_ddim_random_init_sag
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
@ -1519,7 +1495,7 @@ def test_sdxl_random_init_sag(
|
|||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
sdxl.set_num_inference_steps(num_inference_steps=n_steps)
|
||||
sdxl.set_inference_steps(30)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manual_seed(seed=2)
|
||||
|
@ -1577,7 +1553,6 @@ def test_t2i_adapter_depth(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
name, condition_image, expected_image, weights_path = t2i_adapter_data_depth
|
||||
|
||||
|
@ -1589,7 +1564,7 @@ def test_t2i_adapter_depth(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
t2i_adapter = SD1T2IAdapter(target=sd15.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
|
||||
|
||||
|
@ -1618,7 +1593,6 @@ def test_t2i_adapter_xl_canny(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sdxl = sdxl_ddim
|
||||
n_steps = 30
|
||||
|
||||
name, condition_image, expected_image, weights_path = t2i_adapter_xl_data_canny
|
||||
|
||||
|
@ -1635,7 +1609,7 @@ def test_t2i_adapter_xl_canny(
|
|||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
|
||||
sdxl.set_num_inference_steps(n_steps)
|
||||
sdxl.set_inference_steps(30)
|
||||
|
||||
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
|
||||
t2i_adapter.set_scale(0.8)
|
||||
|
@ -1667,14 +1641,13 @@ def test_restart(
|
|||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(30)
|
||||
restart = Restart(ldm=sd15)
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -1706,23 +1679,21 @@ def test_freeu(
|
|||
expected_freeu: Image.Image,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 50
|
||||
first_step = 1
|
||||
|
||||
prompt = "best quality, high quality cute cat"
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inference_steps(50, first_step=1)
|
||||
|
||||
SDFreeUAdapter(
|
||||
sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2]
|
||||
).inject()
|
||||
|
||||
manual_seed(9752)
|
||||
x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype)
|
||||
x = sd15.init_latents((512, 512)).to(device=sd15.device, dtype=sd15.dtype)
|
||||
|
||||
for step in sd15.steps[first_step:]:
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
|
@ -1770,17 +1741,14 @@ def test_hello_world(
|
|||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
first_step = 1
|
||||
ip_adapter.set_scale(0.85)
|
||||
t2i_adapter.set_scale(0.8)
|
||||
sdxl.set_num_inference_steps(50)
|
||||
sdxl.set_inference_steps(50, first_step=1)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
manual_seed(9752)
|
||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image, first_step=first_step).to(
|
||||
device=sdxl.device, dtype=sdxl.dtype
|
||||
)
|
||||
for step in sdxl.steps[first_step:]:
|
||||
x = sdxl.init_latents(size=(1024, 1024), init_image=init_image).to(device=sdxl.device, dtype=sdxl.dtype)
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
|
|
|
@ -133,5 +133,5 @@ def test_scheduler_device(test_device: Device):
|
|||
scheduler = DDIM(num_inference_steps=30, device=test_device)
|
||||
x = randn(1, 4, 32, 32, device=test_device)
|
||||
noise = randn(1, 4, 32, 32, device=test_device)
|
||||
noised = scheduler.add_noise(x, noise, scheduler.steps[0])
|
||||
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
|
||||
assert noised.device == test_device
|
||||
|
|
Loading…
Reference in a new issue