From f4298f87d2f6245ab7dad3ddb351034fbeb189f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 4 Oct 2023 17:36:04 +0200 Subject: [PATCH] pad: add optional padding mode --- src/refiners/fluxion/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index c4be8d0..fb1ce6f 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -23,8 +23,8 @@ def manual_seed(seed: int) -> None: _manual_seed(seed) -def pad(x: Tensor, pad: Iterable[int], value: float = 0.0) -> Tensor: - return _pad(input=x, pad=pad, value=value) # type: ignore +def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor: + return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor: