fix broken dtypes in tiled auto encoders

This commit is contained in:
Laurent 2024-07-11 13:06:19 +00:00 committed by Laureηt
parent f44ae150a7
commit 6ddd5435b8

View file

@ -256,14 +256,15 @@ def _create_blending_mask(
blending: int,
num_channels: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
) -> torch.Tensor:
mask = torch.ones(size, device=device)
mask = torch.ones(size, device=device, dtype=dtype)
if blending == 0:
return mask
blending = min(blending, min(size) // 2)
ramp = torch.linspace(0, 1, steps=blending, device=device)
ramp = torch.linspace(0, 1, steps=blending, device=device, dtype=dtype)
# Apply ramps only if not at the corresponding edge
if not is_edge[0]: # top
@ -445,8 +446,8 @@ class LatentDiffusionAutoencoder(Chain):
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
image_tensor = image_to_tensor(image, device=self.device)
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device, dtype=self.dtype)
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
@ -481,7 +482,7 @@ class LatentDiffusionAutoencoder(Chain):
if len(tiles) == 1:
return self.encode(image_tensor)
result = torch.zeros((1, 4, *latent_size), device=self.device)
result = torch.zeros((1, 4, *latent_size), device=self.device, dtype=self.dtype)
weights = torch.zeros_like(result)
for latent_tile in tiles:
@ -509,6 +510,7 @@ class LatentDiffusionAutoencoder(Chain):
blending // 8,
num_channels=4,
device=self.device,
dtype=self.dtype,
is_edge=is_edge,
)
@ -543,7 +545,7 @@ class LatentDiffusionAutoencoder(Chain):
if len(tiles) == 1:
return self.decode(latents)
result = torch.zeros((1, 3, *pixel_size), device=self.device)
result = torch.zeros((1, 3, *pixel_size), device=self.device, dtype=self.dtype)
weights = torch.zeros_like(result)
for latent_tile in tiles:
@ -570,7 +572,12 @@ class LatentDiffusionAutoencoder(Chain):
)
pixel_tile_mask = _create_blending_mask(
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
size=pixel_tile_size,
blending=blending,
num_channels=3,
device=self.device,
dtype=self.dtype,
is_edge=is_edge,
)
result[
:,