mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
fix broken dtypes in tiled auto encoders
This commit is contained in:
parent
436fb091ed
commit
15b1ff0e2e
|
@ -256,14 +256,15 @@ def _create_blending_mask(
|
||||||
blending: int,
|
blending: int,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
|
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mask = torch.ones(size, device=device)
|
mask = torch.ones(size, device=device, dtype=dtype)
|
||||||
if blending == 0:
|
if blending == 0:
|
||||||
return mask
|
return mask
|
||||||
blending = min(blending, min(size) // 2)
|
blending = min(blending, min(size) // 2)
|
||||||
|
|
||||||
ramp = torch.linspace(0, 1, steps=blending, device=device)
|
ramp = torch.linspace(0, 1, steps=blending, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Apply ramps only if not at the corresponding edge
|
# Apply ramps only if not at the corresponding edge
|
||||||
if not is_edge[0]: # top
|
if not is_edge[0]: # top
|
||||||
|
@ -445,8 +446,8 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
|
|
||||||
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
|
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
|
||||||
|
|
||||||
image_tensor = image_to_tensor(image, device=self.device)
|
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||||
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
|
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device, dtype=self.dtype)
|
||||||
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
|
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
|
||||||
|
|
||||||
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
|
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||||
|
@ -481,7 +482,7 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
if len(tiles) == 1:
|
if len(tiles) == 1:
|
||||||
return self.encode(image_tensor)
|
return self.encode(image_tensor)
|
||||||
|
|
||||||
result = torch.zeros((1, 4, *latent_size), device=self.device)
|
result = torch.zeros((1, 4, *latent_size), device=self.device, dtype=self.dtype)
|
||||||
weights = torch.zeros_like(result)
|
weights = torch.zeros_like(result)
|
||||||
|
|
||||||
for latent_tile in tiles:
|
for latent_tile in tiles:
|
||||||
|
@ -509,6 +510,7 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
blending // 8,
|
blending // 8,
|
||||||
num_channels=4,
|
num_channels=4,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
is_edge=is_edge,
|
is_edge=is_edge,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -543,7 +545,7 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
if len(tiles) == 1:
|
if len(tiles) == 1:
|
||||||
return self.decode(latents)
|
return self.decode(latents)
|
||||||
|
|
||||||
result = torch.zeros((1, 3, *pixel_size), device=self.device)
|
result = torch.zeros((1, 3, *pixel_size), device=self.device, dtype=self.dtype)
|
||||||
weights = torch.zeros_like(result)
|
weights = torch.zeros_like(result)
|
||||||
|
|
||||||
for latent_tile in tiles:
|
for latent_tile in tiles:
|
||||||
|
@ -570,7 +572,12 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
pixel_tile_mask = _create_blending_mask(
|
pixel_tile_mask = _create_blending_mask(
|
||||||
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
|
size=pixel_tile_size,
|
||||||
|
blending=blending,
|
||||||
|
num_channels=3,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
is_edge=is_edge,
|
||||||
)
|
)
|
||||||
result[
|
result[
|
||||||
:,
|
:,
|
||||||
|
|
Loading…
Reference in a new issue