From 0336bc78b5d6964ed0c2ff5150dba377823fcc31 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 29 Mar 2024 17:41:06 +0000 Subject: [PATCH] simplify interpolate function and layer --- src/refiners/fluxion/layers/sampling.py | 15 +++++++++++++-- src/refiners/fluxion/utils.py | 18 ++++++++++++------ .../stable_diffusion_1/model.py | 2 +- .../foundationals/segment_anything/model.py | 4 ++-- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index 135f82a..dffd813 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -16,15 +16,26 @@ class Interpolate(Module): This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate]. """ - def __init__(self) -> None: + def __init__( + self, + mode: str = "nearest", + antialias: bool = False, + ) -> None: super().__init__() + self.mode = mode + self.antialias = antialias def forward( self, x: Tensor, shape: Size, ) -> Tensor: - return interpolate(x, shape) + return interpolate( + x=x, + size=shape, + mode=self.mode, + antialias=self.antialias, + ) class Downsample(Chain): diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 7184a4a..c37df2a 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -40,12 +40,18 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore -def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor: - return ( - _interpolate(x, scale_factor=factor, mode=mode) - if isinstance(factor, float | int) - else _interpolate(x, size=factor, mode=mode) - ) # type: ignore +def interpolate( + x: Tensor, + size: torch.Size, + mode: str = "nearest", + antialias: bool = False, +) -> Tensor: + return _interpolate( # type: ignore + input=x, + size=size, + mode=mode, + antialias=antialias, + ) # Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 873fc4b..8c33d44 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -228,7 +228,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device) mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype) - self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size)) + self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size)) init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1 masked_init_image = init_image_tensor * (1 - mask_tensor) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 7a50ddf..b8a50d8 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -232,9 +232,9 @@ class SegmentAnything(fl.Chain): Returns: The postprocessed masks. """ - masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear") + masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear") masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time - masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear") + masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear") return masks