diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index ba8b3f2..4ccefc4 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -256,14 +256,15 @@ def _create_blending_mask( blending: int, num_channels: int, device: torch.device | None = None, + dtype: torch.dtype | None = None, is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False), ) -> torch.Tensor: - mask = torch.ones(size, device=device) + mask = torch.ones(size, device=device, dtype=dtype) if blending == 0: return mask blending = min(blending, min(size) // 2) - ramp = torch.linspace(0, 1, steps=blending, device=device) + ramp = torch.linspace(0, 1, steps=blending, device=device, dtype=dtype) # Apply ramps only if not at the corresponding edge if not is_edge[0]: # top @@ -445,8 +446,8 @@ class LatentDiffusionAutoencoder(Chain): downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore - image_tensor = image_to_tensor(image, device=self.device) - downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device) + image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype) + downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device, dtype=self.dtype) downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max()) std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True) @@ -481,7 +482,7 @@ class LatentDiffusionAutoencoder(Chain): if len(tiles) == 1: return self.encode(image_tensor) - result = torch.zeros((1, 4, *latent_size), device=self.device) + result = torch.zeros((1, 4, *latent_size), device=self.device, dtype=self.dtype) weights = torch.zeros_like(result) for latent_tile in tiles: @@ -509,6 +510,7 @@ class LatentDiffusionAutoencoder(Chain): blending // 8, num_channels=4, device=self.device, + dtype=self.dtype, is_edge=is_edge, ) @@ -543,7 +545,7 @@ class LatentDiffusionAutoencoder(Chain): if len(tiles) == 1: return self.decode(latents) - result = torch.zeros((1, 3, *pixel_size), device=self.device) + result = torch.zeros((1, 3, *pixel_size), device=self.device, dtype=self.dtype) weights = torch.zeros_like(result) for latent_tile in tiles: @@ -570,7 +572,12 @@ class LatentDiffusionAutoencoder(Chain): ) pixel_tile_mask = _create_blending_mask( - pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge + size=pixel_tile_size, + blending=blending, + num_channels=3, + device=self.device, + dtype=self.dtype, + is_edge=is_edge, ) result[ :,