PointFlow/metrics/pytorch_structural_losses/match_cost.py
2019-07-13 21:32:26 -07:00

46 lines
1.8 KiB
Python

import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply