init commit

This commit is contained in:
pengsongyou 2021-11-08 11:16:33 +01:00
parent 12757682f1
commit 54c2783479

View file

@ -17,7 +17,6 @@ class DPSR(nn.Module):
self.dim = len(res) self.dim = len(res)
self.denom = np.prod(res) self.denom = np.prod(res)
G = spec_gaussian_filter(res=res, sig=sig).float() G = spec_gaussian_filter(res=res, sig=sig).float()
G = G
# self.G.requires_grad = False # True, if we also make sig a learnable parameter # self.G.requires_grad = False # True, if we also make sig a learnable parameter
self.omega = fftfreqs(res, dtype=torch.float32) self.omega = fftfreqs(res, dtype=torch.float32)
self.scale = scale self.scale = scale
@ -33,11 +32,6 @@ class DPSR(nn.Module):
assert(V.shape == N.shape) # [b, nv, ndims] assert(V.shape == N.shape) # [b, nv, ndims]
ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2] ras_p = point_rasterize(V, N, self.res) # [b, n_dim, dim0, dim1, dim2]
#!!! OLD
# ras_s = torch.rfft(ras_p, signal_ndim=self.dim) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
# ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+2))+[1, self.dim+2]))
# N_ = (ras_s * self.G) # [b, n_dim, dim0, dim1, dim2/2+1, 2]
ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4)) ras_s = torch.fft.rfftn(ras_p, dim=(2,3,4))
ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1])) ras_s = ras_s.permute(*tuple([0]+list(range(2, self.dim+1))+[self.dim+1, 1]))
N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1] N_ = ras_s[..., None] * self.G # [b, dim0, dim1, dim2/2+1, n_dim, 1]
@ -46,7 +40,6 @@ class DPSR(nn.Module):
omega *= 2 * np.pi # normalize frequencies omega *= 2 * np.pi # normalize frequencies
omega = omega.to(V.device) omega = omega.to(V.device)
# DivN = torch.sum(-img(N_) * omega, dim=-2) #!!! OLD [b, dim0, dim1, dim2/2+1, 2]
DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2) DivN = torch.sum(-img(torch.view_as_real(N_[..., 0])) * omega, dim=-2)
Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1] Lap = -torch.sum(omega**2, -2) # [dim0, dim1, dim2/2+1, 1]
@ -55,7 +48,6 @@ class DPSR(nn.Module):
Phi[tuple([0] * self.dim)] = 0 Phi[tuple([0] * self.dim)] = 0
Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2] Phi = Phi.permute(*tuple([[self.dim+1] + list(range(self.dim+1))])) # [b, dim0, dim1, dim2/2+1, 2]
# phi = torch.irfft(Phi, signal_ndim=self.dim, signal_sizes=self.res)#!!! OLD [b, dim0, dim1, dim2]
phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3)) phi = torch.fft.irfftn(torch.view_as_complex(Phi), s=self.res, dim=(1,2,3))
if self.shift or self.scale: if self.shift or self.scale:
@ -70,6 +62,5 @@ class DPSR(nn.Module):
phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))])) phi = phi.permute(*tuple([[self.dim] + list(range(self.dim))]))
if self.scale: if self.scale:
# phi = phi / fv0.view(*tuple([-1] + [1] * self.dim)) * 0.5
phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5 phi = -phi / torch.abs(fv0.view(*tuple([-1]+[1] * self.dim))) *0.5
return phi return phi