fix broken dtypes in tiled auto encoders

This commit is contained in:
Laurent 2024-07-11 13:06:19 +00:00 committed by Laureηt
parent f44ae150a7
commit 6ddd5435b8

View file

@ -256,14 +256,15 @@ def _create_blending_mask(
blending: int, blending: int,
num_channels: int, num_channels: int,
device: torch.device | None = None, device: torch.device | None = None,
dtype: torch.dtype | None = None,
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False), is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
) -> torch.Tensor: ) -> torch.Tensor:
mask = torch.ones(size, device=device) mask = torch.ones(size, device=device, dtype=dtype)
if blending == 0: if blending == 0:
return mask return mask
blending = min(blending, min(size) // 2) blending = min(blending, min(size) // 2)
ramp = torch.linspace(0, 1, steps=blending, device=device) ramp = torch.linspace(0, 1, steps=blending, device=device, dtype=dtype)
# Apply ramps only if not at the corresponding edge # Apply ramps only if not at the corresponding edge
if not is_edge[0]: # top if not is_edge[0]: # top
@ -445,8 +446,8 @@ class LatentDiffusionAutoencoder(Chain):
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
image_tensor = image_to_tensor(image, device=self.device) image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device) downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device, dtype=self.dtype)
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max()) downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True) std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
@ -481,7 +482,7 @@ class LatentDiffusionAutoencoder(Chain):
if len(tiles) == 1: if len(tiles) == 1:
return self.encode(image_tensor) return self.encode(image_tensor)
result = torch.zeros((1, 4, *latent_size), device=self.device) result = torch.zeros((1, 4, *latent_size), device=self.device, dtype=self.dtype)
weights = torch.zeros_like(result) weights = torch.zeros_like(result)
for latent_tile in tiles: for latent_tile in tiles:
@ -509,6 +510,7 @@ class LatentDiffusionAutoencoder(Chain):
blending // 8, blending // 8,
num_channels=4, num_channels=4,
device=self.device, device=self.device,
dtype=self.dtype,
is_edge=is_edge, is_edge=is_edge,
) )
@ -543,7 +545,7 @@ class LatentDiffusionAutoencoder(Chain):
if len(tiles) == 1: if len(tiles) == 1:
return self.decode(latents) return self.decode(latents)
result = torch.zeros((1, 3, *pixel_size), device=self.device) result = torch.zeros((1, 3, *pixel_size), device=self.device, dtype=self.dtype)
weights = torch.zeros_like(result) weights = torch.zeros_like(result)
for latent_tile in tiles: for latent_tile in tiles:
@ -570,7 +572,12 @@ class LatentDiffusionAutoencoder(Chain):
) )
pixel_tile_mask = _create_blending_mask( pixel_tile_mask = _create_blending_mask(
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge size=pixel_tile_size,
blending=blending,
num_channels=3,
device=self.device,
dtype=self.dtype,
is_edge=is_edge,
) )
result[ result[
:, :,