2021-10-19 20:54:46 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
import modules.functional as PF
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
__all__ = ["FrustumPointNetLoss", "get_box_corners_3d"]
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
class FrustumPointNetLoss(nn.Module):
|
2023-04-11 09:12:58 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_heading_angle_bins,
|
|
|
|
num_size_templates,
|
|
|
|
size_templates,
|
|
|
|
box_loss_weight=1.0,
|
|
|
|
corners_loss_weight=10.0,
|
|
|
|
heading_residual_loss_weight=20.0,
|
|
|
|
size_residual_loss_weight=20.0,
|
|
|
|
):
|
2021-10-19 20:54:46 +00:00
|
|
|
super().__init__()
|
|
|
|
self.box_loss_weight = box_loss_weight
|
|
|
|
self.corners_loss_weight = corners_loss_weight
|
|
|
|
self.heading_residual_loss_weight = heading_residual_loss_weight
|
|
|
|
self.size_residual_loss_weight = size_residual_loss_weight
|
|
|
|
|
|
|
|
self.num_heading_angle_bins = num_heading_angle_bins
|
|
|
|
self.num_size_templates = num_size_templates
|
2023-04-11 09:12:58 +00:00
|
|
|
self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3))
|
2021-10-19 20:54:46 +00:00
|
|
|
self.register_buffer(
|
2023-04-11 09:12:58 +00:00
|
|
|
"heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, inputs, targets):
|
2023-04-11 09:12:58 +00:00
|
|
|
mask_logits = inputs["mask_logits"] # (B, 2, N)
|
|
|
|
center_reg = inputs["center_reg"] # (B, 3)
|
|
|
|
center = inputs["center"] # (B, 3)
|
|
|
|
heading_scores = inputs["heading_scores"] # (B, NH)
|
|
|
|
heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH)
|
|
|
|
heading_residuals = inputs["heading_residuals"] # (B, NH)
|
|
|
|
size_scores = inputs["size_scores"] # (B, NS)
|
|
|
|
size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3)
|
|
|
|
size_residuals = inputs["size_residuals"] # (B, NS, 3)
|
|
|
|
|
|
|
|
mask_logits_target = targets["mask_logits"] # (B, N)
|
|
|
|
center_target = targets["center"] # (B, 3)
|
|
|
|
heading_bin_id_target = targets["heading_bin_id"] # (B, )
|
|
|
|
heading_residual_target = targets["heading_residual"] # (B, )
|
|
|
|
size_template_id_target = targets["size_template_id"] # (B, )
|
|
|
|
size_residual_target = targets["size_residual"] # (B, 3)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
batch_size = center.size(0)
|
|
|
|
batch_id = torch.arange(batch_size, device=center.device)
|
|
|
|
|
|
|
|
# Basic Classification and Regression losses
|
|
|
|
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
|
|
|
|
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
|
|
|
|
size_loss = F.cross_entropy(size_scores, size_template_id_target)
|
|
|
|
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
|
|
|
|
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
|
|
|
|
|
|
|
|
# Refinement losses for size/heading
|
|
|
|
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
|
|
|
|
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
|
|
|
|
heading_residual_normalized_loss = PF.huber_loss(
|
|
|
|
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
|
|
|
|
)
|
|
|
|
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
|
|
|
|
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
|
|
|
|
size_residual_normalized_loss = PF.huber_loss(
|
|
|
|
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
|
|
|
|
)
|
|
|
|
|
|
|
|
# Bounding box losses
|
2023-04-11 09:12:58 +00:00
|
|
|
heading = (
|
|
|
|
heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]
|
|
|
|
) # (B, )
|
2021-10-19 20:54:46 +00:00
|
|
|
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
2023-04-11 09:12:58 +00:00
|
|
|
size = (
|
|
|
|
size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]
|
|
|
|
) # (B, 3)
|
2021-10-19 20:54:46 +00:00
|
|
|
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
|
|
|
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
|
|
|
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
2023-04-11 09:12:58 +00:00
|
|
|
corners_target, corners_target_flip = get_box_corners_3d(
|
|
|
|
centers=center_target, headings=heading_target, sizes=size_target, with_flip=True
|
|
|
|
) # (B, 3, 8)
|
|
|
|
corners_loss = PF.huber_loss(
|
|
|
|
torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)),
|
|
|
|
delta=1.0,
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
# Summing up
|
|
|
|
loss = mask_loss + self.box_loss_weight * (
|
2023-04-11 09:12:58 +00:00
|
|
|
center_loss
|
|
|
|
+ center_reg_loss
|
|
|
|
+ heading_loss
|
|
|
|
+ size_loss
|
|
|
|
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
|
|
|
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
|
|
|
+ self.corners_loss_weight * corners_loss
|
2021-10-19 20:54:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
|
|
|
"""
|
|
|
|
:param centers: coords of box centers, FloatTensor[N, 3]
|
|
|
|
:param headings: heading angles, FloatTensor[N, ]
|
|
|
|
:param sizes: box sizes, FloatTensor[N, 3]
|
|
|
|
:param with_flip: bool, whether to return flipped box (headings + np.pi)
|
|
|
|
:return:
|
|
|
|
coords of box corners, FloatTensor[N, 3, 8]
|
|
|
|
NOTE: corner points are in counter clockwise order, e.g.,
|
|
|
|
2--1
|
|
|
|
3--0 5
|
|
|
|
7--4
|
|
|
|
"""
|
|
|
|
l = sizes[:, 0] # (N,)
|
|
|
|
w = sizes[:, 1] # (N,)
|
|
|
|
h = sizes[:, 2] # (N,)
|
2023-04-11 09:12:58 +00:00
|
|
|
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
|
|
|
|
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
|
|
|
|
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
c = torch.cos(headings) # (N,)
|
|
|
|
s = torch.sin(headings) # (N,)
|
|
|
|
o = torch.ones_like(headings) # (N,)
|
|
|
|
z = torch.zeros_like(headings) # (N,)
|
|
|
|
|
|
|
|
centers = centers.unsqueeze(-1) # (B, 3, 1)
|
|
|
|
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
|
|
|
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
|
|
|
|
if with_flip:
|
|
|
|
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
|
|
|
|
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
|
|
|
|
else:
|
|
|
|
return torch.matmul(R, corners) + centers
|
|
|
|
|
|
|
|
# centers = centers.unsqueeze(1) # (B, 1, 3)
|
|
|
|
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
|
|
|
|
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
|
|
|
# if with_flip:
|
|
|
|
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
|
|
|
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
|
|
|
|
# else:
|
|
|
|
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
|
|
|
|
|
|
|
|
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
|
|
|
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
|
|
|
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
|
|
|
|
# corners = corners.transpose(1, 2) # (N, 8, 3)
|