init commit
This commit is contained in:
parent
12757682f1
commit
54c2783479
|
@ -17,7 +17,6 @@ class DPSR(nn.Module):
|
||||||
self.dim = len(res)
|
self.dim = len(res)
|
||||||
self.denom = np.prod(res)
|
self.denom = np.prod(res)
|
||||||
G = spec_gaussian_filter(res=res, sig=sig).float()
|
G = spec_gaussian_filter(res=res, sig=sig).float()
|
||||||
G = G
|
|
||||||
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
|
# self.G.requires_grad = False # True, if we also make sig a learnable parameter
|
||||||
self.omega = fftfreqs(res, dtype=torch.float32)
|
self.omega = fftfreqs(res, dtype=torch.float32)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
@ -33,11 +32,6 @@ class DPSR(nn.Module):
|
||||||
assert(V.shape == N.shape) # [b, nv, ndims]
|
assert(V.shape == N.shape) # [b, nv, ndims]
|
||||||
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
|
||||||
|
|
||||||
#!!! OLD
|
|
||||||
# ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
|
|
||||||
# ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2]))
|
|
||||||
# N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
|
|
||||||
|
|
||||||
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
|
||||||
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
|
||||||
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
|
||||||
|
@ -46,7 +40,6 @@ class DPSR(nn.Module):
|
||||||
omega *= 2 * np.pi # normalize frequencies
|
omega *= 2 * np.pi # normalize frequencies
|
||||||
omega = omega.to(V.device)
|
omega = omega.to(V.device)
|
||||||
|
|
||||||
# DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2]
|
|
||||||
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
|
||||||
|
|
||||||
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
|
||||||
|
@ -55,7 +48,6 @@ class DPSR(nn.Module):
|
||||||
Phi[tuple([0] * self.dim)] = 0
|
Phi[tuple([0] * self.dim)] = 0
|
||||||
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
|
||||||
|
|
||||||
# phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2]
|
|
||||||
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
|
||||||
|
|
||||||
if self.shift or self.scale:
|
if self.shift or self.scale:
|
||||||
|
@ -70,6 +62,5 @@ class DPSR(nn.Module):
|
||||||
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
# phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5
|
|
||||||
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
|
||||||
return phi
|
return phi
|
Loading…
Reference in a new issue