Compare commits

..

No commits in common. "4450cad9a31feb5477e914d1283350c5d5de9ccd" and "3e3d80cff5c23a631fe5ba4ca97db3f452893ed2" have entirely different histories.

51 changed files with 1471 additions and 2075 deletions

3
.gitignore vendored
View file

@ -1,6 +1,3 @@
data
checkpoints
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View file

@ -1,9 +0,0 @@
{
"recommendations": [
"editorconfig.editorconfig",
"eamodio.gitlens",
"ms-python.python",
"ms-python.black-formatter",
"charliermarsh.ruff",
]
}

17
.vscode/launch.json vendored
View file

@ -1,17 +0,0 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"OMP_NUM_THREADS": "1",
"CUDA_VISIBLE_DEVICES": "1",
}
}
]
}

67
.vscode/settings.json vendored
View file

@ -1,67 +0,0 @@
{
// nice editor settings
"editor.formatOnSave": true,
"editor.formatOnPaste": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true,
"source.fixAll": false,
},
"editor.rulers": [
120
],
// editorconfig redundancy
"files.insertFinalNewline": true,
"files.trimTrailingWhitespace": true,
// hidde unimportant files/folders
"files.exclude": {
// defaults
"**/.git": true,
"**/.svn": true,
"**/.hg": true,
"**/CVS": true,
"**/.DS_Store": true,
"**/Thumbs.db": true,
// annoying
"**/__pycache__": true,
"**/.mypy_cache": true,
"**/.ruff_cache": true,
"**/*.tmp": true,
},
// cpp /clang / cmake settings
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools",
"C_Cpp.intelliSenseEngine": "disabled",
"C_Cpp.intelliSenseEngineFallback": "enabled",
"C_Cpp.clang_format_path": "/softs/compiler/llvm/latest/bin/clang-format",
"C_Cpp.codeAnalysis.clangTidy.enabled": true,
"C_Cpp.codeAnalysis.clangTidy.path": "/softs/compiler/llvm/latest/bin/clang-tidy",
"clangd.path": "/softs/compiler/llvm/latest/bin/clangd",
"cmake.cmakePath": "/softs/cmake/latest/bin/cmake",
"cmake.preferredGenerators": [
"Ninja",
"Unix Makefiles"
],
"cmakeFormat.exePath": "/softs/conda/auto/envs/cmake-format/bin/cmake-format",
"cmake.languageSupport.dotnetPath": "/softs/conda/auto/envs/dotnet/lib/dotnet/dotnet",
// python settings
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pointmlp/bin/python",
"python.linting.enabled": true,
"python.linting.lintOnSave": true,
"python.linting.mypyEnabled": true,
// fixes for broken auto-activation on rosetta
"python.terminal.activateEnvironment": false,
"terminal.integrated.profiles.linux": {
"python": {
"path": "bash",
"icon": "rocket",
"args": [
"--init-file",
".vscode/setup.sh"
],
}
},
"terminal.integrated.env.linux": {
"PYTHONPATH": "${workspaceFolder}/src/",
"SLURM_JOB_ID": null, // unset or else lightning_logs v_num uses it
},
}

11
.vscode/setup.sh vendored
View file

@ -1,11 +0,0 @@
#!/bin/bash
source ~/.bashrc
conda_init
conda activate pointmlp
export PS1="(pointmlp)[\\u@\\h \\W]\\$ "
module load compilers
module load mpfr

View file

@ -1,31 +1,37 @@
# Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework ICLR 2022 # Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework ICLR 2022
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1)
[![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch) [![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch)
<img src="images/smile.png" height="70px">
<img src="images/neu.png" height="70px"> <div align="left">
<img src="images/columbia.png" height="70px"> <a><img src="images/smile.png" height="70px" ></a>
<a><img src="images/neu.png" height="70px" ></a>
<a><img src="images/columbia.png" height="70px" ></a>
</div>
[open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu) [open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu)
![](images/overview.png) <div align="center">
<img src="images/overview.png" width="650px" height="300px">
</div>
Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information. Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information.
## BibTeX ## BibTeX
```bibtex
@article{ma2022rethinking, @article{ma2022rethinking,
title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework}, title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework},
author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun}, author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun},
journal={arXiv preprint arXiv:2202.07123}, journal={arXiv preprint arXiv:2202.07123},
year={2022} year={2022}
} }
```
## Model Zoo ## Model Zoo
@ -36,6 +42,8 @@ If you run the same codes for several times, you will get different results (eve
The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br> The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br>
Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs. Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs.
------ ------
The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026). The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026).
@ -46,6 +54,8 @@ On ScanObjectNN, fixed pointMLP achieves a result of **84.4% mAcc** and **86.1%
Stay tuned. More elite versions and voting results will be uploaded. Stay tuned. More elite versions and voting results will be uploaded.
## News & Updates: ## News & Updates:
- [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder. - [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder.
@ -56,6 +66,9 @@ Stay tuned. More elite versions and voting results will be uploaded.
:point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026). :point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026).
## Install ## Install
```bash ```bash
@ -79,7 +92,8 @@ pip install cycler einops h5py pyyaml==5.4.1 scikit-learn==0.24.2 scipy tqdm mat
pip install pointnet2_ops_lib/. pip install pointnet2_ops_lib/.
``` ```
## Usage
## Useage
### Classification ModelNet40 ### Classification ModelNet40
**Train**: The dataset will be automatically downloaded, run following command to train. **Train**: The dataset will be automatically downloaded, run following command to train.
@ -147,5 +161,10 @@ Our implementation is mainly based on the following codebases. We gratefully tha
[Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
## LICENSE ## LICENSE
PointMLP is under the Apache-2.0 license. PointMLP is under the Apache-2.0 license.

View file

@ -1,15 +1,14 @@
import fvcore.common
import fvcore.nn
import torch import torch
import fvcore.nn
import fvcore.common
from fvcore.nn import FlopCountAnalysis from fvcore.nn import FlopCountAnalysis
from classification_ScanObjectNN.models import pointMLPElite from classification_ScanObjectNN.models import pointMLPElite
model = pointMLPElite() model = pointMLPElite()
model.eval() model.eval()
# model = deit_tiny_patch16_224() # model = deit_tiny_patch16_224()
inputs = torch.randn((1, 3, 1024)) inputs = (torch.randn((1,3,1024)))
k = 1024.0 k = 1024.0
flops = FlopCountAnalysis(model, inputs).total() flops = FlopCountAnalysis(model, inputs).total()
print(f"Flops : {flops}") print(f"Flops : {flops}")

View file

@ -1,37 +1,33 @@
import glob
import os import os
import glob
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def download(): def download():
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data") DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR): if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR) os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048")): if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = "https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip" www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www) zipfile = os.path.basename(www)
os.system(f"wget {www} --no-check-certificate; unzip {zipfile}") os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
os.system(f"mv {zipfile[:-4]} {DATA_DIR}") os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system("rm %s" % (zipfile)) os.system('rm %s' % (zipfile))
def load_data(partition): def load_data(partition):
download() download()
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data") DATA_DIR = os.path.join(BASE_DIR, 'data')
all_data = [] all_data = []
all_label = [] all_label = []
for h5_name in glob.glob(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048", "ply_data_%s*.h5" % partition)): for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)):
# print(f"h5_name: {h5_name}") # print(f"h5_name: {h5_name}")
f = h5py.File(h5_name, "r") f = h5py.File(h5_name,'r')
data = f["data"][:].astype("float32") data = f['data'][:].astype('float32')
label = f["label"][:].astype("int64") label = f['label'][:].astype('int64')
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -39,25 +35,23 @@ def load_data(partition):
all_label = np.concatenate(all_label, axis=0) all_label = np.concatenate(all_label, axis=0)
return all_data, all_label return all_data, all_label
def random_point_dropout(pc, max_dropout_ratio=0.875): def random_point_dropout(pc, max_dropout_ratio=0.875):
"""batch_pc: BxNx3.""" ''' batch_pc: BxNx3 '''
# for b in range(batch_pc.shape[0]): # for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random(pc.shape[0]) <= dropout_ratio)[0] drop_idx = np.where(np.random.random((pc.shape[0]))<=dropout_ratio)[0]
# print ('use random drop', len(drop_idx)) # print ('use random drop', len(drop_idx))
if len(drop_idx)>0: if len(drop_idx)>0:
pc[drop_idx,:] = pc[0,:] # set to the first point pc[drop_idx,:] = pc[0,:] # set to the first point
return pc return pc
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3]) xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32") translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
return translated_pointcloud
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
N, C = pointcloud.shape N, C = pointcloud.shape
@ -66,7 +60,7 @@ def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
class ModelNet40(Dataset): class ModelNet40(Dataset):
def __init__(self, num_points, partition="train"): def __init__(self, num_points, partition='train'):
self.data, self.label = load_data(partition) self.data, self.label = load_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition self.partition = partition
@ -74,7 +68,7 @@ class ModelNet40(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][:self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == "train": if self.partition == 'train':
# pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all # pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) np.random.shuffle(pointcloud)
@ -84,25 +78,19 @@ class ModelNet40(Dataset):
return self.data.shape[0] return self.data.shape[0]
if __name__ == "__main__": if __name__ == '__main__':
train = ModelNet40(1024) train = ModelNet40(1024)
test = ModelNet40(1024, "test") test = ModelNet40(1024, 'test')
# for data, label in train: # for data, label in train:
# print(data.shape) # print(data.shape)
# print(label.shape) # print(label.shape)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
train_loader = DataLoader(ModelNet40(partition='train', num_points=1024), num_workers=4,
train_loader = DataLoader( batch_size=32, shuffle=True, drop_last=True)
ModelNet40(partition="train", num_points=1024),
num_workers=4,
batch_size=32,
shuffle=True,
drop_last=True,
)
for batch_idx, (data, label) in enumerate(train_loader): for batch_idx, (data, label) in enumerate(train_loader):
print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}") print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}")
train_set = ModelNet40(partition="train", num_points=1024) train_set = ModelNet40(partition='train', num_points=1024)
test_set = ModelNet40(partition="test", num_points=1024) test_set = ModelNet40(partition='test', num_points=1024)
print(f"train_set size {train_set.__len__()}") print(f"train_set size {train_set.__len__()}")
print(f"test_set size {test_set.__len__()}") print(f"test_set size {test_set.__len__()}")

View file

@ -1,9 +1,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed.""" ''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -16,6 +16,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction="mean") loss = F.cross_entropy(pred, gold, reduction='mean')
return loss return loss

View file

@ -1,46 +1,41 @@
"""Usage: """
python main.py --model PointMLP --msg demo. Usage:
python main.py --model PointMLP --msg demo
""" """
import argparse import argparse
import datetime
import logging
import os import os
import logging
import models as models import datetime
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
from data import ModelNet40 from data import ModelNet40
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader import sklearn.metrics as metrics
from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model import numpy as np
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser("training") parser = argparse.ArgumentParser('training')
parser.add_argument( parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
"-c", help='path to save checkpoint (default: checkpoint)')
"--checkpoint", parser.add_argument('--msg', type=str, help='message after checkpoint')
type=str, parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
metavar="PATH", parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
help="path to save checkpoint (default: checkpoint)", parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training')
) parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
parser.add_argument("--msg", type=str, help="message after checkpoint") parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training')
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training") parser.add_argument('--min_lr', default=0.005, type=float, help='min lr')
parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]") parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate')
parser.add_argument("--epoch", default=300, type=int, help="number of epoch in training") parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument("--num_points", type=int, default=1024, help="Point Number") parser.add_argument('--workers', default=8, type=int, help='workers')
parser.add_argument("--learning_rate", default=0.1, type=float, help="learning rate in training")
parser.add_argument("--min_lr", default=0.005, type=float, help="min lr")
parser.add_argument("--weight_decay", type=float, default=2e-4, help="decay rate")
parser.add_argument("--seed", type=int, help="random seed")
parser.add_argument("--workers", default=8, type=int, help="workers")
return parser.parse_args() return parser.parse_args()
@ -51,7 +46,7 @@ def main():
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
assert torch.cuda.is_available(), "Please ensure codes are executed in cuda." assert torch.cuda.is_available(), "Please ensure codes are executed in cuda."
device = "cuda" device = 'cuda'
if args.seed is not None: if args.seed is not None:
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
@ -60,19 +55,19 @@ def main():
torch.set_printoptions(10) torch.set_printoptions(10)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
os.environ["PYTHONHASHSEED"] = str(args.seed) os.environ['PYTHONHASHSEED'] = str(args.seed)
time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
if args.msg is None: if args.msg is None:
message = time_str message = time_str
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = "checkpoints/" + args.model + message + "-" + str(args.seed) args.checkpoint = 'checkpoints/' + args.model + message + '-' + str(args.seed)
if not os.path.isdir(args.checkpoint): if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint) mkdir_p(args.checkpoint)
screen_logger = logging.getLogger("Model") screen_logger = logging.getLogger("Model")
screen_logger.setLevel(logging.INFO) screen_logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(message)s") formatter = logging.Formatter('%(message)s')
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt")) file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
@ -84,19 +79,19 @@ def main():
# Model # Model
printf(f"args: {args}") printf(f"args: {args}")
printf("==> Building model..") printf('==> Building model..')
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == "cuda": if device == 'cuda':
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
best_test_acc = 0.0 # best test accuracy best_test_acc = 0. # best test accuracy
best_train_acc = 0.0 best_train_acc = 0.
best_test_acc_avg = 0.0 best_test_acc_avg = 0.
best_train_acc_avg = 0.0 best_train_acc_avg = 0.
best_test_loss = float("inf") best_test_loss = float("inf")
best_train_loss = float("inf") best_train_loss = float("inf")
start_epoch = 0 # start from epoch 0 or last checkpoint epoch start_epoch = 0 # start from epoch 0 or last checkpoint epoch
@ -104,49 +99,30 @@ def main():
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")): if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
save_args(args) save_args(args)
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
logger.set_names( logger.set_names(["Epoch-Num", 'Learning-Rate',
[ 'Train-Loss', 'Train-acc-B', 'Train-acc',
"Epoch-Num", 'Valid-Loss', 'Valid-acc-B', 'Valid-acc'])
"Learning-Rate",
"Train-Loss",
"Train-acc-B",
"Train-acc",
"Valid-Loss",
"Valid-acc-B",
"Valid-acc",
],
)
else: else:
printf(f"Resuming last checkpoint from {args.checkpoint}") printf(f"Resuming last checkpoint from {args.checkpoint}")
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth") checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint["net"]) net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint["epoch"] start_epoch = checkpoint['epoch']
best_test_acc = checkpoint["best_test_acc"] best_test_acc = checkpoint['best_test_acc']
best_train_acc = checkpoint["best_train_acc"] best_train_acc = checkpoint['best_train_acc']
best_test_acc_avg = checkpoint["best_test_acc_avg"] best_test_acc_avg = checkpoint['best_test_acc_avg']
best_train_acc_avg = checkpoint["best_train_acc_avg"] best_train_acc_avg = checkpoint['best_train_acc_avg']
best_test_loss = checkpoint["best_test_loss"] best_test_loss = checkpoint['best_test_loss']
best_train_loss = checkpoint["best_train_loss"] best_train_loss = checkpoint['best_train_loss']
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
optimizer_dict = checkpoint["optimizer"] optimizer_dict = checkpoint['optimizer']
printf("==> Preparing data..") printf('==> Preparing data..')
train_loader = DataLoader( train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=args.workers,
ModelNet40(partition="train", num_points=args.num_points), batch_size=args.batch_size, shuffle=True, drop_last=True)
num_workers=args.workers, test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=args.workers,
batch_size=args.batch_size, batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
ModelNet40(partition="test", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size // 2,
shuffle=False,
drop_last=False,
)
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None: if optimizer_dict is not None:
@ -154,7 +130,7 @@ def main():
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1) scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch): for epoch in range(start_epoch, args.epoch):
printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"])) printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"} train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
scheduler.step() scheduler.step()
@ -173,46 +149,31 @@ def main():
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
save_model( save_model(
net, net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
epoch,
path=args.checkpoint,
acc=test_out["acc"],
is_best=is_best,
best_test_acc=best_test_acc, # best test accuracy best_test_acc=best_test_acc, # best test accuracy
best_train_acc=best_train_acc, best_train_acc=best_train_acc,
best_test_acc_avg=best_test_acc_avg, best_test_acc_avg=best_test_acc_avg,
best_train_acc_avg=best_train_acc_avg, best_train_acc_avg=best_train_acc_avg,
best_test_loss=best_test_loss, best_test_loss=best_test_loss,
best_train_loss=best_train_loss, best_train_loss=best_train_loss,
optimizer=optimizer.state_dict(), optimizer=optimizer.state_dict()
)
logger.append(
[
epoch,
optimizer.param_groups[0]["lr"],
train_out["loss"],
train_out["acc_avg"],
train_out["acc"],
test_out["loss"],
test_out["acc_avg"],
test_out["acc"],
],
) )
logger.append([epoch, optimizer.param_groups[0]['lr'],
train_out["loss"], train_out["acc_avg"], train_out["acc"],
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
printf( printf(
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s", f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s")
)
printf( printf(
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% " f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n", f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n")
)
logger.close() logger.close()
printf("++++++++" * 2 + "Final results" + "++++++++" * 2) printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2)
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++") printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++") printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++") printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++") printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
printf("++++++++" * 5) printf(f"++++++++" * 5)
def train(net, trainloader, optimizer, criterion, device): def train(net, trainloader, optimizer, criterion, device):
@ -241,21 +202,17 @@ def train(net, trainloader, optimizer, criterion, device):
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(trainloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
train_true = np.concatenate(train_true) train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred) train_pred = np.concatenate(train_pred)
return { return {
"loss": float("%.3f" % (train_loss / (batch_idx + 1))), "loss": float("%.3f" % (train_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))),
"time": time_cost, "time": time_cost
} }
@ -279,23 +236,19 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost, "time": time_cost
} }
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View file

@ -1 +1,3 @@
from __future__ import absolute_import
from .pointmlp import pointMLP, pointMLPElite from .pointmlp import pointMLP, pointMLPElite

View file

@ -1,32 +1,35 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torch import einsum # from torch import einsum
# from einops import rearrange, repeat # from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == "gelu": if activation.lower() == 'gelu':
return nn.GELU() return nn.GELU()
elif activation.lower() == "rrelu": elif activation.lower() == 'rrelu':
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == "selu": elif activation.lower() == 'selu':
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == "silu": elif activation.lower() == 'silu':
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == "hardswish": elif activation.lower() == 'hardswish':
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == "leakyrelu": elif activation.lower() == 'leakyrelu':
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
"""Calculate Euclid distance between each two points. """
src^T * dst = xn * xm + yn * ym + zn * zm; Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -35,7 +38,7 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M]. dist: per-point square distance, [B, N, M]
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
@ -46,12 +49,12 @@ def square_distance(src, dst):
def index_points(points, idx): def index_points(points, idx):
"""Input: """
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S]. idx: sample index data, [B, S]
Return: Return:
new_points:, indexed points data, [B, S, C]. new_points:, indexed points data, [B, S, C]
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -60,15 +63,17 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
return points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
"""Input: """
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint]. centroids: sampled pointcloud index, [B, npoint]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -86,14 +91,14 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
"""Input: """
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]. new_xyz: query points, [B, S, 3]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -109,13 +114,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
"""Input: """
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]. new_xyz: query points, [B, S, C]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -124,12 +129,13 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] """
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others. :param kwargs: others
""" """
super().__init__() super(LocalGrouper, self).__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -138,7 +144,7 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel=3 if self.use_xyz else 0
@ -168,11 +174,7 @@ class LocalGrouper(nn.Module):
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = ( std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
.unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta
@ -181,13 +183,13 @@ class LocalGrouper(nn.Module):
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
super().__init__() super(ConvBNReLU1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act, self.act
) )
def forward(self, x): def forward(self, x):
@ -195,43 +197,30 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
super().__init__() super(ConvBNReLURes1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
in_channels=channel, kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act, self.act
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=channel, out_channels=channel,
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, bias=bias),
out_channels=channel, nn.BatchNorm1d(channel)
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -239,34 +228,21 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__( def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
self, activation='relu', use_xyz=True):
channels, """
out_channels, input: [b,g,k,d]: output:[b,d,g]
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PreExtraction, self).__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3+2*channels if use_xyz else 2*channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D( ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
out_channels, bias=bias, activation=activation)
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -278,20 +254,22 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
return x.reshape(b, n, -1).permute(0, 2, 1) x = x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
"""input[b,d,g]; output[b,d,g] """
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PosExtraction, self).__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation), ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -300,32 +278,17 @@ class PosExtraction(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__( def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
self, activation="relu", bias=True, use_xyz=True, normalize="center",
points=1024, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
class_num=40, k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
embed_dim=64, super(Model, self).__init__()
groups=1,
res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="center",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[2, 2, 2, 2],
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = class_num self.class_num = class_num
self.points = points self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
assert ( assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion) "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -342,26 +305,13 @@ class Model(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction( pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
last_channel,
out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion, res_expansion=res_expansion,
bias=bias, bias=bias, activation=activation, use_xyz=use_xyz)
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction( pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
out_channel, res_expansion=res_expansion, bias=bias, activation=activation)
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
@ -376,7 +326,7 @@ class Model(nn.Module):
nn.BatchNorm1d(256), nn.BatchNorm1d(256),
self.act, self.act,
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Linear(256, self.class_num), nn.Linear(256, self.class_num)
) )
def forward(self, x): def forward(self, x):
@ -390,59 +340,29 @@ class Model(nn.Module):
x = self.pos_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g]
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
return self.classifier(x) x = self.classifier(x)
return x
def pointMLP(num_classes=40, **kwargs) -> Model: def pointMLP(num_classes=40, **kwargs) -> Model:
return Model( return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0,
points=1024, activation="relu", bias=False, use_xyz=False, normalize="anchor",
class_num=num_classes, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
embed_dim=64, k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
groups=1,
res_expansion=1.0,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
def pointMLPElite(num_classes=40, **kwargs) -> Model: def pointMLPElite(num_classes=40, **kwargs) -> Model:
return Model( return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25,
points=1024, activation="relu", bias=False, use_xyz=False, normalize="anchor",
class_num=num_classes, dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1],
embed_dim=32, k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
groups=1,
res_expansion=0.25,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 1],
pre_blocks=[1, 1, 2, 1],
pos_blocks=[1, 1, 2, 1],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
if __name__ == "__main__":
data = torch.rand(2, 3, 1024).cuda()
print(data.shape)
if __name__ == '__main__':
data = torch.rand(2, 3, 1024)
print("===> testing pointMLP ...") print("===> testing pointMLP ...")
model = pointMLP().cuda() model = pointMLP()
out = model(data) out = model(data)
print(out.shape) print(out.shape)
print("===> testing pointMLPElite ...")
model = pointMLPElite().cuda()
out = model(data)
print(out.shape)

View file

@ -1,81 +1,71 @@
"""python test.py --model pointMLP --msg 20220209053148-404.""" """
python test.py --model pointMLP --msg 20220209053148-404
"""
import argparse import argparse
import datetime
import os import os
import datetime
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from data import ModelNet40
from helper import cal_loss
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import progress_bar import models as models
from utils import progress_bar, IOStream
from data import ModelNet40
import sklearn.metrics as metrics
from helper import cal_loss
import numpy as np
import torch.nn.functional as F
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name])) model_names = sorted(name for name in models.__dict__
if callable(models.__dict__[name]))
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser("training") parser = argparse.ArgumentParser('training')
parser.add_argument( parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
"-c", help='path to save checkpoint (default: checkpoint)')
"--checkpoint", parser.add_argument('--msg', type=str, help='message after checkpoint')
type=str, parser.add_argument('--batch_size', type=int, default=16, help='batch size in training')
metavar="PATH", parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]')
help="path to save checkpoint (default: checkpoint)", parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
) parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
print(f"args: {args}") print(f"args: {args}")
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = 'cuda'
else: else:
device = "cpu" device = 'cpu'
print(f"==> Using device: {device}") print(f"==> Using device: {device}")
# if args.msg is None: if args.msg is None:
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
# else: else:
# message = "-"+args.msg message = "-"+args.msg
# args.checkpoint = 'checkpoints/' + args.model + message args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
print("==> Preparing data..") print('==> Preparing data..')
test_loader = DataLoader( test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
ModelNet40(partition="test", num_points=args.num_points), batch_size=args.batch_size, shuffle=False, drop_last=False)
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
# Model # Model
print("==> Building model..") print('==> Building model..')
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth') checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu")) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == "cuda": if device == 'cuda':
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
net.load_state_dict(checkpoint["net"]) net.load_state_dict(checkpoint['net'])
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}") print(f"Vanilla out: {test_out}")
@ -101,23 +91,19 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost, "time": time_cost
} }
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View file

@ -1,4 +1,5 @@
"""Useful utils.""" """Useful utils
from .logger import * """
from .misc import * from .misc import *
from .logger import *
from .progress.progress.bar import Bar as Bar from .progress.progress.bar import Bar as Bar

View file

@ -1,50 +1,48 @@
# A simple torch style logger # A simple torch style logger
# (C) Wei YANG 2017 # (C) Wei YANG 2017
from __future__ import absolute_import
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
import sys
import numpy as np import numpy as np
__all__ = ["Logger", "LoggerMonitor", "savefig"] __all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None): def savefig(fname, dpi=None):
dpi = 150 if dpi is None else dpi dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi) plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None): def plot_overlap(logger, names=None):
names = logger.names if names is None else names names = logger.names if names == None else names
numbers = logger.numbers numbers = logger.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
return [logger.title + "(" + name + ")" for name in names] return [logger.title + '(' + name + ')' for name in names]
class Logger:
"""Save training process to log file with simple plot function."""
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False): def __init__(self, fpath, title=None, resume=False):
self.file = None self.file = None
self.resume = resume self.resume = resume
self.title = "" if title is None else title self.title = '' if title == None else title
if fpath is not None: if fpath is not None:
if resume: if resume:
self.file = open(fpath) self.file = open(fpath, 'r')
name = self.file.readline() name = self.file.readline()
self.names = name.rstrip().split("\t") self.names = name.rstrip().split('\t')
self.numbers = {} self.numbers = {}
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.numbers[name] = [] self.numbers[name] = []
for numbers in self.file: for numbers in self.file:
numbers = numbers.rstrip().split("\t") numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)): for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i]) self.numbers[self.names[i]].append(numbers[i])
self.file.close() self.file.close()
self.file = open(fpath, "a") self.file = open(fpath, 'a')
else: else:
self.file = open(fpath, "w") self.file = open(fpath, 'w')
def set_names(self, names): def set_names(self, names):
if self.resume: if self.resume:
@ -54,39 +52,38 @@ class Logger:
self.names = names self.names = names
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.file.write(name) self.file.write(name)
self.file.write("\t") self.file.write('\t')
self.numbers[name] = [] self.numbers[name] = []
self.file.write("\n") self.file.write('\n')
self.file.flush() self.file.flush()
def append(self, numbers): def append(self, numbers):
assert len(self.names) == len(numbers), "Numbers do not match names" assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers): for index, num in enumerate(numbers):
self.file.write(f"{num:.6f}") self.file.write("{0:.6f}".format(num))
self.file.write("\t") self.file.write('\t')
self.numbers[self.names[index]].append(num) self.numbers[self.names[index]].append(num)
self.file.write("\n") self.file.write('\n')
self.file.flush() self.file.flush()
def plot(self, names=None): def plot(self, names=None):
names = self.names if names is None else names names = self.names if names == None else names
numbers = self.numbers numbers = self.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + "(" + name + ")" for name in names]) plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True) plt.grid(True)
def close(self): def close(self):
if self.file is not None: if self.file is not None:
self.file.close() self.file.close()
class LoggerMonitor(object):
class LoggerMonitor: '''Load and visualize multiple logs.'''
"""Load and visualize multiple logs."""
def __init__ (self, paths): def __init__ (self, paths):
"""Paths is a distionary with {name:filepath} pair.""" '''paths is a distionary with {name:filepath} pair'''
self.loggers = [] self.loggers = []
for title, path in paths.items(): for title, path in paths.items():
logger = Logger(path, title=title, resume=True) logger = Logger(path, title=title, resume=True)
@ -98,11 +95,10 @@ class LoggerMonitor:
legend_text = [] legend_text = []
for logger in self.loggers: for logger in self.loggers:
legend_text += plot_overlap(logger, names) legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True) plt.grid(True)
if __name__ == '__main__':
if __name__ == "__main__":
# # Example # # Example
# logger = Logger('test.txt') # logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss']) # logger.set_names(['Train loss', 'Valid loss','Test loss'])
@ -119,13 +115,13 @@ if __name__ == "__main__":
# Example: logger monitor # Example: logger monitor
paths = { paths = {
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt", 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt", 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt", 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
} }
field = ["Valid Acc."] field = ['Valid Acc.']
monitor = LoggerMonitor(paths) monitor = LoggerMonitor(paths)
monitor.plot(names=field) monitor.plot(names=field)
savefig("test.eps") savefig('test.eps')

View file

@ -1,43 +1,36 @@
"""Some helper functions for PyTorch, including: '''Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset. - get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization. - msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress. - progress_bar: progress bar mimic xlua.progress.
""" '''
import errno import errno
import os import os
import random
import shutil
import sys import sys
import time import time
import math
import numpy as np
import torch import torch
import torch.nn as nn import shutil
import numpy as np
import random
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init
__all__ = [
"get_mean_and_std", import torch.nn as nn
"init_params", import torch.nn.init as init
"mkdir_p", from torch.autograd import Variable
"AverageMeter",
"progress_bar", __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter',
"save_model", 'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"]
"save_args",
"set_seed",
"IOStream",
"cal_loss",
]
def get_mean_and_std(dataset): def get_mean_and_std(dataset):
"""Compute the mean and std value of dataset.""" '''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3) mean = torch.zeros(3)
std = torch.zeros(3) std = torch.zeros(3)
print("==> Computing mean and std..") print('==> Computing mean and std..')
for inputs, _targets in dataloader: for inputs, targets in dataloader:
for i in range(3): for i in range(3):
mean[i] += inputs[:,i,:,:].mean() mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std() std[i] += inputs[:,i,:,:].std()
@ -45,12 +38,11 @@ def get_mean_and_std(dataset):
std.div_(len(dataset)) std.div_(len(dataset))
return mean, std return mean, std
def init_params(net): def init_params(net):
"""Init layer parameters.""" '''Init layer parameters.'''
for m in net.modules(): for m in net.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode="fan_out") init.kaiming_normal(m.weight, mode='fan_out')
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -61,9 +53,8 @@ def init_params(net):
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
def mkdir_p(path): def mkdir_p(path):
"""Make dir if not exist.""" '''make dir if not exist'''
try: try:
os.makedirs(path) os.makedirs(path)
except OSError as exc: # Python >2.5 except OSError as exc: # Python >2.5
@ -72,12 +63,10 @@ def mkdir_p(path):
else: else:
raise raise
class AverageMeter(object):
class AverageMeter:
"""Computes and stores the average and current value """Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262. Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -94,11 +83,10 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
TOTAL_BAR_LENGTH = 65.0
TOTAL_BAR_LENGTH = 65.
last_time = time.time() last_time = time.time()
begin_time = last_time begin_time = last_time
def progress_bar(current, total, msg=None): def progress_bar(current, total, msg=None):
global last_time, begin_time global last_time, begin_time
if current == 0: if current == 0:
@ -107,13 +95,13 @@ def progress_bar(current, total, msg=None):
cur_len = int(TOTAL_BAR_LENGTH*current/total) cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(" [") sys.stdout.write(' [')
for _i in range(cur_len): for i in range(cur_len):
sys.stdout.write("=") sys.stdout.write('=')
sys.stdout.write(">") sys.stdout.write('>')
for _i in range(rest_len): for i in range(rest_len):
sys.stdout.write(".") sys.stdout.write('.')
sys.stdout.write("]") sys.stdout.write(']')
cur_time = time.time() cur_time = time.time()
step_time = cur_time - last_time step_time = cur_time - last_time
@ -121,12 +109,12 @@ def progress_bar(current, total, msg=None):
tot_time = cur_time - begin_time tot_time = cur_time - begin_time
L = [] L = []
L.append(" Step: %s" % format_time(step_time)) L.append(' Step: %s' % format_time(step_time))
L.append(" | Tot: %s" % format_time(tot_time)) L.append(' | Tot: %s' % format_time(tot_time))
if msg: if msg:
L.append(" | " + msg) L.append(' | ' + msg)
msg = "".join(L) msg = ''.join(L)
sys.stdout.write(msg) sys.stdout.write(msg)
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ') # sys.stdout.write(' ')
@ -134,12 +122,12 @@ def progress_bar(current, total, msg=None):
# Go back to the center of the bar. # Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b') # sys.stdout.write('\b')
sys.stdout.write(" %d/%d " % (current + 1, total)) sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1: if current < total-1:
sys.stdout.write("\r") sys.stdout.write('\r')
else: else:
sys.stdout.write("\n") sys.stdout.write('\n')
sys.stdout.flush() sys.stdout.flush()
@ -154,54 +142,56 @@ def format_time(seconds):
seconds = seconds - secondsf seconds = seconds - secondsf
millis = int(seconds*1000) millis = int(seconds*1000)
f = "" f = ''
i = 1 i = 1
if days > 0: if days > 0:
f += str(days) + "D" f += str(days) + 'D'
i += 1 i += 1
if hours > 0 and i <= 2: if hours > 0 and i <= 2:
f += str(hours) + "h" f += str(hours) + 'h'
i += 1 i += 1
if minutes > 0 and i <= 2: if minutes > 0 and i <= 2:
f += str(minutes) + "m" f += str(minutes) + 'm'
i += 1 i += 1
if secondsf > 0 and i <= 2: if secondsf > 0 and i <= 2:
f += str(secondsf) + "s" f += str(secondsf) + 's'
i += 1 i += 1
if millis > 0 and i <= 2: if millis > 0 and i <= 2:
f += str(millis) + "ms" f += str(millis) + 'ms'
i += 1 i += 1
if f == "": if f == '':
f = "0ms" f = '0ms'
return f return f
def save_model(net, epoch, path, acc, is_best, **kwargs): def save_model(net, epoch, path, acc, is_best, **kwargs):
state = { state = {
"net": net.state_dict(), 'net': net.state_dict(),
"epoch": epoch, 'epoch': epoch,
"acc": acc, 'acc': acc
} }
for key, value in kwargs.items(): for key, value in kwargs.items():
state[key] = value state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth") filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath) torch.save(state, filepath)
if is_best: if is_best:
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth")) shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth'))
def save_args(args): def save_args(args):
file = open(os.path.join(args.checkpoint, "args.txt"), "w") file = open(os.path.join(args.checkpoint, 'args.txt'), "w")
for k, v in vars(args).items(): for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n") file.write(f"{k}:\t {v}\n")
file.close() file.close()
def set_seed(seed=None): def set_seed(seed=None):
if seed is None: if seed is None:
return return
random.seed(seed) random.seed(seed)
os.environ["PYTHONHASHSEED"] = "%s" % seed os.environ['PYTHONHASHSEED'] = ("%s" % seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@ -210,14 +200,15 @@ def set_seed(seed=None):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create a file and write the text into it # create a file and write the text into it
class IOStream: class IOStream():
def __init__(self, path): def __init__(self, path):
self.f = open(path, "a") self.f = open(path, 'a')
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text + "\n") self.f.write(text+'\n')
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -225,7 +216,8 @@ class IOStream:
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed.""" ''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -238,6 +230,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction="mean") loss = F.cross_entropy(pred, gold, reduction='mean')
return loss return loss

View file

@ -12,6 +12,7 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import division
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
@ -19,10 +20,11 @@ from math import ceil
from sys import stderr from sys import stderr
from time import time from time import time
__version__ = "1.3"
__version__ = '1.3'
class Infinite: class Infinite(object):
file = stderr file = stderr
sma_window = 10 # Simple Moving Average window sma_window = 10 # Simple Moving Average window
@ -36,7 +38,7 @@ class Infinite:
setattr(self, key, val) setattr(self, key, val)
def __getitem__(self, key): def __getitem__(self, key):
if key.startswith("_"): if key.startswith('_'):
return None return None
return getattr(self, key, None) return getattr(self, key, None)
@ -81,8 +83,8 @@ class Infinite:
class Progress(Infinite): class Progress(Infinite):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super(Progress, self).__init__(*args, **kwargs)
self.max = kwargs.get("max", 100) self.max = kwargs.get('max', 100)
@property @property
def eta(self): def eta(self):

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,18 +14,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Progress from . import Progress
from .helpers import WritelnMixin from .helpers import WritelnMixin
class Bar(WritelnMixin, Progress): class Bar(WritelnMixin, Progress):
width = 32 width = 32
message = "" message = ''
suffix = "%(index)d/%(max)d" suffix = '%(index)d/%(max)d'
bar_prefix = " |" bar_prefix = ' |'
bar_suffix = "| " bar_suffix = '| '
empty_fill = " " empty_fill = ' '
fill = "#" fill = '#'
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -34,30 +37,31 @@ class Bar(WritelnMixin, Progress):
bar = self.fill * filled_length bar = self.fill * filled_length
empty = self.empty_fill * empty_length empty = self.empty_fill * empty_length
suffix = self.suffix % self suffix = self.suffix % self
line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix]) line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
suffix])
self.writeln(line) self.writeln(line)
class ChargingBar(Bar): class ChargingBar(Bar):
suffix = "%(percent)d%%" suffix = '%(percent)d%%'
bar_prefix = " " bar_prefix = ' '
bar_suffix = " " bar_suffix = ' '
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class FillingSquaresBar(ChargingBar): class FillingSquaresBar(ChargingBar):
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class FillingCirclesBar(ChargingBar): class FillingCirclesBar(ChargingBar):
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class IncrementalBar(Bar): class IncrementalBar(Bar):
phases = (" ", "", "", "", "", "", "", "", "") phases = (' ', '', '', '', '', '', '', '', '')
def update(self): def update(self):
nphases = len(self.phases) nphases = len(self.phases)
@ -68,16 +72,17 @@ class IncrementalBar(Bar):
message = self.message % self message = self.message % self
bar = self.phases[-1] * nfull bar = self.phases[-1] * nfull
current = self.phases[phase] if phase > 0 else "" current = self.phases[phase] if phase > 0 else ''
empty = self.empty_fill * max(0, nempty - len(current)) empty = self.empty_fill * max(0, nempty - len(current))
suffix = self.suffix % self suffix = self.suffix % self
line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix]) line = ''.join([message, self.bar_prefix, bar, current, empty,
self.bar_suffix, suffix])
self.writeln(line) self.writeln(line)
class PixelBar(IncrementalBar): class PixelBar(IncrementalBar):
phases = ("", "", "", "", "", "", "", "") phases = ('', '', '', '', '', '', '', '')
class ShadyBar(IncrementalBar): class ShadyBar(IncrementalBar):
phases = (" ", "", "", "", "") phases = (' ', '', '', '', '')

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,12 +14,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite, Progress from . import Infinite, Progress
from .helpers import WriteMixin from .helpers import WriteMixin
class Counter(WriteMixin, Infinite): class Counter(WriteMixin, Infinite):
message = "" message = ''
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -32,7 +35,7 @@ class Countdown(WriteMixin, Progress):
class Stack(WriteMixin, Progress): class Stack(WriteMixin, Progress):
phases = (" ", "", "", "", "", "", "", "", "") phases = (' ', '', '', '', '', '', '', '', '')
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -42,4 +45,4 @@ class Stack(WriteMixin, Progress):
class Pie(Stack): class Pie(Stack):
phases = ("", "", "", "", "") phases = ('', '', '', '', '')

View file

@ -12,76 +12,78 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import print_function
HIDE_CURSOR = "\x1b[?25l"
SHOW_CURSOR = "\x1b[?25h"
class WriteMixin: HIDE_CURSOR = '\x1b[?25l'
SHOW_CURSOR = '\x1b[?25h'
class WriteMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super().__init__(**kwargs) super(WriteMixin, self).__init__(**kwargs)
self._width = 0 self._width = 0
if message: if message:
self.message = message self.message = message
if self.file.isatty(): if self.file.isatty():
if self.hide_cursor: if self.hide_cursor:
print(HIDE_CURSOR, end="", file=self.file) print(HIDE_CURSOR, end='', file=self.file)
print(self.message, end="", file=self.file) print(self.message, end='', file=self.file)
self.file.flush() self.file.flush()
def write(self, s): def write(self, s):
if self.file.isatty(): if self.file.isatty():
b = "\b" * self._width b = '\b' * self._width
c = s.ljust(self._width) c = s.ljust(self._width)
print(b + c, end="", file=self.file) print(b + c, end='', file=self.file)
self._width = max(self._width, len(s)) self._width = max(self._width, len(s))
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(SHOW_CURSOR, end="", file=self.file) print(SHOW_CURSOR, end='', file=self.file)
class WritelnMixin: class WritelnMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super().__init__(**kwargs) super(WritelnMixin, self).__init__(**kwargs)
if message: if message:
self.message = message self.message = message
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(HIDE_CURSOR, end="", file=self.file) print(HIDE_CURSOR, end='', file=self.file)
def clearln(self): def clearln(self):
if self.file.isatty(): if self.file.isatty():
print("\r\x1b[K", end="", file=self.file) print('\r\x1b[K', end='', file=self.file)
def writeln(self, line): def writeln(self, line):
if self.file.isatty(): if self.file.isatty():
self.clearln() self.clearln()
print(line, end="", file=self.file) print(line, end='', file=self.file)
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty(): if self.file.isatty():
print(file=self.file) print(file=self.file)
if self.hide_cursor: if self.hide_cursor:
print(SHOW_CURSOR, end="", file=self.file) print(SHOW_CURSOR, end='', file=self.file)
from signal import SIGINT, signal from signal import signal, SIGINT
from sys import exit from sys import exit
class SigIntMixin: class SigIntMixin(object):
"""Registers a signal handler that calls finish on SIGINT.""" """Registers a signal handler that calls finish on SIGINT"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super(SigIntMixin, self).__init__(*args, **kwargs)
signal(SIGINT, self._sigint_handler) signal(SIGINT, self._sigint_handler)
def _sigint_handler(self, signum, frame): def _sigint_handler(self, signum, frame):

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,13 +14,14 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite from . import Infinite
from .helpers import WriteMixin from .helpers import WriteMixin
class Spinner(WriteMixin, Infinite): class Spinner(WriteMixin, Infinite):
message = "" message = ''
phases = ("-", "\\", "|", "/") phases = ('-', '\\', '|', '/')
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -27,16 +30,15 @@ class Spinner(WriteMixin, Infinite):
class PieSpinner(Spinner): class PieSpinner(Spinner):
phases = ["", "", "", ""] phases = ['', '', '', '']
class MoonSpinner(Spinner): class MoonSpinner(Spinner):
phases = ["", "", "", ""] phases = ['', '', '', '']
class LineSpinner(Spinner): class LineSpinner(Spinner):
phases = ["", "", "", "", "", ""] phases = ['', '', '', '', '', '']
class PixelSpinner(Spinner): class PixelSpinner(Spinner):
phases = ["", "", "", "", "", "", "", ""] phases = ['','', '', '', '', '', '', '']

View file

@ -1,27 +1,29 @@
#!/usr/bin/env python #!/usr/bin/env python
import progress
from setuptools import setup from setuptools import setup
import progress
setup( setup(
name="progress", name='progress',
version=progress.__version__, version=progress.__version__,
description="Easy to use progress bars", description='Easy to use progress bars',
long_description=open("README.rst").read(), long_description=open('README.rst').read(),
author="Giorgos Verigakis", author='Giorgos Verigakis',
author_email="verigak@gmail.com", author_email='verigak@gmail.com',
url="http://github.com/verigak/progress/", url='http://github.com/verigak/progress/',
license="ISC", license='ISC',
packages=["progress"], packages=['progress'],
classifiers=[ classifiers=[
"Environment :: Console", 'Environment :: Console',
"Intended Audience :: Developers", 'Intended Audience :: Developers',
"License :: OSI Approved :: ISC License (ISCL)", 'License :: OSI Approved :: ISC License (ISCL)',
"Programming Language :: Python :: 2.6", 'Programming Language :: Python :: 2.6',
"Programming Language :: Python :: 2.7", 'Programming Language :: Python :: 2.7',
"Programming Language :: Python :: 3.3", 'Programming Language :: Python :: 3.3',
"Programming Language :: Python :: 3.4", 'Programming Language :: Python :: 3.4',
"Programming Language :: Python :: 3.5", 'Programming Language :: Python :: 3.5',
"Programming Language :: Python :: 3.6", 'Programming Language :: Python :: 3.6',
], ]
) )

View file

@ -1,12 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function
import random import random
import time import time
from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar from progress.bar import (Bar, ChargingBar, FillingSquaresBar,
from progress.counter import Countdown, Counter, Pie, Stack FillingCirclesBar, IncrementalBar, PixelBar,
from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner ShadyBar)
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
PixelSpinner)
from progress.counter import Counter, Countdown, Stack, Pie
def sleep(): def sleep():
@ -16,29 +20,29 @@ def sleep():
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]" suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for _i in bar.iter(range(200)): for i in bar.iter(range(200)):
sleep() sleep()
for bar_cls in (IncrementalBar, PixelBar, ShadyBar): for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]" suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for _i in bar.iter(range(200)): for i in bar.iter(range(200)):
sleep() sleep()
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
for _i in spin(spin.__name__ + " ").iter(range(100)): for i in spin(spin.__name__ + ' ').iter(range(100)):
sleep() sleep()
print() print()
for singleton in (Counter, Countdown, Stack, Pie): for singleton in (Counter, Countdown, Stack, Pie):
for _i in singleton(singleton.__name__ + " ").iter(range(100)): for i in singleton(singleton.__name__ + ' ').iter(range(100)):
sleep() sleep()
print() print()
bar = IncrementalBar("Random", suffix="%(index)d") bar = IncrementalBar('Random', suffix='%(index)d')
for _i in range(100): for i in range(100):
bar.goto(random.randint(0, 100)) bar.goto(random.randint(0, 100))
sleep() sleep()
bar.finish() bar.finish()

View file

@ -1,52 +1,47 @@
import argparse import argparse
import datetime
import os import os
import datetime
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn.parallel import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from data import ModelNet40
from helper import cal_loss
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import IOStream, progress_bar import models as models
from utils import progress_bar, IOStream
from data import ModelNet40
import sklearn.metrics as metrics
from helper import cal_loss
import numpy as np
import torch.nn.functional as F
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name])) model_names = sorted(name for name in models.__dict__
if callable(models.__dict__[name]))
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser("training") parser = argparse.ArgumentParser('training')
parser.add_argument( parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
"-c", help='path to save checkpoint (default: checkpoint)')
"--checkpoint", parser.add_argument('--msg', type=str, help='message after checkpoint')
type=str, parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
metavar="PATH", parser.add_argument('--model', default='model31A', help='model name [default: pointnet_cls]')
help="path to save checkpoint (default: checkpoint)", parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
) parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
parser.add_argument("--msg", type=str, help="message after checkpoint") parser.add_argument('--seed', type=int, help='random seed (default: 1)')
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument("--model", default="model31A", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
parser.add_argument("--seed", type=int, help="random seed (default: 1)")
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py # Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
parser.add_argument("--NUM_PEPEAT", type=int, default=300) parser.add_argument('--NUM_PEPEAT', type=int, default=300)
parser.add_argument("--NUM_VOTE", type=int, default=10) parser.add_argument('--NUM_VOTE', type=int, default=10)
parser.add_argument("--validate", action="store_true", help="Validate the original testing result.") parser.add_argument('--validate', action='store_true', help='Validate the original testing result.')
return parser.parse_args() return parser.parse_args()
class PointcloudScale: # input random scaling class PointcloudScale(object): # input random scaling
def __init__(self, scale_low=2.0 / 3.0, scale_high=3.0 / 2.0): def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
self.scale_low = scale_low self.scale_low = scale_low
self.scale_high = scale_high self.scale_high = scale_high
@ -73,52 +68,45 @@ def main():
torch.set_printoptions(10) torch.set_printoptions(10)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
os.environ["PYTHONHASHSEED"] = str(args.seed) os.environ['PYTHONHASHSEED'] = str(args.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = 'cuda'
else: else:
device = "cpu" device = 'cpu'
print(f"==> Using device: {device}") print(f"==> Using device: {device}")
if args.msg is None: if args.msg is None:
message = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = "checkpoints/" + args.model + message args.checkpoint = 'checkpoints/' + args.model + message
print("==> Preparing data..") print('==> Preparing data..')
test_loader = DataLoader( test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
ModelNet40(partition="test", num_points=args.num_points), batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
num_workers=4,
batch_size=args.batch_size // 2,
shuffle=False,
drop_last=False,
)
# Model # Model
print("==> Building model..") print('==> Building model..')
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
checkpoint_path = os.path.join(args.checkpoint, "best_checkpoint.pth") checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == "cuda": if device == 'cuda':
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
net.load_state_dict(checkpoint["net"]) net.load_state_dict(checkpoint['net'])
if args.validate: if args.validate:
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}") print(f"Vanilla out: {test_out}")
print( print(f"Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n"
"Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n" f"Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n"
"Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n" f"[note : Original result is achieved with V100 GPUs.]\n\n\n")
"[note : Original result is achieved with V100 GPUs.]\n\n\n",
)
# Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu. # Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu.
# On different GPUs, and different number of GPUs, both OA and mean_acc vary a little. # On different GPUs, and different number of GPUs, both OA and mean_acc vary a little.
# Also, the batch size also affect the testing results, could not understand. # Also, the batch size also affect the testing results, could not understand.
print("===> start voting evaluation...") print(f"===> start voting evaluation...")
voting(net, test_loader, device, args) voting(net, test_loader, device, args)
@ -142,28 +130,23 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost, "time": time_cost
} }
def voting(net, testloader, device, args): def voting(net, testloader, device, args):
name = ( name = '/evaluate_voting' + str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) + 'seed_' + str(
"/evaluate_voting" + str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) + "seed_" + str(args.seed) + ".log" args.seed) + '.log'
)
io = IOStream(args.checkpoint + name) io = IOStream(args.checkpoint + name)
io.cprint(str(args)) io.cprint(str(args))
@ -178,7 +161,7 @@ def voting(net, testloader, device, args):
test_true = [] test_true = []
test_pred = [] test_pred = []
for _batch_idx, (data, label) in enumerate(testloader): for batch_idx, (data, label) in enumerate(testloader):
data, label = data.to(device), label.to(device).squeeze() data, label = data.to(device), label.to(device).squeeze()
pred = 0 pred = 0
for v in range(args.NUM_VOTE): for v in range(args.NUM_VOTE):
@ -195,24 +178,19 @@ def voting(net, testloader, device, args):
test_pred.append(pred_choice.detach().cpu().numpy()) test_pred.append(pred_choice.detach().cpu().numpy())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
test_acc = 100.0 * metrics.accuracy_score(test_true, test_pred) test_acc = 100. * metrics.accuracy_score(test_true, test_pred)
test_mean_acc = 100.0 * metrics.balanced_accuracy_score(test_true, test_pred) test_mean_acc = 100. * metrics.balanced_accuracy_score(test_true, test_pred)
if test_acc > best_acc: if test_acc > best_acc:
best_acc = test_acc best_acc = test_acc
if test_mean_acc > best_mean_acc: if test_mean_acc > best_mean_acc:
best_mean_acc = test_mean_acc best_mean_acc = test_mean_acc
outstr = "Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]" % ( outstr = 'Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]' % \
i, (i, test_acc, test_mean_acc, best_acc, best_mean_acc)
test_acc,
test_mean_acc,
best_acc,
best_mean_acc,
)
io.cprint(outstr) io.cprint(outstr)
final_outstr = "Final voting test acc: %.6f," % (best_acc * 100) final_outstr = 'Final voting test acc: %.6f,' % (best_acc * 100)
io.cprint(final_outstr) io.cprint(final_outstr)
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View file

@ -1,7 +1,10 @@
"""ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip.""" """
ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip
"""
import os import os
import sys
import glob
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -11,18 +14,18 @@ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def download(): def download():
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data") DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR): if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR) os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, "h5_files")): if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')):
# note that this link only contains the hardest perturbed variant (PB_T50_RS). # note that this link only contains the hardest perturbed variant (PB_T50_RS).
# for full versions, consider the following link. # for full versions, consider the following link.
www = "https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip" www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip'
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip' # www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
zipfile = os.path.basename(www) zipfile = os.path.basename(www)
os.system(f"wget {www} --no-check-certificate; unzip {zipfile}") os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
os.system(f"mv {zipfile[:-4]} {DATA_DIR}") os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system("rm %s" % (zipfile)) os.system('rm %s' % (zipfile))
def load_scanobjectnn_data(partition): def load_scanobjectnn_data(partition):
@ -31,10 +34,10 @@ def load_scanobjectnn_data(partition):
all_data = [] all_data = []
all_label = [] all_label = []
h5_name = BASE_DIR + "/data/h5_files/main_split/" + partition + "_objectdataset_augmentedrot_scale75.h5" h5_name = BASE_DIR + '/data/h5_files/main_split/' + partition + '_objectdataset_augmentedrot_scale75.h5'
f = h5py.File(h5_name, mode="r") f = h5py.File(h5_name, mode="r")
data = f["data"][:].astype("float32") data = f['data'][:].astype('float32')
label = f["label"][:].astype("int64") label = f['label'][:].astype('int64')
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -44,14 +47,15 @@ def load_scanobjectnn_data(partition):
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3]) xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32") translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
return translated_pointcloud
class ScanObjectNN(Dataset): class ScanObjectNN(Dataset):
def __init__(self, num_points, partition="training"): def __init__(self, num_points, partition='training'):
self.data, self.label = load_scanobjectnn_data(partition) self.data, self.label = load_scanobjectnn_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition self.partition = partition
@ -59,7 +63,7 @@ class ScanObjectNN(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][:self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == "training": if self.partition == 'training':
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) np.random.shuffle(pointcloud)
return pointcloud, label return pointcloud, label
@ -68,9 +72,9 @@ class ScanObjectNN(Dataset):
return self.data.shape[0] return self.data.shape[0]
if __name__ == "__main__": if __name__ == '__main__':
train = ScanObjectNN(1024) train = ScanObjectNN(1024)
test = ScanObjectNN(1024, "test") test = ScanObjectNN(1024, 'test')
for data, label in train: for data, label in train:
print(data.shape) print(data.shape)
print(label) print(label)

View file

@ -1,50 +1,45 @@
"""for training with resume functions. """
for training with resume functions.
Usage: Usage:
python main.py --model PointNet --msg demo python main.py --model PointNet --msg demo
or or
CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &. CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &
""" """
import argparse import argparse
import datetime
import logging
import os import os
import logging
import models as models import datetime
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
from ScanObjectNN import ScanObjectNN from ScanObjectNN import ScanObjectNN
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader import sklearn.metrics as metrics
from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model import numpy as np
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser("training") parser = argparse.ArgumentParser('training')
parser.add_argument( parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
"-c", help='path to save checkpoint (default: checkpoint)')
"--checkpoint", parser.add_argument('--msg', type=str, help='message after checkpoint')
type=str, parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
metavar="PATH", parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
help="path to save checkpoint (default: checkpoint)", parser.add_argument('--num_classes', default=15, type=int, help='default value for classes of ScanObjectNN')
) parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
parser.add_argument("--msg", type=str, help="message after checkpoint") parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training") parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate in training')
parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]") parser.add_argument('--weight_decay', type=float, default=1e-4, help='decay rate')
parser.add_argument("--num_classes", default=15, type=int, help="default value for classes of ScanObjectNN") parser.add_argument('--smoothing', action='store_true', default=False, help='loss smoothing')
parser.add_argument("--epoch", default=200, type=int, help="number of epoch in training") parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument("--num_points", type=int, default=1024, help="Point Number") parser.add_argument('--workers', default=4, type=int, help='workers')
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate in training")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="decay rate")
parser.add_argument("--smoothing", action="store_true", default=False, help="loss smoothing")
parser.add_argument("--seed", type=int, help="random seed")
parser.add_argument("--workers", default=4, type=int, help="workers")
return parser.parse_args() return parser.parse_args()
@ -54,23 +49,23 @@ def main():
if args.seed is not None: if args.seed is not None:
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = 'cuda'
if args.seed is not None: if args.seed is not None:
torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed(args.seed)
else: else:
device = "cpu" device = 'cpu'
time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
if args.msg is None: if args.msg is None:
message = time_str message = time_str
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = "checkpoints/" + args.model + message args.checkpoint = 'checkpoints/' + args.model + message
if not os.path.isdir(args.checkpoint): if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint) mkdir_p(args.checkpoint)
screen_logger = logging.getLogger("Model") screen_logger = logging.getLogger("Model")
screen_logger.setLevel(logging.INFO) screen_logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(message)s") formatter = logging.Formatter('%(message)s')
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt")) file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
@ -82,19 +77,19 @@ def main():
# Model # Model
printf(f"args: {args}") printf(f"args: {args}")
printf("==> Building model..") printf('==> Building model..')
net = models.__dict__[args.model](num_classes=args.num_classes) net = models.__dict__[args.model](num_classes=args.num_classes)
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == "cuda": if device == 'cuda':
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
best_test_acc = 0.0 # best test accuracy best_test_acc = 0. # best test accuracy
best_train_acc = 0.0 best_train_acc = 0.
best_test_acc_avg = 0.0 best_test_acc_avg = 0.
best_train_acc_avg = 0.0 best_train_acc_avg = 0.
best_test_loss = float("inf") best_test_loss = float("inf")
best_train_loss = float("inf") best_train_loss = float("inf")
start_epoch = 0 # start from epoch 0 or last checkpoint epoch start_epoch = 0 # start from epoch 0 or last checkpoint epoch
@ -102,49 +97,30 @@ def main():
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")): if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
save_args(args) save_args(args)
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
logger.set_names( logger.set_names(["Epoch-Num", 'Learning-Rate',
[ 'Train-Loss', 'Train-acc-B', 'Train-acc',
"Epoch-Num", 'Valid-Loss', 'Valid-acc-B', 'Valid-acc'])
"Learning-Rate",
"Train-Loss",
"Train-acc-B",
"Train-acc",
"Valid-Loss",
"Valid-acc-B",
"Valid-acc",
],
)
else: else:
printf(f"Resuming last checkpoint from {args.checkpoint}") printf(f"Resuming last checkpoint from {args.checkpoint}")
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth") checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint["net"]) net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint["epoch"] start_epoch = checkpoint['epoch']
best_test_acc = checkpoint["best_test_acc"] best_test_acc = checkpoint['best_test_acc']
best_train_acc = checkpoint["best_train_acc"] best_train_acc = checkpoint['best_train_acc']
best_test_acc_avg = checkpoint["best_test_acc_avg"] best_test_acc_avg = checkpoint['best_test_acc_avg']
best_train_acc_avg = checkpoint["best_train_acc_avg"] best_train_acc_avg = checkpoint['best_train_acc_avg']
best_test_loss = checkpoint["best_test_loss"] best_test_loss = checkpoint['best_test_loss']
best_train_loss = checkpoint["best_train_loss"] best_train_loss = checkpoint['best_train_loss']
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
optimizer_dict = checkpoint["optimizer"] optimizer_dict = checkpoint['optimizer']
printf("==> Preparing data..") printf('==> Preparing data..')
train_loader = DataLoader( train_loader = DataLoader(ScanObjectNN(partition='training', num_points=args.num_points), num_workers=args.workers,
ScanObjectNN(partition="training", num_points=args.num_points), batch_size=args.batch_size, shuffle=True, drop_last=True)
num_workers=args.workers, test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points), num_workers=args.workers,
batch_size=args.batch_size, batch_size=args.batch_size, shuffle=True, drop_last=False)
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
ScanObjectNN(partition="test", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
)
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None: if optimizer_dict is not None:
@ -152,7 +128,7 @@ def main():
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1) scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch): for epoch in range(start_epoch, args.epoch):
printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"])) printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"} train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
scheduler.step() scheduler.step()
@ -171,46 +147,31 @@ def main():
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
save_model( save_model(
net, net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
epoch,
path=args.checkpoint,
acc=test_out["acc"],
is_best=is_best,
best_test_acc=best_test_acc, # best test accuracy best_test_acc=best_test_acc, # best test accuracy
best_train_acc=best_train_acc, best_train_acc=best_train_acc,
best_test_acc_avg=best_test_acc_avg, best_test_acc_avg=best_test_acc_avg,
best_train_acc_avg=best_train_acc_avg, best_train_acc_avg=best_train_acc_avg,
best_test_loss=best_test_loss, best_test_loss=best_test_loss,
best_train_loss=best_train_loss, best_train_loss=best_train_loss,
optimizer=optimizer.state_dict(), optimizer=optimizer.state_dict()
)
logger.append(
[
epoch,
optimizer.param_groups[0]["lr"],
train_out["loss"],
train_out["acc_avg"],
train_out["acc"],
test_out["loss"],
test_out["acc_avg"],
test_out["acc"],
],
) )
logger.append([epoch, optimizer.param_groups[0]['lr'],
train_out["loss"], train_out["acc_avg"], train_out["acc"],
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
printf( printf(
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s", f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s")
)
printf( printf(
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% " f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n", f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n")
)
logger.close() logger.close()
printf("++++++++" * 2 + "Final results" + "++++++++" * 2) printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2)
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++") printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++") printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++") printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++") printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
printf("++++++++" * 5) printf(f"++++++++" * 5)
def train(net, trainloader, optimizer, criterion, device): def train(net, trainloader, optimizer, criterion, device):
@ -238,21 +199,17 @@ def train(net, trainloader, optimizer, criterion, device):
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(trainloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
train_true = np.concatenate(train_true) train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred) train_pred = np.concatenate(train_pred)
return { return {
"loss": float("%.3f" % (train_loss / (batch_idx + 1))), "loss": float("%.3f" % (train_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))),
"time": time_cost, "time": time_cost
} }
@ -276,23 +233,19 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar( progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
batch_idx, % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost, "time": time_cost
} }
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View file

@ -1 +1,3 @@
from __future__ import absolute_import
from .pointmlp import pointMLP, pointMLPElite from .pointmlp import pointMLP, pointMLPElite

View file

@ -1,32 +1,35 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torch import einsum # from torch import einsum
# from einops import rearrange, repeat # from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == "gelu": if activation.lower() == 'gelu':
return nn.GELU() return nn.GELU()
elif activation.lower() == "rrelu": elif activation.lower() == 'rrelu':
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == "selu": elif activation.lower() == 'selu':
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == "silu": elif activation.lower() == 'silu':
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == "hardswish": elif activation.lower() == 'hardswish':
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == "leakyrelu": elif activation.lower() == 'leakyrelu':
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
"""Calculate Euclid distance between each two points. """
src^T * dst = xn * xm + yn * ym + zn * zm; Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -35,7 +38,7 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M]. dist: per-point square distance, [B, N, M]
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
@ -46,12 +49,12 @@ def square_distance(src, dst):
def index_points(points, idx): def index_points(points, idx):
"""Input: """
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S]. idx: sample index data, [B, S]
Return: Return:
new_points:, indexed points data, [B, S, C]. new_points:, indexed points data, [B, S, C]
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -60,15 +63,17 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
return points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
"""Input: """
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint]. centroids: sampled pointcloud index, [B, npoint]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -86,14 +91,14 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
"""Input: """
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]. new_xyz: query points, [B, S, 3]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -109,13 +114,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
"""Input: """
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]. new_xyz: query points, [B, S, C]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -124,12 +129,13 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] """
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others. :param kwargs: others
""" """
super().__init__() super(LocalGrouper, self).__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -138,7 +144,7 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel=3 if self.use_xyz else 0
@ -168,11 +174,7 @@ class LocalGrouper(nn.Module):
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = ( std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
.unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta
@ -181,13 +183,13 @@ class LocalGrouper(nn.Module):
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
super().__init__() super(ConvBNReLU1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act, self.act
) )
def forward(self, x): def forward(self, x):
@ -195,43 +197,30 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
super().__init__() super(ConvBNReLURes1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
in_channels=channel, kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act, self.act
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=channel, out_channels=channel,
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, bias=bias),
out_channels=channel, nn.BatchNorm1d(channel)
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -239,34 +228,21 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__( def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
self, activation='relu', use_xyz=True):
channels, """
out_channels, input: [b,g,k,d]: output:[b,d,g]
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PreExtraction, self).__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3+2*channels if use_xyz else 2*channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D( ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
out_channels, bias=bias, activation=activation)
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -278,20 +254,22 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
return x.reshape(b, n, -1).permute(0, 2, 1) x = x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
"""input[b,d,g]; output[b,d,g] """
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PosExtraction, self).__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation), ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -300,32 +278,17 @@ class PosExtraction(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__( def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
self, activation="relu", bias=True, use_xyz=True, normalize="center",
points=1024, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
class_num=40, k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
embed_dim=64, super(Model, self).__init__()
groups=1,
res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="center",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[2, 2, 2, 2],
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = class_num self.class_num = class_num
self.points = points self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
assert ( assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion) "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -342,26 +305,13 @@ class Model(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction( pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
last_channel,
out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion, res_expansion=res_expansion,
bias=bias, bias=bias, activation=activation, use_xyz=use_xyz)
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction( pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
out_channel, res_expansion=res_expansion, bias=bias, activation=activation)
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
@ -376,7 +326,7 @@ class Model(nn.Module):
nn.BatchNorm1d(256), nn.BatchNorm1d(256),
self.act, self.act,
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Linear(256, self.class_num), nn.Linear(256, self.class_num)
) )
def forward(self, x): def forward(self, x):
@ -390,52 +340,29 @@ class Model(nn.Module):
x = self.pos_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g]
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
return self.classifier(x) x = self.classifier(x)
return x
def pointMLP(num_classes=40, **kwargs) -> Model: def pointMLP(num_classes=40, **kwargs) -> Model:
return Model( return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0,
points=1024, activation="relu", bias=False, use_xyz=False, normalize="anchor",
class_num=num_classes, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
embed_dim=64, k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
groups=1,
res_expansion=1.0,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
def pointMLPElite(num_classes=40, **kwargs) -> Model: def pointMLPElite(num_classes=40, **kwargs) -> Model:
return Model( return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25,
points=1024, activation="relu", bias=False, use_xyz=False, normalize="anchor",
class_num=num_classes, dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1],
embed_dim=32, k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
groups=1,
res_expansion=0.25,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 1],
pre_blocks=[1, 1, 2, 1],
pos_blocks=[1, 1, 2, 1],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
if __name__ == '__main__':
if __name__ == "__main__":
data = torch.rand(2, 3, 1024) data = torch.rand(2, 3, 1024)
print("===> testing pointMLP ...") print("===> testing pointMLP ...")
model = pointMLP() model = pointMLP()
out = model(data) out = model(data)
print(out.shape) print(out.shape)

View file

@ -1,4 +1,5 @@
"""Useful utils.""" """Useful utils
from .logger import * """
from .misc import * from .misc import *
from .logger import *
from .progress.progress.bar import Bar as Bar from .progress.progress.bar import Bar as Bar

View file

@ -1,50 +1,48 @@
# A simple torch style logger # A simple torch style logger
# (C) Wei YANG 2017 # (C) Wei YANG 2017
from __future__ import absolute_import
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
import sys
import numpy as np import numpy as np
__all__ = ["Logger", "LoggerMonitor", "savefig"] __all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None): def savefig(fname, dpi=None):
dpi = 150 if dpi is None else dpi dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi) plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None): def plot_overlap(logger, names=None):
names = logger.names if names is None else names names = logger.names if names == None else names
numbers = logger.numbers numbers = logger.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
return [logger.title + "(" + name + ")" for name in names] return [logger.title + '(' + name + ')' for name in names]
class Logger:
"""Save training process to log file with simple plot function."""
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False): def __init__(self, fpath, title=None, resume=False):
self.file = None self.file = None
self.resume = resume self.resume = resume
self.title = "" if title is None else title self.title = '' if title == None else title
if fpath is not None: if fpath is not None:
if resume: if resume:
self.file = open(fpath) self.file = open(fpath, 'r')
name = self.file.readline() name = self.file.readline()
self.names = name.rstrip().split("\t") self.names = name.rstrip().split('\t')
self.numbers = {} self.numbers = {}
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.numbers[name] = [] self.numbers[name] = []
for numbers in self.file: for numbers in self.file:
numbers = numbers.rstrip().split("\t") numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)): for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i]) self.numbers[self.names[i]].append(numbers[i])
self.file.close() self.file.close()
self.file = open(fpath, "a") self.file = open(fpath, 'a')
else: else:
self.file = open(fpath, "w") self.file = open(fpath, 'w')
def set_names(self, names): def set_names(self, names):
if self.resume: if self.resume:
@ -54,39 +52,38 @@ class Logger:
self.names = names self.names = names
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.file.write(name) self.file.write(name)
self.file.write("\t") self.file.write('\t')
self.numbers[name] = [] self.numbers[name] = []
self.file.write("\n") self.file.write('\n')
self.file.flush() self.file.flush()
def append(self, numbers): def append(self, numbers):
assert len(self.names) == len(numbers), "Numbers do not match names" assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers): for index, num in enumerate(numbers):
self.file.write(f"{num:.6f}") self.file.write("{0:.6f}".format(num))
self.file.write("\t") self.file.write('\t')
self.numbers[self.names[index]].append(num) self.numbers[self.names[index]].append(num)
self.file.write("\n") self.file.write('\n')
self.file.flush() self.file.flush()
def plot(self, names=None): def plot(self, names=None):
names = self.names if names is None else names names = self.names if names == None else names
numbers = self.numbers numbers = self.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + "(" + name + ")" for name in names]) plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True) plt.grid(True)
def close(self): def close(self):
if self.file is not None: if self.file is not None:
self.file.close() self.file.close()
class LoggerMonitor(object):
class LoggerMonitor: '''Load and visualize multiple logs.'''
"""Load and visualize multiple logs."""
def __init__ (self, paths): def __init__ (self, paths):
"""Paths is a distionary with {name:filepath} pair.""" '''paths is a distionary with {name:filepath} pair'''
self.loggers = [] self.loggers = []
for title, path in paths.items(): for title, path in paths.items():
logger = Logger(path, title=title, resume=True) logger = Logger(path, title=title, resume=True)
@ -98,11 +95,10 @@ class LoggerMonitor:
legend_text = [] legend_text = []
for logger in self.loggers: for logger in self.loggers:
legend_text += plot_overlap(logger, names) legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True) plt.grid(True)
if __name__ == '__main__':
if __name__ == "__main__":
# # Example # # Example
# logger = Logger('test.txt') # logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss']) # logger.set_names(['Train loss', 'Valid loss','Test loss'])
@ -119,13 +115,13 @@ if __name__ == "__main__":
# Example: logger monitor # Example: logger monitor
paths = { paths = {
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt", 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt", 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt", 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
} }
field = ["Valid Acc."] field = ['Valid Acc.']
monitor = LoggerMonitor(paths) monitor = LoggerMonitor(paths)
monitor.plot(names=field) monitor.plot(names=field)
savefig("test.eps") savefig('test.eps')

View file

@ -1,43 +1,36 @@
"""Some helper functions for PyTorch, including: '''Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset. - get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization. - msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress. - progress_bar: progress bar mimic xlua.progress.
""" '''
import errno import errno
import os import os
import random
import shutil
import sys import sys
import time import time
import math
import numpy as np
import torch import torch
import torch.nn as nn import shutil
import numpy as np
import random
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init
__all__ = [
"get_mean_and_std", import torch.nn as nn
"init_params", import torch.nn.init as init
"mkdir_p", from torch.autograd import Variable
"AverageMeter",
"progress_bar", __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter',
"save_model", 'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"]
"save_args",
"set_seed",
"IOStream",
"cal_loss",
]
def get_mean_and_std(dataset): def get_mean_and_std(dataset):
"""Compute the mean and std value of dataset.""" '''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3) mean = torch.zeros(3)
std = torch.zeros(3) std = torch.zeros(3)
print("==> Computing mean and std..") print('==> Computing mean and std..')
for inputs, _targets in dataloader: for inputs, targets in dataloader:
for i in range(3): for i in range(3):
mean[i] += inputs[:,i,:,:].mean() mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std() std[i] += inputs[:,i,:,:].std()
@ -45,12 +38,11 @@ def get_mean_and_std(dataset):
std.div_(len(dataset)) std.div_(len(dataset))
return mean, std return mean, std
def init_params(net): def init_params(net):
"""Init layer parameters.""" '''Init layer parameters.'''
for m in net.modules(): for m in net.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode="fan_out") init.kaiming_normal(m.weight, mode='fan_out')
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -61,9 +53,8 @@ def init_params(net):
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
def mkdir_p(path): def mkdir_p(path):
"""Make dir if not exist.""" '''make dir if not exist'''
try: try:
os.makedirs(path) os.makedirs(path)
except OSError as exc: # Python >2.5 except OSError as exc: # Python >2.5
@ -72,12 +63,10 @@ def mkdir_p(path):
else: else:
raise raise
class AverageMeter(object):
class AverageMeter:
"""Computes and stores the average and current value """Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262. Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -94,11 +83,10 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
TOTAL_BAR_LENGTH = 65.0
TOTAL_BAR_LENGTH = 65.
last_time = time.time() last_time = time.time()
begin_time = last_time begin_time = last_time
def progress_bar(current, total, msg=None): def progress_bar(current, total, msg=None):
global last_time, begin_time global last_time, begin_time
if current == 0: if current == 0:
@ -107,13 +95,13 @@ def progress_bar(current, total, msg=None):
cur_len = int(TOTAL_BAR_LENGTH*current/total) cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(" [") sys.stdout.write(' [')
for _i in range(cur_len): for i in range(cur_len):
sys.stdout.write("=") sys.stdout.write('=')
sys.stdout.write(">") sys.stdout.write('>')
for _i in range(rest_len): for i in range(rest_len):
sys.stdout.write(".") sys.stdout.write('.')
sys.stdout.write("]") sys.stdout.write(']')
cur_time = time.time() cur_time = time.time()
step_time = cur_time - last_time step_time = cur_time - last_time
@ -121,12 +109,12 @@ def progress_bar(current, total, msg=None):
tot_time = cur_time - begin_time tot_time = cur_time - begin_time
L = [] L = []
L.append(" Step: %s" % format_time(step_time)) L.append(' Step: %s' % format_time(step_time))
L.append(" | Tot: %s" % format_time(tot_time)) L.append(' | Tot: %s' % format_time(tot_time))
if msg: if msg:
L.append(" | " + msg) L.append(' | ' + msg)
msg = "".join(L) msg = ''.join(L)
sys.stdout.write(msg) sys.stdout.write(msg)
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ') # sys.stdout.write(' ')
@ -134,12 +122,12 @@ def progress_bar(current, total, msg=None):
# Go back to the center of the bar. # Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b') # sys.stdout.write('\b')
sys.stdout.write(" %d/%d " % (current + 1, total)) sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1: if current < total-1:
sys.stdout.write("\r") sys.stdout.write('\r')
else: else:
sys.stdout.write("\n") sys.stdout.write('\n')
sys.stdout.flush() sys.stdout.flush()
@ -154,54 +142,56 @@ def format_time(seconds):
seconds = seconds - secondsf seconds = seconds - secondsf
millis = int(seconds*1000) millis = int(seconds*1000)
f = "" f = ''
i = 1 i = 1
if days > 0: if days > 0:
f += str(days) + "D" f += str(days) + 'D'
i += 1 i += 1
if hours > 0 and i <= 2: if hours > 0 and i <= 2:
f += str(hours) + "h" f += str(hours) + 'h'
i += 1 i += 1
if minutes > 0 and i <= 2: if minutes > 0 and i <= 2:
f += str(minutes) + "m" f += str(minutes) + 'm'
i += 1 i += 1
if secondsf > 0 and i <= 2: if secondsf > 0 and i <= 2:
f += str(secondsf) + "s" f += str(secondsf) + 's'
i += 1 i += 1
if millis > 0 and i <= 2: if millis > 0 and i <= 2:
f += str(millis) + "ms" f += str(millis) + 'ms'
i += 1 i += 1
if f == "": if f == '':
f = "0ms" f = '0ms'
return f return f
def save_model(net, epoch, path, acc, is_best, **kwargs): def save_model(net, epoch, path, acc, is_best, **kwargs):
state = { state = {
"net": net.state_dict(), 'net': net.state_dict(),
"epoch": epoch, 'epoch': epoch,
"acc": acc, 'acc': acc
} }
for key, value in kwargs.items(): for key, value in kwargs.items():
state[key] = value state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth") filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath) torch.save(state, filepath)
if is_best: if is_best:
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth")) shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth'))
def save_args(args): def save_args(args):
file = open(os.path.join(args.checkpoint, "args.txt"), "w") file = open(os.path.join(args.checkpoint, 'args.txt'), "w")
for k, v in vars(args).items(): for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n") file.write(f"{k}:\t {v}\n")
file.close() file.close()
def set_seed(seed=None): def set_seed(seed=None):
if seed is None: if seed is None:
return return
random.seed(seed) random.seed(seed)
os.environ["PYTHONHASHSEED"] = "%s" % seed os.environ['PYTHONHASHSEED'] = ("%s" % seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@ -210,14 +200,15 @@ def set_seed(seed=None):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create a file and write the text into it # create a file and write the text into it
class IOStream: class IOStream():
def __init__(self, path): def __init__(self, path):
self.f = open(path, "a") self.f = open(path, 'a')
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text + "\n") self.f.write(text+'\n')
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -225,7 +216,8 @@ class IOStream:
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed.""" ''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -238,6 +230,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction="mean") loss = F.cross_entropy(pred, gold, reduction='mean')
return loss return loss

View file

@ -12,6 +12,7 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import division
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
@ -19,10 +20,11 @@ from math import ceil
from sys import stderr from sys import stderr
from time import time from time import time
__version__ = "1.3"
__version__ = '1.3'
class Infinite: class Infinite(object):
file = stderr file = stderr
sma_window = 10 # Simple Moving Average window sma_window = 10 # Simple Moving Average window
@ -36,7 +38,7 @@ class Infinite:
setattr(self, key, val) setattr(self, key, val)
def __getitem__(self, key): def __getitem__(self, key):
if key.startswith("_"): if key.startswith('_'):
return None return None
return getattr(self, key, None) return getattr(self, key, None)
@ -81,8 +83,8 @@ class Infinite:
class Progress(Infinite): class Progress(Infinite):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super(Progress, self).__init__(*args, **kwargs)
self.max = kwargs.get("max", 100) self.max = kwargs.get('max', 100)
@property @property
def eta(self): def eta(self):

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,18 +14,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Progress from . import Progress
from .helpers import WritelnMixin from .helpers import WritelnMixin
class Bar(WritelnMixin, Progress): class Bar(WritelnMixin, Progress):
width = 32 width = 32
message = "" message = ''
suffix = "%(index)d/%(max)d" suffix = '%(index)d/%(max)d'
bar_prefix = " |" bar_prefix = ' |'
bar_suffix = "| " bar_suffix = '| '
empty_fill = " " empty_fill = ' '
fill = "#" fill = '#'
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -34,30 +37,31 @@ class Bar(WritelnMixin, Progress):
bar = self.fill * filled_length bar = self.fill * filled_length
empty = self.empty_fill * empty_length empty = self.empty_fill * empty_length
suffix = self.suffix % self suffix = self.suffix % self
line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix]) line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
suffix])
self.writeln(line) self.writeln(line)
class ChargingBar(Bar): class ChargingBar(Bar):
suffix = "%(percent)d%%" suffix = '%(percent)d%%'
bar_prefix = " " bar_prefix = ' '
bar_suffix = " " bar_suffix = ' '
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class FillingSquaresBar(ChargingBar): class FillingSquaresBar(ChargingBar):
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class FillingCirclesBar(ChargingBar): class FillingCirclesBar(ChargingBar):
empty_fill = "" empty_fill = ''
fill = "" fill = ''
class IncrementalBar(Bar): class IncrementalBar(Bar):
phases = (" ", "", "", "", "", "", "", "", "") phases = (' ', '', '', '', '', '', '', '', '')
def update(self): def update(self):
nphases = len(self.phases) nphases = len(self.phases)
@ -68,16 +72,17 @@ class IncrementalBar(Bar):
message = self.message % self message = self.message % self
bar = self.phases[-1] * nfull bar = self.phases[-1] * nfull
current = self.phases[phase] if phase > 0 else "" current = self.phases[phase] if phase > 0 else ''
empty = self.empty_fill * max(0, nempty - len(current)) empty = self.empty_fill * max(0, nempty - len(current))
suffix = self.suffix % self suffix = self.suffix % self
line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix]) line = ''.join([message, self.bar_prefix, bar, current, empty,
self.bar_suffix, suffix])
self.writeln(line) self.writeln(line)
class PixelBar(IncrementalBar): class PixelBar(IncrementalBar):
phases = ("", "", "", "", "", "", "", "") phases = ('', '', '', '', '', '', '', '')
class ShadyBar(IncrementalBar): class ShadyBar(IncrementalBar):
phases = (" ", "", "", "", "") phases = (' ', '', '', '', '')

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,12 +14,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite, Progress from . import Infinite, Progress
from .helpers import WriteMixin from .helpers import WriteMixin
class Counter(WriteMixin, Infinite): class Counter(WriteMixin, Infinite):
message = "" message = ''
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -32,7 +35,7 @@ class Countdown(WriteMixin, Progress):
class Stack(WriteMixin, Progress): class Stack(WriteMixin, Progress):
phases = (" ", "", "", "", "", "", "", "", "") phases = (' ', '', '', '', '', '', '', '', '')
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -42,4 +45,4 @@ class Stack(WriteMixin, Progress):
class Pie(Stack): class Pie(Stack):
phases = ("", "", "", "", "") phases = ('', '', '', '', '')

View file

@ -12,76 +12,78 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import print_function
HIDE_CURSOR = "\x1b[?25l"
SHOW_CURSOR = "\x1b[?25h"
class WriteMixin: HIDE_CURSOR = '\x1b[?25l'
SHOW_CURSOR = '\x1b[?25h'
class WriteMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super().__init__(**kwargs) super(WriteMixin, self).__init__(**kwargs)
self._width = 0 self._width = 0
if message: if message:
self.message = message self.message = message
if self.file.isatty(): if self.file.isatty():
if self.hide_cursor: if self.hide_cursor:
print(HIDE_CURSOR, end="", file=self.file) print(HIDE_CURSOR, end='', file=self.file)
print(self.message, end="", file=self.file) print(self.message, end='', file=self.file)
self.file.flush() self.file.flush()
def write(self, s): def write(self, s):
if self.file.isatty(): if self.file.isatty():
b = "\b" * self._width b = '\b' * self._width
c = s.ljust(self._width) c = s.ljust(self._width)
print(b + c, end="", file=self.file) print(b + c, end='', file=self.file)
self._width = max(self._width, len(s)) self._width = max(self._width, len(s))
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(SHOW_CURSOR, end="", file=self.file) print(SHOW_CURSOR, end='', file=self.file)
class WritelnMixin: class WritelnMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super().__init__(**kwargs) super(WritelnMixin, self).__init__(**kwargs)
if message: if message:
self.message = message self.message = message
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(HIDE_CURSOR, end="", file=self.file) print(HIDE_CURSOR, end='', file=self.file)
def clearln(self): def clearln(self):
if self.file.isatty(): if self.file.isatty():
print("\r\x1b[K", end="", file=self.file) print('\r\x1b[K', end='', file=self.file)
def writeln(self, line): def writeln(self, line):
if self.file.isatty(): if self.file.isatty():
self.clearln() self.clearln()
print(line, end="", file=self.file) print(line, end='', file=self.file)
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty(): if self.file.isatty():
print(file=self.file) print(file=self.file)
if self.hide_cursor: if self.hide_cursor:
print(SHOW_CURSOR, end="", file=self.file) print(SHOW_CURSOR, end='', file=self.file)
from signal import SIGINT, signal from signal import signal, SIGINT
from sys import exit from sys import exit
class SigIntMixin: class SigIntMixin(object):
"""Registers a signal handler that calls finish on SIGINT.""" """Registers a signal handler that calls finish on SIGINT"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super(SigIntMixin, self).__init__(*args, **kwargs)
signal(SIGINT, self._sigint_handler) signal(SIGINT, self._sigint_handler)
def _sigint_handler(self, signum, frame): def _sigint_handler(self, signum, frame):

View file

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -12,13 +14,14 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite from . import Infinite
from .helpers import WriteMixin from .helpers import WriteMixin
class Spinner(WriteMixin, Infinite): class Spinner(WriteMixin, Infinite):
message = "" message = ''
phases = ("-", "\\", "|", "/") phases = ('-', '\\', '|', '/')
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -27,16 +30,15 @@ class Spinner(WriteMixin, Infinite):
class PieSpinner(Spinner): class PieSpinner(Spinner):
phases = ["", "", "", ""] phases = ['', '', '', '']
class MoonSpinner(Spinner): class MoonSpinner(Spinner):
phases = ["", "", "", ""] phases = ['', '', '', '']
class LineSpinner(Spinner): class LineSpinner(Spinner):
phases = ["", "", "", "", "", ""] phases = ['', '', '', '', '', '']
class PixelSpinner(Spinner): class PixelSpinner(Spinner):
phases = ["", "", "", "", "", "", "", ""] phases = ['','', '', '', '', '', '', '']

View file

@ -1,27 +1,29 @@
#!/usr/bin/env python #!/usr/bin/env python
import progress
from setuptools import setup from setuptools import setup
import progress
setup( setup(
name="progress", name='progress',
version=progress.__version__, version=progress.__version__,
description="Easy to use progress bars", description='Easy to use progress bars',
long_description=open("README.rst").read(), long_description=open('README.rst').read(),
author="Giorgos Verigakis", author='Giorgos Verigakis',
author_email="verigak@gmail.com", author_email='verigak@gmail.com',
url="http://github.com/verigak/progress/", url='http://github.com/verigak/progress/',
license="ISC", license='ISC',
packages=["progress"], packages=['progress'],
classifiers=[ classifiers=[
"Environment :: Console", 'Environment :: Console',
"Intended Audience :: Developers", 'Intended Audience :: Developers',
"License :: OSI Approved :: ISC License (ISCL)", 'License :: OSI Approved :: ISC License (ISCL)',
"Programming Language :: Python :: 2.6", 'Programming Language :: Python :: 2.6',
"Programming Language :: Python :: 2.7", 'Programming Language :: Python :: 2.7',
"Programming Language :: Python :: 3.3", 'Programming Language :: Python :: 3.3',
"Programming Language :: Python :: 3.4", 'Programming Language :: Python :: 3.4',
"Programming Language :: Python :: 3.5", 'Programming Language :: Python :: 3.5',
"Programming Language :: Python :: 3.6", 'Programming Language :: Python :: 3.6',
], ]
) )

View file

@ -1,12 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function
import random import random
import time import time
from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar from progress.bar import (Bar, ChargingBar, FillingSquaresBar,
from progress.counter import Countdown, Counter, Pie, Stack FillingCirclesBar, IncrementalBar, PixelBar,
from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner ShadyBar)
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
PixelSpinner)
from progress.counter import Counter, Countdown, Stack, Pie
def sleep(): def sleep():
@ -16,29 +20,29 @@ def sleep():
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]" suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for _i in bar.iter(range(200)): for i in bar.iter(range(200)):
sleep() sleep()
for bar_cls in (IncrementalBar, PixelBar, ShadyBar): for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]" suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for _i in bar.iter(range(200)): for i in bar.iter(range(200)):
sleep() sleep()
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
for _i in spin(spin.__name__ + " ").iter(range(100)): for i in spin(spin.__name__ + ' ').iter(range(100)):
sleep() sleep()
print() print()
for singleton in (Counter, Countdown, Stack, Pie): for singleton in (Counter, Countdown, Stack, Pie):
for _i in singleton(singleton.__name__ + " ").iter(range(100)): for i in singleton(singleton.__name__ + ' ').iter(range(100)):
sleep() sleep()
print() print()
bar = IncrementalBar("Random", suffix="%(index)d") bar = IncrementalBar('Random', suffix='%(index)d')
for _i in range(100): for i in range(100):
bar.goto(random.randint(0, 100)) bar.goto(random.randint(0, 100))
sleep() sleep()
bar.finish() bar.finish()

View file

@ -1,33 +1,23 @@
name: pointmlp name: pointmlp
channels: channels:
- pytorch - pytorch
- nvidia - nvidia
- conda-forge - conda-forge
dependencies: dependencies:
#---# basic python # - cudatoolkit=10.2.89
- pytorch - cudatoolkit=11.1
- tqdm - cycler=0.10.0
- numpy - einops=0.3.0
- scipy - h5py=3.2.1
- scikit-learn - matplotlib=3.4.2
#---# file readers - numpy=1.20.2
- h5py - numpy-base=1.20.2
- pyyaml - pytorch=1.8.1
#---# tooling (linting, typing...) - pyyaml=5.4.1
- ruff - scikit-learn=0.24.2
- mypy - scipy=1.6.3
- black - torchvision=0.9.1
- isort - tqdm=4.61.1
#---# visu
- matplotlib
#---# pytorch
- cudatoolkit
- cycler
- einops
- torchvision
- pip - pip
- pip: - pip:
- pointnet2_ops_lib/. - pointnet2_ops_lib/.

BIN
overview.pdf Normal file

Binary file not shown.

BIN
overview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

View file

@ -1,46 +1,30 @@
import argparse from __future__ import print_function
import os import os
import random import argparse
from collections import defaultdict import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from util.data_util import PartNormalDataset
import torch.nn.functional as F
import torch.nn as nn
import model as models import model as models
import numpy as np import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from util.util import to_categorical, compute_overall_iou, IOStream
from tqdm import tqdm from tqdm import tqdm
from util.data_util import PartNormalDataset from collections import defaultdict
from util.util import IOStream, compute_overall_iou, to_categorical from torch.autograd import Variable
import random
classes_str = [
"aero", classes_str = ['aero','bag','cap','car','chair','ear','guitar','knife','lamp','lapt','moto','mug','Pistol','rock','stake','table']
"bag",
"cap",
"car",
"chair",
"ear",
"guitar",
"knife",
"lamp",
"lapt",
"moto",
"mug",
"Pistol",
"rock",
"stake",
"table",
]
def _init_(): def _init_():
if not os.path.exists("checkpoints"): if not os.path.exists('checkpoints'):
os.makedirs("checkpoints") os.makedirs('checkpoints')
if not os.path.exists("checkpoints/" + args.exp_name): if not os.path.exists('checkpoints/' + args.exp_name):
os.makedirs("checkpoints/" + args.exp_name) os.makedirs('checkpoints/' + args.exp_name)
def weight_init(m): def weight_init(m):
@ -65,6 +49,7 @@ def weight_init(m):
def train(args, io): def train(args, io):
# ============= Model =================== # ============= Model ===================
num_part = 50 num_part = 50
device = torch.device("cuda" if args.cuda else "cpu") device = torch.device("cuda" if args.cuda else "cpu")
@ -76,19 +61,16 @@ def train(args, io):
model = nn.DataParallel(model) model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
"""Resume or not""" '''Resume or not'''
if args.resume: if args.resume:
state_dict = torch.load( state_dict = torch.load("checkpoints/%s/best_insiou_model.pth" % args.exp_name,
"checkpoints/%s/best_insiou_model.pth" % args.exp_name, map_location=torch.device('cpu'))['model']
map_location=torch.device("cpu"),
)["model"]
for k in state_dict.keys(): for k in state_dict.keys():
if "module" not in k: if 'module' not in k:
from collections import OrderedDict from collections import OrderedDict
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k in state_dict: for k in state_dict:
new_state_dict["module." + k] = state_dict[k] new_state_dict['module.' + k] = state_dict[k]
state_dict = new_state_dict state_dict = new_state_dict
break break
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
@ -99,27 +81,17 @@ def train(args, io):
print("Training from scratch...") print("Training from scratch...")
# =========== Dataloader ================= # =========== Dataloader =================
train_data = PartNormalDataset(npoints=2048, split="trainval", normalize=False) train_data = PartNormalDataset(npoints=2048, split='trainval', normalize=False)
print("The number of training data is:%d", len(train_data)) print("The number of training data is:%d", len(train_data))
test_data = PartNormalDataset(npoints=2048, split="test", normalize=False) test_data = PartNormalDataset(npoints=2048, split='test', normalize=False)
print("The number of test data is:%d", len(test_data)) print("The number of test data is:%d", len(test_data))
train_loader = DataLoader( train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
train_data, drop_last=True)
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
)
test_loader = DataLoader( test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers,
test_data, drop_last=False)
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.workers,
drop_last=False,
)
# ============= Optimizer ================ # ============= Optimizer ================
if args.use_sgd: if args.use_sgd:
@ -129,7 +101,7 @@ def train(args, io):
print("Use Adam") print("Use Adam")
opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
if args.scheduler == "cos": if args.scheduler == 'cos':
print("Use CosLR") print("Use CosLR")
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100) scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100)
else: else:
@ -144,33 +116,28 @@ def train(args, io):
num_classes = 16 num_classes = 16
for epoch in range(args.epochs): for epoch in range(args.epochs):
train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io) train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io)
test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io) test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io)
# 1. when get the best accuracy, save the model: # 1. when get the best accuracy, save the model:
if test_metrics["accuracy"] > best_acc: if test_metrics['accuracy'] > best_acc:
best_acc = test_metrics["accuracy"] best_acc = test_metrics['accuracy']
io.cprint("Max Acc:%.5f" % best_acc) io.cprint('Max Acc:%.5f' % best_acc)
state = { state = {
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), 'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
"optimizer": opt.state_dict(), 'optimizer': opt.state_dict(), 'epoch': epoch, 'test_acc': best_acc}
"epoch": epoch, torch.save(state, 'checkpoints/%s/best_acc_model.pth' % args.exp_name)
"test_acc": best_acc,
}
torch.save(state, "checkpoints/%s/best_acc_model.pth" % args.exp_name)
# 2. when get the best instance_iou, save the model: # 2. when get the best instance_iou, save the model:
if test_metrics["shape_avg_iou"] > best_instance_iou: if test_metrics['shape_avg_iou'] > best_instance_iou:
best_instance_iou = test_metrics["shape_avg_iou"] best_instance_iou = test_metrics['shape_avg_iou']
io.cprint("Max instance iou:%.5f" % best_instance_iou) io.cprint('Max instance iou:%.5f' % best_instance_iou)
state = { state = {
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), 'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
"optimizer": opt.state_dict(), 'optimizer': opt.state_dict(), 'epoch': epoch, 'test_instance_iou': best_instance_iou}
"epoch": epoch, torch.save(state, 'checkpoints/%s/best_insiou_model.pth' % args.exp_name)
"test_instance_iou": best_instance_iou,
}
torch.save(state, "checkpoints/%s/best_insiou_model.pth" % args.exp_name)
# 3. when get the best class_iou, save the model: # 3. when get the best class_iou, save the model:
# first we need to calculate the average per-class iou # first we need to calculate the average per-class iou
@ -182,28 +149,22 @@ def train(args, io):
best_class_iou = avg_class_iou best_class_iou = avg_class_iou
# print the iou of each class: # print the iou of each class:
for cat_idx in range(16): for cat_idx in range(16):
io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx])) io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx]))
io.cprint("Max class iou:%.5f" % best_class_iou) io.cprint('Max class iou:%.5f' % best_class_iou)
state = { state = {
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), 'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
"optimizer": opt.state_dict(), 'optimizer': opt.state_dict(), 'epoch': epoch, 'test_class_iou': best_class_iou}
"epoch": epoch, torch.save(state, 'checkpoints/%s/best_clsiou_model.pth' % args.exp_name)
"test_class_iou": best_class_iou,
}
torch.save(state, "checkpoints/%s/best_clsiou_model.pth" % args.exp_name)
# report best acc, ins_iou, cls_iou # report best acc, ins_iou, cls_iou
io.cprint("Final Max Acc:%.5f" % best_acc) io.cprint('Final Max Acc:%.5f' % best_acc)
io.cprint("Final Max instance iou:%.5f" % best_instance_iou) io.cprint('Final Max instance iou:%.5f' % best_instance_iou)
io.cprint("Final Max class iou:%.5f" % best_class_iou) io.cprint('Final Max class iou:%.5f' % best_class_iou)
# save last model # save last model
state = { state = {
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), 'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
"optimizer": opt.state_dict(), 'optimizer': opt.state_dict(), 'epoch': args.epochs - 1, 'test_iou': best_instance_iou}
"epoch": args.epochs - 1, torch.save(state, 'checkpoints/%s/model_ep%d.pth' % (args.exp_name, args.epochs))
"test_iou": best_instance_iou,
}
torch.save(state, "checkpoints/%s/model_ep%d.pth" % (args.exp_name, args.epochs))
def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io): def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io):
@ -214,41 +175,22 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
metrics = defaultdict(lambda: list()) metrics = defaultdict(lambda: list())
model.train() model.train()
for _batch_id, (points, label, target, norm_plt) in tqdm( for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
enumerate(train_loader),
total=len(train_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = ( points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \
Variable(points.float()), Variable(norm_plt.float())
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = ( points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \
points.cuda(non_blocking=True), target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
label.squeeze(1).cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
# target: b,n # target: b,n
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50
loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0]) loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0])
# instance iou without considering the class average at each batch_size: # instance iou without considering the class average at each batch_size:
batch_shapeious = compute_overall_iou( batch_shapeious = compute_overall_iou(seg_pred, target, num_part) # list of of current batch_iou:[iou1,iou2,...,iou#b_size]
seg_pred,
target,
num_part,
) # list of of current batch_iou:[iou1,iou2,...,iou#b_size]
# total iou of current batch in each process: # total iou of current batch in each process:
batch_shapeious = seg_pred.new_tensor( batch_shapeious = seg_pred.new_tensor([np.sum(batch_shapeious)], dtype=torch.float64) # same device with seg_pred!!!
[np.sum(batch_shapeious)],
dtype=torch.float64,
) # same device with seg_pred!!!
# Loss backward # Loss backward
loss = torch.mean(loss) loss = torch.mean(loss)
@ -270,25 +212,21 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
# Note: We do not need to calculate per_class iou during training # Note: We do not need to calculate per_class iou during training
if args.scheduler == "cos": if args.scheduler == 'cos':
scheduler.step() scheduler.step()
elif args.scheduler == "step": elif args.scheduler == 'step':
if opt.param_groups[0]["lr"] > 0.9e-5: if opt.param_groups[0]['lr'] > 0.9e-5:
scheduler.step() scheduler.step()
if opt.param_groups[0]["lr"] < 0.9e-5: if opt.param_groups[0]['lr'] < 0.9e-5:
for param_group in opt.param_groups: for param_group in opt.param_groups:
param_group["lr"] = 0.9e-5 param_group['lr'] = 0.9e-5
io.cprint("Learning rate: %f" % opt.param_groups[0]["lr"]) io.cprint('Learning rate: %f' % opt.param_groups[0]['lr'])
metrics["accuracy"] = np.mean(accuracy) metrics['accuracy'] = np.mean(accuracy)
metrics["shape_avg_iou"] = shape_ious * 1.0 / count metrics['shape_avg_iou'] = shape_ious * 1.0 / count
outstr = "Train %d, loss: %f, train acc: %f, train ins_iou: %f" % ( outstr = 'Train %d, loss: %f, train acc: %f, train ins_iou: %f' % (epoch+1, train_loss * 1.0 / count,
epoch + 1, metrics['accuracy'], metrics['shape_avg_iou'])
train_loss * 1.0 / count,
metrics["accuracy"],
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
@ -303,26 +241,14 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
model.eval() model.eval()
# label_size: b, means each sample has one corresponding class # label_size: b, means each sample has one corresponding class
for _batch_id, (points, label, target, norm_plt) in tqdm( for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
enumerate(test_loader),
total=len(test_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = ( points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \
Variable(points.float()), Variable(norm_plt.float())
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = ( points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \
points.cuda(non_blocking=True), target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
label.squeeze(1).cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
# instance iou without considering the class average at each batch_size: # instance iou without considering the class average at each batch_size:
@ -355,19 +281,13 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
for cat_idx in range(16): for cat_idx in range(16):
if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending
final_total_per_cat_iou[cat_idx] = ( final_total_per_cat_iou[cat_idx] = final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx] # avg class iou across all samples
final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx]
) # avg class iou across all samples
metrics["accuracy"] = np.mean(accuracy) metrics['accuracy'] = np.mean(accuracy)
metrics["shape_avg_iou"] = shape_ious * 1.0 / count metrics['shape_avg_iou'] = shape_ious * 1.0 / count
outstr = "Test %d, loss: %f, test acc: %f test ins_iou: %f" % ( outstr = 'Test %d, loss: %f, test acc: %f test ins_iou: %f' % (epoch + 1, test_loss * 1.0 / count,
epoch + 1, metrics['accuracy'], metrics['shape_avg_iou'])
test_loss * 1.0 / count,
metrics["accuracy"],
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
@ -376,16 +296,11 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
def test(args, io): def test(args, io):
# Dataloader # Dataloader
test_data = PartNormalDataset(npoints=2048, split="test", normalize=False) test_data = PartNormalDataset(npoints=2048, split='test', normalize=False)
print("The number of test data is:%d", len(test_data)) print("The number of test data is:%d", len(test_data))
test_loader = DataLoader( test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers,
test_data, drop_last=False)
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.workers,
drop_last=False,
)
# Try to load models # Try to load models
num_part = 50 num_part = 50
@ -395,15 +310,12 @@ def test(args, io):
io.cprint(str(model)) io.cprint(str(model))
from collections import OrderedDict from collections import OrderedDict
state_dict = torch.load("checkpoints/%s/best_%s_model.pth" % (args.exp_name, args.model_type),
state_dict = torch.load( map_location=torch.device('cpu'))['model']
f"checkpoints/{args.exp_name}/best_{args.model_type}_model.pth",
map_location=torch.device("cpu"),
)["model"]
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for layer in state_dict: for layer in state_dict:
new_state_dict[layer.replace("module.", "")] = state_dict[layer] new_state_dict[layer.replace('module.', '')] = state_dict[layer]
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
model.eval() model.eval()
@ -412,29 +324,16 @@ def test(args, io):
metrics = defaultdict(lambda: list()) metrics = defaultdict(lambda: list())
hist_acc = [] hist_acc = []
shape_ious = [] shape_ious = []
total_per_cat_iou = np.zeros(16).astype(np.float32) total_per_cat_iou = np.zeros((16)).astype(np.float32)
total_per_cat_seen = np.zeros(16).astype(np.int32) total_per_cat_seen = np.zeros((16)).astype(np.int32)
for _batch_id, (points, label, target, norm_plt) in tqdm( for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
enumerate(test_loader),
total=len(test_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = ( points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), Variable(norm_plt.float())
Variable(points.float()),
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = ( points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze().cuda(
points.cuda(non_blocking=True), non_blocking=True), target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
label.squeeze().cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
with torch.no_grad(): with torch.no_grad():
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
@ -454,11 +353,11 @@ def test(args, io):
target = target.view(-1, 1)[:, 0] target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1] pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum() correct = pred_choice.eq(target.data).cpu().sum()
metrics["accuracy"].append(correct.item() / (batch_size * num_point)) metrics['accuracy'].append(correct.item() / (batch_size * num_point))
hist_acc += metrics["accuracy"] hist_acc += metrics['accuracy']
metrics["accuracy"] = np.mean(hist_acc) metrics['accuracy'] = np.mean(hist_acc)
metrics["shape_avg_iou"] = np.mean(shape_ious) metrics['shape_avg_iou'] = np.mean(shape_ious)
for cat_idx in range(16): for cat_idx in range(16):
if total_per_cat_seen[cat_idx] > 0: if total_per_cat_seen[cat_idx] > 0:
total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx] total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx]
@ -467,41 +366,47 @@ def test(args, io):
class_iou = 0 class_iou = 0
for cat_idx in range(16): for cat_idx in range(16):
class_iou += total_per_cat_iou[cat_idx] class_iou += total_per_cat_iou[cat_idx]
io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx])) # print the iou of each class io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx])) # print the iou of each class
avg_class_iou = class_iou / 16 avg_class_iou = class_iou / 16
outstr = "Test :: test acc: {:f} test class mIOU: {:f}, test instance mIOU: {:f}".format( outstr = 'Test :: test acc: %f test class mIOU: %f, test instance mIOU: %f' % (metrics['accuracy'], avg_class_iou, metrics['shape_avg_iou'])
metrics["accuracy"],
avg_class_iou,
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
if __name__ == "__main__": if __name__ == "__main__":
# Training settings # Training settings
parser = argparse.ArgumentParser(description="3D Shape Part Segmentation") parser = argparse.ArgumentParser(description='3D Shape Part Segmentation')
parser.add_argument("--model", type=str, default="PointMLP1") parser.add_argument('--model', type=str, default='PointMLP1')
parser.add_argument("--exp_name", type=str, default="demo1", metavar="N", help="Name of the experiment") parser.add_argument('--exp_name', type=str, default='demo1', metavar='N',
parser.add_argument("--batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)") help='Name of the experiment')
parser.add_argument("--test_batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)") parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
parser.add_argument("--epochs", type=int, default=350, metavar="N", help="number of episode to train") help='Size of batch)')
parser.add_argument("--use_sgd", type=bool, default=False, help="Use SGD") parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size',
parser.add_argument("--scheduler", type=str, default="step", help="lr scheduler") help='Size of batch)')
parser.add_argument("--step", type=int, default=40, help="lr decay step") parser.add_argument('--epochs', type=int, default=350, metavar='N',
parser.add_argument("--lr", type=float, default=0.003, metavar="LR", help="learning rate") help='number of episode to train')
parser.add_argument("--momentum", type=float, default=0.9, metavar="M", help="SGD momentum (default: 0.9)") parser.add_argument('--use_sgd', type=bool, default=False,
parser.add_argument("--no_cuda", type=bool, default=False, help="enables CUDA training") help='Use SGD')
parser.add_argument("--manual_seed", type=int, metavar="S", help="random seed (default: 1)") parser.add_argument('--scheduler', type=str, default='step',
parser.add_argument("--eval", type=bool, default=False, help="evaluate the model") help='lr scheduler')
parser.add_argument("--num_points", type=int, default=2048, help="num of points to use") parser.add_argument('--step', type=int, default=40,
parser.add_argument("--workers", type=int, default=12) help='lr decay step')
parser.add_argument("--resume", type=bool, default=False, help="Resume training or not") parser.add_argument('--lr', type=float, default=0.003, metavar='LR',
parser.add_argument( help='learning rate')
"--model_type", parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
type=str, help='SGD momentum (default: 0.9)')
default="insiou", parser.add_argument('--no_cuda', type=bool, default=False,
help="choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)", help='enables CUDA training')
) parser.add_argument('--manual_seed', type=int, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--eval', type=bool, default=False,
help='evaluate the model')
parser.add_argument('--num_points', type=int, default=2048,
help='num of points to use')
parser.add_argument('--workers', type=int, default=12)
parser.add_argument('--resume', type=bool, default=False,
help='Resume training or not')
parser.add_argument('--model_type', type=str, default='insiou',
help='choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)')
args = parser.parse_args() args = parser.parse_args()
args.exp_name = args.model+"_"+args.exp_name args.exp_name = args.model+"_"+args.exp_name
@ -509,9 +414,9 @@ if __name__ == "__main__":
_init_() _init_()
if not args.eval: if not args.eval:
io = IOStream("checkpoints/" + args.exp_name + "/%s_train.log" % (args.exp_name)) io = IOStream('checkpoints/' + args.exp_name + '/%s_train.log' % (args.exp_name))
else: else:
io = IOStream("checkpoints/" + args.exp_name + "/%s_test.log" % (args.exp_name)) io = IOStream('checkpoints/' + args.exp_name + '/%s_test.log' % (args.exp_name))
io.cprint(str(args)) io.cprint(str(args))
if args.manual_seed is not None: if args.manual_seed is not None:
@ -522,12 +427,12 @@ if __name__ == "__main__":
args.cuda = not args.no_cuda and torch.cuda.is_available() args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda: if args.cuda:
io.cprint("Using GPU") io.cprint('Using GPU')
if args.manual_seed is not None: if args.manual_seed is not None:
torch.cuda.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed) torch.cuda.manual_seed_all(args.manual_seed)
else: else:
io.cprint("Using CPU") io.cprint('Using CPU')
if not args.eval: if not args.eval:
train(args, io) train(args, io)

View file

@ -1 +1,2 @@
from __future__ import absolute_import
from .pointMLP import pointMLP from .pointMLP import pointMLP

View file

@ -1,32 +1,34 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import einsum
from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == "gelu": if activation.lower() == 'gelu':
return nn.GELU() return nn.GELU()
elif activation.lower() == "rrelu": elif activation.lower() == 'rrelu':
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == "selu": elif activation.lower() == 'selu':
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == "silu": elif activation.lower() == 'silu':
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == "hardswish": elif activation.lower() == 'hardswish':
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == "leakyrelu": elif activation.lower() == 'leakyrelu':
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
elif activation.lower() == "leakyrelu0.2": elif activation.lower() == 'leakyrelu0.2':
return nn.LeakyReLU(negative_slope=0.2, inplace=True) return nn.LeakyReLU(negative_slope=0.2, inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
"""Calculate Euclid distance between each two points. """
src^T * dst = xn * xm + yn * ym + zn * zm; Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -35,7 +37,7 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M]. dist: per-point square distance, [B, N, M]
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
@ -46,12 +48,12 @@ def square_distance(src, dst):
def index_points(points, idx): def index_points(points, idx):
"""Input: """
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S]. idx: sample index data, [B, S]
Return: Return:
new_points:, indexed points data, [B, S, C]. new_points:, indexed points data, [B, S, C]
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -60,15 +62,17 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
return points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
"""Input: """
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint]. centroids: sampled pointcloud index, [B, npoint]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -86,14 +90,14 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
"""Input: """
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]. new_xyz: query points, [B, S, 3]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -109,13 +113,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
"""Input: """
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]. new_xyz: query points, [B, S, C]
Return: Return:
group_idx: grouped points index, [B, S, nsample]. group_idx: grouped points index, [B, S, nsample]
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -124,12 +128,13 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs):
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] """
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others. :param kwargs: others
""" """
super().__init__() super(LocalGrouper, self).__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -138,7 +143,7 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel=3 if self.use_xyz else 0
@ -168,11 +173,7 @@ class LocalGrouper(nn.Module):
if self.normalize =="anchor": if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = ( std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
.unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta grouped_points = self.affine_alpha*grouped_points + self.affine_beta
@ -181,13 +182,13 @@ class LocalGrouper(nn.Module):
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
super().__init__() super(ConvBNReLU1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act, self.act
) )
def forward(self, x): def forward(self, x):
@ -195,43 +196,30 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
super().__init__() super(ConvBNReLURes1D, self).__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
in_channels=channel, kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act, self.act
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, groups=groups, bias=bias),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=channel, out_channels=channel,
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
in_channels=int(channel * res_expansion), kernel_size=kernel_size, bias=bias),
out_channels=channel, nn.BatchNorm1d(channel)
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -239,34 +227,21 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__( def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
self, activation='relu', use_xyz=True):
channels, """
out_channels, input: [b,g,k,d]: output:[b,d,g]
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PreExtraction, self).__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3+2*channels if use_xyz else 2*channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D( ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
out_channels, bias=bias, activation=activation)
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -278,20 +253,22 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
return x.reshape(b, n, -1).permute(0, 2, 1) x = x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
"""input[b,d,g]; output[b,d,g] """
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super().__init__() super(PosExtraction, self).__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation), ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -300,27 +277,22 @@ class PosExtraction(nn.Module):
class PointNetFeaturePropagation(nn.Module): class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation="relu"): def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
super().__init__() super(PointNetFeaturePropagation, self).__init__()
self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias) self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias)
self.extraction = PosExtraction( self.extraction = PosExtraction(out_channel, blocks, groups=groups,
out_channel, res_expansion=res_expansion, bias=bias, activation=activation)
blocks,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
def forward(self, xyz1, xyz2, points1, points2): def forward(self, xyz1, xyz2, points1, points2):
"""Input: """
Input:
xyz1: input points position data, [B, N, 3] xyz1: input points position data, [B, N, 3]
xyz2: sampled input points position data, [B, S, 3] xyz2: sampled input points position data, [B, S, 3]
points1: input points data, [B, D', N] points1: input points data, [B, D', N]
points2: input points data, [B, D'', S]. points2: input points data, [B, D'', S]
Return: Return:
new_points: upsampled points data, [B, D''', N]. new_points: upsampled points data, [B, D''', N]
""" """
# xyz1 = xyz1.permute(0, 2, 1) # xyz1 = xyz1.permute(0, 2, 1)
# xyz2 = xyz2.permute(0, 2, 1) # xyz2 = xyz2.permute(0, 2, 1)
@ -349,40 +321,26 @@ class PointNetFeaturePropagation(nn.Module):
new_points = new_points.permute(0, 2, 1) new_points = new_points.permute(0, 2, 1)
new_points = self.fuse(new_points) new_points = self.fuse(new_points)
return self.extraction(new_points) new_points = self.extraction(new_points)
return new_points
class PointMLP(nn.Module): class PointMLP(nn.Module):
def __init__( def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expansion=1.0,
self, activation="relu", bias=True, use_xyz=True, normalize="anchor",
num_classes=50, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
points=2048, k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4],
embed_dim=64, de_dims=[512, 256, 128, 128], de_blocks=[2,2,2,2],
groups=1, gmp_dim=64,cls_dim=64, **kwargs):
res_expansion=1.0, super(PointMLP, self).__init__()
activation="relu",
bias=True,
use_xyz=True,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[4, 4, 4, 4],
de_dims=[512, 256, 128, 128],
de_blocks=[2, 2, 2, 2],
gmp_dim=64,
cls_dim=64,
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = num_classes self.class_num = num_classes
self.points = points self.points = points
self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation)
assert ( assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion) "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -401,31 +359,19 @@ class PointMLP(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction( pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
last_channel,
out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion, res_expansion=res_expansion,
bias=bias, bias=bias, activation=activation, use_xyz=use_xyz)
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction( pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
out_channel, res_expansion=res_expansion, bias=bias, activation=activation)
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
en_dims.append(last_channel) en_dims.append(last_channel)
### Building Decoder ##### ### Building Decoder #####
self.decode_list = nn.ModuleList() self.decode_list = nn.ModuleList()
en_dims.reverse() en_dims.reverse()
@ -433,15 +379,9 @@ class PointMLP(nn.Module):
assert len(en_dims) ==len(de_dims) == len(de_blocks)+1 assert len(en_dims) ==len(de_dims) == len(de_blocks)+1
for i in range(len(en_dims)-1): for i in range(len(en_dims)-1):
self.decode_list.append( self.decode_list.append(
PointNetFeaturePropagation( PointNetFeaturePropagation(de_dims[i]+en_dims[i+1], de_dims[i+1],
de_dims[i] + en_dims[i + 1], blocks=de_blocks[i], groups=groups, res_expansion=res_expansion,
de_dims[i + 1], bias=bias, activation=activation)
blocks=de_blocks[i],
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.act = get_activation(activation) self.act = get_activation(activation)
@ -449,7 +389,7 @@ class PointMLP(nn.Module):
# class label mapping # class label mapping
self.cls_map = nn.Sequential( self.cls_map = nn.Sequential(
ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation), ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation),
ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation), ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation)
) )
# global max pooling mapping # global max pooling mapping
self.gmp_map_list = nn.ModuleList() self.gmp_map_list = nn.ModuleList()
@ -462,7 +402,7 @@ class PointMLP(nn.Module):
nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias), nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias),
nn.BatchNorm1d(128), nn.BatchNorm1d(128),
nn.Dropout(), nn.Dropout(),
nn.Conv1d(128, num_classes, 1, bias=bias), nn.Conv1d(128, num_classes, 1, bias=bias)
) )
self.en_dims = en_dims self.en_dims = en_dims
@ -501,42 +441,24 @@ class PointMLP(nn.Module):
x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1) x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1)
x = self.classifier(x) x = self.classifier(x)
x = F.log_softmax(x, dim=1) x = F.log_softmax(x, dim=1)
return x.permute(0, 2, 1) x = x.permute(0, 2, 1)
return x
def pointMLP(num_classes=50, **kwargs) -> PointMLP: def pointMLP(num_classes=50, **kwargs) -> PointMLP:
return PointMLP( return PointMLP(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0,
num_classes=num_classes, activation="relu", bias=True, use_xyz=True, normalize="anchor",
points=2048, dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
embed_dim=64, k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4],
groups=1, de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4],
res_expansion=1.0, gmp_dim=64,cls_dim=64, **kwargs)
activation="relu",
bias=True,
use_xyz=True,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[4, 4, 4, 4],
de_dims=[512, 256, 128, 128],
de_blocks=[4, 4, 4, 4],
gmp_dim=64,
cls_dim=64,
**kwargs,
)
if __name__ == "__main__": if __name__ == '__main__':
data = torch.rand(2, 3, 2048).cuda() data = torch.rand(2, 3, 2048)
norm = torch.rand(2, 3, 2048).cuda() norm = torch.rand(2, 3, 2048)
cls_label = torch.rand([2, 16]).cuda() cls_label = torch.rand([2, 16])
print(f"data shape: {data.shape}") print("===> testing modelD ...")
print(f"norm shape: {norm.shape}") model = pointMLP(50)
print(f"cls_label shape: {cls_label.shape}") out = model(data, cls_label) # [2,2048,50]
print(out.shape)
print("===> testing pointMLP (segmentation) ...")
model = pointMLP(50).cuda()
out = model(data, norm, cls_label) # [2,2048,50]
print(f"out shape: {out.shape}")

View file

@ -1,21 +1,19 @@
import glob import glob
import json
import os
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
import os
import json
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def load_data(partition): def load_data(partition):
all_data = [] all_data = []
all_label = [] all_label = []
for h5_name in glob.glob("./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5" % partition): for h5_name in glob.glob('./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5' % partition):
f = h5py.File(h5_name) f = h5py.File(h5_name)
data = f["data"][:].astype("float32") data = f['data'][:].astype('float32')
label = f["label"][:].astype("int64") label = f['label'][:].astype('int64')
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -28,14 +26,16 @@ def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
return pc / m pc = pc / m
return pc
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3]) xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32") translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
return translated_pointcloud
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
@ -46,7 +46,7 @@ def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
# =========== ModelNet40 ================= # =========== ModelNet40 =================
class ModelNet40(Dataset): class ModelNet40(Dataset):
def __init__(self, num_points, partition="train"): def __init__(self, num_points, partition='train'):
self.data, self.label = load_data(partition) self.data, self.label = load_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition # Here the new given partition will cover the 'train' self.partition = partition # Here the new given partition will cover the 'train'
@ -54,7 +54,7 @@ class ModelNet40(Dataset):
def __getitem__(self, item): # indice of the pts or label def __getitem__(self, item): # indice of the pts or label
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][:self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == "train": if self.partition == 'train':
# pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model # pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) # shuffle the order of pts np.random.shuffle(pointcloud) # shuffle the order of pts
@ -66,46 +66,46 @@ class ModelNet40(Dataset):
# =========== ShapeNet Part ================= # =========== ShapeNet Part =================
class PartNormalDataset(Dataset): class PartNormalDataset(Dataset):
def __init__(self, npoints=2500, split="train", normalize=False): def __init__(self, npoints=2500, split='train', normalize=False):
self.npoints = npoints self.npoints = npoints
self.root = "./data/shapenetcore_partanno_segmentation_benchmark_v0_normal" self.root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal'
self.catfile = os.path.join(self.root, "synsetoffset2category.txt") self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {} self.cat = {}
self.normalize = normalize self.normalize = normalize
with open(self.catfile) as f: with open(self.catfile, 'r') as f:
for line in f: for line in f:
ls = line.strip().split() ls = line.strip().split()
self.cat[ls[0]] = ls[1] self.cat[ls[0]] = ls[1]
self.cat = {k: v for k, v in self.cat.items()} self.cat = {k: v for k, v in self.cat.items()}
self.meta = {} self.meta = {}
with open(os.path.join(self.root, "train_test_split", "shuffled_train_file_list.json")) as f: with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split("/")[2]) for d in json.load(f)]) train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, "train_test_split", "shuffled_val_file_list.json")) as f: with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split("/")[2]) for d in json.load(f)]) val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, "train_test_split", "shuffled_test_file_list.json")) as f: with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split("/")[2]) for d in json.load(f)]) test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat: for item in self.cat:
self.meta[item] = [] self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item]) dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point)) fns = sorted(os.listdir(dir_point))
if split == "trainval": if split == 'trainval':
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == "train": elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids] fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == "val": elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids] fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == "test": elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids] fns = [fn for fn in fns if fn[0:-4] in test_ids]
else: else:
print("Unknown split: %s. Exiting.." % (split)) print('Unknown split: %s. Exiting..' % (split))
exit(-1) exit(-1)
for fn in fns: for fn in fns:
token = os.path.splitext(os.path.basename(fn))[0] token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append(os.path.join(dir_point, token + ".txt")) self.meta[item].append(os.path.join(dir_point, token + '.txt'))
self.datapath = [] self.datapath = []
for item in self.cat: for item in self.cat:
@ -114,24 +114,11 @@ class PartNormalDataset(Dataset):
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
self.seg_classes = { self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
"Earphone": [16, 17, 18], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
"Motorbike": [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
"Rocket": [41, 42, 43], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
"Car": [8, 9, 10, 11], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
"Laptop": [28, 29],
"Cap": [6, 7],
"Skateboard": [44, 45, 46],
"Mug": [36, 37],
"Guitar": [19, 20, 21],
"Bag": [4, 5],
"Lamp": [24, 25, 26, 27],
"Table": [47, 48, 49],
"Airplane": [0, 1, 2, 3],
"Pistol": [38, 39, 40],
"Chair": [12, 13, 14, 15],
"Knife": [22, 23],
}
self.cache = {} # from index to (point_set, cls, seg) tuple self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 20000 self.cache_size = 20000
@ -169,9 +156,9 @@ class PartNormalDataset(Dataset):
return len(self.datapath) return len(self.datapath)
if __name__ == "__main__": if __name__ == '__main__':
train = PartNormalDataset(npoints=2048, split="trainval", normalize=False) train = PartNormalDataset(npoints=2048, split='trainval', normalize=False)
test = PartNormalDataset(npoints=2048, split="test", normalize=False) test = PartNormalDataset(npoints=2048, split='test', normalize=False)
for data, label, _, _ in train: for data, label, _, _ in train:
print(data.shape) print(data.shape)
print(label.shape) print(label.shape)

View file

@ -4,7 +4,8 @@ import torch.nn.functional as F
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed.""" ''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
if smoothing: if smoothing:
@ -17,19 +18,19 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction="mean") loss = F.cross_entropy(pred, gold, reduction='mean')
return loss return loss
# create a file and write the text into it: # create a file and write the text into it:
class IOStream: class IOStream():
def __init__(self, path): def __init__(self, path):
self.f = open(path, "a") self.f = open(path, 'a')
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text + "\n") self.f.write(text+'\n')
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -37,9 +38,9 @@ class IOStream:
def to_categorical(y, num_classes): def to_categorical(y, num_classes):
"""1-hot encodes a tensor.""" """ 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if y.is_cuda: if (y.is_cuda):
return new_y.cuda(non_blocking=True) return new_y.cuda(non_blocking=True)
return new_y return new_y
@ -52,9 +53,7 @@ def compute_overall_iou(pred, target, num_classes):
target_np = target.cpu().data.numpy() target_np = target.cpu().data.numpy()
for shape_idx in range(pred.size(0)): # sample_idx for shape_idx in range(pred.size(0)): # sample_idx
part_ious = [] part_ious = []
for part in range( for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
num_classes,
): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
# for target, each point has a class no matter which category owns this point! also 50 classes!!! # for target, each point has a class no matter which category owns this point! also 50 classes!!!
# only return 1 when both belongs to this class, which means correct: # only return 1 when both belongs to this class, which means correct:
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
@ -66,7 +65,5 @@ def compute_overall_iou(pred, target, num_classes):
if F != 0: if F != 0:
iou = I / float(U) # iou across all points for this class iou = I / float(U) # iou across all points for this class
part_ious.append(iou) # append the iou of this class part_ious.append(iou) # append the iou of this class
shape_ious.append( shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!)
np.mean(part_ious),
) # each time append an average iou across all classes of this sample (sample_level!)
return shape_ious # [batch_size] return shape_ious # [batch_size]

View file

@ -1,15 +1,16 @@
from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def build_shared_mlp(mlp_spec: list[int], bn: bool = True): def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
layers = [] layers = []
for i in range(1, len(mlp_spec)): for i in range(1, len(mlp_spec)):
layers.append( layers.append(
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn), nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
) )
if bn: if bn:
layers.append(nn.BatchNorm2d(mlp_spec[i])) layers.append(nn.BatchNorm2d(mlp_spec[i]))
@ -20,37 +21,36 @@ def build_shared_mlp(mlp_spec: list[int], bn: bool = True):
class _PointnetSAModuleBase(nn.Module): class _PointnetSAModuleBase(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super(_PointnetSAModuleBase, self).__init__()
self.npoint = None self.npoint = None
self.groupers = None self.groupers = None
self.mlps = None self.mlps = None
def forward( def forward(
self, self, xyz: torch.Tensor, features: Optional[torch.Tensor]
xyz: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]:
features: torch.Tensor | None, r"""
) -> tuple[torch.Tensor, torch.Tensor]: Parameters
r"""Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the features (B, N, 3) tensor of the xyz coordinates of the features
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor of the descriptors of the the features (B, C, N) tensor of the descriptors of the the features
Returns: Returns
------- -------
new_xyz : torch.Tensor new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new features' xyz (B, npoint, 3) tensor of the new features' xyz
new_features : torch.Tensor new_features : torch.Tensor
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
""" """
new_features_list = [] new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous() xyz_flipped = xyz.transpose(1, 2).contiguous()
new_xyz = ( new_xyz = (
pointnet2_utils.gather_operation( pointnet2_utils.gather_operation(
xyz_flipped, xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
pointnet2_utils.furthest_point_sample(xyz, self.npoint),
) )
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
@ -60,15 +60,12 @@ class _PointnetSAModuleBase(nn.Module):
for i in range(len(self.groupers)): for i in range(len(self.groupers)):
new_features = self.groupers[i]( new_features = self.groupers[i](
xyz, xyz, new_xyz, features
new_xyz,
features,
) # (B, C, npoint, nsample) ) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
new_features = F.max_pool2d( new_features = F.max_pool2d(
new_features, new_features, kernel_size=[1, new_features.size(3)]
kernel_size=[1, new_features.size(3)],
) # (B, mlp[-1], npoint, 1) ) # (B, mlp[-1], npoint, 1)
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
@ -78,7 +75,7 @@ class _PointnetSAModuleBase(nn.Module):
class PointnetSAModuleMSG(_PointnetSAModuleBase): class PointnetSAModuleMSG(_PointnetSAModuleBase):
r"""Pointnet set abstrction layer with multiscale grouping. r"""Pointnet set abstrction layer with multiscale grouping
Parameters Parameters
---------- ----------
@ -96,7 +93,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
super().__init__() super(PointnetSAModuleMSG, self).__init__()
assert len(radii) == len(nsamples) == len(mlps) assert len(radii) == len(nsamples) == len(mlps)
@ -109,7 +106,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
self.groupers.append( self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None if npoint is not None
else pointnet2_utils.GroupAll(use_xyz), else pointnet2_utils.GroupAll(use_xyz)
) )
mlp_spec = mlps[i] mlp_spec = mlps[i]
if use_xyz: if use_xyz:
@ -119,7 +116,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
class PointnetSAModule(PointnetSAModuleMSG): class PointnetSAModule(PointnetSAModuleMSG):
r"""Pointnet set abstrction layer. r"""Pointnet set abstrction layer
Parameters Parameters
---------- ----------
@ -136,16 +133,10 @@ class PointnetSAModule(PointnetSAModuleMSG):
""" """
def __init__( def __init__(
self, self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
mlp,
npoint=None,
radius=None,
nsample=None,
bn=True,
use_xyz=True,
): ):
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
super().__init__( super(PointnetSAModule, self).__init__(
mlps=[mlp], mlps=[mlp],
npoint=npoint, npoint=npoint,
radii=[radius], radii=[radius],
@ -156,7 +147,7 @@ class PointnetSAModule(PointnetSAModuleMSG):
class PointnetFPModule(nn.Module): class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another. r"""Propigates the features of one set to another
Parameters Parameters
---------- ----------
@ -168,12 +159,13 @@ class PointnetFPModule(nn.Module):
def __init__(self, mlp, bn=True): def __init__(self, mlp, bn=True):
# type: (PointnetFPModule, List[int], bool) -> None # type: (PointnetFPModule, List[int], bool) -> None
super().__init__() super(PointnetFPModule, self).__init__()
self.mlp = build_shared_mlp(mlp, bn=bn) self.mlp = build_shared_mlp(mlp, bn=bn)
def forward(self, unknown, known, unknow_feats, known_feats): def forward(self, unknown, known, unknow_feats, known_feats):
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""Parameters r"""
Parameters
---------- ----------
unknown : torch.Tensor unknown : torch.Tensor
(B, n, 3) tensor of the xyz positions of the unknown features (B, n, 3) tensor of the xyz positions of the unknown features
@ -184,11 +176,12 @@ class PointnetFPModule(nn.Module):
known_feats : torch.Tensor known_feats : torch.Tensor
(B, C2, m) tensor of features to be propigated (B, C2, m) tensor of features to be propigated
Returns: Returns
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, mlp[-1], n) tensor of the features of the unknown features (B, mlp[-1], n) tensor of the features of the unknown features
""" """
if known is not None: if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known) dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8) dist_recip = 1.0 / (dist + 1e-8)
@ -196,19 +189,16 @@ class PointnetFPModule(nn.Module):
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate( interpolated_feats = pointnet2_utils.three_interpolate(
known_feats, known_feats, idx, weight
idx,
weight,
) )
else: else:
interpolated_feats = known_feats.expand( interpolated_feats = known_feats.expand(
*(known_feats.size()[0:2] + [unknown.size(1)]), *(known_feats.size()[0:2] + [unknown.size(1)])
) )
if unknow_feats is not None: if unknow_feats is not None:
new_features = torch.cat( new_features = torch.cat(
[interpolated_feats, unknow_feats], [interpolated_feats, unknow_feats], dim=1
dim=1,
) # (B, C2 + C1, n) ) # (B, C2 + C1, n)
else: else:
new_features = interpolated_feats new_features = interpolated_feats

View file

@ -1,24 +1,22 @@
import warnings
from typing import *
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings
from torch.autograd import Function from torch.autograd import Function
from typing import *
try: try:
import pointnet2_ops._ext as _ext import pointnet2_ops._ext as _ext
except ImportError: except ImportError:
import glob
import os
import os.path as osp
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
import glob
import os.path as osp
import os
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
osp.join(_ext_src_root, "src", "*.cu"), osp.join(_ext_src_root, "src", "*.cu")
) )
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
@ -37,8 +35,9 @@ class FurthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz, npoint): def forward(ctx, xyz, npoint):
# type: (Any, torch.Tensor, int) -> torch.Tensor # type: (Any, torch.Tensor, int) -> torch.Tensor
r"""Uses iterative furthest point sampling to select a set of npoint features that have the largest r"""
minimum distance. Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
Parameters Parameters
---------- ----------
@ -47,7 +46,7 @@ class FurthestPointSampling(Function):
npoint : int32 npoint : int32
number of features in the sampled set number of features in the sampled set
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, npoint) tensor containing the set (B, npoint) tensor containing the set
@ -70,7 +69,9 @@ class GatherOperation(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx): def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""Parameters r"""
Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor (B, C, N) tensor
@ -78,11 +79,12 @@ class GatherOperation(Function):
idx : torch.Tensor idx : torch.Tensor
(B, npoint) tensor of the features to gather (B, npoint) tensor of the features to gather
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, C, npoint) tensor (B, C, npoint) tensor
""" """
ctx.save_for_backward(idx, features) ctx.save_for_backward(idx, features)
return _ext.gather_points(features, idx) return _ext.gather_points(features, idx)
@ -103,15 +105,16 @@ class ThreeNN(Function):
@staticmethod @staticmethod
def forward(ctx, unknown, known): def forward(ctx, unknown, known):
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
r"""Find the three nearest neighbors of unknown in known r"""
Find the three nearest neighbors of unknown in known
Parameters Parameters
---------- ----------
unknown : torch.Tensor unknown : torch.Tensor
(B, n, 3) tensor of known features (B, n, 3) tensor of known features
known : torch.Tensor known : torch.Tensor
(B, m, 3) tensor of unknown features. (B, m, 3) tensor of unknown features
Returns: Returns
------- -------
dist : torch.Tensor dist : torch.Tensor
(B, n, 3) l2 distance to the three nearest neighbors (B, n, 3) l2 distance to the three nearest neighbors
@ -137,7 +140,8 @@ class ThreeInterpolate(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx, weight): def forward(ctx, features, idx, weight):
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
r"""Performs weight linear interpolation on 3 features r"""
Performs weight linear interpolation on 3 features
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
@ -145,9 +149,9 @@ class ThreeInterpolate(Function):
idx : torch.Tensor idx : torch.Tensor
(B, n, 3) three nearest neighbors of the target features in features (B, n, 3) three nearest neighbors of the target features in features
weight : torch.Tensor weight : torch.Tensor
(B, n, 3) weights. (B, n, 3) weights
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, c, n) tensor of the interpolated features (B, c, n) tensor of the interpolated features
@ -159,12 +163,13 @@ class ThreeInterpolate(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
r"""Parameters r"""
Parameters
---------- ----------
grad_out : torch.Tensor grad_out : torch.Tensor
(B, c, n) tensor with gradients of ouputs (B, c, n) tensor with gradients of ouputs
Returns: Returns
------- -------
grad_features : torch.Tensor grad_features : torch.Tensor
(B, c, m) tensor with gradients of features (B, c, m) tensor with gradients of features
@ -177,10 +182,7 @@ class ThreeInterpolate(Function):
m = features.size(2) m = features.size(2)
grad_features = _ext.three_interpolate_grad( grad_features = _ext.three_interpolate_grad(
grad_out.contiguous(), grad_out.contiguous(), idx, weight, m
idx,
weight,
m,
) )
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
@ -193,14 +195,16 @@ class GroupingOperation(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx): def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""Parameters r"""
Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor of features to group (B, C, N) tensor of features to group
idx : torch.Tensor idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of features to group with (B, npoint, nsample) tensor containing the indicies of features to group with
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, C, npoint, nsample) tensor (B, C, npoint, nsample) tensor
@ -212,12 +216,14 @@ class GroupingOperation(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
r"""Parameters r"""
Parameters
---------- ----------
grad_out : torch.Tensor grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward (B, C, npoint, nsample) tensor of the gradients of the output from forward
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, C, N) gradient of the features (B, C, N) gradient of the features
@ -238,7 +244,9 @@ class BallQuery(Function):
@staticmethod @staticmethod
def forward(ctx, radius, nsample, xyz, new_xyz): def forward(ctx, radius, nsample, xyz, new_xyz):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""Parameters r"""
Parameters
---------- ----------
radius : float radius : float
radius of the balls radius of the balls
@ -249,7 +257,7 @@ class BallQuery(Function):
new_xyz : torch.Tensor new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query (B, npoint, 3) centers of the ball query
Returns: Returns
------- -------
torch.Tensor torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls (B, npoint, nsample) tensor with the indicies of the features that form the query balls
@ -269,7 +277,8 @@ ball_query = BallQuery.apply
class QueryAndGroup(nn.Module): class QueryAndGroup(nn.Module):
r"""Groups with a ball query of radius. r"""
Groups with a ball query of radius
Parameters Parameters
--------- ---------
@ -281,12 +290,13 @@ class QueryAndGroup(nn.Module):
def __init__(self, radius, nsample, use_xyz=True): def __init__(self, radius, nsample, use_xyz=True):
# type: (QueryAndGroup, float, int, bool) -> None # type: (QueryAndGroup, float, int, bool) -> None
super().__init__() super(QueryAndGroup, self).__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz, new_xyz, features=None): def forward(self, xyz, new_xyz, features=None):
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
r"""Parameters r"""
Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
xyz coordinates of the features (B, N, 3) xyz coordinates of the features (B, N, 3)
@ -295,11 +305,12 @@ class QueryAndGroup(nn.Module):
features : torch.Tensor features : torch.Tensor
Descriptors of the features (B, C, N) Descriptors of the features (B, C, N)
Returns: Returns
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, 3 + C, npoint, nsample) tensor (B, 3 + C, npoint, nsample) tensor
""" """
idx = ball_query(self.radius, self.nsample, xyz, new_xyz) idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous() xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
@ -309,20 +320,22 @@ class QueryAndGroup(nn.Module):
grouped_features = grouping_operation(features, idx) grouped_features = grouping_operation(features, idx)
if self.use_xyz: if self.use_xyz:
new_features = torch.cat( new_features = torch.cat(
[grouped_xyz, grouped_features], [grouped_xyz, grouped_features], dim=1
dim=1,
) # (B, C + 3, npoint, nsample) ) # (B, C + 3, npoint, nsample)
else: else:
new_features = grouped_features new_features = grouped_features
else: else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" assert (
self.use_xyz
), "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz new_features = grouped_xyz
return new_features return new_features
class GroupAll(nn.Module): class GroupAll(nn.Module):
r"""Groups all features. r"""
Groups all features
Parameters Parameters
--------- ---------
@ -330,12 +343,13 @@ class GroupAll(nn.Module):
def __init__(self, use_xyz=True): def __init__(self, use_xyz=True):
# type: (GroupAll, bool) -> None # type: (GroupAll, bool) -> None
super().__init__() super(GroupAll, self).__init__()
self.use_xyz = use_xyz self.use_xyz = use_xyz
def forward(self, xyz, new_xyz, features=None): def forward(self, xyz, new_xyz, features=None):
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
r"""Parameters r"""
Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
xyz coordinates of the features (B, N, 3) xyz coordinates of the features (B, N, 3)
@ -344,18 +358,18 @@ class GroupAll(nn.Module):
features : torch.Tensor features : torch.Tensor
Descriptors of the features (B, C, N) Descriptors of the features (B, C, N)
Returns: Returns
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, C + 3, 1, N) tensor (B, C + 3, 1, N) tensor
""" """
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None: if features is not None:
grouped_features = features.unsqueeze(2) grouped_features = features.unsqueeze(2)
if self.use_xyz: if self.use_xyz:
new_features = torch.cat( new_features = torch.cat(
[grouped_xyz, grouped_features], [grouped_xyz, grouped_features], dim=1
dim=1,
) # (B, 3 + C, 1, N) ) # (B, 3 + C, 1, N)
else: else:
new_features = grouped_features new_features = grouped_features

View file

@ -8,7 +8,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
this_dir = osp.dirname(osp.abspath(__file__)) this_dir = osp.dirname(osp.abspath(__file__))
_ext_src_root = osp.join("pointnet2_ops", "_ext-src") _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
osp.join(_ext_src_root, "src", "*.cu"), osp.join(_ext_src_root, "src", "*.cu")
) )
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
@ -32,7 +32,7 @@ setup(
"nvcc": ["-O3", "-Xfatbin", "-compress-all"], "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
}, },
include_dirs=[osp.join(this_dir, _ext_src_root, "include")], include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
), )
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
include_package_data=True, include_package_data=True,

View file

@ -1,64 +0,0 @@
[tool.ruff]
line-length = 120
ignore-init-module-imports = true
ignore = [
"G004", # Logging statement uses f-string
"EM102", # Exception must not use an f-string literal, assign to variable first
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"N812", # Lowercase imported as non lowercase
]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"C90", # mccabe
"COM", # flake8-commas
"D", # pydocstyle
"EM", # flake8-errmsg
"E", # pycodestyle errors
"F", # Pyflakes
"G", # flake8-logging-format
"I", # isort
"N", # pep8-naming
"PIE", # flake8-pie
"PTH", # flake8-use-pathlib
"TD", # flake8-todo
"FIX", # flake8-fixme
"RET", # flake8-return
"RUF", # ruff
"S", # flake8-bandit
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warnings
]
[tool.ruff.pydocstyle]
convention = "google"
[tool.ruff.isort]
known-first-party = ["pointnet2_ops"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"src/aube/main.py" = ["E402", "F401"]
[tool.black]
exclude = '''
/(
\.git
\.venv
)/
'''
include = '\.pyi?$'
line-length = 120
target-version = ["py310"]
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true

12
requirements.txt Normal file
View file

@ -0,0 +1,12 @@
torch
torchvision
cudatoolkit
cycler
einops
h5py
matplotlib==3.4.2
pytorch
pyyaml==5.4.1
scikit-learn==0.24.2
scipy
tqdm