mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix broken dtypes in tiled auto encoders
This commit is contained in:
parent
f44ae150a7
commit
6ddd5435b8
|
@ -256,14 +256,15 @@ def _create_blending_mask(
|
|||
blending: int,
|
||||
num_channels: int,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
|
||||
) -> torch.Tensor:
|
||||
mask = torch.ones(size, device=device)
|
||||
mask = torch.ones(size, device=device, dtype=dtype)
|
||||
if blending == 0:
|
||||
return mask
|
||||
blending = min(blending, min(size) // 2)
|
||||
|
||||
ramp = torch.linspace(0, 1, steps=blending, device=device)
|
||||
ramp = torch.linspace(0, 1, steps=blending, device=device, dtype=dtype)
|
||||
|
||||
# Apply ramps only if not at the corresponding edge
|
||||
if not is_edge[0]: # top
|
||||
|
@ -445,8 +446,8 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
|
||||
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
|
||||
|
||||
image_tensor = image_to_tensor(image, device=self.device)
|
||||
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
|
||||
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device, dtype=self.dtype)
|
||||
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
|
||||
|
||||
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||
|
@ -481,7 +482,7 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
if len(tiles) == 1:
|
||||
return self.encode(image_tensor)
|
||||
|
||||
result = torch.zeros((1, 4, *latent_size), device=self.device)
|
||||
result = torch.zeros((1, 4, *latent_size), device=self.device, dtype=self.dtype)
|
||||
weights = torch.zeros_like(result)
|
||||
|
||||
for latent_tile in tiles:
|
||||
|
@ -509,6 +510,7 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
blending // 8,
|
||||
num_channels=4,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
is_edge=is_edge,
|
||||
)
|
||||
|
||||
|
@ -543,7 +545,7 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
if len(tiles) == 1:
|
||||
return self.decode(latents)
|
||||
|
||||
result = torch.zeros((1, 3, *pixel_size), device=self.device)
|
||||
result = torch.zeros((1, 3, *pixel_size), device=self.device, dtype=self.dtype)
|
||||
weights = torch.zeros_like(result)
|
||||
|
||||
for latent_tile in tiles:
|
||||
|
@ -570,7 +572,12 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
)
|
||||
|
||||
pixel_tile_mask = _create_blending_mask(
|
||||
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
|
||||
size=pixel_tile_size,
|
||||
blending=blending,
|
||||
num_channels=3,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
is_edge=is_edge,
|
||||
)
|
||||
result[
|
||||
:,
|
||||
|
|
Loading…
Reference in a new issue