fix the std bug

leads to unstable testing
This commit is contained in:
Xu Ma 2022-02-16 01:08:56 -05:00
parent ceb37f825d
commit abc917654c
4 changed files with 13 additions and 8 deletions

View file

@ -1,3 +1,7 @@
"""
Usage:
python main.py --model PointMLP --msg demo
"""
import argparse import argparse
import os import os
import logging import logging
@ -28,6 +32,7 @@ def parse_args():
parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training') parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training')
parser.add_argument('--num_points', type=int, default=1024, help='Point Number') parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training') parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training')
parser.add_argument('--min_lr', default=0.005, type=float, help='min lr')
parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate') parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate')
parser.add_argument('--seed', type=int, help='random seed') parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument('--workers', default=8, type=int, help='workers') parser.add_argument('--workers', default=8, type=int, help='workers')
@ -122,7 +127,7 @@ def main():
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None: if optimizer_dict is not None:
optimizer.load_state_dict(optimizer_dict) optimizer.load_state_dict(optimizer_dict)
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=1e-3, last_epoch=start_epoch - 1) scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch): for epoch in range(start_epoch, args.epoch):
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr'])) printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))

View file

@ -170,11 +170,11 @@ class LocalGrouper(nn.Module):
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize =="center":
std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std(grouped_points-mean) std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta

View file

@ -170,11 +170,11 @@ class LocalGrouper(nn.Module):
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize =="center":
std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std(grouped_points-mean) std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta

View file

@ -169,11 +169,11 @@ class LocalGrouper(nn.Module):
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize =="center":
std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std(grouped_points-mean) std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta