Compare commits

...

10 commits

Author SHA1 Message Date
Laurent FAINSIN 4450cad9a3 fix images 2023-08-03 17:06:30 +02:00
Laurent FAINSIN f9ca08faa7 format README.md 2023-08-03 17:02:38 +02:00
Laurent FAINSIN e9735d1e36 useless files 2023-08-03 16:54:21 +02:00
Laurent FAINSIN f680d6f7a8 auto-formatting 2023-08-03 16:40:14 +02:00
Laurent FAINSIN cc00fa7215 add a bunch of configs 2023-08-03 16:38:02 +02:00
Laurent FAINSIN caca863e02 removed useless requirements.txt 2023-08-03 16:02:09 +02:00
Laurent FAINSIN c50897112d JESUS CHRIST DID THEY CHECK THEIR CODE ? 2023-08-03 15:51:53 +02:00
Laurent FAINSIN 62e57ecfc0 what in tarnation is this shit ? 2023-08-03 15:39:35 +02:00
Laurent FAINSIN 4fcb30ca42 ignore checkpoints and data folders 2023-08-03 15:33:13 +02:00
Laurent FAINSIN fae2f13bbf remove deps constraints 2023-08-03 14:11:24 +02:00
51 changed files with 2071 additions and 1467 deletions

3
.gitignore vendored
View file

@ -1,3 +1,6 @@
data
checkpoints
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

9
.vscode/extensions.json vendored Normal file
View file

@ -0,0 +1,9 @@
{
"recommendations": [
"editorconfig.editorconfig",
"eamodio.gitlens",
"ms-python.python",
"ms-python.black-formatter",
"charliermarsh.ruff",
]
}

17
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,17 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"OMP_NUM_THREADS": "1",
"CUDA_VISIBLE_DEVICES": "1",
}
}
]
}

67
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,67 @@
{
// nice editor settings
"editor.formatOnSave": true,
"editor.formatOnPaste": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true,
"source.fixAll": false,
},
"editor.rulers": [
120
],
// editorconfig redundancy
"files.insertFinalNewline": true,
"files.trimTrailingWhitespace": true,
// hidde unimportant files/folders
"files.exclude": {
// defaults
"**/.git": true,
"**/.svn": true,
"**/.hg": true,
"**/CVS": true,
"**/.DS_Store": true,
"**/Thumbs.db": true,
// annoying
"**/__pycache__": true,
"**/.mypy_cache": true,
"**/.ruff_cache": true,
"**/*.tmp": true,
},
// cpp /clang / cmake settings
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools",
"C_Cpp.intelliSenseEngine": "disabled",
"C_Cpp.intelliSenseEngineFallback": "enabled",
"C_Cpp.clang_format_path": "/softs/compiler/llvm/latest/bin/clang-format",
"C_Cpp.codeAnalysis.clangTidy.enabled": true,
"C_Cpp.codeAnalysis.clangTidy.path": "/softs/compiler/llvm/latest/bin/clang-tidy",
"clangd.path": "/softs/compiler/llvm/latest/bin/clangd",
"cmake.cmakePath": "/softs/cmake/latest/bin/cmake",
"cmake.preferredGenerators": [
"Ninja",
"Unix Makefiles"
],
"cmakeFormat.exePath": "/softs/conda/auto/envs/cmake-format/bin/cmake-format",
"cmake.languageSupport.dotnetPath": "/softs/conda/auto/envs/dotnet/lib/dotnet/dotnet",
// python settings
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pointmlp/bin/python",
"python.linting.enabled": true,
"python.linting.lintOnSave": true,
"python.linting.mypyEnabled": true,
// fixes for broken auto-activation on rosetta
"python.terminal.activateEnvironment": false,
"terminal.integrated.profiles.linux": {
"python": {
"path": "bash",
"icon": "rocket",
"args": [
"--init-file",
".vscode/setup.sh"
],
}
},
"terminal.integrated.env.linux": {
"PYTHONPATH": "${workspaceFolder}/src/",
"SLURM_JOB_ID": null, // unset or else lightning_logs v_num uses it
},
}

11
.vscode/setup.sh vendored Executable file
View file

@ -0,0 +1,11 @@
#!/bin/bash
source ~/.bashrc
conda_init
conda activate pointmlp
export PS1="(pointmlp)[\\u@\\h \\W]\\$ "
module load compilers
module load mpfr

View file

@ -1,48 +1,40 @@
# Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework ICLR 2022 # Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework ICLR 2022
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1)
[![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch) [![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch)
<img src="images/smile.png" height="70px">
<img src="images/neu.png" height="70px">
<img src="images/columbia.png" height="70px">
<div align="left"> [open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu)
<a><img src="images/smile.png" height="70px" ></a>
<a><img src="images/neu.png" height="70px" ></a>
<a><img src="images/columbia.png" height="70px" ></a>
</div>
[open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu) ![](images/overview.png)
<div align="center">
<img src="images/overview.png" width="650px" height="300px">
</div>
Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information. Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information.
## BibTeX ## BibTeX
@article{ma2022rethinking, ```bibtex
title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework}, @article{ma2022rethinking,
author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun}, title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework},
journal={arXiv preprint arXiv:2202.07123}, author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun},
year={2022} journal={arXiv preprint arXiv:2202.07123},
} year={2022}
}
```
## Model Zoo ## Model Zoo
**Questions on ModelNet40 classification results (a common issue for ModelNet40 dataset in the community)** **Questions on ModelNet40 classification results (a common issue for ModelNet40 dataset in the community)**
The performance on ModelNet40 of almost all methods are not stable, see (https://github.com/CVMI-Lab/PAConv/issues/9#issuecomment-873371422).<br>
If you run the same codes for several times, you will get different results (even with fixed seed).<br>
The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br>
Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs.
The performance on ModelNet40 of almost all methods are not stable, see (https://github.com/CVMI-Lab/PAConv/issues/9#issuecomment-873371422).<br>
If you run the same codes for several times, you will get different results (even with fixed seed).<br>
The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br>
Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs.
------ ------
@ -54,8 +46,6 @@ On ScanObjectNN, fixed pointMLP achieves a result of **84.4% mAcc** and **86.1%
Stay tuned. More elite versions and voting results will be uploaded. Stay tuned. More elite versions and voting results will be uploaded.
## News & Updates: ## News & Updates:
- [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder. - [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder.
@ -66,9 +56,6 @@ Stay tuned. More elite versions and voting results will be uploaded.
:point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026). :point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026).
## Install ## Install
```bash ```bash
@ -92,8 +79,7 @@ pip install cycler einops h5py pyyaml==5.4.1 scikit-learn==0.24.2 scipy tqdm mat
pip install pointnet2_ops_lib/. pip install pointnet2_ops_lib/.
``` ```
## Usage
## Useage
### Classification ModelNet40 ### Classification ModelNet40
**Train**: The dataset will be automatically downloaded, run following command to train. **Train**: The dataset will be automatically downloaded, run following command to train.
@ -120,7 +106,7 @@ python voting.py --model pointMLP --msg demo
The dataset will be automatically downloaded The dataset will be automatically downloaded
- Train pointMLP/pointMLPElite - Train pointMLP/pointMLPElite
```bash ```bash
cd classification_ScanObjectNN cd classification_ScanObjectNN
# train pointMLP # train pointMLP
@ -161,10 +147,5 @@ Our implementation is mainly based on the following codebases. We gratefully tha
[Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
## LICENSE ## LICENSE
PointMLP is under the Apache-2.0 license.
PointMLP is under the Apache-2.0 license.

View file

@ -1,20 +1,21 @@
import torch
import fvcore.nn
import fvcore.common import fvcore.common
import fvcore.nn
import torch
from fvcore.nn import FlopCountAnalysis from fvcore.nn import FlopCountAnalysis
from classification_ScanObjectNN.models import pointMLPElite from classification_ScanObjectNN.models import pointMLPElite
model = pointMLPElite() model = pointMLPElite()
model.eval() model.eval()
# model = deit_tiny_patch16_224() # model = deit_tiny_patch16_224()
inputs = (torch.randn((1,3,1024))) inputs = torch.randn((1, 3, 1024))
k = 1024.0 k = 1024.0
flops = FlopCountAnalysis(model, inputs).total() flops = FlopCountAnalysis(model, inputs).total()
print(f"Flops : {flops}") print(f"Flops : {flops}")
flops = flops/(k**3) flops = flops / (k**3)
print(f"Flops : {flops:.1f}G") print(f"Flops : {flops:.1f}G")
params = fvcore.nn.parameter_count(model)[""] params = fvcore.nn.parameter_count(model)[""]
print(f"Params : {params}") print(f"Params : {params}")
params = params/(k**2) params = params / (k**2)
print(f"Params : {params:.1f}M") print(f"Params : {params:.1f}M")

View file

@ -1,33 +1,37 @@
import os
import glob import glob
import os
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def download(): def download():
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data') DATA_DIR = os.path.join(BASE_DIR, "data")
if not os.path.exists(DATA_DIR): if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR) os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): if not os.path.exists(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048")):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' www = "https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip"
zipfile = os.path.basename(www) zipfile = os.path.basename(www)
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) os.system(f"wget {www} --no-check-certificate; unzip {zipfile}")
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) os.system(f"mv {zipfile[:-4]} {DATA_DIR}")
os.system('rm %s' % (zipfile)) os.system("rm %s" % (zipfile))
def load_data(partition): def load_data(partition):
download() download()
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data') DATA_DIR = os.path.join(BASE_DIR, "data")
all_data = [] all_data = []
all_label = [] all_label = []
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): for h5_name in glob.glob(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048", "ply_data_%s*.h5" % partition)):
# print(f"h5_name: {h5_name}") # print(f"h5_name: {h5_name}")
f = h5py.File(h5_name,'r') f = h5py.File(h5_name, "r")
data = f['data'][:].astype('float32') data = f["data"][:].astype("float32")
label = f['label'][:].astype('int64') label = f["label"][:].astype("int64")
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -35,40 +39,42 @@ def load_data(partition):
all_label = np.concatenate(all_label, axis=0) all_label = np.concatenate(all_label, axis=0)
return all_data, all_label return all_data, all_label
def random_point_dropout(pc, max_dropout_ratio=0.875): def random_point_dropout(pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3."""
# for b in range(batch_pc.shape[0]): # for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((pc.shape[0]))<=dropout_ratio)[0] drop_idx = np.where(np.random.random(pc.shape[0]) <= dropout_ratio)[0]
# print ('use random drop', len(drop_idx)) # print ('use random drop', len(drop_idx))
if len(drop_idx)>0: if len(drop_idx) > 0:
pc[drop_idx,:] = pc[0,:] # set to the first point pc[drop_idx, :] = pc[0, :] # set to the first point
return pc return pc
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
return translated_pointcloud
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
N, C = pointcloud.shape N, C = pointcloud.shape
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
return pointcloud return pointcloud
class ModelNet40(Dataset): class ModelNet40(Dataset):
def __init__(self, num_points, partition='train'): def __init__(self, num_points, partition="train"):
self.data, self.label = load_data(partition) self.data, self.label = load_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition self.partition = partition
def __getitem__(self, item): def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][: self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == 'train': if self.partition == "train":
# pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all # pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) np.random.shuffle(pointcloud)
@ -78,19 +84,25 @@ class ModelNet40(Dataset):
return self.data.shape[0] return self.data.shape[0]
if __name__ == '__main__': if __name__ == "__main__":
train = ModelNet40(1024) train = ModelNet40(1024)
test = ModelNet40(1024, 'test') test = ModelNet40(1024, "test")
# for data, label in train: # for data, label in train:
# print(data.shape) # print(data.shape)
# print(label.shape) # print(label.shape)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
train_loader = DataLoader(ModelNet40(partition='train', num_points=1024), num_workers=4,
batch_size=32, shuffle=True, drop_last=True) train_loader = DataLoader(
ModelNet40(partition="train", num_points=1024),
num_workers=4,
batch_size=32,
shuffle=True,
drop_last=True,
)
for batch_idx, (data, label) in enumerate(train_loader): for batch_idx, (data, label) in enumerate(train_loader):
print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}") print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}")
train_set = ModelNet40(partition='train', num_points=1024) train_set = ModelNet40(partition="train", num_points=1024)
test_set = ModelNet40(partition='test', num_points=1024) test_set = ModelNet40(partition="test", num_points=1024)
print(f"train_set size {train_set.__len__()}") print(f"train_set size {train_set.__len__()}")
print(f"test_set size {test_set.__len__()}") print(f"test_set size {test_set.__len__()}")

View file

@ -1,9 +1,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -16,6 +16,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction='mean') loss = F.cross_entropy(pred, gold, reduction="mean")
return loss return loss

View file

@ -1,41 +1,46 @@
""" """Usage:
Usage: python main.py --model PointMLP --msg demo.
python main.py --model PointMLP --msg demo
""" """
import argparse import argparse
import os
import logging
import datetime import datetime
import logging
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
from data import ModelNet40 from data import ModelNet40
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import sklearn.metrics as metrics from torch.utils.data import DataLoader
import numpy as np from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser('training') parser = argparse.ArgumentParser("training")
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH', parser.add_argument(
help='path to save checkpoint (default: checkpoint)') "-c",
parser.add_argument('--msg', type=str, help='message after checkpoint') "--checkpoint",
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training') type=str,
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]') metavar="PATH",
parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training') help="path to save checkpoint (default: checkpoint)",
parser.add_argument('--num_points', type=int, default=1024, help='Point Number') )
parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training') parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument('--min_lr', default=0.005, type=float, help='min lr') parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate') parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]")
parser.add_argument('--seed', type=int, help='random seed') parser.add_argument("--epoch", default=300, type=int, help="number of epoch in training")
parser.add_argument('--workers', default=8, type=int, help='workers') parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
parser.add_argument("--learning_rate", default=0.1, type=float, help="learning rate in training")
parser.add_argument("--min_lr", default=0.005, type=float, help="min lr")
parser.add_argument("--weight_decay", type=float, default=2e-4, help="decay rate")
parser.add_argument("--seed", type=int, help="random seed")
parser.add_argument("--workers", default=8, type=int, help="workers")
return parser.parse_args() return parser.parse_args()
@ -46,7 +51,7 @@ def main():
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
assert torch.cuda.is_available(), "Please ensure codes are executed in cuda." assert torch.cuda.is_available(), "Please ensure codes are executed in cuda."
device = 'cuda' device = "cuda"
if args.seed is not None: if args.seed is not None:
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
@ -55,19 +60,19 @@ def main():
torch.set_printoptions(10) torch.set_printoptions(10)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(args.seed) os.environ["PYTHONHASHSEED"] = str(args.seed)
time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
if args.msg is None: if args.msg is None:
message = time_str message = time_str
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = 'checkpoints/' + args.model + message + '-' + str(args.seed) args.checkpoint = "checkpoints/" + args.model + message + "-" + str(args.seed)
if not os.path.isdir(args.checkpoint): if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint) mkdir_p(args.checkpoint)
screen_logger = logging.getLogger("Model") screen_logger = logging.getLogger("Model")
screen_logger.setLevel(logging.INFO) screen_logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt")) file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
@ -79,19 +84,19 @@ def main():
# Model # Model
printf(f"args: {args}") printf(f"args: {args}")
printf('==> Building model..') printf("==> Building model..")
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == 'cuda': if device == "cuda":
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
best_test_acc = 0. # best test accuracy best_test_acc = 0.0 # best test accuracy
best_train_acc = 0. best_train_acc = 0.0
best_test_acc_avg = 0. best_test_acc_avg = 0.0
best_train_acc_avg = 0. best_train_acc_avg = 0.0
best_test_loss = float("inf") best_test_loss = float("inf")
best_train_loss = float("inf") best_train_loss = float("inf")
start_epoch = 0 # start from epoch 0 or last checkpoint epoch start_epoch = 0 # start from epoch 0 or last checkpoint epoch
@ -99,30 +104,49 @@ def main():
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")): if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
save_args(args) save_args(args)
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model) logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model)
logger.set_names(["Epoch-Num", 'Learning-Rate', logger.set_names(
'Train-Loss', 'Train-acc-B', 'Train-acc', [
'Valid-Loss', 'Valid-acc-B', 'Valid-acc']) "Epoch-Num",
"Learning-Rate",
"Train-Loss",
"Train-acc-B",
"Train-acc",
"Valid-Loss",
"Valid-acc-B",
"Valid-acc",
],
)
else: else:
printf(f"Resuming last checkpoint from {args.checkpoint}") printf(f"Resuming last checkpoint from {args.checkpoint}")
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth") checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['net']) net.load_state_dict(checkpoint["net"])
start_epoch = checkpoint['epoch'] start_epoch = checkpoint["epoch"]
best_test_acc = checkpoint['best_test_acc'] best_test_acc = checkpoint["best_test_acc"]
best_train_acc = checkpoint['best_train_acc'] best_train_acc = checkpoint["best_train_acc"]
best_test_acc_avg = checkpoint['best_test_acc_avg'] best_test_acc_avg = checkpoint["best_test_acc_avg"]
best_train_acc_avg = checkpoint['best_train_acc_avg'] best_train_acc_avg = checkpoint["best_train_acc_avg"]
best_test_loss = checkpoint['best_test_loss'] best_test_loss = checkpoint["best_test_loss"]
best_train_loss = checkpoint['best_train_loss'] best_train_loss = checkpoint["best_train_loss"]
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True) logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True)
optimizer_dict = checkpoint['optimizer'] optimizer_dict = checkpoint["optimizer"]
printf('==> Preparing data..') printf("==> Preparing data..")
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=args.workers, train_loader = DataLoader(
batch_size=args.batch_size, shuffle=True, drop_last=True) ModelNet40(partition="train", num_points=args.num_points),
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=args.workers, num_workers=args.workers,
batch_size=args.batch_size // 2, shuffle=False, drop_last=False) batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
ModelNet40(partition="test", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size // 2,
shuffle=False,
drop_last=False,
)
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None: if optimizer_dict is not None:
@ -130,7 +154,7 @@ def main():
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1) scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch): for epoch in range(start_epoch, args.epoch):
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr'])) printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"]))
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"} train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
scheduler.step() scheduler.step()
@ -149,31 +173,46 @@ def main():
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
save_model( save_model(
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best, net,
epoch,
path=args.checkpoint,
acc=test_out["acc"],
is_best=is_best,
best_test_acc=best_test_acc, # best test accuracy best_test_acc=best_test_acc, # best test accuracy
best_train_acc=best_train_acc, best_train_acc=best_train_acc,
best_test_acc_avg=best_test_acc_avg, best_test_acc_avg=best_test_acc_avg,
best_train_acc_avg=best_train_acc_avg, best_train_acc_avg=best_train_acc_avg,
best_test_loss=best_test_loss, best_test_loss=best_test_loss,
best_train_loss=best_train_loss, best_train_loss=best_train_loss,
optimizer=optimizer.state_dict() optimizer=optimizer.state_dict(),
)
logger.append(
[
epoch,
optimizer.param_groups[0]["lr"],
train_out["loss"],
train_out["acc_avg"],
train_out["acc"],
test_out["loss"],
test_out["acc_avg"],
test_out["acc"],
],
) )
logger.append([epoch, optimizer.param_groups[0]['lr'],
train_out["loss"], train_out["acc_avg"], train_out["acc"],
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
printf( printf(
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s") f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s",
)
printf( printf(
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% " f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n") f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n",
)
logger.close() logger.close()
printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2) printf("++++++++" * 2 + "Final results" + "++++++++" * 2)
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++") printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++") printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++") printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++") printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
printf(f"++++++++" * 5) printf("++++++++" * 5)
def train(net, trainloader, optimizer, criterion, device): def train(net, trainloader, optimizer, criterion, device):
@ -202,17 +241,21 @@ def train(net, trainloader, optimizer, criterion, device):
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(trainloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
train_true = np.concatenate(train_true) train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred) train_pred = np.concatenate(train_pred)
return { return {
"loss": float("%.3f" % (train_loss / (batch_idx + 1))), "loss": float("%.3f" % (train_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))),
"time": time_cost "time": time_cost,
} }
@ -236,19 +279,23 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost "time": time_cost,
} }
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,3 +1 @@
from __future__ import absolute_import
from .pointmlp import pointMLP, pointMLPElite from .pointmlp import pointMLP, pointMLPElite

View file

@ -1,35 +1,32 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torch import einsum # from torch import einsum
# from einops import rearrange, repeat # from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == 'gelu': if activation.lower() == "gelu":
return nn.GELU() return nn.GELU()
elif activation.lower() == 'rrelu': elif activation.lower() == "rrelu":
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == 'selu': elif activation.lower() == "selu":
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == 'silu': elif activation.lower() == "silu":
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == 'hardswish': elif activation.lower() == "hardswish":
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == 'leakyrelu': elif activation.lower() == "leakyrelu":
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
""" """Calculate Euclid distance between each two points.
Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm;
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -38,23 +35,23 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M] dist: per-point square distance, [B, N, M].
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
""" """Input:
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S] idx: sample index data, [B, S].
Return: Return:
new_points:, indexed points data, [B, S, C] new_points:, indexed points data, [B, S, C].
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -63,17 +60,15 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :] return points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
""" """Input:
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint] centroids: sampled pointcloud index, [B, npoint].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -91,21 +86,21 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
""" """Input:
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] new_xyz: query points, [B, S, 3].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
_, S, _ = new_xyz.shape _, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N group_idx[sqrdists > radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N mask = group_idx == N
@ -114,13 +109,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
""" """Input:
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C] new_xyz: query points, [B, S, C].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -129,13 +124,12 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
""" """Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others :param kwargs: others.
""" """
super(LocalGrouper, self).__init__() super().__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -144,11 +138,11 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel = 3 if self.use_xyz else 0
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
def forward(self, xyz, points): def forward(self, xyz, points):
@ -167,29 +161,33 @@ class LocalGrouper(nn.Module):
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
grouped_points = index_points(points, idx) # [B, npoint, k, d] grouped_points = index_points(points, idx) # [B, npoint, k, d]
if self.use_xyz: if self.use_xyz:
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize == "center":
mean = torch.mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize == "anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) std = (
grouped_points = (grouped_points-mean)/(std + 1e-5) torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta .unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points - mean) / (std + 1e-5)
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
return new_xyz, new_points return new_xyz, new_points
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
super(ConvBNReLU1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act self.act,
) )
def forward(self, x): def forward(self, x):
@ -197,30 +195,43 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
super(ConvBNReLURes1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=channel,
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act self.act,
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=int(channel * res_expansion),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, bias=bias), in_channels=int(channel * res_expansion),
nn.BatchNorm1d(channel) out_channels=channel,
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -228,21 +239,34 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, def __init__(
activation='relu', use_xyz=True): self,
""" channels,
input: [b,g,k,d]: output:[b,d,g] out_channels,
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PreExtraction, self).__init__() super().__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3 + 2 * channels if use_xyz else 2 * channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, ConvBNReLURes1D(
bias=bias, activation=activation) out_channels,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -254,22 +278,20 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
x = x.reshape(b, n, -1).permute(0, 2, 1) return x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
""" """input[b,d,g]; output[b,d,g]
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PosExtraction, self).__init__() super().__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -278,17 +300,32 @@ class PosExtraction(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0, def __init__(
activation="relu", bias=True, use_xyz=True, normalize="center", self,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], points=1024,
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs): class_num=40,
super(Model, self).__init__() embed_dim=64,
groups=1,
res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="center",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[2, 2, 2, 2],
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = class_num self.class_num = class_num
self.points = points self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ assert (
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -305,13 +342,26 @@ class Model(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, pre_block_module = PreExtraction(
res_expansion=res_expansion, last_channel,
bias=bias, activation=activation, use_xyz=use_xyz) out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, pos_block_module = PosExtraction(
res_expansion=res_expansion, bias=bias, activation=activation) out_channel,
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
@ -326,7 +376,7 @@ class Model(nn.Module):
nn.BatchNorm1d(256), nn.BatchNorm1d(256),
self.act, self.act,
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Linear(256, self.class_num) nn.Linear(256, self.class_num),
) )
def forward(self, x): def forward(self, x):
@ -340,29 +390,59 @@ class Model(nn.Module):
x = self.pos_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g]
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
x = self.classifier(x) return self.classifier(x)
return x
def pointMLP(num_classes=40, **kwargs) -> Model: def pointMLP(num_classes=40, **kwargs) -> Model:
return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0, return Model(
activation="relu", bias=False, use_xyz=False, normalize="anchor", points=1024,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], class_num=num_classes,
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) embed_dim=64,
groups=1,
res_expansion=1.0,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
def pointMLPElite(num_classes=40, **kwargs) -> Model: def pointMLPElite(num_classes=40, **kwargs) -> Model:
return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25, return Model(
activation="relu", bias=False, use_xyz=False, normalize="anchor", points=1024,
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1], class_num=num_classes,
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) embed_dim=32,
groups=1,
res_expansion=0.25,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 1],
pre_blocks=[1, 1, 2, 1],
pos_blocks=[1, 1, 2, 1],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
if __name__ == "__main__":
data = torch.rand(2, 3, 1024).cuda()
print(data.shape)
if __name__ == '__main__':
data = torch.rand(2, 3, 1024)
print("===> testing pointMLP ...") print("===> testing pointMLP ...")
model = pointMLP() model = pointMLP().cuda()
out = model(data) out = model(data)
print(out.shape) print(out.shape)
print("===> testing pointMLPElite ...")
model = pointMLPElite().cuda()
out = model(data)
print(out.shape)

View file

@ -1,71 +1,81 @@
""" """python test.py --model pointMLP --msg 20220209053148-404."""
python test.py --model pointMLP --msg 20220209053148-404
"""
import argparse import argparse
import os
import datetime import datetime
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import progress_bar, IOStream
from data import ModelNet40 from data import ModelNet40
import sklearn.metrics as metrics
from helper import cal_loss from helper import cal_loss
import numpy as np from torch.utils.data import DataLoader
import torch.nn.functional as F from utils import progress_bar
model_names = sorted(name for name in models.__dict__ model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
if callable(models.__dict__[name]))
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser('training') parser = argparse.ArgumentParser("training")
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH', parser.add_argument(
help='path to save checkpoint (default: checkpoint)') "-c",
parser.add_argument('--msg', type=str, help='message after checkpoint') "--checkpoint",
parser.add_argument('--batch_size', type=int, default=16, help='batch size in training') type=str,
parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]') metavar="PATH",
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') help="path to save checkpoint (default: checkpoint)",
parser.add_argument('--num_points', type=int, default=1024, help='Point Number') )
parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
print(f"args: {args}") print(f"args: {args}")
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if torch.cuda.is_available(): if torch.cuda.is_available():
device = 'cuda' device = "cuda"
else: else:
device = 'cpu' device = "cpu"
print(f"==> Using device: {device}") print(f"==> Using device: {device}")
if args.msg is None: # if args.msg is None:
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) # message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
else: # else:
message = "-"+args.msg # message = "-"+args.msg
args.checkpoint = 'checkpoints/' + args.model + message # args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
print('==> Preparing data..') print("==> Preparing data..")
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4, test_loader = DataLoader(
batch_size=args.batch_size, shuffle=False, drop_last=False) ModelNet40(partition="test", num_points=args.num_points),
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
# Model # Model
print('==> Building model..') print("==> Building model..")
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth') # checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == 'cuda': if device == "cuda":
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
net.load_state_dict(checkpoint['net']) net.load_state_dict(checkpoint["net"])
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}") print(f"Vanilla out: {test_out}")
@ -91,19 +101,23 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost "time": time_cost,
} }
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,5 +1,4 @@
"""Useful utils """Useful utils."""
"""
from .misc import *
from .logger import * from .logger import *
from .misc import *
from .progress.progress.bar import Bar as Bar from .progress.progress.bar import Bar as Bar

View file

@ -1,89 +1,92 @@
# A simple torch style logger # A simple torch style logger
# (C) Wei YANG 2017 # (C) Wei YANG 2017
from __future__ import absolute_import
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
import sys
import numpy as np import numpy as np
__all__ = ['Logger', 'LoggerMonitor', 'savefig'] __all__ = ["Logger", "LoggerMonitor", "savefig"]
def savefig(fname, dpi=None): def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi dpi = 150 if dpi is None else dpi
plt.savefig(fname, dpi=dpi) plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None): def plot_overlap(logger, names=None):
names = logger.names if names == None else names names = logger.names if names is None else names
numbers = logger.numbers numbers = logger.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
return [logger.title + '(' + name + ')' for name in names] return [logger.title + "(" + name + ")" for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.''' class Logger:
def __init__(self, fpath, title=None, resume=False): """Save training process to log file with simple plot function."""
def __init__(self, fpath, title=None, resume=False):
self.file = None self.file = None
self.resume = resume self.resume = resume
self.title = '' if title == None else title self.title = "" if title is None else title
if fpath is not None: if fpath is not None:
if resume: if resume:
self.file = open(fpath, 'r') self.file = open(fpath)
name = self.file.readline() name = self.file.readline()
self.names = name.rstrip().split('\t') self.names = name.rstrip().split("\t")
self.numbers = {} self.numbers = {}
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.numbers[name] = [] self.numbers[name] = []
for numbers in self.file: for numbers in self.file:
numbers = numbers.rstrip().split('\t') numbers = numbers.rstrip().split("\t")
for i in range(0, len(numbers)): for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i]) self.numbers[self.names[i]].append(numbers[i])
self.file.close() self.file.close()
self.file = open(fpath, 'a') self.file = open(fpath, "a")
else: else:
self.file = open(fpath, 'w') self.file = open(fpath, "w")
def set_names(self, names): def set_names(self, names):
if self.resume: if self.resume:
pass pass
# initialize numbers as empty list # initialize numbers as empty list
self.numbers = {} self.numbers = {}
self.names = names self.names = names
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.file.write(name) self.file.write(name)
self.file.write('\t') self.file.write("\t")
self.numbers[name] = [] self.numbers[name] = []
self.file.write('\n') self.file.write("\n")
self.file.flush() self.file.flush()
def append(self, numbers): def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names' assert len(self.names) == len(numbers), "Numbers do not match names"
for index, num in enumerate(numbers): for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num)) self.file.write(f"{num:.6f}")
self.file.write('\t') self.file.write("\t")
self.numbers[self.names[index]].append(num) self.numbers[self.names[index]].append(num)
self.file.write('\n') self.file.write("\n")
self.file.flush() self.file.flush()
def plot(self, names=None): def plot(self, names=None):
names = self.names if names == None else names names = self.names if names is None else names
numbers = self.numbers numbers = self.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + '(' + name + ')' for name in names]) plt.legend([self.title + "(" + name + ")" for name in names])
plt.grid(True) plt.grid(True)
def close(self): def close(self):
if self.file is not None: if self.file is not None:
self.file.close() self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.''' class LoggerMonitor:
def __init__ (self, paths): """Load and visualize multiple logs."""
'''paths is a distionary with {name:filepath} pair'''
def __init__(self, paths):
"""Paths is a distionary with {name:filepath} pair."""
self.loggers = [] self.loggers = []
for title, path in paths.items(): for title, path in paths.items():
logger = Logger(path, title=title, resume=True) logger = Logger(path, title=title, resume=True)
@ -95,10 +98,11 @@ class LoggerMonitor(object):
legend_text = [] legend_text = []
for logger in self.loggers: for logger in self.loggers:
legend_text += plot_overlap(logger, names) legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
plt.grid(True) plt.grid(True)
if __name__ == '__main__':
if __name__ == "__main__":
# # Example # # Example
# logger = Logger('test.txt') # logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss']) # logger.set_names(['Train loss', 'Valid loss','Test loss'])
@ -115,13 +119,13 @@ if __name__ == '__main__':
# Example: logger monitor # Example: logger monitor
paths = { paths = {
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', "resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', "resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', "resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
} }
field = ['Valid Acc.'] field = ["Valid Acc."]
monitor = LoggerMonitor(paths) monitor = LoggerMonitor(paths)
monitor.plot(names=field) monitor.plot(names=field)
savefig('test.eps') savefig("test.eps")

View file

@ -1,48 +1,56 @@
'''Some helper functions for PyTorch, including: """Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset. - get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization. - msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress. - progress_bar: progress bar mimic xlua.progress.
''' """
import errno import errno
import os import os
import random
import shutil
import sys import sys
import time import time
import math
import torch
import shutil
import numpy as np import numpy as np
import random import torch
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.autograd import Variable
__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', __all__ = [
'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"] "get_mean_and_std",
"init_params",
"mkdir_p",
"AverageMeter",
"progress_bar",
"save_model",
"save_args",
"set_seed",
"IOStream",
"cal_loss",
]
def get_mean_and_std(dataset): def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.''' """Compute the mean and std value of dataset."""
dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3) mean = torch.zeros(3)
std = torch.zeros(3) std = torch.zeros(3)
print('==> Computing mean and std..') print("==> Computing mean and std..")
for inputs, targets in dataloader: for inputs, _targets in dataloader:
for i in range(3): for i in range(3):
mean[i] += inputs[:,i,:,:].mean() mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:,i,:,:].std() std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset)) mean.div_(len(dataset))
std.div_(len(dataset)) std.div_(len(dataset))
return mean, std return mean, std
def init_params(net): def init_params(net):
'''Init layer parameters.''' """Init layer parameters."""
for m in net.modules(): for m in net.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out') init.kaiming_normal(m.weight, mode="fan_out")
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -53,8 +61,9 @@ def init_params(net):
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
def mkdir_p(path): def mkdir_p(path):
'''make dir if not exist''' """Make dir if not exist."""
try: try:
os.makedirs(path) os.makedirs(path)
except OSError as exc: # Python >2.5 except OSError as exc: # Python >2.5
@ -63,10 +72,12 @@ def mkdir_p(path):
else: else:
raise raise
class AverageMeter(object):
class AverageMeter:
"""Computes and stores the average and current value """Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -83,25 +94,26 @@ class AverageMeter(object):
self.avg = self.sum / self.count self.avg = self.sum / self.count
TOTAL_BAR_LENGTH = 65.0
TOTAL_BAR_LENGTH = 65.
last_time = time.time() last_time = time.time()
begin_time = last_time begin_time = last_time
def progress_bar(current, total, msg=None): def progress_bar(current, total, msg=None):
global last_time, begin_time global last_time, begin_time
if current == 0: if current == 0:
begin_time = time.time() # Reset for new bar. begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total) cur_len = int(TOTAL_BAR_LENGTH * current / total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [') sys.stdout.write(" [")
for i in range(cur_len): for _i in range(cur_len):
sys.stdout.write('=') sys.stdout.write("=")
sys.stdout.write('>') sys.stdout.write(">")
for i in range(rest_len): for _i in range(rest_len):
sys.stdout.write('.') sys.stdout.write(".")
sys.stdout.write(']') sys.stdout.write("]")
cur_time = time.time() cur_time = time.time()
step_time = cur_time - last_time step_time = cur_time - last_time
@ -109,12 +121,12 @@ def progress_bar(current, total, msg=None):
tot_time = cur_time - begin_time tot_time = cur_time - begin_time
L = [] L = []
L.append(' Step: %s' % format_time(step_time)) L.append(" Step: %s" % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time)) L.append(" | Tot: %s" % format_time(tot_time))
if msg: if msg:
L.append(' | ' + msg) L.append(" | " + msg)
msg = ''.join(L) msg = "".join(L)
sys.stdout.write(msg) sys.stdout.write(msg)
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ') # sys.stdout.write(' ')
@ -122,76 +134,74 @@ def progress_bar(current, total, msg=None):
# Go back to the center of the bar. # Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b') # sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total)) sys.stdout.write(" %d/%d " % (current + 1, total))
if current < total-1: if current < total - 1:
sys.stdout.write('\r') sys.stdout.write("\r")
else: else:
sys.stdout.write('\n') sys.stdout.write("\n")
sys.stdout.flush() sys.stdout.flush()
def format_time(seconds): def format_time(seconds):
days = int(seconds / 3600/24) days = int(seconds / 3600 / 24)
seconds = seconds - days*3600*24 seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600) hours = int(seconds / 3600)
seconds = seconds - hours*3600 seconds = seconds - hours * 3600
minutes = int(seconds / 60) minutes = int(seconds / 60)
seconds = seconds - minutes*60 seconds = seconds - minutes * 60
secondsf = int(seconds) secondsf = int(seconds)
seconds = seconds - secondsf seconds = seconds - secondsf
millis = int(seconds*1000) millis = int(seconds * 1000)
f = '' f = ""
i = 1 i = 1
if days > 0: if days > 0:
f += str(days) + 'D' f += str(days) + "D"
i += 1 i += 1
if hours > 0 and i <= 2: if hours > 0 and i <= 2:
f += str(hours) + 'h' f += str(hours) + "h"
i += 1 i += 1
if minutes > 0 and i <= 2: if minutes > 0 and i <= 2:
f += str(minutes) + 'm' f += str(minutes) + "m"
i += 1 i += 1
if secondsf > 0 and i <= 2: if secondsf > 0 and i <= 2:
f += str(secondsf) + 's' f += str(secondsf) + "s"
i += 1 i += 1
if millis > 0 and i <= 2: if millis > 0 and i <= 2:
f += str(millis) + 'ms' f += str(millis) + "ms"
i += 1 i += 1
if f == '': if f == "":
f = '0ms' f = "0ms"
return f return f
def save_model(net, epoch, path, acc, is_best, **kwargs): def save_model(net, epoch, path, acc, is_best, **kwargs):
state = { state = {
'net': net.state_dict(), "net": net.state_dict(),
'epoch': epoch, "epoch": epoch,
'acc': acc "acc": acc,
} }
for key, value in kwargs.items(): for key, value in kwargs.items():
state[key] = value state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth") filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath) torch.save(state, filepath)
if is_best: if is_best:
shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth')) shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
def save_args(args): def save_args(args):
file = open(os.path.join(args.checkpoint, 'args.txt'), "w") file = open(os.path.join(args.checkpoint, "args.txt"), "w")
for k, v in vars(args).items(): for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n") file.write(f"{k}:\t {v}\n")
file.close() file.close()
def set_seed(seed=None): def set_seed(seed=None):
if seed is None: if seed is None:
return return
random.seed(seed) random.seed(seed)
os.environ['PYTHONHASHSEED'] = ("%s" % seed) os.environ["PYTHONHASHSEED"] = "%s" % seed
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@ -200,15 +210,14 @@ def set_seed(seed=None):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create a file and write the text into it # create a file and write the text into it
class IOStream(): class IOStream:
def __init__(self, path): def __init__(self, path):
self.f = open(path, 'a') self.f = open(path, "a")
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text+'\n') self.f.write(text + "\n")
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -216,8 +225,7 @@ class IOStream():
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. ''' """Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -230,6 +238,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction='mean') loss = F.cross_entropy(pred, gold, reduction="mean")
return loss return loss

View file

@ -12,7 +12,6 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import division
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
@ -20,13 +19,12 @@ from math import ceil
from sys import stderr from sys import stderr
from time import time from time import time
__version__ = "1.3"
__version__ = '1.3'
class Infinite(object): class Infinite:
file = stderr file = stderr
sma_window = 10 # Simple Moving Average window sma_window = 10 # Simple Moving Average window
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.index = 0 self.index = 0
@ -38,7 +36,7 @@ class Infinite(object):
setattr(self, key, val) setattr(self, key, val)
def __getitem__(self, key): def __getitem__(self, key):
if key.startswith('_'): if key.startswith("_"):
return None return None
return getattr(self, key, None) return getattr(self, key, None)
@ -83,8 +81,8 @@ class Infinite(object):
class Progress(Infinite): class Progress(Infinite):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Progress, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.max = kwargs.get('max', 100) self.max = kwargs.get("max", 100)
@property @property
def eta(self): def eta(self):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,19 +12,18 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Progress from . import Progress
from .helpers import WritelnMixin from .helpers import WritelnMixin
class Bar(WritelnMixin, Progress): class Bar(WritelnMixin, Progress):
width = 32 width = 32
message = '' message = ""
suffix = '%(index)d/%(max)d' suffix = "%(index)d/%(max)d"
bar_prefix = ' |' bar_prefix = " |"
bar_suffix = '| ' bar_suffix = "| "
empty_fill = ' ' empty_fill = " "
fill = '#' fill = "#"
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -37,52 +34,50 @@ class Bar(WritelnMixin, Progress):
bar = self.fill * filled_length bar = self.fill * filled_length
empty = self.empty_fill * empty_length empty = self.empty_fill * empty_length
suffix = self.suffix % self suffix = self.suffix % self
line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix])
suffix])
self.writeln(line) self.writeln(line)
class ChargingBar(Bar): class ChargingBar(Bar):
suffix = '%(percent)d%%' suffix = "%(percent)d%%"
bar_prefix = ' ' bar_prefix = " "
bar_suffix = ' ' bar_suffix = " "
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class FillingSquaresBar(ChargingBar): class FillingSquaresBar(ChargingBar):
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class FillingCirclesBar(ChargingBar): class FillingCirclesBar(ChargingBar):
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class IncrementalBar(Bar): class IncrementalBar(Bar):
phases = (' ', '', '', '', '', '', '', '', '') phases = (" ", "", "", "", "", "", "", "", "")
def update(self): def update(self):
nphases = len(self.phases) nphases = len(self.phases)
filled_len = self.width * self.progress filled_len = self.width * self.progress
nfull = int(filled_len) # Number of full chars nfull = int(filled_len) # Number of full chars
phase = int((filled_len - nfull) * nphases) # Phase of last char phase = int((filled_len - nfull) * nphases) # Phase of last char
nempty = self.width - nfull # Number of empty chars nempty = self.width - nfull # Number of empty chars
message = self.message % self message = self.message % self
bar = self.phases[-1] * nfull bar = self.phases[-1] * nfull
current = self.phases[phase] if phase > 0 else '' current = self.phases[phase] if phase > 0 else ""
empty = self.empty_fill * max(0, nempty - len(current)) empty = self.empty_fill * max(0, nempty - len(current))
suffix = self.suffix % self suffix = self.suffix % self
line = ''.join([message, self.bar_prefix, bar, current, empty, line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix])
self.bar_suffix, suffix])
self.writeln(line) self.writeln(line)
class PixelBar(IncrementalBar): class PixelBar(IncrementalBar):
phases = ('', '', '', '', '', '', '', '') phases = ("", "", "", "", "", "", "", "")
class ShadyBar(IncrementalBar): class ShadyBar(IncrementalBar):
phases = (' ', '', '', '', '') phases = (" ", "", "", "", "")

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,13 +12,12 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite, Progress from . import Infinite, Progress
from .helpers import WriteMixin from .helpers import WriteMixin
class Counter(WriteMixin, Infinite): class Counter(WriteMixin, Infinite):
message = '' message = ""
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -35,7 +32,7 @@ class Countdown(WriteMixin, Progress):
class Stack(WriteMixin, Progress): class Stack(WriteMixin, Progress):
phases = (' ', '', '', '', '', '', '', '', '') phases = (" ", "", "", "", "", "", "", "", "")
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -45,4 +42,4 @@ class Stack(WriteMixin, Progress):
class Pie(Stack): class Pie(Stack):
phases = ('', '', '', '', '') phases = ("", "", "", "", "")

View file

@ -12,78 +12,76 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import print_function
HIDE_CURSOR = "\x1b[?25l"
SHOW_CURSOR = "\x1b[?25h"
HIDE_CURSOR = '\x1b[?25l' class WriteMixin:
SHOW_CURSOR = '\x1b[?25h'
class WriteMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super(WriteMixin, self).__init__(**kwargs) super().__init__(**kwargs)
self._width = 0 self._width = 0
if message: if message:
self.message = message self.message = message
if self.file.isatty(): if self.file.isatty():
if self.hide_cursor: if self.hide_cursor:
print(HIDE_CURSOR, end='', file=self.file) print(HIDE_CURSOR, end="", file=self.file)
print(self.message, end='', file=self.file) print(self.message, end="", file=self.file)
self.file.flush() self.file.flush()
def write(self, s): def write(self, s):
if self.file.isatty(): if self.file.isatty():
b = '\b' * self._width b = "\b" * self._width
c = s.ljust(self._width) c = s.ljust(self._width)
print(b + c, end='', file=self.file) print(b + c, end="", file=self.file)
self._width = max(self._width, len(s)) self._width = max(self._width, len(s))
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(SHOW_CURSOR, end='', file=self.file) print(SHOW_CURSOR, end="", file=self.file)
class WritelnMixin(object): class WritelnMixin:
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super(WritelnMixin, self).__init__(**kwargs) super().__init__(**kwargs)
if message: if message:
self.message = message self.message = message
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(HIDE_CURSOR, end='', file=self.file) print(HIDE_CURSOR, end="", file=self.file)
def clearln(self): def clearln(self):
if self.file.isatty(): if self.file.isatty():
print('\r\x1b[K', end='', file=self.file) print("\r\x1b[K", end="", file=self.file)
def writeln(self, line): def writeln(self, line):
if self.file.isatty(): if self.file.isatty():
self.clearln() self.clearln()
print(line, end='', file=self.file) print(line, end="", file=self.file)
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty(): if self.file.isatty():
print(file=self.file) print(file=self.file)
if self.hide_cursor: if self.hide_cursor:
print(SHOW_CURSOR, end='', file=self.file) print(SHOW_CURSOR, end="", file=self.file)
from signal import signal, SIGINT from signal import SIGINT, signal
from sys import exit from sys import exit
class SigIntMixin(object): class SigIntMixin:
"""Registers a signal handler that calls finish on SIGINT""" """Registers a signal handler that calls finish on SIGINT."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SigIntMixin, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
signal(SIGINT, self._sigint_handler) signal(SIGINT, self._sigint_handler)
def _sigint_handler(self, signum, frame): def _sigint_handler(self, signum, frame):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,14 +12,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite from . import Infinite
from .helpers import WriteMixin from .helpers import WriteMixin
class Spinner(WriteMixin, Infinite): class Spinner(WriteMixin, Infinite):
message = '' message = ""
phases = ('-', '\\', '|', '/') phases = ("-", "\\", "|", "/")
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -30,15 +27,16 @@ class Spinner(WriteMixin, Infinite):
class PieSpinner(Spinner): class PieSpinner(Spinner):
phases = ['', '', '', ''] phases = ["", "", "", ""]
class MoonSpinner(Spinner): class MoonSpinner(Spinner):
phases = ['', '', '', ''] phases = ["", "", "", ""]
class LineSpinner(Spinner): class LineSpinner(Spinner):
phases = ['', '', '', '', '', ''] phases = ["", "", "", "", "", ""]
class PixelSpinner(Spinner): class PixelSpinner(Spinner):
phases = ['','', '', '', '', '', '', ''] phases = ["", "", "", "", "", "", "", ""]

View file

@ -1,29 +1,27 @@
#!/usr/bin/env python #!/usr/bin/env python
import progress
from setuptools import setup from setuptools import setup
import progress
setup( setup(
name='progress', name="progress",
version=progress.__version__, version=progress.__version__,
description='Easy to use progress bars', description="Easy to use progress bars",
long_description=open('README.rst').read(), long_description=open("README.rst").read(),
author='Giorgos Verigakis', author="Giorgos Verigakis",
author_email='verigak@gmail.com', author_email="verigak@gmail.com",
url='http://github.com/verigak/progress/', url="http://github.com/verigak/progress/",
license='ISC', license="ISC",
packages=['progress'], packages=["progress"],
classifiers=[ classifiers=[
'Environment :: Console', "Environment :: Console",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: ISC License (ISCL)', "License :: OSI Approved :: ISC License (ISCL)",
'Programming Language :: Python :: 2.6', "Programming Language :: Python :: 2.6",
'Programming Language :: Python :: 2.7', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3.3', "Programming Language :: Python :: 3.3",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 3.4",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
] ],
) )

View file

@ -1,16 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function
import random import random
import time import time
from progress.bar import (Bar, ChargingBar, FillingSquaresBar, from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar
FillingCirclesBar, IncrementalBar, PixelBar, from progress.counter import Countdown, Counter, Pie, Stack
ShadyBar) from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
PixelSpinner)
from progress.counter import Counter, Countdown, Stack, Pie
def sleep(): def sleep():
@ -20,29 +16,29 @@ def sleep():
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]"
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for i in bar.iter(range(200)): for _i in bar.iter(range(200)):
sleep() sleep()
for bar_cls in (IncrementalBar, PixelBar, ShadyBar): for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]"
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for i in bar.iter(range(200)): for _i in bar.iter(range(200)):
sleep() sleep()
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
for i in spin(spin.__name__ + ' ').iter(range(100)): for _i in spin(spin.__name__ + " ").iter(range(100)):
sleep() sleep()
print() print()
for singleton in (Counter, Countdown, Stack, Pie): for singleton in (Counter, Countdown, Stack, Pie):
for i in singleton(singleton.__name__ + ' ').iter(range(100)): for _i in singleton(singleton.__name__ + " ").iter(range(100)):
sleep() sleep()
print() print()
bar = IncrementalBar('Random', suffix='%(index)d') bar = IncrementalBar("Random", suffix="%(index)d")
for i in range(100): for _i in range(100):
bar.goto(random.randint(0, 100)) bar.goto(random.randint(0, 100))
sleep() sleep()
bar.finish() bar.finish()

View file

@ -1,47 +1,52 @@
import argparse import argparse
import os
import datetime import datetime
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import progress_bar, IOStream
from data import ModelNet40 from data import ModelNet40
import sklearn.metrics as metrics
from helper import cal_loss from helper import cal_loss
import numpy as np from torch.utils.data import DataLoader
import torch.nn.functional as F from utils import IOStream, progress_bar
model_names = sorted(name for name in models.__dict__ model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
if callable(models.__dict__[name]))
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser('training') parser = argparse.ArgumentParser("training")
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH', parser.add_argument(
help='path to save checkpoint (default: checkpoint)') "-c",
parser.add_argument('--msg', type=str, help='message after checkpoint') "--checkpoint",
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training') type=str,
parser.add_argument('--model', default='model31A', help='model name [default: pointnet_cls]') metavar="PATH",
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') help="path to save checkpoint (default: checkpoint)",
parser.add_argument('--num_points', type=int, default=1024, help='Point Number') )
parser.add_argument('--seed', type=int, help='random seed (default: 1)') parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument("--model", default="model31A", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
parser.add_argument("--seed", type=int, help="random seed (default: 1)")
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py # Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
parser.add_argument('--NUM_PEPEAT', type=int, default=300) parser.add_argument("--NUM_PEPEAT", type=int, default=300)
parser.add_argument('--NUM_VOTE', type=int, default=10) parser.add_argument("--NUM_VOTE", type=int, default=10)
parser.add_argument('--validate', action='store_true', help='Validate the original testing result.') parser.add_argument("--validate", action="store_true", help="Validate the original testing result.")
return parser.parse_args() return parser.parse_args()
class PointcloudScale(object): # input random scaling class PointcloudScale: # input random scaling
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): def __init__(self, scale_low=2.0 / 3.0, scale_high=3.0 / 2.0):
self.scale_low = scale_low self.scale_low = scale_low
self.scale_high = scale_high self.scale_high = scale_high
@ -68,45 +73,52 @@ def main():
torch.set_printoptions(10) torch.set_printoptions(10)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(args.seed) os.environ["PYTHONHASHSEED"] = str(args.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = 'cuda' device = "cuda"
else: else:
device = 'cpu' device = "cpu"
print(f"==> Using device: {device}") print(f"==> Using device: {device}")
if args.msg is None: if args.msg is None:
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) message = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = 'checkpoints/' + args.model + message args.checkpoint = "checkpoints/" + args.model + message
print('==> Preparing data..') print("==> Preparing data..")
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4, test_loader = DataLoader(
batch_size=args.batch_size // 2, shuffle=False, drop_last=False) ModelNet40(partition="test", num_points=args.num_points),
num_workers=4,
batch_size=args.batch_size // 2,
shuffle=False,
drop_last=False,
)
# Model # Model
print('==> Building model..') print("==> Building model..")
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth') checkpoint_path = os.path.join(args.checkpoint, "best_checkpoint.pth")
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == 'cuda': if device == "cuda":
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
net.load_state_dict(checkpoint['net']) net.load_state_dict(checkpoint["net"])
if args.validate: if args.validate:
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}") print(f"Vanilla out: {test_out}")
print(f"Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n" print(
f"Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n" "Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n"
f"[note : Original result is achieved with V100 GPUs.]\n\n\n") "Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n"
"[note : Original result is achieved with V100 GPUs.]\n\n\n",
)
# Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu. # Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu.
# On different GPUs, and different number of GPUs, both OA and mean_acc vary a little. # On different GPUs, and different number of GPUs, both OA and mean_acc vary a little.
# Also, the batch size also affect the testing results, could not understand. # Also, the batch size also affect the testing results, could not understand.
print(f"===> start voting evaluation...") print("===> start voting evaluation...")
voting(net, test_loader, device, args) voting(net, test_loader, device, args)
@ -130,23 +142,28 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost "time": time_cost,
} }
def voting(net, testloader, device, args): def voting(net, testloader, device, args):
name = '/evaluate_voting' + str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) + 'seed_' + str( name = (
args.seed) + '.log' "/evaluate_voting" + str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) + "seed_" + str(args.seed) + ".log"
)
io = IOStream(args.checkpoint + name) io = IOStream(args.checkpoint + name)
io.cprint(str(args)) io.cprint(str(args))
@ -161,7 +178,7 @@ def voting(net, testloader, device, args):
test_true = [] test_true = []
test_pred = [] test_pred = []
for batch_idx, (data, label) in enumerate(testloader): for _batch_idx, (data, label) in enumerate(testloader):
data, label = data.to(device), label.to(device).squeeze() data, label = data.to(device), label.to(device).squeeze()
pred = 0 pred = 0
for v in range(args.NUM_VOTE): for v in range(args.NUM_VOTE):
@ -178,19 +195,24 @@ def voting(net, testloader, device, args):
test_pred.append(pred_choice.detach().cpu().numpy()) test_pred.append(pred_choice.detach().cpu().numpy())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
test_acc = 100. * metrics.accuracy_score(test_true, test_pred) test_acc = 100.0 * metrics.accuracy_score(test_true, test_pred)
test_mean_acc = 100. * metrics.balanced_accuracy_score(test_true, test_pred) test_mean_acc = 100.0 * metrics.balanced_accuracy_score(test_true, test_pred)
if test_acc > best_acc: if test_acc > best_acc:
best_acc = test_acc best_acc = test_acc
if test_mean_acc > best_mean_acc: if test_mean_acc > best_mean_acc:
best_mean_acc = test_mean_acc best_mean_acc = test_mean_acc
outstr = 'Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]' % \ outstr = "Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]" % (
(i, test_acc, test_mean_acc, best_acc, best_mean_acc) i,
test_acc,
test_mean_acc,
best_acc,
best_mean_acc,
)
io.cprint(outstr) io.cprint(outstr)
final_outstr = 'Final voting test acc: %.6f,' % (best_acc * 100) final_outstr = "Final voting test acc: %.6f," % (best_acc * 100)
io.cprint(final_outstr) io.cprint(final_outstr)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,10 +1,7 @@
""" """ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip."""
ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip
"""
import os import os
import sys
import glob
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -14,18 +11,18 @@ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def download(): def download():
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data') DATA_DIR = os.path.join(BASE_DIR, "data")
if not os.path.exists(DATA_DIR): if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR) os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')): if not os.path.exists(os.path.join(DATA_DIR, "h5_files")):
# note that this link only contains the hardest perturbed variant (PB_T50_RS). # note that this link only contains the hardest perturbed variant (PB_T50_RS).
# for full versions, consider the following link. # for full versions, consider the following link.
www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip' www = "https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip"
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip' # www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
zipfile = os.path.basename(www) zipfile = os.path.basename(www)
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) os.system(f"wget {www} --no-check-certificate; unzip {zipfile}")
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) os.system(f"mv {zipfile[:-4]} {DATA_DIR}")
os.system('rm %s' % (zipfile)) os.system("rm %s" % (zipfile))
def load_scanobjectnn_data(partition): def load_scanobjectnn_data(partition):
@ -34,10 +31,10 @@ def load_scanobjectnn_data(partition):
all_data = [] all_data = []
all_label = [] all_label = []
h5_name = BASE_DIR + '/data/h5_files/main_split/' + partition + '_objectdataset_augmentedrot_scale75.h5' h5_name = BASE_DIR + "/data/h5_files/main_split/" + partition + "_objectdataset_augmentedrot_scale75.h5"
f = h5py.File(h5_name, mode="r") f = h5py.File(h5_name, mode="r")
data = f['data'][:].astype('float32') data = f["data"][:].astype("float32")
label = f['label'][:].astype('int64') label = f["label"][:].astype("int64")
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -47,23 +44,22 @@ def load_scanobjectnn_data(partition):
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3]) xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
return translated_pointcloud
class ScanObjectNN(Dataset): class ScanObjectNN(Dataset):
def __init__(self, num_points, partition='training'): def __init__(self, num_points, partition="training"):
self.data, self.label = load_scanobjectnn_data(partition) self.data, self.label = load_scanobjectnn_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition self.partition = partition
def __getitem__(self, item): def __getitem__(self, item):
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][: self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == 'training': if self.partition == "training":
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) np.random.shuffle(pointcloud)
return pointcloud, label return pointcloud, label
@ -72,9 +68,9 @@ class ScanObjectNN(Dataset):
return self.data.shape[0] return self.data.shape[0]
if __name__ == '__main__': if __name__ == "__main__":
train = ScanObjectNN(1024) train = ScanObjectNN(1024)
test = ScanObjectNN(1024, 'test') test = ScanObjectNN(1024, "test")
for data, label in train: for data, label in train:
print(data.shape) print(data.shape)
print(label) print(label)

View file

@ -1,45 +1,50 @@
""" """for training with resume functions.
for training with resume functions.
Usage: Usage:
python main.py --model PointNet --msg demo python main.py --model PointNet --msg demo
or or
CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out & CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &.
""" """
import argparse import argparse
import os
import logging
import datetime import datetime
import logging
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data.distributed import torch.utils.data.distributed
from torch.utils.data import DataLoader
import models as models
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
from ScanObjectNN import ScanObjectNN from ScanObjectNN import ScanObjectNN
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import sklearn.metrics as metrics from torch.utils.data import DataLoader
import numpy as np from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model
def parse_args(): def parse_args():
"""Parameters""" """Parameters"""
parser = argparse.ArgumentParser('training') parser = argparse.ArgumentParser("training")
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH', parser.add_argument(
help='path to save checkpoint (default: checkpoint)') "-c",
parser.add_argument('--msg', type=str, help='message after checkpoint') "--checkpoint",
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training') type=str,
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]') metavar="PATH",
parser.add_argument('--num_classes', default=15, type=int, help='default value for classes of ScanObjectNN') help="path to save checkpoint (default: checkpoint)",
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training') )
parser.add_argument('--num_points', type=int, default=1024, help='Point Number') parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate in training') parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument('--weight_decay', type=float, default=1e-4, help='decay rate') parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]")
parser.add_argument('--smoothing', action='store_true', default=False, help='loss smoothing') parser.add_argument("--num_classes", default=15, type=int, help="default value for classes of ScanObjectNN")
parser.add_argument('--seed', type=int, help='random seed') parser.add_argument("--epoch", default=200, type=int, help="number of epoch in training")
parser.add_argument('--workers', default=4, type=int, help='workers') parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate in training")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="decay rate")
parser.add_argument("--smoothing", action="store_true", default=False, help="loss smoothing")
parser.add_argument("--seed", type=int, help="random seed")
parser.add_argument("--workers", default=4, type=int, help="workers")
return parser.parse_args() return parser.parse_args()
@ -49,23 +54,23 @@ def main():
if args.seed is not None: if args.seed is not None:
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = 'cuda' device = "cuda"
if args.seed is not None: if args.seed is not None:
torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed(args.seed)
else: else:
device = 'cpu' device = "cpu"
time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
if args.msg is None: if args.msg is None:
message = time_str message = time_str
else: else:
message = "-" + args.msg message = "-" + args.msg
args.checkpoint = 'checkpoints/' + args.model + message args.checkpoint = "checkpoints/" + args.model + message
if not os.path.isdir(args.checkpoint): if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint) mkdir_p(args.checkpoint)
screen_logger = logging.getLogger("Model") screen_logger = logging.getLogger("Model")
screen_logger.setLevel(logging.INFO) screen_logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt")) file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
@ -77,19 +82,19 @@ def main():
# Model # Model
printf(f"args: {args}") printf(f"args: {args}")
printf('==> Building model..') printf("==> Building model..")
net = models.__dict__[args.model](num_classes=args.num_classes) net = models.__dict__[args.model](num_classes=args.num_classes)
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == 'cuda': if device == "cuda":
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)
cudnn.benchmark = True cudnn.benchmark = True
best_test_acc = 0. # best test accuracy best_test_acc = 0.0 # best test accuracy
best_train_acc = 0. best_train_acc = 0.0
best_test_acc_avg = 0. best_test_acc_avg = 0.0
best_train_acc_avg = 0. best_train_acc_avg = 0.0
best_test_loss = float("inf") best_test_loss = float("inf")
best_train_loss = float("inf") best_train_loss = float("inf")
start_epoch = 0 # start from epoch 0 or last checkpoint epoch start_epoch = 0 # start from epoch 0 or last checkpoint epoch
@ -97,30 +102,49 @@ def main():
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")): if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
save_args(args) save_args(args)
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model) logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model)
logger.set_names(["Epoch-Num", 'Learning-Rate', logger.set_names(
'Train-Loss', 'Train-acc-B', 'Train-acc', [
'Valid-Loss', 'Valid-acc-B', 'Valid-acc']) "Epoch-Num",
"Learning-Rate",
"Train-Loss",
"Train-acc-B",
"Train-acc",
"Valid-Loss",
"Valid-acc-B",
"Valid-acc",
],
)
else: else:
printf(f"Resuming last checkpoint from {args.checkpoint}") printf(f"Resuming last checkpoint from {args.checkpoint}")
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth") checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['net']) net.load_state_dict(checkpoint["net"])
start_epoch = checkpoint['epoch'] start_epoch = checkpoint["epoch"]
best_test_acc = checkpoint['best_test_acc'] best_test_acc = checkpoint["best_test_acc"]
best_train_acc = checkpoint['best_train_acc'] best_train_acc = checkpoint["best_train_acc"]
best_test_acc_avg = checkpoint['best_test_acc_avg'] best_test_acc_avg = checkpoint["best_test_acc_avg"]
best_train_acc_avg = checkpoint['best_train_acc_avg'] best_train_acc_avg = checkpoint["best_train_acc_avg"]
best_test_loss = checkpoint['best_test_loss'] best_test_loss = checkpoint["best_test_loss"]
best_train_loss = checkpoint['best_train_loss'] best_train_loss = checkpoint["best_train_loss"]
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True) logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True)
optimizer_dict = checkpoint['optimizer'] optimizer_dict = checkpoint["optimizer"]
printf('==> Preparing data..') printf("==> Preparing data..")
train_loader = DataLoader(ScanObjectNN(partition='training', num_points=args.num_points), num_workers=args.workers, train_loader = DataLoader(
batch_size=args.batch_size, shuffle=True, drop_last=True) ScanObjectNN(partition="training", num_points=args.num_points),
test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points), num_workers=args.workers, num_workers=args.workers,
batch_size=args.batch_size, shuffle=True, drop_last=False) batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
ScanObjectNN(partition="test", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
)
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None: if optimizer_dict is not None:
@ -128,7 +152,7 @@ def main():
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1) scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch): for epoch in range(start_epoch, args.epoch):
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr'])) printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"]))
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"} train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
test_out = validate(net, test_loader, criterion, device) test_out = validate(net, test_loader, criterion, device)
scheduler.step() scheduler.step()
@ -147,31 +171,46 @@ def main():
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
save_model( save_model(
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best, net,
epoch,
path=args.checkpoint,
acc=test_out["acc"],
is_best=is_best,
best_test_acc=best_test_acc, # best test accuracy best_test_acc=best_test_acc, # best test accuracy
best_train_acc=best_train_acc, best_train_acc=best_train_acc,
best_test_acc_avg=best_test_acc_avg, best_test_acc_avg=best_test_acc_avg,
best_train_acc_avg=best_train_acc_avg, best_train_acc_avg=best_train_acc_avg,
best_test_loss=best_test_loss, best_test_loss=best_test_loss,
best_train_loss=best_train_loss, best_train_loss=best_train_loss,
optimizer=optimizer.state_dict() optimizer=optimizer.state_dict(),
)
logger.append(
[
epoch,
optimizer.param_groups[0]["lr"],
train_out["loss"],
train_out["acc_avg"],
train_out["acc"],
test_out["loss"],
test_out["acc_avg"],
test_out["acc"],
],
) )
logger.append([epoch, optimizer.param_groups[0]['lr'],
train_out["loss"], train_out["acc_avg"], train_out["acc"],
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
printf( printf(
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s") f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s",
)
printf( printf(
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% " f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n") f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n",
)
logger.close() logger.close()
printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2) printf("++++++++" * 2 + "Final results" + "++++++++" * 2)
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++") printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++") printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++") printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++") printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
printf(f"++++++++" * 5) printf("++++++++" * 5)
def train(net, trainloader, optimizer, criterion, device): def train(net, trainloader, optimizer, criterion, device):
@ -199,17 +238,21 @@ def train(net, trainloader, optimizer, criterion, device):
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(trainloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
train_true = np.concatenate(train_true) train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred) train_pred = np.concatenate(train_pred)
return { return {
"loss": float("%.3f" % (train_loss / (batch_idx + 1))), "loss": float("%.3f" % (train_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))),
"time": time_cost "time": time_cost,
} }
@ -233,19 +276,23 @@ def validate(net, testloader, criterion, device):
test_pred.append(preds.detach().cpu().numpy()) test_pred.append(preds.detach().cpu().numpy())
total += label.size(0) total += label.size(0)
correct += preds.eq(label).sum().item() correct += preds.eq(label).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' progress_bar(
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true) test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred) test_pred = np.concatenate(test_pred)
return { return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))), "loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))), "acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))), "acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost "time": time_cost,
} }
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,3 +1 @@
from __future__ import absolute_import
from .pointmlp import pointMLP, pointMLPElite from .pointmlp import pointMLP, pointMLPElite

View file

@ -1,35 +1,32 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torch import einsum # from torch import einsum
# from einops import rearrange, repeat # from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == 'gelu': if activation.lower() == "gelu":
return nn.GELU() return nn.GELU()
elif activation.lower() == 'rrelu': elif activation.lower() == "rrelu":
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == 'selu': elif activation.lower() == "selu":
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == 'silu': elif activation.lower() == "silu":
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == 'hardswish': elif activation.lower() == "hardswish":
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == 'leakyrelu': elif activation.lower() == "leakyrelu":
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
""" """Calculate Euclid distance between each two points.
Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm;
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -38,23 +35,23 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M] dist: per-point square distance, [B, N, M].
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
""" """Input:
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S] idx: sample index data, [B, S].
Return: Return:
new_points:, indexed points data, [B, S, C] new_points:, indexed points data, [B, S, C].
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -63,17 +60,15 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :] return points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
""" """Input:
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint] centroids: sampled pointcloud index, [B, npoint].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -91,21 +86,21 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
""" """Input:
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] new_xyz: query points, [B, S, 3].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
_, S, _ = new_xyz.shape _, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N group_idx[sqrdists > radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N mask = group_idx == N
@ -114,13 +109,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
""" """Input:
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C] new_xyz: query points, [B, S, C].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -129,13 +124,12 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
""" """Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others :param kwargs: others.
""" """
super(LocalGrouper, self).__init__() super().__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -144,11 +138,11 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel = 3 if self.use_xyz else 0
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
def forward(self, xyz, points): def forward(self, xyz, points):
@ -167,29 +161,33 @@ class LocalGrouper(nn.Module):
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
grouped_points = index_points(points, idx) # [B, npoint, k, d] grouped_points = index_points(points, idx) # [B, npoint, k, d]
if self.use_xyz: if self.use_xyz:
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize == "center":
mean = torch.mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize == "anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) std = (
grouped_points = (grouped_points-mean)/(std + 1e-5) torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta .unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points - mean) / (std + 1e-5)
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
return new_xyz, new_points return new_xyz, new_points
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
super(ConvBNReLU1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act self.act,
) )
def forward(self, x): def forward(self, x):
@ -197,30 +195,43 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
super(ConvBNReLURes1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=channel,
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act self.act,
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=int(channel * res_expansion),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, bias=bias), in_channels=int(channel * res_expansion),
nn.BatchNorm1d(channel) out_channels=channel,
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -228,21 +239,34 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, def __init__(
activation='relu', use_xyz=True): self,
""" channels,
input: [b,g,k,d]: output:[b,d,g] out_channels,
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PreExtraction, self).__init__() super().__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3 + 2 * channels if use_xyz else 2 * channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, ConvBNReLURes1D(
bias=bias, activation=activation) out_channels,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -254,22 +278,20 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
x = x.reshape(b, n, -1).permute(0, 2, 1) return x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
""" """input[b,d,g]; output[b,d,g]
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PosExtraction, self).__init__() super().__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -278,17 +300,32 @@ class PosExtraction(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0, def __init__(
activation="relu", bias=True, use_xyz=True, normalize="center", self,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], points=1024,
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs): class_num=40,
super(Model, self).__init__() embed_dim=64,
groups=1,
res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="center",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[2, 2, 2, 2],
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = class_num self.class_num = class_num
self.points = points self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ assert (
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -305,13 +342,26 @@ class Model(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, pre_block_module = PreExtraction(
res_expansion=res_expansion, last_channel,
bias=bias, activation=activation, use_xyz=use_xyz) out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, pos_block_module = PosExtraction(
res_expansion=res_expansion, bias=bias, activation=activation) out_channel,
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
@ -326,7 +376,7 @@ class Model(nn.Module):
nn.BatchNorm1d(256), nn.BatchNorm1d(256),
self.act, self.act,
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Linear(256, self.class_num) nn.Linear(256, self.class_num),
) )
def forward(self, x): def forward(self, x):
@ -340,29 +390,52 @@ class Model(nn.Module):
x = self.pos_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g]
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
x = self.classifier(x) return self.classifier(x)
return x
def pointMLP(num_classes=40, **kwargs) -> Model: def pointMLP(num_classes=40, **kwargs) -> Model:
return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0, return Model(
activation="relu", bias=False, use_xyz=False, normalize="anchor", points=1024,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], class_num=num_classes,
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) embed_dim=64,
groups=1,
res_expansion=1.0,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
def pointMLPElite(num_classes=40, **kwargs) -> Model: def pointMLPElite(num_classes=40, **kwargs) -> Model:
return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25, return Model(
activation="relu", bias=False, use_xyz=False, normalize="anchor", points=1024,
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1], class_num=num_classes,
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) embed_dim=32,
groups=1,
res_expansion=0.25,
activation="relu",
bias=False,
use_xyz=False,
normalize="anchor",
dim_expansion=[2, 2, 2, 1],
pre_blocks=[1, 1, 2, 1],
pos_blocks=[1, 1, 2, 1],
k_neighbors=[24, 24, 24, 24],
reducers=[2, 2, 2, 2],
**kwargs,
)
if __name__ == '__main__':
if __name__ == "__main__":
data = torch.rand(2, 3, 1024) data = torch.rand(2, 3, 1024)
print("===> testing pointMLP ...") print("===> testing pointMLP ...")
model = pointMLP() model = pointMLP()
out = model(data) out = model(data)
print(out.shape) print(out.shape)

View file

@ -1,5 +1,4 @@
"""Useful utils """Useful utils."""
"""
from .misc import *
from .logger import * from .logger import *
from .misc import *
from .progress.progress.bar import Bar as Bar from .progress.progress.bar import Bar as Bar

View file

@ -1,89 +1,92 @@
# A simple torch style logger # A simple torch style logger
# (C) Wei YANG 2017 # (C) Wei YANG 2017
from __future__ import absolute_import
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
import sys
import numpy as np import numpy as np
__all__ = ['Logger', 'LoggerMonitor', 'savefig'] __all__ = ["Logger", "LoggerMonitor", "savefig"]
def savefig(fname, dpi=None): def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi dpi = 150 if dpi is None else dpi
plt.savefig(fname, dpi=dpi) plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None): def plot_overlap(logger, names=None):
names = logger.names if names == None else names names = logger.names if names is None else names
numbers = logger.numbers numbers = logger.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
return [logger.title + '(' + name + ')' for name in names] return [logger.title + "(" + name + ")" for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.''' class Logger:
def __init__(self, fpath, title=None, resume=False): """Save training process to log file with simple plot function."""
def __init__(self, fpath, title=None, resume=False):
self.file = None self.file = None
self.resume = resume self.resume = resume
self.title = '' if title == None else title self.title = "" if title is None else title
if fpath is not None: if fpath is not None:
if resume: if resume:
self.file = open(fpath, 'r') self.file = open(fpath)
name = self.file.readline() name = self.file.readline()
self.names = name.rstrip().split('\t') self.names = name.rstrip().split("\t")
self.numbers = {} self.numbers = {}
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.numbers[name] = [] self.numbers[name] = []
for numbers in self.file: for numbers in self.file:
numbers = numbers.rstrip().split('\t') numbers = numbers.rstrip().split("\t")
for i in range(0, len(numbers)): for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i]) self.numbers[self.names[i]].append(numbers[i])
self.file.close() self.file.close()
self.file = open(fpath, 'a') self.file = open(fpath, "a")
else: else:
self.file = open(fpath, 'w') self.file = open(fpath, "w")
def set_names(self, names): def set_names(self, names):
if self.resume: if self.resume:
pass pass
# initialize numbers as empty list # initialize numbers as empty list
self.numbers = {} self.numbers = {}
self.names = names self.names = names
for _, name in enumerate(self.names): for _, name in enumerate(self.names):
self.file.write(name) self.file.write(name)
self.file.write('\t') self.file.write("\t")
self.numbers[name] = [] self.numbers[name] = []
self.file.write('\n') self.file.write("\n")
self.file.flush() self.file.flush()
def append(self, numbers): def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names' assert len(self.names) == len(numbers), "Numbers do not match names"
for index, num in enumerate(numbers): for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num)) self.file.write(f"{num:.6f}")
self.file.write('\t') self.file.write("\t")
self.numbers[self.names[index]].append(num) self.numbers[self.names[index]].append(num)
self.file.write('\n') self.file.write("\n")
self.file.flush() self.file.flush()
def plot(self, names=None): def plot(self, names=None):
names = self.names if names == None else names names = self.names if names is None else names
numbers = self.numbers numbers = self.numbers
for _, name in enumerate(names): for _, name in enumerate(names):
x = np.arange(len(numbers[name])) x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name])) plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + '(' + name + ')' for name in names]) plt.legend([self.title + "(" + name + ")" for name in names])
plt.grid(True) plt.grid(True)
def close(self): def close(self):
if self.file is not None: if self.file is not None:
self.file.close() self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.''' class LoggerMonitor:
def __init__ (self, paths): """Load and visualize multiple logs."""
'''paths is a distionary with {name:filepath} pair'''
def __init__(self, paths):
"""Paths is a distionary with {name:filepath} pair."""
self.loggers = [] self.loggers = []
for title, path in paths.items(): for title, path in paths.items():
logger = Logger(path, title=title, resume=True) logger = Logger(path, title=title, resume=True)
@ -95,10 +98,11 @@ class LoggerMonitor(object):
legend_text = [] legend_text = []
for logger in self.loggers: for logger in self.loggers:
legend_text += plot_overlap(logger, names) legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
plt.grid(True) plt.grid(True)
if __name__ == '__main__':
if __name__ == "__main__":
# # Example # # Example
# logger = Logger('test.txt') # logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss']) # logger.set_names(['Train loss', 'Valid loss','Test loss'])
@ -115,13 +119,13 @@ if __name__ == '__main__':
# Example: logger monitor # Example: logger monitor
paths = { paths = {
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', "resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', "resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', "resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
} }
field = ['Valid Acc.'] field = ["Valid Acc."]
monitor = LoggerMonitor(paths) monitor = LoggerMonitor(paths)
monitor.plot(names=field) monitor.plot(names=field)
savefig('test.eps') savefig("test.eps")

View file

@ -1,48 +1,56 @@
'''Some helper functions for PyTorch, including: """Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset. - get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization. - msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress. - progress_bar: progress bar mimic xlua.progress.
''' """
import errno import errno
import os import os
import random
import shutil
import sys import sys
import time import time
import math
import torch
import shutil
import numpy as np import numpy as np
import random import torch
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.autograd import Variable
__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', __all__ = [
'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"] "get_mean_and_std",
"init_params",
"mkdir_p",
"AverageMeter",
"progress_bar",
"save_model",
"save_args",
"set_seed",
"IOStream",
"cal_loss",
]
def get_mean_and_std(dataset): def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.''' """Compute the mean and std value of dataset."""
dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3) mean = torch.zeros(3)
std = torch.zeros(3) std = torch.zeros(3)
print('==> Computing mean and std..') print("==> Computing mean and std..")
for inputs, targets in dataloader: for inputs, _targets in dataloader:
for i in range(3): for i in range(3):
mean[i] += inputs[:,i,:,:].mean() mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:,i,:,:].std() std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset)) mean.div_(len(dataset))
std.div_(len(dataset)) std.div_(len(dataset))
return mean, std return mean, std
def init_params(net): def init_params(net):
'''Init layer parameters.''' """Init layer parameters."""
for m in net.modules(): for m in net.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out') init.kaiming_normal(m.weight, mode="fan_out")
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -53,8 +61,9 @@ def init_params(net):
if m.bias: if m.bias:
init.constant(m.bias, 0) init.constant(m.bias, 0)
def mkdir_p(path): def mkdir_p(path):
'''make dir if not exist''' """Make dir if not exist."""
try: try:
os.makedirs(path) os.makedirs(path)
except OSError as exc: # Python >2.5 except OSError as exc: # Python >2.5
@ -63,10 +72,12 @@ def mkdir_p(path):
else: else:
raise raise
class AverageMeter(object):
class AverageMeter:
"""Computes and stores the average and current value """Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
""" """
def __init__(self): def __init__(self):
self.reset() self.reset()
@ -83,25 +94,26 @@ class AverageMeter(object):
self.avg = self.sum / self.count self.avg = self.sum / self.count
TOTAL_BAR_LENGTH = 65.0
TOTAL_BAR_LENGTH = 65.
last_time = time.time() last_time = time.time()
begin_time = last_time begin_time = last_time
def progress_bar(current, total, msg=None): def progress_bar(current, total, msg=None):
global last_time, begin_time global last_time, begin_time
if current == 0: if current == 0:
begin_time = time.time() # Reset for new bar. begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total) cur_len = int(TOTAL_BAR_LENGTH * current / total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [') sys.stdout.write(" [")
for i in range(cur_len): for _i in range(cur_len):
sys.stdout.write('=') sys.stdout.write("=")
sys.stdout.write('>') sys.stdout.write(">")
for i in range(rest_len): for _i in range(rest_len):
sys.stdout.write('.') sys.stdout.write(".")
sys.stdout.write(']') sys.stdout.write("]")
cur_time = time.time() cur_time = time.time()
step_time = cur_time - last_time step_time = cur_time - last_time
@ -109,12 +121,12 @@ def progress_bar(current, total, msg=None):
tot_time = cur_time - begin_time tot_time = cur_time - begin_time
L = [] L = []
L.append(' Step: %s' % format_time(step_time)) L.append(" Step: %s" % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time)) L.append(" | Tot: %s" % format_time(tot_time))
if msg: if msg:
L.append(' | ' + msg) L.append(" | " + msg)
msg = ''.join(L) msg = "".join(L)
sys.stdout.write(msg) sys.stdout.write(msg)
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ') # sys.stdout.write(' ')
@ -122,76 +134,74 @@ def progress_bar(current, total, msg=None):
# Go back to the center of the bar. # Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b') # sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total)) sys.stdout.write(" %d/%d " % (current + 1, total))
if current < total-1: if current < total - 1:
sys.stdout.write('\r') sys.stdout.write("\r")
else: else:
sys.stdout.write('\n') sys.stdout.write("\n")
sys.stdout.flush() sys.stdout.flush()
def format_time(seconds): def format_time(seconds):
days = int(seconds / 3600/24) days = int(seconds / 3600 / 24)
seconds = seconds - days*3600*24 seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600) hours = int(seconds / 3600)
seconds = seconds - hours*3600 seconds = seconds - hours * 3600
minutes = int(seconds / 60) minutes = int(seconds / 60)
seconds = seconds - minutes*60 seconds = seconds - minutes * 60
secondsf = int(seconds) secondsf = int(seconds)
seconds = seconds - secondsf seconds = seconds - secondsf
millis = int(seconds*1000) millis = int(seconds * 1000)
f = '' f = ""
i = 1 i = 1
if days > 0: if days > 0:
f += str(days) + 'D' f += str(days) + "D"
i += 1 i += 1
if hours > 0 and i <= 2: if hours > 0 and i <= 2:
f += str(hours) + 'h' f += str(hours) + "h"
i += 1 i += 1
if minutes > 0 and i <= 2: if minutes > 0 and i <= 2:
f += str(minutes) + 'm' f += str(minutes) + "m"
i += 1 i += 1
if secondsf > 0 and i <= 2: if secondsf > 0 and i <= 2:
f += str(secondsf) + 's' f += str(secondsf) + "s"
i += 1 i += 1
if millis > 0 and i <= 2: if millis > 0 and i <= 2:
f += str(millis) + 'ms' f += str(millis) + "ms"
i += 1 i += 1
if f == '': if f == "":
f = '0ms' f = "0ms"
return f return f
def save_model(net, epoch, path, acc, is_best, **kwargs): def save_model(net, epoch, path, acc, is_best, **kwargs):
state = { state = {
'net': net.state_dict(), "net": net.state_dict(),
'epoch': epoch, "epoch": epoch,
'acc': acc "acc": acc,
} }
for key, value in kwargs.items(): for key, value in kwargs.items():
state[key] = value state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth") filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath) torch.save(state, filepath)
if is_best: if is_best:
shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth')) shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
def save_args(args): def save_args(args):
file = open(os.path.join(args.checkpoint, 'args.txt'), "w") file = open(os.path.join(args.checkpoint, "args.txt"), "w")
for k, v in vars(args).items(): for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n") file.write(f"{k}:\t {v}\n")
file.close() file.close()
def set_seed(seed=None): def set_seed(seed=None):
if seed is None: if seed is None:
return return
random.seed(seed) random.seed(seed)
os.environ['PYTHONHASHSEED'] = ("%s" % seed) os.environ["PYTHONHASHSEED"] = "%s" % seed
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@ -200,15 +210,14 @@ def set_seed(seed=None):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create a file and write the text into it # create a file and write the text into it
class IOStream(): class IOStream:
def __init__(self, path): def __init__(self, path):
self.f = open(path, 'a') self.f = open(path, "a")
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text+'\n') self.f.write(text + "\n")
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -216,8 +225,7 @@ class IOStream():
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. ''' """Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1) gold = gold.contiguous().view(-1)
if smoothing: if smoothing:
@ -230,6 +238,6 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction='mean') loss = F.cross_entropy(pred, gold, reduction="mean")
return loss return loss

View file

@ -12,7 +12,6 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import division
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
@ -20,13 +19,12 @@ from math import ceil
from sys import stderr from sys import stderr
from time import time from time import time
__version__ = "1.3"
__version__ = '1.3'
class Infinite(object): class Infinite:
file = stderr file = stderr
sma_window = 10 # Simple Moving Average window sma_window = 10 # Simple Moving Average window
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.index = 0 self.index = 0
@ -38,7 +36,7 @@ class Infinite(object):
setattr(self, key, val) setattr(self, key, val)
def __getitem__(self, key): def __getitem__(self, key):
if key.startswith('_'): if key.startswith("_"):
return None return None
return getattr(self, key, None) return getattr(self, key, None)
@ -83,8 +81,8 @@ class Infinite(object):
class Progress(Infinite): class Progress(Infinite):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Progress, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.max = kwargs.get('max', 100) self.max = kwargs.get("max", 100)
@property @property
def eta(self): def eta(self):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,19 +12,18 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Progress from . import Progress
from .helpers import WritelnMixin from .helpers import WritelnMixin
class Bar(WritelnMixin, Progress): class Bar(WritelnMixin, Progress):
width = 32 width = 32
message = '' message = ""
suffix = '%(index)d/%(max)d' suffix = "%(index)d/%(max)d"
bar_prefix = ' |' bar_prefix = " |"
bar_suffix = '| ' bar_suffix = "| "
empty_fill = ' ' empty_fill = " "
fill = '#' fill = "#"
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -37,52 +34,50 @@ class Bar(WritelnMixin, Progress):
bar = self.fill * filled_length bar = self.fill * filled_length
empty = self.empty_fill * empty_length empty = self.empty_fill * empty_length
suffix = self.suffix % self suffix = self.suffix % self
line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix])
suffix])
self.writeln(line) self.writeln(line)
class ChargingBar(Bar): class ChargingBar(Bar):
suffix = '%(percent)d%%' suffix = "%(percent)d%%"
bar_prefix = ' ' bar_prefix = " "
bar_suffix = ' ' bar_suffix = " "
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class FillingSquaresBar(ChargingBar): class FillingSquaresBar(ChargingBar):
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class FillingCirclesBar(ChargingBar): class FillingCirclesBar(ChargingBar):
empty_fill = '' empty_fill = ""
fill = '' fill = ""
class IncrementalBar(Bar): class IncrementalBar(Bar):
phases = (' ', '', '', '', '', '', '', '', '') phases = (" ", "", "", "", "", "", "", "", "")
def update(self): def update(self):
nphases = len(self.phases) nphases = len(self.phases)
filled_len = self.width * self.progress filled_len = self.width * self.progress
nfull = int(filled_len) # Number of full chars nfull = int(filled_len) # Number of full chars
phase = int((filled_len - nfull) * nphases) # Phase of last char phase = int((filled_len - nfull) * nphases) # Phase of last char
nempty = self.width - nfull # Number of empty chars nempty = self.width - nfull # Number of empty chars
message = self.message % self message = self.message % self
bar = self.phases[-1] * nfull bar = self.phases[-1] * nfull
current = self.phases[phase] if phase > 0 else '' current = self.phases[phase] if phase > 0 else ""
empty = self.empty_fill * max(0, nempty - len(current)) empty = self.empty_fill * max(0, nempty - len(current))
suffix = self.suffix % self suffix = self.suffix % self
line = ''.join([message, self.bar_prefix, bar, current, empty, line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix])
self.bar_suffix, suffix])
self.writeln(line) self.writeln(line)
class PixelBar(IncrementalBar): class PixelBar(IncrementalBar):
phases = ('', '', '', '', '', '', '', '') phases = ("", "", "", "", "", "", "", "")
class ShadyBar(IncrementalBar): class ShadyBar(IncrementalBar):
phases = (' ', '', '', '', '') phases = (" ", "", "", "", "")

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,13 +12,12 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite, Progress from . import Infinite, Progress
from .helpers import WriteMixin from .helpers import WriteMixin
class Counter(WriteMixin, Infinite): class Counter(WriteMixin, Infinite):
message = '' message = ""
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -35,7 +32,7 @@ class Countdown(WriteMixin, Progress):
class Stack(WriteMixin, Progress): class Stack(WriteMixin, Progress):
phases = (' ', '', '', '', '', '', '', '', '') phases = (" ", "", "", "", "", "", "", "", "")
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -45,4 +42,4 @@ class Stack(WriteMixin, Progress):
class Pie(Stack): class Pie(Stack):
phases = ('', '', '', '', '') phases = ("", "", "", "", "")

View file

@ -12,78 +12,76 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import print_function
HIDE_CURSOR = "\x1b[?25l"
SHOW_CURSOR = "\x1b[?25h"
HIDE_CURSOR = '\x1b[?25l' class WriteMixin:
SHOW_CURSOR = '\x1b[?25h'
class WriteMixin(object):
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super(WriteMixin, self).__init__(**kwargs) super().__init__(**kwargs)
self._width = 0 self._width = 0
if message: if message:
self.message = message self.message = message
if self.file.isatty(): if self.file.isatty():
if self.hide_cursor: if self.hide_cursor:
print(HIDE_CURSOR, end='', file=self.file) print(HIDE_CURSOR, end="", file=self.file)
print(self.message, end='', file=self.file) print(self.message, end="", file=self.file)
self.file.flush() self.file.flush()
def write(self, s): def write(self, s):
if self.file.isatty(): if self.file.isatty():
b = '\b' * self._width b = "\b" * self._width
c = s.ljust(self._width) c = s.ljust(self._width)
print(b + c, end='', file=self.file) print(b + c, end="", file=self.file)
self._width = max(self._width, len(s)) self._width = max(self._width, len(s))
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(SHOW_CURSOR, end='', file=self.file) print(SHOW_CURSOR, end="", file=self.file)
class WritelnMixin(object): class WritelnMixin:
hide_cursor = False hide_cursor = False
def __init__(self, message=None, **kwargs): def __init__(self, message=None, **kwargs):
super(WritelnMixin, self).__init__(**kwargs) super().__init__(**kwargs)
if message: if message:
self.message = message self.message = message
if self.file.isatty() and self.hide_cursor: if self.file.isatty() and self.hide_cursor:
print(HIDE_CURSOR, end='', file=self.file) print(HIDE_CURSOR, end="", file=self.file)
def clearln(self): def clearln(self):
if self.file.isatty(): if self.file.isatty():
print('\r\x1b[K', end='', file=self.file) print("\r\x1b[K", end="", file=self.file)
def writeln(self, line): def writeln(self, line):
if self.file.isatty(): if self.file.isatty():
self.clearln() self.clearln()
print(line, end='', file=self.file) print(line, end="", file=self.file)
self.file.flush() self.file.flush()
def finish(self): def finish(self):
if self.file.isatty(): if self.file.isatty():
print(file=self.file) print(file=self.file)
if self.hide_cursor: if self.hide_cursor:
print(SHOW_CURSOR, end='', file=self.file) print(SHOW_CURSOR, end="", file=self.file)
from signal import signal, SIGINT from signal import SIGINT, signal
from sys import exit from sys import exit
class SigIntMixin(object): class SigIntMixin:
"""Registers a signal handler that calls finish on SIGINT""" """Registers a signal handler that calls finish on SIGINT."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SigIntMixin, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
signal(SIGINT, self._sigint_handler) signal(SIGINT, self._sigint_handler)
def _sigint_handler(self, signum, frame): def _sigint_handler(self, signum, frame):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com> # Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
# #
# Permission to use, copy, modify, and distribute this software for any # Permission to use, copy, modify, and distribute this software for any
@ -14,14 +12,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
from . import Infinite from . import Infinite
from .helpers import WriteMixin from .helpers import WriteMixin
class Spinner(WriteMixin, Infinite): class Spinner(WriteMixin, Infinite):
message = '' message = ""
phases = ('-', '\\', '|', '/') phases = ("-", "\\", "|", "/")
hide_cursor = True hide_cursor = True
def update(self): def update(self):
@ -30,15 +27,16 @@ class Spinner(WriteMixin, Infinite):
class PieSpinner(Spinner): class PieSpinner(Spinner):
phases = ['', '', '', ''] phases = ["", "", "", ""]
class MoonSpinner(Spinner): class MoonSpinner(Spinner):
phases = ['', '', '', ''] phases = ["", "", "", ""]
class LineSpinner(Spinner): class LineSpinner(Spinner):
phases = ['', '', '', '', '', ''] phases = ["", "", "", "", "", ""]
class PixelSpinner(Spinner): class PixelSpinner(Spinner):
phases = ['','', '', '', '', '', '', ''] phases = ["", "", "", "", "", "", "", ""]

View file

@ -1,29 +1,27 @@
#!/usr/bin/env python #!/usr/bin/env python
import progress
from setuptools import setup from setuptools import setup
import progress
setup( setup(
name='progress', name="progress",
version=progress.__version__, version=progress.__version__,
description='Easy to use progress bars', description="Easy to use progress bars",
long_description=open('README.rst').read(), long_description=open("README.rst").read(),
author='Giorgos Verigakis', author="Giorgos Verigakis",
author_email='verigak@gmail.com', author_email="verigak@gmail.com",
url='http://github.com/verigak/progress/', url="http://github.com/verigak/progress/",
license='ISC', license="ISC",
packages=['progress'], packages=["progress"],
classifiers=[ classifiers=[
'Environment :: Console', "Environment :: Console",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: ISC License (ISCL)', "License :: OSI Approved :: ISC License (ISCL)",
'Programming Language :: Python :: 2.6', "Programming Language :: Python :: 2.6",
'Programming Language :: Python :: 2.7', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3.3', "Programming Language :: Python :: 3.3",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 3.4",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
] ],
) )

View file

@ -1,16 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function
import random import random
import time import time
from progress.bar import (Bar, ChargingBar, FillingSquaresBar, from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar
FillingCirclesBar, IncrementalBar, PixelBar, from progress.counter import Countdown, Counter, Pie, Stack
ShadyBar) from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
PixelSpinner)
from progress.counter import Counter, Countdown, Stack, Pie
def sleep(): def sleep():
@ -20,29 +16,29 @@ def sleep():
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]"
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for i in bar.iter(range(200)): for _i in bar.iter(range(200)):
sleep() sleep()
for bar_cls in (IncrementalBar, PixelBar, ShadyBar): for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]"
bar = bar_cls(bar_cls.__name__, suffix=suffix) bar = bar_cls(bar_cls.__name__, suffix=suffix)
for i in bar.iter(range(200)): for _i in bar.iter(range(200)):
sleep() sleep()
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
for i in spin(spin.__name__ + ' ').iter(range(100)): for _i in spin(spin.__name__ + " ").iter(range(100)):
sleep() sleep()
print() print()
for singleton in (Counter, Countdown, Stack, Pie): for singleton in (Counter, Countdown, Stack, Pie):
for i in singleton(singleton.__name__ + ' ').iter(range(100)): for _i in singleton(singleton.__name__ + " ").iter(range(100)):
sleep() sleep()
print() print()
bar = IncrementalBar('Random', suffix='%(index)d') bar = IncrementalBar("Random", suffix="%(index)d")
for i in range(100): for _i in range(100):
bar.goto(random.randint(0, 100)) bar.goto(random.randint(0, 100))
sleep() sleep()
bar.finish() bar.finish()

View file

@ -1,23 +1,33 @@
name: pointmlp name: pointmlp
channels: channels:
- pytorch - pytorch
- nvidia - nvidia
- conda-forge - conda-forge
dependencies: dependencies:
# - cudatoolkit=10.2.89 #---# basic python
- cudatoolkit=11.1 - pytorch
- cycler=0.10.0 - tqdm
- einops=0.3.0 - numpy
- h5py=3.2.1 - scipy
- matplotlib=3.4.2 - scikit-learn
- numpy=1.20.2 #---# file readers
- numpy-base=1.20.2 - h5py
- pytorch=1.8.1 - pyyaml
- pyyaml=5.4.1 #---# tooling (linting, typing...)
- scikit-learn=0.24.2 - ruff
- scipy=1.6.3 - mypy
- torchvision=0.9.1 - black
- tqdm=4.61.1 - isort
#---# visu
- matplotlib
#---# pytorch
- cudatoolkit
- cycler
- einops
- torchvision
- pip - pip
- pip: - pip:
- pointnet2_ops_lib/. - pointnet2_ops_lib/.

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 MiB

View file

@ -1,30 +1,46 @@
from __future__ import print_function
import os
import argparse import argparse
import torch import os
import torch.optim as optim import random
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR from collections import defaultdict
from util.data_util import PartNormalDataset
import torch.nn.functional as F
import torch.nn as nn
import model as models import model as models
import numpy as np import numpy as np
from torch.utils.data import DataLoader import torch
from util.util import to_categorical, compute_overall_iou, IOStream import torch.nn as nn
from tqdm import tqdm import torch.nn.functional as F
from collections import defaultdict import torch.optim as optim
from torch.autograd import Variable from torch.autograd import Variable
import random from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from util.data_util import PartNormalDataset
from util.util import IOStream, compute_overall_iou, to_categorical
classes_str = [
classes_str = ['aero','bag','cap','car','chair','ear','guitar','knife','lamp','lapt','moto','mug','Pistol','rock','stake','table'] "aero",
"bag",
"cap",
"car",
"chair",
"ear",
"guitar",
"knife",
"lamp",
"lapt",
"moto",
"mug",
"Pistol",
"rock",
"stake",
"table",
]
def _init_(): def _init_():
if not os.path.exists('checkpoints'): if not os.path.exists("checkpoints"):
os.makedirs('checkpoints') os.makedirs("checkpoints")
if not os.path.exists('checkpoints/' + args.exp_name): if not os.path.exists("checkpoints/" + args.exp_name):
os.makedirs('checkpoints/' + args.exp_name) os.makedirs("checkpoints/" + args.exp_name)
def weight_init(m): def weight_init(m):
@ -49,7 +65,6 @@ def weight_init(m):
def train(args, io): def train(args, io):
# ============= Model =================== # ============= Model ===================
num_part = 50 num_part = 50
device = torch.device("cuda" if args.cuda else "cpu") device = torch.device("cuda" if args.cuda else "cpu")
@ -61,16 +76,19 @@ def train(args, io):
model = nn.DataParallel(model) model = nn.DataParallel(model)
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
'''Resume or not''' """Resume or not"""
if args.resume: if args.resume:
state_dict = torch.load("checkpoints/%s/best_insiou_model.pth" % args.exp_name, state_dict = torch.load(
map_location=torch.device('cpu'))['model'] "checkpoints/%s/best_insiou_model.pth" % args.exp_name,
map_location=torch.device("cpu"),
)["model"]
for k in state_dict.keys(): for k in state_dict.keys():
if 'module' not in k: if "module" not in k:
from collections import OrderedDict from collections import OrderedDict
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k in state_dict: for k in state_dict:
new_state_dict['module.' + k] = state_dict[k] new_state_dict["module." + k] = state_dict[k]
state_dict = new_state_dict state_dict = new_state_dict
break break
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
@ -81,27 +99,37 @@ def train(args, io):
print("Training from scratch...") print("Training from scratch...")
# =========== Dataloader ================= # =========== Dataloader =================
train_data = PartNormalDataset(npoints=2048, split='trainval', normalize=False) train_data = PartNormalDataset(npoints=2048, split="trainval", normalize=False)
print("The number of training data is:%d", len(train_data)) print("The number of training data is:%d", len(train_data))
test_data = PartNormalDataset(npoints=2048, split='test', normalize=False) test_data = PartNormalDataset(npoints=2048, split="test", normalize=False)
print("The number of test data is:%d", len(test_data)) print("The number of test data is:%d", len(test_data))
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, train_loader = DataLoader(
drop_last=True) train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
drop_last=True,
)
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, test_loader = DataLoader(
drop_last=False) test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.workers,
drop_last=False,
)
# ============= Optimizer ================ # ============= Optimizer ================
if args.use_sgd: if args.use_sgd:
print("Use SGD") print("Use SGD")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=0) opt = optim.SGD(model.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=0)
else: else:
print("Use Adam") print("Use Adam")
opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
if args.scheduler == 'cos': if args.scheduler == "cos":
print("Use CosLR") print("Use CosLR")
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100) scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100)
else: else:
@ -116,28 +144,33 @@ def train(args, io):
num_classes = 16 num_classes = 16
for epoch in range(args.epochs): for epoch in range(args.epochs):
train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io) train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io)
test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io) test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io)
# 1. when get the best accuracy, save the model: # 1. when get the best accuracy, save the model:
if test_metrics['accuracy'] > best_acc: if test_metrics["accuracy"] > best_acc:
best_acc = test_metrics['accuracy'] best_acc = test_metrics["accuracy"]
io.cprint('Max Acc:%.5f' % best_acc) io.cprint("Max Acc:%.5f" % best_acc)
state = { state = {
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), "model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_acc': best_acc} "optimizer": opt.state_dict(),
torch.save(state, 'checkpoints/%s/best_acc_model.pth' % args.exp_name) "epoch": epoch,
"test_acc": best_acc,
}
torch.save(state, "checkpoints/%s/best_acc_model.pth" % args.exp_name)
# 2. when get the best instance_iou, save the model: # 2. when get the best instance_iou, save the model:
if test_metrics['shape_avg_iou'] > best_instance_iou: if test_metrics["shape_avg_iou"] > best_instance_iou:
best_instance_iou = test_metrics['shape_avg_iou'] best_instance_iou = test_metrics["shape_avg_iou"]
io.cprint('Max instance iou:%.5f' % best_instance_iou) io.cprint("Max instance iou:%.5f" % best_instance_iou)
state = { state = {
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), "model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_instance_iou': best_instance_iou} "optimizer": opt.state_dict(),
torch.save(state, 'checkpoints/%s/best_insiou_model.pth' % args.exp_name) "epoch": epoch,
"test_instance_iou": best_instance_iou,
}
torch.save(state, "checkpoints/%s/best_insiou_model.pth" % args.exp_name)
# 3. when get the best class_iou, save the model: # 3. when get the best class_iou, save the model:
# first we need to calculate the average per-class iou # first we need to calculate the average per-class iou
@ -149,22 +182,28 @@ def train(args, io):
best_class_iou = avg_class_iou best_class_iou = avg_class_iou
# print the iou of each class: # print the iou of each class:
for cat_idx in range(16): for cat_idx in range(16):
io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx])) io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx]))
io.cprint('Max class iou:%.5f' % best_class_iou) io.cprint("Max class iou:%.5f" % best_class_iou)
state = { state = {
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), "model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_class_iou': best_class_iou} "optimizer": opt.state_dict(),
torch.save(state, 'checkpoints/%s/best_clsiou_model.pth' % args.exp_name) "epoch": epoch,
"test_class_iou": best_class_iou,
}
torch.save(state, "checkpoints/%s/best_clsiou_model.pth" % args.exp_name)
# report best acc, ins_iou, cls_iou # report best acc, ins_iou, cls_iou
io.cprint('Final Max Acc:%.5f' % best_acc) io.cprint("Final Max Acc:%.5f" % best_acc)
io.cprint('Final Max instance iou:%.5f' % best_instance_iou) io.cprint("Final Max instance iou:%.5f" % best_instance_iou)
io.cprint('Final Max class iou:%.5f' % best_class_iou) io.cprint("Final Max class iou:%.5f" % best_class_iou)
# save last model # save last model
state = { state = {
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), "model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
'optimizer': opt.state_dict(), 'epoch': args.epochs - 1, 'test_iou': best_instance_iou} "optimizer": opt.state_dict(),
torch.save(state, 'checkpoints/%s/model_ep%d.pth' % (args.exp_name, args.epochs)) "epoch": args.epochs - 1,
"test_iou": best_instance_iou,
}
torch.save(state, "checkpoints/%s/model_ep%d.pth" % (args.exp_name, args.epochs))
def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io): def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io):
@ -175,22 +214,41 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
metrics = defaultdict(lambda: list()) metrics = defaultdict(lambda: list())
model.train() model.train()
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9): for _batch_id, (points, label, target, norm_plt) in tqdm(
enumerate(train_loader),
total=len(train_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \ points, label, target, norm_plt = (
Variable(norm_plt.float()) Variable(points.float()),
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \ points, label, target, norm_plt = (
target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True) points.cuda(non_blocking=True),
label.squeeze(1).cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
# target: b,n # target: b,n
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50
loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0]) loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0])
# instance iou without considering the class average at each batch_size: # instance iou without considering the class average at each batch_size:
batch_shapeious = compute_overall_iou(seg_pred, target, num_part) # list of of current batch_iou:[iou1,iou2,...,iou#b_size] batch_shapeious = compute_overall_iou(
seg_pred,
target,
num_part,
) # list of of current batch_iou:[iou1,iou2,...,iou#b_size]
# total iou of current batch in each process: # total iou of current batch in each process:
batch_shapeious = seg_pred.new_tensor([np.sum(batch_shapeious)], dtype=torch.float64) # same device with seg_pred!!! batch_shapeious = seg_pred.new_tensor(
[np.sum(batch_shapeious)],
dtype=torch.float64,
) # same device with seg_pred!!!
# Loss backward # Loss backward
loss = torch.mean(loss) loss = torch.mean(loss)
@ -200,33 +258,37 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
# accuracy # accuracy
seg_pred = seg_pred.contiguous().view(-1, num_part) # b*n,50 seg_pred = seg_pred.contiguous().view(-1, num_part) # b*n,50
target = target.view(-1, 1)[:, 0] # b*n target = target.view(-1, 1)[:, 0] # b*n
pred_choice = seg_pred.contiguous().data.max(1)[1] # b*n pred_choice = seg_pred.contiguous().data.max(1)[1] # b*n
correct = pred_choice.eq(target.contiguous().data).sum() # torch.int64: total number of correct-predict pts correct = pred_choice.eq(target.contiguous().data).sum() # torch.int64: total number of correct-predict pts
# sum # sum
shape_ious += batch_shapeious.item() # count the sum of ious in each iteration shape_ious += batch_shapeious.item() # count the sum of ious in each iteration
count += batch_size # count the total number of samples in each iteration count += batch_size # count the total number of samples in each iteration
train_loss += loss.item() * batch_size train_loss += loss.item() * batch_size
accuracy.append(correct.item()/(batch_size * num_point)) # append the accuracy of each iteration accuracy.append(correct.item() / (batch_size * num_point)) # append the accuracy of each iteration
# Note: We do not need to calculate per_class iou during training # Note: We do not need to calculate per_class iou during training
if args.scheduler == 'cos': if args.scheduler == "cos":
scheduler.step() scheduler.step()
elif args.scheduler == 'step': elif args.scheduler == "step":
if opt.param_groups[0]['lr'] > 0.9e-5: if opt.param_groups[0]["lr"] > 0.9e-5:
scheduler.step() scheduler.step()
if opt.param_groups[0]['lr'] < 0.9e-5: if opt.param_groups[0]["lr"] < 0.9e-5:
for param_group in opt.param_groups: for param_group in opt.param_groups:
param_group['lr'] = 0.9e-5 param_group["lr"] = 0.9e-5
io.cprint('Learning rate: %f' % opt.param_groups[0]['lr']) io.cprint("Learning rate: %f" % opt.param_groups[0]["lr"])
metrics['accuracy'] = np.mean(accuracy) metrics["accuracy"] = np.mean(accuracy)
metrics['shape_avg_iou'] = shape_ious * 1.0 / count metrics["shape_avg_iou"] = shape_ious * 1.0 / count
outstr = 'Train %d, loss: %f, train acc: %f, train ins_iou: %f' % (epoch+1, train_loss * 1.0 / count, outstr = "Train %d, loss: %f, train acc: %f, train ins_iou: %f" % (
metrics['accuracy'], metrics['shape_avg_iou']) epoch + 1,
train_loss * 1.0 / count,
metrics["accuracy"],
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
@ -241,14 +303,26 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
model.eval() model.eval()
# label_size: b, means each sample has one corresponding class # label_size: b, means each sample has one corresponding class
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9): for _batch_id, (points, label, target, norm_plt) in tqdm(
enumerate(test_loader),
total=len(test_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \ points, label, target, norm_plt = (
Variable(norm_plt.float()) Variable(points.float()),
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \ points, label, target, norm_plt = (
target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True) points.cuda(non_blocking=True),
label.squeeze(1).cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
# instance iou without considering the class average at each batch_size: # instance iou without considering the class average at each batch_size:
@ -281,13 +355,19 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
for cat_idx in range(16): for cat_idx in range(16):
if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending
final_total_per_cat_iou[cat_idx] = final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx] # avg class iou across all samples final_total_per_cat_iou[cat_idx] = (
final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx]
) # avg class iou across all samples
metrics['accuracy'] = np.mean(accuracy) metrics["accuracy"] = np.mean(accuracy)
metrics['shape_avg_iou'] = shape_ious * 1.0 / count metrics["shape_avg_iou"] = shape_ious * 1.0 / count
outstr = 'Test %d, loss: %f, test acc: %f test ins_iou: %f' % (epoch + 1, test_loss * 1.0 / count, outstr = "Test %d, loss: %f, test acc: %f test ins_iou: %f" % (
metrics['accuracy'], metrics['shape_avg_iou']) epoch + 1,
test_loss * 1.0 / count,
metrics["accuracy"],
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
@ -296,11 +376,16 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
def test(args, io): def test(args, io):
# Dataloader # Dataloader
test_data = PartNormalDataset(npoints=2048, split='test', normalize=False) test_data = PartNormalDataset(npoints=2048, split="test", normalize=False)
print("The number of test data is:%d", len(test_data)) print("The number of test data is:%d", len(test_data))
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, test_loader = DataLoader(
drop_last=False) test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.workers,
drop_last=False,
)
# Try to load models # Try to load models
num_part = 50 num_part = 50
@ -310,12 +395,15 @@ def test(args, io):
io.cprint(str(model)) io.cprint(str(model))
from collections import OrderedDict from collections import OrderedDict
state_dict = torch.load("checkpoints/%s/best_%s_model.pth" % (args.exp_name, args.model_type),
map_location=torch.device('cpu'))['model'] state_dict = torch.load(
f"checkpoints/{args.exp_name}/best_{args.model_type}_model.pth",
map_location=torch.device("cpu"),
)["model"]
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for layer in state_dict: for layer in state_dict:
new_state_dict[layer.replace('module.', '')] = state_dict[layer] new_state_dict[layer.replace("module.", "")] = state_dict[layer]
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
model.eval() model.eval()
@ -324,16 +412,29 @@ def test(args, io):
metrics = defaultdict(lambda: list()) metrics = defaultdict(lambda: list())
hist_acc = [] hist_acc = []
shape_ious = [] shape_ious = []
total_per_cat_iou = np.zeros((16)).astype(np.float32) total_per_cat_iou = np.zeros(16).astype(np.float32)
total_per_cat_seen = np.zeros((16)).astype(np.int32) total_per_cat_seen = np.zeros(16).astype(np.int32)
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9): for _batch_id, (points, label, target, norm_plt) in tqdm(
enumerate(test_loader),
total=len(test_loader),
smoothing=0.9,
):
batch_size, num_point, _ = points.size() batch_size, num_point, _ = points.size()
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), Variable(norm_plt.float()) points, label, target, norm_plt = (
Variable(points.float()),
Variable(label.long()),
Variable(target.long()),
Variable(norm_plt.float()),
)
points = points.transpose(2, 1) points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1)
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze().cuda( points, label, target, norm_plt = (
non_blocking=True), target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True) points.cuda(non_blocking=True),
label.squeeze().cuda(non_blocking=True),
target.cuda(non_blocking=True),
norm_plt.cuda(non_blocking=True),
)
with torch.no_grad(): with torch.no_grad():
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50 seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
@ -353,11 +454,11 @@ def test(args, io):
target = target.view(-1, 1)[:, 0] target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1] pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum() correct = pred_choice.eq(target.data).cpu().sum()
metrics['accuracy'].append(correct.item() / (batch_size * num_point)) metrics["accuracy"].append(correct.item() / (batch_size * num_point))
hist_acc += metrics['accuracy'] hist_acc += metrics["accuracy"]
metrics['accuracy'] = np.mean(hist_acc) metrics["accuracy"] = np.mean(hist_acc)
metrics['shape_avg_iou'] = np.mean(shape_ious) metrics["shape_avg_iou"] = np.mean(shape_ious)
for cat_idx in range(16): for cat_idx in range(16):
if total_per_cat_seen[cat_idx] > 0: if total_per_cat_seen[cat_idx] > 0:
total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx] total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx]
@ -366,57 +467,51 @@ def test(args, io):
class_iou = 0 class_iou = 0
for cat_idx in range(16): for cat_idx in range(16):
class_iou += total_per_cat_iou[cat_idx] class_iou += total_per_cat_iou[cat_idx]
io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx])) # print the iou of each class io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx])) # print the iou of each class
avg_class_iou = class_iou / 16 avg_class_iou = class_iou / 16
outstr = 'Test :: test acc: %f test class mIOU: %f, test instance mIOU: %f' % (metrics['accuracy'], avg_class_iou, metrics['shape_avg_iou']) outstr = "Test :: test acc: {:f} test class mIOU: {:f}, test instance mIOU: {:f}".format(
metrics["accuracy"],
avg_class_iou,
metrics["shape_avg_iou"],
)
io.cprint(outstr) io.cprint(outstr)
if __name__ == "__main__": if __name__ == "__main__":
# Training settings # Training settings
parser = argparse.ArgumentParser(description='3D Shape Part Segmentation') parser = argparse.ArgumentParser(description="3D Shape Part Segmentation")
parser.add_argument('--model', type=str, default='PointMLP1') parser.add_argument("--model", type=str, default="PointMLP1")
parser.add_argument('--exp_name', type=str, default='demo1', metavar='N', parser.add_argument("--exp_name", type=str, default="demo1", metavar="N", help="Name of the experiment")
help='Name of the experiment') parser.add_argument("--batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)")
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', parser.add_argument("--test_batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)")
help='Size of batch)') parser.add_argument("--epochs", type=int, default=350, metavar="N", help="number of episode to train")
parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size', parser.add_argument("--use_sgd", type=bool, default=False, help="Use SGD")
help='Size of batch)') parser.add_argument("--scheduler", type=str, default="step", help="lr scheduler")
parser.add_argument('--epochs', type=int, default=350, metavar='N', parser.add_argument("--step", type=int, default=40, help="lr decay step")
help='number of episode to train') parser.add_argument("--lr", type=float, default=0.003, metavar="LR", help="learning rate")
parser.add_argument('--use_sgd', type=bool, default=False, parser.add_argument("--momentum", type=float, default=0.9, metavar="M", help="SGD momentum (default: 0.9)")
help='Use SGD') parser.add_argument("--no_cuda", type=bool, default=False, help="enables CUDA training")
parser.add_argument('--scheduler', type=str, default='step', parser.add_argument("--manual_seed", type=int, metavar="S", help="random seed (default: 1)")
help='lr scheduler') parser.add_argument("--eval", type=bool, default=False, help="evaluate the model")
parser.add_argument('--step', type=int, default=40, parser.add_argument("--num_points", type=int, default=2048, help="num of points to use")
help='lr decay step') parser.add_argument("--workers", type=int, default=12)
parser.add_argument('--lr', type=float, default=0.003, metavar='LR', parser.add_argument("--resume", type=bool, default=False, help="Resume training or not")
help='learning rate') parser.add_argument(
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', "--model_type",
help='SGD momentum (default: 0.9)') type=str,
parser.add_argument('--no_cuda', type=bool, default=False, default="insiou",
help='enables CUDA training') help="choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)",
parser.add_argument('--manual_seed', type=int, metavar='S', )
help='random seed (default: 1)')
parser.add_argument('--eval', type=bool, default=False,
help='evaluate the model')
parser.add_argument('--num_points', type=int, default=2048,
help='num of points to use')
parser.add_argument('--workers', type=int, default=12)
parser.add_argument('--resume', type=bool, default=False,
help='Resume training or not')
parser.add_argument('--model_type', type=str, default='insiou',
help='choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)')
args = parser.parse_args() args = parser.parse_args()
args.exp_name = args.model+"_"+args.exp_name args.exp_name = args.model + "_" + args.exp_name
_init_() _init_()
if not args.eval: if not args.eval:
io = IOStream('checkpoints/' + args.exp_name + '/%s_train.log' % (args.exp_name)) io = IOStream("checkpoints/" + args.exp_name + "/%s_train.log" % (args.exp_name))
else: else:
io = IOStream('checkpoints/' + args.exp_name + '/%s_test.log' % (args.exp_name)) io = IOStream("checkpoints/" + args.exp_name + "/%s_test.log" % (args.exp_name))
io.cprint(str(args)) io.cprint(str(args))
if args.manual_seed is not None: if args.manual_seed is not None:
@ -427,12 +522,12 @@ if __name__ == "__main__":
args.cuda = not args.no_cuda and torch.cuda.is_available() args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda: if args.cuda:
io.cprint('Using GPU') io.cprint("Using GPU")
if args.manual_seed is not None: if args.manual_seed is not None:
torch.cuda.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed) torch.cuda.manual_seed_all(args.manual_seed)
else: else:
io.cprint('Using CPU') io.cprint("Using CPU")
if not args.eval: if not args.eval:
train(args, io) train(args, io)

View file

@ -1,2 +1 @@
from __future__ import absolute_import
from .pointMLP import pointMLP from .pointMLP import pointMLP

View file

@ -1,34 +1,32 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import einsum
from einops import rearrange, repeat
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def get_activation(activation): def get_activation(activation):
if activation.lower() == 'gelu': if activation.lower() == "gelu":
return nn.GELU() return nn.GELU()
elif activation.lower() == 'rrelu': elif activation.lower() == "rrelu":
return nn.RReLU(inplace=True) return nn.RReLU(inplace=True)
elif activation.lower() == 'selu': elif activation.lower() == "selu":
return nn.SELU(inplace=True) return nn.SELU(inplace=True)
elif activation.lower() == 'silu': elif activation.lower() == "silu":
return nn.SiLU(inplace=True) return nn.SiLU(inplace=True)
elif activation.lower() == 'hardswish': elif activation.lower() == "hardswish":
return nn.Hardswish(inplace=True) return nn.Hardswish(inplace=True)
elif activation.lower() == 'leakyrelu': elif activation.lower() == "leakyrelu":
return nn.LeakyReLU(inplace=True) return nn.LeakyReLU(inplace=True)
elif activation.lower() == 'leakyrelu0.2': elif activation.lower() == "leakyrelu0.2":
return nn.LeakyReLU(negative_slope=0.2, inplace=True) return nn.LeakyReLU(negative_slope=0.2, inplace=True)
else: else:
return nn.ReLU(inplace=True) return nn.ReLU(inplace=True)
def square_distance(src, dst): def square_distance(src, dst):
""" """Calculate Euclid distance between each two points.
Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm;
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
@ -37,23 +35,23 @@ def square_distance(src, dst):
src: source points, [B, N, C] src: source points, [B, N, C]
dst: target points, [B, M, C] dst: target points, [B, M, C]
Output: Output:
dist: per-point square distance, [B, N, M] dist: per-point square distance, [B, N, M].
""" """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
""" """Input:
Input:
points: input points data, [B, N, C] points: input points data, [B, N, C]
idx: sample index data, [B, S] idx: sample index data, [B, S].
Return: Return:
new_points:, indexed points data, [B, S, C] new_points:, indexed points data, [B, S, C].
""" """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
@ -62,17 +60,15 @@ def index_points(points, idx):
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :] return points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint): def farthest_point_sample(xyz, npoint):
""" """Input:
Input:
xyz: pointcloud data, [B, N, 3] xyz: pointcloud data, [B, N, 3]
npoint: number of samples npoint: number of samples
Return: Return:
centroids: sampled pointcloud index, [B, npoint] centroids: sampled pointcloud index, [B, npoint].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
@ -90,21 +86,21 @@ def farthest_point_sample(xyz, npoint):
def query_ball_point(radius, nsample, xyz, new_xyz): def query_ball_point(radius, nsample, xyz, new_xyz):
""" """Input:
Input:
radius: local region radius radius: local region radius
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, 3] xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] new_xyz: query points, [B, S, 3].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
device = xyz.device device = xyz.device
B, N, C = xyz.shape B, N, C = xyz.shape
_, S, _ = new_xyz.shape _, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N group_idx[sqrdists > radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N mask = group_idx == N
@ -113,13 +109,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
def knn_point(nsample, xyz, new_xyz): def knn_point(nsample, xyz, new_xyz):
""" """Input:
Input:
nsample: max sample number in local region nsample: max sample number in local region
xyz: all points, [B, N, C] xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C] new_xyz: query points, [B, S, C].
Return: Return:
group_idx: grouped points index, [B, S, nsample] group_idx: grouped points index, [B, S, nsample].
""" """
sqrdists = square_distance(new_xyz, xyz) sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
@ -128,13 +124,12 @@ def knn_point(nsample, xyz, new_xyz):
class LocalGrouper(nn.Module): class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs):
""" """Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number :param groups: groups number
:param kneighbors: k-nerighbors :param kneighbors: k-nerighbors
:param kwargs: others :param kwargs: others.
""" """
super(LocalGrouper, self).__init__() super().__init__()
self.groups = groups self.groups = groups
self.kneighbors = kneighbors self.kneighbors = kneighbors
self.use_xyz = use_xyz self.use_xyz = use_xyz
@ -143,11 +138,11 @@ class LocalGrouper(nn.Module):
else: else:
self.normalize = None self.normalize = None
if self.normalize not in ["center", "anchor"]: if self.normalize not in ["center", "anchor"]:
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None self.normalize = None
if self.normalize is not None: if self.normalize is not None:
add_channel=3 if self.use_xyz else 0 add_channel = 3 if self.use_xyz else 0
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
def forward(self, xyz, points): def forward(self, xyz, points):
@ -166,29 +161,33 @@ class LocalGrouper(nn.Module):
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
grouped_points = index_points(points, idx) # [B, npoint, k, d] grouped_points = index_points(points, idx) # [B, npoint, k, d]
if self.use_xyz: if self.use_xyz:
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None: if self.normalize is not None:
if self.normalize =="center": if self.normalize == "center":
mean = torch.mean(grouped_points, dim=2, keepdim=True) mean = torch.mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor": if self.normalize == "anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) std = (
grouped_points = (grouped_points-mean)/(std + 1e-5) torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta .unsqueeze(dim=-1)
.unsqueeze(dim=-1)
)
grouped_points = (grouped_points - mean) / (std + 1e-5)
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
return new_xyz, new_points return new_xyz, new_points
class ConvBNReLU1D(nn.Module): class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
super(ConvBNReLU1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
self.act self.act,
) )
def forward(self, x): def forward(self, x):
@ -196,30 +195,43 @@ class ConvBNReLU1D(nn.Module):
class ConvBNReLURes1D(nn.Module): class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
super(ConvBNReLURes1D, self).__init__() super().__init__()
self.act = get_activation(activation) self.act = get_activation(activation)
self.net1 = nn.Sequential( self.net1 = nn.Sequential(
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=channel,
out_channels=int(channel * res_expansion),
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(int(channel * res_expansion)), nn.BatchNorm1d(int(channel * res_expansion)),
self.act self.act,
) )
if groups > 1: if groups > 1:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, groups=groups, bias=bias), in_channels=int(channel * res_expansion),
out_channels=channel,
kernel_size=kernel_size,
groups=groups,
bias=bias,
),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
self.act, self.act,
nn.Conv1d(in_channels=channel, out_channels=channel, nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(channel), nn.BatchNorm1d(channel),
) )
else: else:
self.net2 = nn.Sequential( self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, nn.Conv1d(
kernel_size=kernel_size, bias=bias), in_channels=int(channel * res_expansion),
nn.BatchNorm1d(channel) out_channels=channel,
kernel_size=kernel_size,
bias=bias,
),
nn.BatchNorm1d(channel),
) )
def forward(self, x): def forward(self, x):
@ -227,21 +239,34 @@ class ConvBNReLURes1D(nn.Module):
class PreExtraction(nn.Module): class PreExtraction(nn.Module):
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, def __init__(
activation='relu', use_xyz=True): self,
""" channels,
input: [b,g,k,d]: output:[b,d,g] out_channels,
blocks=1,
groups=1,
res_expansion=1,
bias=True,
activation="relu",
use_xyz=True,
):
"""input: [b,g,k,d]: output:[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PreExtraction, self).__init__() super().__init__()
in_channels = 3+2*channels if use_xyz else 2*channels in_channels = 3 + 2 * channels if use_xyz else 2 * channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, ConvBNReLURes1D(
bias=bias, activation=activation) out_channels,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -253,22 +278,20 @@ class PreExtraction(nn.Module):
batch_size, _, _ = x.size() batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k] x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
x = x.reshape(b, n, -1).permute(0, 2, 1) return x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module): class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
""" """input[b,d,g]; output[b,d,g]
input[b,d,g]; output[b,d,g]
:param channels: :param channels:
:param blocks: :param blocks:
""" """
super(PosExtraction, self).__init__() super().__init__()
operation = [] operation = []
for _ in range(blocks): for _ in range(blocks):
operation.append( operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
) )
self.operation = nn.Sequential(*operation) self.operation = nn.Sequential(*operation)
@ -277,22 +300,27 @@ class PosExtraction(nn.Module):
class PointNetFeaturePropagation(nn.Module): class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
super(PointNetFeaturePropagation, self).__init__() super().__init__()
self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias) self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias)
self.extraction = PosExtraction(out_channel, blocks, groups=groups, self.extraction = PosExtraction(
res_expansion=res_expansion, bias=bias, activation=activation) out_channel,
blocks,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
def forward(self, xyz1, xyz2, points1, points2): def forward(self, xyz1, xyz2, points1, points2):
""" """Input:
Input:
xyz1: input points position data, [B, N, 3] xyz1: input points position data, [B, N, 3]
xyz2: sampled input points position data, [B, S, 3] xyz2: sampled input points position data, [B, S, 3]
points1: input points data, [B, D', N] points1: input points data, [B, D', N]
points2: input points data, [B, D'', S] points2: input points data, [B, D'', S].
Return: Return:
new_points: upsampled points data, [B, D''', N] new_points: upsampled points data, [B, D''', N].
""" """
# xyz1 = xyz1.permute(0, 2, 1) # xyz1 = xyz1.permute(0, 2, 1)
# xyz2 = xyz2.permute(0, 2, 1) # xyz2 = xyz2.permute(0, 2, 1)
@ -321,26 +349,40 @@ class PointNetFeaturePropagation(nn.Module):
new_points = new_points.permute(0, 2, 1) new_points = new_points.permute(0, 2, 1)
new_points = self.fuse(new_points) new_points = self.fuse(new_points)
new_points = self.extraction(new_points) return self.extraction(new_points)
return new_points
class PointMLP(nn.Module): class PointMLP(nn.Module):
def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expansion=1.0, def __init__(
activation="relu", bias=True, use_xyz=True, normalize="anchor", self,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], num_classes=50,
k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], points=2048,
de_dims=[512, 256, 128, 128], de_blocks=[2,2,2,2], embed_dim=64,
gmp_dim=64,cls_dim=64, **kwargs): groups=1,
super(PointMLP, self).__init__() res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[4, 4, 4, 4],
de_dims=[512, 256, 128, 128],
de_blocks=[2, 2, 2, 2],
gmp_dim=64,
cls_dim=64,
**kwargs,
):
super().__init__()
self.stages = len(pre_blocks) self.stages = len(pre_blocks)
self.class_num = num_classes self.class_num = num_classes
self.points = points self.points = points
self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation) self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation)
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ assert (
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList() self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList()
@ -359,29 +401,47 @@ class PointMLP(nn.Module):
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper) self.local_grouper_list.append(local_grouper)
# append pre_block_list # append pre_block_list
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, pre_block_module = PreExtraction(
res_expansion=res_expansion, last_channel,
bias=bias, activation=activation, use_xyz=use_xyz) out_channel,
pre_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
use_xyz=use_xyz,
)
self.pre_blocks_list.append(pre_block_module) self.pre_blocks_list.append(pre_block_module)
# append pos_block_list # append pos_block_list
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, pos_block_module = PosExtraction(
res_expansion=res_expansion, bias=bias, activation=activation) out_channel,
pos_block_num,
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
)
self.pos_blocks_list.append(pos_block_module) self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel last_channel = out_channel
en_dims.append(last_channel) en_dims.append(last_channel)
### Building Decoder ##### ### Building Decoder #####
self.decode_list = nn.ModuleList() self.decode_list = nn.ModuleList()
en_dims.reverse() en_dims.reverse()
de_dims.insert(0,en_dims[0]) de_dims.insert(0, en_dims[0])
assert len(en_dims) ==len(de_dims) == len(de_blocks)+1 assert len(en_dims) == len(de_dims) == len(de_blocks) + 1
for i in range(len(en_dims)-1): for i in range(len(en_dims) - 1):
self.decode_list.append( self.decode_list.append(
PointNetFeaturePropagation(de_dims[i]+en_dims[i+1], de_dims[i+1], PointNetFeaturePropagation(
blocks=de_blocks[i], groups=groups, res_expansion=res_expansion, de_dims[i] + en_dims[i + 1],
bias=bias, activation=activation) de_dims[i + 1],
blocks=de_blocks[i],
groups=groups,
res_expansion=res_expansion,
bias=bias,
activation=activation,
),
) )
self.act = get_activation(activation) self.act = get_activation(activation)
@ -389,26 +449,26 @@ class PointMLP(nn.Module):
# class label mapping # class label mapping
self.cls_map = nn.Sequential( self.cls_map = nn.Sequential(
ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation), ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation),
ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation) ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation),
) )
# global max pooling mapping # global max pooling mapping
self.gmp_map_list = nn.ModuleList() self.gmp_map_list = nn.ModuleList()
for en_dim in en_dims: for en_dim in en_dims:
self.gmp_map_list.append(ConvBNReLU1D(en_dim, gmp_dim, bias=bias, activation=activation)) self.gmp_map_list.append(ConvBNReLU1D(en_dim, gmp_dim, bias=bias, activation=activation))
self.gmp_map_end = ConvBNReLU1D(gmp_dim*len(en_dims), gmp_dim, bias=bias, activation=activation) self.gmp_map_end = ConvBNReLU1D(gmp_dim * len(en_dims), gmp_dim, bias=bias, activation=activation)
# classifier # classifier
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias), nn.Conv1d(gmp_dim + cls_dim + de_dims[-1], 128, 1, bias=bias),
nn.BatchNorm1d(128), nn.BatchNorm1d(128),
nn.Dropout(), nn.Dropout(),
nn.Conv1d(128, num_classes, 1, bias=bias) nn.Conv1d(128, num_classes, 1, bias=bias),
) )
self.en_dims = en_dims self.en_dims = en_dims
def forward(self, x, norm_plt, cls_label): def forward(self, x, norm_plt, cls_label):
xyz = x.permute(0, 2, 1) xyz = x.permute(0, 2, 1)
x = torch.cat([x,norm_plt],dim=1) x = torch.cat([x, norm_plt], dim=1)
x = self.embedding(x) # B,D,N x = self.embedding(x) # B,D,N
xyz_list = [xyz] # [B, N, 3] xyz_list = [xyz] # [B, N, 3]
@ -428,37 +488,55 @@ class PointMLP(nn.Module):
x_list.reverse() x_list.reverse()
x = x_list[0] x = x_list[0]
for i in range(len(self.decode_list)): for i in range(len(self.decode_list)):
x = self.decode_list[i](xyz_list[i+1], xyz_list[i], x_list[i+1],x) x = self.decode_list[i](xyz_list[i + 1], xyz_list[i], x_list[i + 1], x)
# here is the global context # here is the global context
gmp_list = [] gmp_list = []
for i in range(len(x_list)): for i in range(len(x_list)):
gmp_list.append(F.adaptive_max_pool1d(self.gmp_map_list[i](x_list[i]), 1)) gmp_list.append(F.adaptive_max_pool1d(self.gmp_map_list[i](x_list[i]), 1))
global_context = self.gmp_map_end(torch.cat(gmp_list, dim=1)) # [b, gmp_dim, 1] global_context = self.gmp_map_end(torch.cat(gmp_list, dim=1)) # [b, gmp_dim, 1]
#here is the cls_token # here is the cls_token
cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1] cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1]
x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1) x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1)
x = self.classifier(x) x = self.classifier(x)
x = F.log_softmax(x, dim=1) x = F.log_softmax(x, dim=1)
x = x.permute(0, 2, 1) return x.permute(0, 2, 1)
return x
def pointMLP(num_classes=50, **kwargs) -> PointMLP: def pointMLP(num_classes=50, **kwargs) -> PointMLP:
return PointMLP(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, return PointMLP(
activation="relu", bias=True, use_xyz=True, normalize="anchor", num_classes=num_classes,
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], points=2048,
k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], embed_dim=64,
de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], groups=1,
gmp_dim=64,cls_dim=64, **kwargs) res_expansion=1.0,
activation="relu",
bias=True,
use_xyz=True,
normalize="anchor",
dim_expansion=[2, 2, 2, 2],
pre_blocks=[2, 2, 2, 2],
pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32],
reducers=[4, 4, 4, 4],
de_dims=[512, 256, 128, 128],
de_blocks=[4, 4, 4, 4],
gmp_dim=64,
cls_dim=64,
**kwargs,
)
if __name__ == '__main__': if __name__ == "__main__":
data = torch.rand(2, 3, 2048) data = torch.rand(2, 3, 2048).cuda()
norm = torch.rand(2, 3, 2048) norm = torch.rand(2, 3, 2048).cuda()
cls_label = torch.rand([2, 16]) cls_label = torch.rand([2, 16]).cuda()
print("===> testing modelD ...") print(f"data shape: {data.shape}")
model = pointMLP(50) print(f"norm shape: {norm.shape}")
out = model(data, cls_label) # [2,2048,50] print(f"cls_label shape: {cls_label.shape}")
print(out.shape)
print("===> testing pointMLP (segmentation) ...")
model = pointMLP(50).cuda()
out = model(data, norm, cls_label) # [2,2048,50]
print(f"out shape: {out.shape}")

View file

@ -1,19 +1,21 @@
import glob import glob
import json
import os
import h5py import h5py
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
import os
import json
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
def load_data(partition): def load_data(partition):
all_data = [] all_data = []
all_label = [] all_label = []
for h5_name in glob.glob('./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5' % partition): for h5_name in glob.glob("./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5" % partition):
f = h5py.File(h5_name) f = h5py.File(h5_name)
data = f['data'][:].astype('float32') data = f["data"][:].astype("float32")
label = f['label'][:].astype('int64') label = f["label"][:].astype("int64")
f.close() f.close()
all_data.append(data) all_data.append(data)
all_label.append(label) all_label.append(label)
@ -25,36 +27,34 @@ def load_data(partition):
def pc_normalize(pc): def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m return pc / m
return pc
def translate_pointcloud(pointcloud): def translate_pointcloud(pointcloud):
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
return translated_pointcloud
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
N, C = pointcloud.shape N, C = pointcloud.shape
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
return pointcloud return pointcloud
# =========== ModelNet40 ================= # =========== ModelNet40 =================
class ModelNet40(Dataset): class ModelNet40(Dataset):
def __init__(self, num_points, partition='train'): def __init__(self, num_points, partition="train"):
self.data, self.label = load_data(partition) self.data, self.label = load_data(partition)
self.num_points = num_points self.num_points = num_points
self.partition = partition # Here the new given partition will cover the 'train' self.partition = partition # Here the new given partition will cover the 'train'
def __getitem__(self, item): # indice of the pts or label def __getitem__(self, item): # indice of the pts or label
pointcloud = self.data[item][:self.num_points] pointcloud = self.data[item][: self.num_points]
label = self.label[item] label = self.label[item]
if self.partition == 'train': if self.partition == "train":
# pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model # pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model
pointcloud = translate_pointcloud(pointcloud) pointcloud = translate_pointcloud(pointcloud)
np.random.shuffle(pointcloud) # shuffle the order of pts np.random.shuffle(pointcloud) # shuffle the order of pts
@ -66,46 +66,46 @@ class ModelNet40(Dataset):
# =========== ShapeNet Part ================= # =========== ShapeNet Part =================
class PartNormalDataset(Dataset): class PartNormalDataset(Dataset):
def __init__(self, npoints=2500, split='train', normalize=False): def __init__(self, npoints=2500, split="train", normalize=False):
self.npoints = npoints self.npoints = npoints
self.root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal' self.root = "./data/shapenetcore_partanno_segmentation_benchmark_v0_normal"
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') self.catfile = os.path.join(self.root, "synsetoffset2category.txt")
self.cat = {} self.cat = {}
self.normalize = normalize self.normalize = normalize
with open(self.catfile, 'r') as f: with open(self.catfile) as f:
for line in f: for line in f:
ls = line.strip().split() ls = line.strip().split()
self.cat[ls[0]] = ls[1] self.cat[ls[0]] = ls[1]
self.cat = {k: v for k, v in self.cat.items()} self.cat = {k: v for k, v in self.cat.items()}
self.meta = {} self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: with open(os.path.join(self.root, "train_test_split", "shuffled_train_file_list.json")) as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) train_ids = set([str(d.split("/")[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: with open(os.path.join(self.root, "train_test_split", "shuffled_val_file_list.json")) as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) val_ids = set([str(d.split("/")[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: with open(os.path.join(self.root, "train_test_split", "shuffled_test_file_list.json")) as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) test_ids = set([str(d.split("/")[2]) for d in json.load(f)])
for item in self.cat: for item in self.cat:
self.meta[item] = [] self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item]) dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point)) fns = sorted(os.listdir(dir_point))
if split == 'trainval': if split == "trainval":
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == 'train': elif split == "train":
fns = [fn for fn in fns if fn[0:-4] in train_ids] fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val': elif split == "val":
fns = [fn for fn in fns if fn[0:-4] in val_ids] fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test': elif split == "test":
fns = [fn for fn in fns if fn[0:-4] in test_ids] fns = [fn for fn in fns if fn[0:-4] in test_ids]
else: else:
print('Unknown split: %s. Exiting..' % (split)) print("Unknown split: %s. Exiting.." % (split))
exit(-1) exit(-1)
for fn in fns: for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0]) token = os.path.splitext(os.path.basename(fn))[0]
self.meta[item].append(os.path.join(dir_point, token + '.txt')) self.meta[item].append(os.path.join(dir_point, token + ".txt"))
self.datapath = [] self.datapath = []
for item in self.cat: for item in self.cat:
@ -114,11 +114,24 @@ class PartNormalDataset(Dataset):
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], self.seg_classes = {
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], "Earphone": [16, 17, 18],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], "Motorbike": [30, 31, 32, 33, 34, 35],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], "Rocket": [41, 42, 43],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} "Car": [8, 9, 10, 11],
"Laptop": [28, 29],
"Cap": [6, 7],
"Skateboard": [44, 45, 46],
"Mug": [36, 37],
"Guitar": [19, 20, 21],
"Bag": [4, 5],
"Lamp": [24, 25, 26, 27],
"Table": [47, 48, 49],
"Airplane": [0, 1, 2, 3],
"Pistol": [38, 39, 40],
"Chair": [12, 13, 14, 15],
"Knife": [22, 23],
}
self.cache = {} # from index to (point_set, cls, seg) tuple self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 20000 self.cache_size = 20000
@ -156,9 +169,9 @@ class PartNormalDataset(Dataset):
return len(self.datapath) return len(self.datapath)
if __name__ == '__main__': if __name__ == "__main__":
train = PartNormalDataset(npoints=2048, split='trainval', normalize=False) train = PartNormalDataset(npoints=2048, split="trainval", normalize=False)
test = PartNormalDataset(npoints=2048, split='test', normalize=False) test = PartNormalDataset(npoints=2048, split="test", normalize=False)
for data, label, _, _ in train: for data, label, _, _ in train:
print(data.shape) print(data.shape)
print(label.shape) print(label.shape)

View file

@ -4,9 +4,8 @@ import torch.nn.functional as F
def cal_loss(pred, gold, smoothing=True): def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. ''' """Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
if smoothing: if smoothing:
eps = 0.2 eps = 0.2
@ -18,19 +17,19 @@ def cal_loss(pred, gold, smoothing=True):
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gold, reduction='mean') loss = F.cross_entropy(pred, gold, reduction="mean")
return loss return loss
# create a file and write the text into it: # create a file and write the text into it:
class IOStream(): class IOStream:
def __init__(self, path): def __init__(self, path):
self.f = open(path, 'a') self.f = open(path, "a")
def cprint(self, text): def cprint(self, text):
print(text) print(text)
self.f.write(text+'\n') self.f.write(text + "\n")
self.f.flush() self.f.flush()
def close(self): def close(self):
@ -38,22 +37,24 @@ class IOStream():
def to_categorical(y, num_classes): def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """ """1-hot encodes a tensor."""
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda): if y.is_cuda:
return new_y.cuda(non_blocking=True) return new_y.cuda(non_blocking=True)
return new_y return new_y
def compute_overall_iou(pred, target, num_classes): def compute_overall_iou(pred, target, num_classes):
shape_ious = [] shape_ious = []
pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample
pred_np = pred.cpu().data.numpy() pred_np = pred.cpu().data.numpy()
target_np = target.cpu().data.numpy() target_np = target.cpu().data.numpy()
for shape_idx in range(pred.size(0)): # sample_idx for shape_idx in range(pred.size(0)): # sample_idx
part_ious = [] part_ious = []
for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes for part in range(
num_classes,
): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
# for target, each point has a class no matter which category owns this point! also 50 classes!!! # for target, each point has a class no matter which category owns this point! also 50 classes!!!
# only return 1 when both belongs to this class, which means correct: # only return 1 when both belongs to this class, which means correct:
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
@ -63,7 +64,9 @@ def compute_overall_iou(pred, target, num_classes):
F = np.sum(target_np[shape_idx] == part) F = np.sum(target_np[shape_idx] == part)
if F != 0: if F != 0:
iou = I / float(U) # iou across all points for this class iou = I / float(U) # iou across all points for this class
part_ious.append(iou) # append the iou of this class part_ious.append(iou) # append the iou of this class
shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!) shape_ious.append(
return shape_ious # [batch_size] np.mean(part_ious),
) # each time append an average iou across all classes of this sample (sample_level!)
return shape_ious # [batch_size]

View file

@ -1,16 +1,15 @@
from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pointnet2_ops import pointnet2_utils from pointnet2_ops import pointnet2_utils
def build_shared_mlp(mlp_spec: List[int], bn: bool = True): def build_shared_mlp(mlp_spec: list[int], bn: bool = True):
layers = [] layers = []
for i in range(1, len(mlp_spec)): for i in range(1, len(mlp_spec)):
layers.append( layers.append(
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn),
) )
if bn: if bn:
layers.append(nn.BatchNorm2d(mlp_spec[i])) layers.append(nn.BatchNorm2d(mlp_spec[i]))
@ -21,36 +20,37 @@ def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
class _PointnetSAModuleBase(nn.Module): class _PointnetSAModuleBase(nn.Module):
def __init__(self): def __init__(self):
super(_PointnetSAModuleBase, self).__init__() super().__init__()
self.npoint = None self.npoint = None
self.groupers = None self.groupers = None
self.mlps = None self.mlps = None
def forward( def forward(
self, xyz: torch.Tensor, features: Optional[torch.Tensor] self,
) -> Tuple[torch.Tensor, torch.Tensor]: xyz: torch.Tensor,
r""" features: torch.Tensor | None,
Parameters ) -> tuple[torch.Tensor, torch.Tensor]:
r"""Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the features (B, N, 3) tensor of the xyz coordinates of the features
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor of the descriptors of the the features (B, C, N) tensor of the descriptors of the the features
Returns Returns:
------- -------
new_xyz : torch.Tensor new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new features' xyz (B, npoint, 3) tensor of the new features' xyz
new_features : torch.Tensor new_features : torch.Tensor
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
""" """
new_features_list = [] new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous() xyz_flipped = xyz.transpose(1, 2).contiguous()
new_xyz = ( new_xyz = (
pointnet2_utils.gather_operation( pointnet2_utils.gather_operation(
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint),
) )
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
@ -60,12 +60,15 @@ class _PointnetSAModuleBase(nn.Module):
for i in range(len(self.groupers)): for i in range(len(self.groupers)):
new_features = self.groupers[i]( new_features = self.groupers[i](
xyz, new_xyz, features xyz,
new_xyz,
features,
) # (B, C, npoint, nsample) ) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
new_features = F.max_pool2d( new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)] new_features,
kernel_size=[1, new_features.size(3)],
) # (B, mlp[-1], npoint, 1) ) # (B, mlp[-1], npoint, 1)
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
@ -75,7 +78,7 @@ class _PointnetSAModuleBase(nn.Module):
class PointnetSAModuleMSG(_PointnetSAModuleBase): class PointnetSAModuleMSG(_PointnetSAModuleBase):
r"""Pointnet set abstrction layer with multiscale grouping r"""Pointnet set abstrction layer with multiscale grouping.
Parameters Parameters
---------- ----------
@ -93,7 +96,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
super(PointnetSAModuleMSG, self).__init__() super().__init__()
assert len(radii) == len(nsamples) == len(mlps) assert len(radii) == len(nsamples) == len(mlps)
@ -106,7 +109,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
self.groupers.append( self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None if npoint is not None
else pointnet2_utils.GroupAll(use_xyz) else pointnet2_utils.GroupAll(use_xyz),
) )
mlp_spec = mlps[i] mlp_spec = mlps[i]
if use_xyz: if use_xyz:
@ -116,7 +119,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
class PointnetSAModule(PointnetSAModuleMSG): class PointnetSAModule(PointnetSAModuleMSG):
r"""Pointnet set abstrction layer r"""Pointnet set abstrction layer.
Parameters Parameters
---------- ----------
@ -133,10 +136,16 @@ class PointnetSAModule(PointnetSAModuleMSG):
""" """
def __init__( def __init__(
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True self,
mlp,
npoint=None,
radius=None,
nsample=None,
bn=True,
use_xyz=True,
): ):
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
super(PointnetSAModule, self).__init__( super().__init__(
mlps=[mlp], mlps=[mlp],
npoint=npoint, npoint=npoint,
radii=[radius], radii=[radius],
@ -147,7 +156,7 @@ class PointnetSAModule(PointnetSAModuleMSG):
class PointnetFPModule(nn.Module): class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another r"""Propigates the features of one set to another.
Parameters Parameters
---------- ----------
@ -159,13 +168,12 @@ class PointnetFPModule(nn.Module):
def __init__(self, mlp, bn=True): def __init__(self, mlp, bn=True):
# type: (PointnetFPModule, List[int], bool) -> None # type: (PointnetFPModule, List[int], bool) -> None
super(PointnetFPModule, self).__init__() super().__init__()
self.mlp = build_shared_mlp(mlp, bn=bn) self.mlp = build_shared_mlp(mlp, bn=bn)
def forward(self, unknown, known, unknow_feats, known_feats): def forward(self, unknown, known, unknow_feats, known_feats):
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
r""" r"""Parameters
Parameters
---------- ----------
unknown : torch.Tensor unknown : torch.Tensor
(B, n, 3) tensor of the xyz positions of the unknown features (B, n, 3) tensor of the xyz positions of the unknown features
@ -176,12 +184,11 @@ class PointnetFPModule(nn.Module):
known_feats : torch.Tensor known_feats : torch.Tensor
(B, C2, m) tensor of features to be propigated (B, C2, m) tensor of features to be propigated
Returns Returns:
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, mlp[-1], n) tensor of the features of the unknown features (B, mlp[-1], n) tensor of the features of the unknown features
""" """
if known is not None: if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known) dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8) dist_recip = 1.0 / (dist + 1e-8)
@ -189,16 +196,19 @@ class PointnetFPModule(nn.Module):
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate( interpolated_feats = pointnet2_utils.three_interpolate(
known_feats, idx, weight known_feats,
idx,
weight,
) )
else: else:
interpolated_feats = known_feats.expand( interpolated_feats = known_feats.expand(
*(known_feats.size()[0:2] + [unknown.size(1)]) *(known_feats.size()[0:2] + [unknown.size(1)]),
) )
if unknow_feats is not None: if unknow_feats is not None:
new_features = torch.cat( new_features = torch.cat(
[interpolated_feats, unknow_feats], dim=1 [interpolated_feats, unknow_feats],
dim=1,
) # (B, C2 + C1, n) ) # (B, C2 + C1, n)
else: else:
new_features = interpolated_feats new_features = interpolated_feats

View file

@ -1,22 +1,24 @@
import warnings
from typing import *
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings
from torch.autograd import Function from torch.autograd import Function
from typing import *
try: try:
import pointnet2_ops._ext as _ext import pointnet2_ops._ext as _ext
except ImportError: except ImportError:
from torch.utils.cpp_extension import load
import glob import glob
import os.path as osp
import os import os
import os.path as osp
from torch.utils.cpp_extension import load
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
osp.join(_ext_src_root, "src", "*.cu") osp.join(_ext_src_root, "src", "*.cu"),
) )
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
@ -35,9 +37,8 @@ class FurthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz, npoint): def forward(ctx, xyz, npoint):
# type: (Any, torch.Tensor, int) -> torch.Tensor # type: (Any, torch.Tensor, int) -> torch.Tensor
r""" r"""Uses iterative furthest point sampling to select a set of npoint features that have the largest
Uses iterative furthest point sampling to select a set of npoint features that have the largest minimum distance.
minimum distance
Parameters Parameters
---------- ----------
@ -46,7 +47,7 @@ class FurthestPointSampling(Function):
npoint : int32 npoint : int32
number of features in the sampled set number of features in the sampled set
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, npoint) tensor containing the set (B, npoint) tensor containing the set
@ -69,9 +70,7 @@ class GatherOperation(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx): def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r""" r"""Parameters
Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor (B, C, N) tensor
@ -79,12 +78,11 @@ class GatherOperation(Function):
idx : torch.Tensor idx : torch.Tensor
(B, npoint) tensor of the features to gather (B, npoint) tensor of the features to gather
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, C, npoint) tensor (B, C, npoint) tensor
""" """
ctx.save_for_backward(idx, features) ctx.save_for_backward(idx, features)
return _ext.gather_points(features, idx) return _ext.gather_points(features, idx)
@ -105,16 +103,15 @@ class ThreeNN(Function):
@staticmethod @staticmethod
def forward(ctx, unknown, known): def forward(ctx, unknown, known):
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
r""" r"""Find the three nearest neighbors of unknown in known
Find the three nearest neighbors of unknown in known
Parameters Parameters
---------- ----------
unknown : torch.Tensor unknown : torch.Tensor
(B, n, 3) tensor of known features (B, n, 3) tensor of known features
known : torch.Tensor known : torch.Tensor
(B, m, 3) tensor of unknown features (B, m, 3) tensor of unknown features.
Returns Returns:
------- -------
dist : torch.Tensor dist : torch.Tensor
(B, n, 3) l2 distance to the three nearest neighbors (B, n, 3) l2 distance to the three nearest neighbors
@ -140,8 +137,7 @@ class ThreeInterpolate(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx, weight): def forward(ctx, features, idx, weight):
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
r""" r"""Performs weight linear interpolation on 3 features
Performs weight linear interpolation on 3 features
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
@ -149,9 +145,9 @@ class ThreeInterpolate(Function):
idx : torch.Tensor idx : torch.Tensor
(B, n, 3) three nearest neighbors of the target features in features (B, n, 3) three nearest neighbors of the target features in features
weight : torch.Tensor weight : torch.Tensor
(B, n, 3) weights (B, n, 3) weights.
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, c, n) tensor of the interpolated features (B, c, n) tensor of the interpolated features
@ -163,13 +159,12 @@ class ThreeInterpolate(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
r""" r"""Parameters
Parameters
---------- ----------
grad_out : torch.Tensor grad_out : torch.Tensor
(B, c, n) tensor with gradients of ouputs (B, c, n) tensor with gradients of ouputs
Returns Returns:
------- -------
grad_features : torch.Tensor grad_features : torch.Tensor
(B, c, m) tensor with gradients of features (B, c, m) tensor with gradients of features
@ -182,7 +177,10 @@ class ThreeInterpolate(Function):
m = features.size(2) m = features.size(2)
grad_features = _ext.three_interpolate_grad( grad_features = _ext.three_interpolate_grad(
grad_out.contiguous(), idx, weight, m grad_out.contiguous(),
idx,
weight,
m,
) )
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
@ -195,16 +193,14 @@ class GroupingOperation(Function):
@staticmethod @staticmethod
def forward(ctx, features, idx): def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r""" r"""Parameters
Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
(B, C, N) tensor of features to group (B, C, N) tensor of features to group
idx : torch.Tensor idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of features to group with (B, npoint, nsample) tensor containing the indicies of features to group with
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, C, npoint, nsample) tensor (B, C, npoint, nsample) tensor
@ -216,14 +212,12 @@ class GroupingOperation(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
r""" r"""Parameters
Parameters
---------- ----------
grad_out : torch.Tensor grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward (B, C, npoint, nsample) tensor of the gradients of the output from forward
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, C, N) gradient of the features (B, C, N) gradient of the features
@ -244,9 +238,7 @@ class BallQuery(Function):
@staticmethod @staticmethod
def forward(ctx, radius, nsample, xyz, new_xyz): def forward(ctx, radius, nsample, xyz, new_xyz):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
r""" r"""Parameters
Parameters
---------- ----------
radius : float radius : float
radius of the balls radius of the balls
@ -257,7 +249,7 @@ class BallQuery(Function):
new_xyz : torch.Tensor new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query (B, npoint, 3) centers of the ball query
Returns Returns:
------- -------
torch.Tensor torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls (B, npoint, nsample) tensor with the indicies of the features that form the query balls
@ -277,8 +269,7 @@ ball_query = BallQuery.apply
class QueryAndGroup(nn.Module): class QueryAndGroup(nn.Module):
r""" r"""Groups with a ball query of radius.
Groups with a ball query of radius
Parameters Parameters
--------- ---------
@ -290,13 +281,12 @@ class QueryAndGroup(nn.Module):
def __init__(self, radius, nsample, use_xyz=True): def __init__(self, radius, nsample, use_xyz=True):
# type: (QueryAndGroup, float, int, bool) -> None # type: (QueryAndGroup, float, int, bool) -> None
super(QueryAndGroup, self).__init__() super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz, new_xyz, features=None): def forward(self, xyz, new_xyz, features=None):
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
r""" r"""Parameters
Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
xyz coordinates of the features (B, N, 3) xyz coordinates of the features (B, N, 3)
@ -305,12 +295,11 @@ class QueryAndGroup(nn.Module):
features : torch.Tensor features : torch.Tensor
Descriptors of the features (B, C, N) Descriptors of the features (B, C, N)
Returns Returns:
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, 3 + C, npoint, nsample) tensor (B, 3 + C, npoint, nsample) tensor
""" """
idx = ball_query(self.radius, self.nsample, xyz, new_xyz) idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous() xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
@ -320,22 +309,20 @@ class QueryAndGroup(nn.Module):
grouped_features = grouping_operation(features, idx) grouped_features = grouping_operation(features, idx)
if self.use_xyz: if self.use_xyz:
new_features = torch.cat( new_features = torch.cat(
[grouped_xyz, grouped_features], dim=1 [grouped_xyz, grouped_features],
dim=1,
) # (B, C + 3, npoint, nsample) ) # (B, C + 3, npoint, nsample)
else: else:
new_features = grouped_features new_features = grouped_features
else: else:
assert ( assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
self.use_xyz
), "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz new_features = grouped_xyz
return new_features return new_features
class GroupAll(nn.Module): class GroupAll(nn.Module):
r""" r"""Groups all features.
Groups all features
Parameters Parameters
--------- ---------
@ -343,13 +330,12 @@ class GroupAll(nn.Module):
def __init__(self, use_xyz=True): def __init__(self, use_xyz=True):
# type: (GroupAll, bool) -> None # type: (GroupAll, bool) -> None
super(GroupAll, self).__init__() super().__init__()
self.use_xyz = use_xyz self.use_xyz = use_xyz
def forward(self, xyz, new_xyz, features=None): def forward(self, xyz, new_xyz, features=None):
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
r""" r"""Parameters
Parameters
---------- ----------
xyz : torch.Tensor xyz : torch.Tensor
xyz coordinates of the features (B, N, 3) xyz coordinates of the features (B, N, 3)
@ -358,18 +344,18 @@ class GroupAll(nn.Module):
features : torch.Tensor features : torch.Tensor
Descriptors of the features (B, C, N) Descriptors of the features (B, C, N)
Returns Returns:
------- -------
new_features : torch.Tensor new_features : torch.Tensor
(B, C + 3, 1, N) tensor (B, C + 3, 1, N) tensor
""" """
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None: if features is not None:
grouped_features = features.unsqueeze(2) grouped_features = features.unsqueeze(2)
if self.use_xyz: if self.use_xyz:
new_features = torch.cat( new_features = torch.cat(
[grouped_xyz, grouped_features], dim=1 [grouped_xyz, grouped_features],
dim=1,
) # (B, 3 + C, 1, N) ) # (B, 3 + C, 1, N)
else: else:
new_features = grouped_features new_features = grouped_features

View file

@ -8,7 +8,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
this_dir = osp.dirname(osp.abspath(__file__)) this_dir = osp.dirname(osp.abspath(__file__))
_ext_src_root = osp.join("pointnet2_ops", "_ext-src") _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
osp.join(_ext_src_root, "src", "*.cu") osp.join(_ext_src_root, "src", "*.cu"),
) )
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
@ -32,7 +32,7 @@ setup(
"nvcc": ["-O3", "-Xfatbin", "-compress-all"], "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
}, },
include_dirs=[osp.join(this_dir, _ext_src_root, "include")], include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
) ),
], ],
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
include_package_data=True, include_package_data=True,

64
pyproject.toml Normal file
View file

@ -0,0 +1,64 @@
[tool.ruff]
line-length = 120
ignore-init-module-imports = true
ignore = [
"G004", # Logging statement uses f-string
"EM102", # Exception must not use an f-string literal, assign to variable first
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"N812", # Lowercase imported as non lowercase
]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"C90", # mccabe
"COM", # flake8-commas
"D", # pydocstyle
"EM", # flake8-errmsg
"E", # pycodestyle errors
"F", # Pyflakes
"G", # flake8-logging-format
"I", # isort
"N", # pep8-naming
"PIE", # flake8-pie
"PTH", # flake8-use-pathlib
"TD", # flake8-todo
"FIX", # flake8-fixme
"RET", # flake8-return
"RUF", # ruff
"S", # flake8-bandit
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warnings
]
[tool.ruff.pydocstyle]
convention = "google"
[tool.ruff.isort]
known-first-party = ["pointnet2_ops"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"src/aube/main.py" = ["E402", "F401"]
[tool.black]
exclude = '''
/(
\.git
\.venv
)/
'''
include = '\.pyi?$'
line-length = 120
target-version = ["py310"]
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true

View file

@ -1,12 +0,0 @@
torch
torchvision
cudatoolkit
cycler
einops
h5py
matplotlib==3.4.2
pytorch
pyyaml==5.4.1
scikit-learn==0.24.2
scipy
tqdm