simplify interpolate function and layer

This commit is contained in:
Laurent 2024-03-29 17:41:06 +00:00 committed by Laureηt
parent 6c37e3f933
commit 0336bc78b5
4 changed files with 28 additions and 11 deletions

View file

@ -16,15 +16,26 @@ class Interpolate(Module):
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
"""
def __init__(self) -> None:
def __init__(
self,
mode: str = "nearest",
antialias: bool = False,
) -> None:
super().__init__()
self.mode = mode
self.antialias = antialias
def forward(
self,
x: Tensor,
shape: Size,
) -> Tensor:
return interpolate(x, shape)
return interpolate(
x=x,
size=shape,
mode=self.mode,
antialias=self.antialias,
)
class Downsample(Chain):

View file

@ -40,12 +40,18 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
return (
_interpolate(x, scale_factor=factor, mode=mode)
if isinstance(factor, float | int)
else _interpolate(x, size=factor, mode=mode)
) # type: ignore
def interpolate(
x: Tensor,
size: torch.Size,
mode: str = "nearest",
antialias: bool = False,
) -> Tensor:
return _interpolate( # type: ignore
input=x,
size=size,
mode=mode,
antialias=antialias,
)
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py

View file

@ -228,7 +228,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)

View file

@ -232,9 +232,9 @@ class SegmentAnything(fl.Chain):
Returns:
The postprocessed masks.
"""
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
return masks