mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
simplify interpolate function and layer
This commit is contained in:
parent
6c37e3f933
commit
0336bc78b5
|
@ -16,15 +16,26 @@ class Interpolate(Module):
|
|||
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "nearest",
|
||||
antialias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.antialias = antialias
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
shape: Size,
|
||||
) -> Tensor:
|
||||
return interpolate(x, shape)
|
||||
return interpolate(
|
||||
x=x,
|
||||
size=shape,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
|
||||
|
||||
class Downsample(Chain):
|
||||
|
|
|
@ -40,12 +40,18 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant
|
|||
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
||||
|
||||
|
||||
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
|
||||
return (
|
||||
_interpolate(x, scale_factor=factor, mode=mode)
|
||||
if isinstance(factor, float | int)
|
||||
else _interpolate(x, size=factor, mode=mode)
|
||||
) # type: ignore
|
||||
def interpolate(
|
||||
x: Tensor,
|
||||
size: torch.Size,
|
||||
mode: str = "nearest",
|
||||
antialias: bool = False,
|
||||
) -> Tensor:
|
||||
return _interpolate( # type: ignore
|
||||
input=x,
|
||||
size=size,
|
||||
mode=mode,
|
||||
antialias=antialias,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||
|
|
|
@ -228,7 +228,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
|
||||
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
||||
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
||||
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
|
||||
self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))
|
||||
|
||||
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||
|
|
|
@ -232,9 +232,9 @@ class SegmentAnything(fl.Chain):
|
|||
Returns:
|
||||
The postprocessed masks.
|
||||
"""
|
||||
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
||||
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
|
||||
return masks
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue