From abc917654cc39a7355ac87e9c95af403bb413926 Mon Sep 17 00:00:00 2001 From: Xu Ma Date: Wed, 16 Feb 2022 01:08:56 -0500 Subject: [PATCH] fix the std bug leads to unstable testing --- classification_ModelNet40/main.py | 7 ++++++- classification_ModelNet40/models/pointmlp.py | 4 ++-- classification_ScanObjectNN/models/pointmlp.py | 4 ++-- part_segmentation/models/pointMLP.py | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/classification_ModelNet40/main.py b/classification_ModelNet40/main.py index 330a52f..2d0ef92 100644 --- a/classification_ModelNet40/main.py +++ b/classification_ModelNet40/main.py @@ -1,3 +1,7 @@ +""" +Usage: +python main.py --model PointMLP --msg demo +""" import argparse import os import logging @@ -28,6 +32,7 @@ def parse_args(): parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training') parser.add_argument('--num_points', type=int, default=1024, help='Point Number') parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training') + parser.add_argument('--min_lr', default=0.005, type=float, help='min lr') parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate') parser.add_argument('--seed', type=int, help='random seed') parser.add_argument('--workers', default=8, type=int, help='workers') @@ -122,7 +127,7 @@ def main(): optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) if optimizer_dict is not None: optimizer.load_state_dict(optimizer_dict) - scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=1e-3, last_epoch=start_epoch - 1) + scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1) for epoch in range(start_epoch, args.epoch): printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr'])) diff --git a/classification_ModelNet40/models/pointmlp.py b/classification_ModelNet40/models/pointmlp.py index c6dce17..597ba46 100644 --- a/classification_ModelNet40/models/pointmlp.py +++ b/classification_ModelNet40/models/pointmlp.py @@ -170,11 +170,11 @@ class LocalGrouper(nn.Module): grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] if self.normalize is not None: if self.normalize =="center": - std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) + mean = torch.mean(grouped_points, dim=2, keepdim=True) if self.normalize =="anchor": mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] - std = torch.std(grouped_points-mean) + std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = self.affine_alpha*grouped_points + self.affine_beta diff --git a/classification_ScanObjectNN/models/pointmlp.py b/classification_ScanObjectNN/models/pointmlp.py index c6dce17..597ba46 100644 --- a/classification_ScanObjectNN/models/pointmlp.py +++ b/classification_ScanObjectNN/models/pointmlp.py @@ -170,11 +170,11 @@ class LocalGrouper(nn.Module): grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] if self.normalize is not None: if self.normalize =="center": - std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) + mean = torch.mean(grouped_points, dim=2, keepdim=True) if self.normalize =="anchor": mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] - std = torch.std(grouped_points-mean) + std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = self.affine_alpha*grouped_points + self.affine_beta diff --git a/part_segmentation/models/pointMLP.py b/part_segmentation/models/pointMLP.py index 5ce90fb..d5c5256 100644 --- a/part_segmentation/models/pointMLP.py +++ b/part_segmentation/models/pointMLP.py @@ -169,11 +169,11 @@ class LocalGrouper(nn.Module): grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] if self.normalize is not None: if self.normalize =="center": - std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) + mean = torch.mean(grouped_points, dim=2, keepdim=True) if self.normalize =="anchor": mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points - mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] - std = torch.std(grouped_points-mean) + mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] + std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = self.affine_alpha*grouped_points + self.affine_beta