mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
simplify interpolate function and layer
This commit is contained in:
parent
6c37e3f933
commit
0336bc78b5
|
@ -16,15 +16,26 @@ class Interpolate(Module):
|
||||||
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
|
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: str = "nearest",
|
||||||
|
antialias: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
shape: Size,
|
shape: Size,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
return interpolate(x, shape)
|
return interpolate(
|
||||||
|
x=x,
|
||||||
|
size=shape,
|
||||||
|
mode=self.mode,
|
||||||
|
antialias=self.antialias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Downsample(Chain):
|
class Downsample(Chain):
|
||||||
|
|
|
@ -40,12 +40,18 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant
|
||||||
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
|
def interpolate(
|
||||||
return (
|
x: Tensor,
|
||||||
_interpolate(x, scale_factor=factor, mode=mode)
|
size: torch.Size,
|
||||||
if isinstance(factor, float | int)
|
mode: str = "nearest",
|
||||||
else _interpolate(x, size=factor, mode=mode)
|
antialias: bool = False,
|
||||||
) # type: ignore
|
) -> Tensor:
|
||||||
|
return _interpolate( # type: ignore
|
||||||
|
input=x,
|
||||||
|
size=size,
|
||||||
|
mode=mode,
|
||||||
|
antialias=antialias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
|
|
@ -228,7 +228,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
|
|
||||||
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
||||||
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
||||||
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
|
self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))
|
||||||
|
|
||||||
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||||
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||||
|
|
|
@ -232,9 +232,9 @@ class SegmentAnything(fl.Chain):
|
||||||
Returns:
|
Returns:
|
||||||
The postprocessed masks.
|
The postprocessed masks.
|
||||||
"""
|
"""
|
||||||
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||||
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue