Compare commits
No commits in common. "4450cad9a31feb5477e914d1283350c5d5de9ccd" and "3e3d80cff5c23a631fe5ba4ca97db3f452893ed2" have entirely different histories.
4450cad9a3
...
3e3d80cff5
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,6 +1,3 @@
|
||||||
data
|
|
||||||
checkpoints
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
|
@ -1,9 +0,0 @@
|
||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"editorconfig.editorconfig",
|
|
||||||
"eamodio.gitlens",
|
|
||||||
"ms-python.python",
|
|
||||||
"ms-python.black-formatter",
|
|
||||||
"charliermarsh.ruff",
|
|
||||||
]
|
|
||||||
}
|
|
17
.vscode/launch.json
vendored
17
.vscode/launch.json
vendored
|
@ -1,17 +0,0 @@
|
||||||
{
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"name": "Python: Current File",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${file}",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"justMyCode": false,
|
|
||||||
"env": {
|
|
||||||
"OMP_NUM_THREADS": "1",
|
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
67
.vscode/settings.json
vendored
67
.vscode/settings.json
vendored
|
@ -1,67 +0,0 @@
|
||||||
{
|
|
||||||
// nice editor settings
|
|
||||||
"editor.formatOnSave": true,
|
|
||||||
"editor.formatOnPaste": true,
|
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.organizeImports": true,
|
|
||||||
"source.fixAll": false,
|
|
||||||
},
|
|
||||||
"editor.rulers": [
|
|
||||||
120
|
|
||||||
],
|
|
||||||
// editorconfig redundancy
|
|
||||||
"files.insertFinalNewline": true,
|
|
||||||
"files.trimTrailingWhitespace": true,
|
|
||||||
// hidde unimportant files/folders
|
|
||||||
"files.exclude": {
|
|
||||||
// defaults
|
|
||||||
"**/.git": true,
|
|
||||||
"**/.svn": true,
|
|
||||||
"**/.hg": true,
|
|
||||||
"**/CVS": true,
|
|
||||||
"**/.DS_Store": true,
|
|
||||||
"**/Thumbs.db": true,
|
|
||||||
// annoying
|
|
||||||
"**/__pycache__": true,
|
|
||||||
"**/.mypy_cache": true,
|
|
||||||
"**/.ruff_cache": true,
|
|
||||||
"**/*.tmp": true,
|
|
||||||
},
|
|
||||||
// cpp /clang / cmake settings
|
|
||||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools",
|
|
||||||
"C_Cpp.intelliSenseEngine": "disabled",
|
|
||||||
"C_Cpp.intelliSenseEngineFallback": "enabled",
|
|
||||||
"C_Cpp.clang_format_path": "/softs/compiler/llvm/latest/bin/clang-format",
|
|
||||||
"C_Cpp.codeAnalysis.clangTidy.enabled": true,
|
|
||||||
"C_Cpp.codeAnalysis.clangTidy.path": "/softs/compiler/llvm/latest/bin/clang-tidy",
|
|
||||||
"clangd.path": "/softs/compiler/llvm/latest/bin/clangd",
|
|
||||||
"cmake.cmakePath": "/softs/cmake/latest/bin/cmake",
|
|
||||||
"cmake.preferredGenerators": [
|
|
||||||
"Ninja",
|
|
||||||
"Unix Makefiles"
|
|
||||||
],
|
|
||||||
"cmakeFormat.exePath": "/softs/conda/auto/envs/cmake-format/bin/cmake-format",
|
|
||||||
"cmake.languageSupport.dotnetPath": "/softs/conda/auto/envs/dotnet/lib/dotnet/dotnet",
|
|
||||||
// python settings
|
|
||||||
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
|
|
||||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pointmlp/bin/python",
|
|
||||||
"python.linting.enabled": true,
|
|
||||||
"python.linting.lintOnSave": true,
|
|
||||||
"python.linting.mypyEnabled": true,
|
|
||||||
// fixes for broken auto-activation on rosetta
|
|
||||||
"python.terminal.activateEnvironment": false,
|
|
||||||
"terminal.integrated.profiles.linux": {
|
|
||||||
"python": {
|
|
||||||
"path": "bash",
|
|
||||||
"icon": "rocket",
|
|
||||||
"args": [
|
|
||||||
"--init-file",
|
|
||||||
".vscode/setup.sh"
|
|
||||||
],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"terminal.integrated.env.linux": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}/src/",
|
|
||||||
"SLURM_JOB_ID": null, // unset or else lightning_logs v_num uses it
|
|
||||||
},
|
|
||||||
}
|
|
11
.vscode/setup.sh
vendored
11
.vscode/setup.sh
vendored
|
@ -1,11 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source ~/.bashrc
|
|
||||||
|
|
||||||
conda_init
|
|
||||||
conda activate pointmlp
|
|
||||||
|
|
||||||
export PS1="(pointmlp)[\\u@\\h \\W]\\$ "
|
|
||||||
|
|
||||||
module load compilers
|
|
||||||
module load mpfr
|
|
59
README.md
59
README.md
|
@ -1,40 +1,48 @@
|
||||||
# Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework (ICLR 2022)
|
# Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework (ICLR 2022)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1)
|
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=rethinking-network-design-and-local-geometry-1)
|
||||||
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1)
|
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rethinking-network-design-and-local-geometry-1/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=rethinking-network-design-and-local-geometry-1)
|
||||||
|
|
||||||
|
|
||||||
[![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch)
|
[![github](https://img.shields.io/github/stars/ma-xu/pointMLP-pytorch?style=social)](https://github.com/ma-xu/pointMLP-pytorch)
|
||||||
|
|
||||||
<img src="images/smile.png" height="70px">
|
|
||||||
<img src="images/neu.png" height="70px">
|
|
||||||
<img src="images/columbia.png" height="70px">
|
|
||||||
|
|
||||||
[open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu)
|
<div align="left">
|
||||||
|
<a><img src="images/smile.png" height="70px" ></a>
|
||||||
|
<a><img src="images/neu.png" height="70px" ></a>
|
||||||
|
<a><img src="images/columbia.png" height="70px" ></a>
|
||||||
|
</div>
|
||||||
|
|
||||||
![](images/overview.png)
|
[open review](https://openreview.net/forum?id=3Pbra-_u76D) | [arXiv](https://arxiv.org/abs/2202.07123) | Primary contact: [Xu Ma](mailto:ma.xu1@northeastern.edu)
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="images/overview.png" width="650px" height="300px">
|
||||||
|
</div>
|
||||||
|
|
||||||
Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information.
|
Overview of one stage in PointMLP. Given an input point cloud, PointMLP progressively extracts local features using residual point MLP blocks. In each stage, we first transform the local point using a geometric affine module, and then local points are extracted before and after aggregation, respectively. By repeating multiple stages, PointMLP progressively enlarges the receptive field and models entire point cloud geometric information.
|
||||||
|
|
||||||
|
|
||||||
## BibTeX
|
## BibTeX
|
||||||
|
|
||||||
```bibtex
|
@article{ma2022rethinking,
|
||||||
@article{ma2022rethinking,
|
title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework},
|
||||||
title={Rethinking network design and local geometry in point cloud: A simple residual MLP framework},
|
author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun},
|
||||||
author={Ma, Xu and Qin, Can and You, Haoxuan and Ran, Haoxi and Fu, Yun},
|
journal={arXiv preprint arXiv:2202.07123},
|
||||||
journal={arXiv preprint arXiv:2202.07123},
|
year={2022}
|
||||||
year={2022}
|
}
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Model Zoo
|
## Model Zoo
|
||||||
|
|
||||||
**Questions on ModelNet40 classification results (a common issue for ModelNet40 dataset in the community)**
|
**Questions on ModelNet40 classification results (a common issue for ModelNet40 dataset in the community)**
|
||||||
|
|
||||||
|
The performance on ModelNet40 of almost all methods are not stable, see (https://github.com/CVMI-Lab/PAConv/issues/9#issuecomment-873371422).<br>
|
||||||
|
If you run the same codes for several times, you will get different results (even with fixed seed).<br>
|
||||||
|
The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br>
|
||||||
|
Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs.
|
||||||
|
|
||||||
|
|
||||||
The performance on ModelNet40 of almost all methods are not stable, see (https://github.com/CVMI-Lab/PAConv/issues/9#issuecomment-873371422).<br>
|
|
||||||
If you run the same codes for several times, you will get different results (even with fixed seed).<br>
|
|
||||||
The best way to reproduce the results is to test with a pretrained model for ModelNet40. <br>
|
|
||||||
Also, the randomness of ModelNet40 is our motivation to experiment on ScanObjectNN, and to report the mean/std results of several runs.
|
|
||||||
|
|
||||||
------
|
------
|
||||||
|
|
||||||
|
@ -46,6 +54,8 @@ On ScanObjectNN, fixed pointMLP achieves a result of **84.4% mAcc** and **86.1%
|
||||||
|
|
||||||
Stay tuned. More elite versions and voting results will be uploaded.
|
Stay tuned. More elite versions and voting results will be uploaded.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## News & Updates:
|
## News & Updates:
|
||||||
|
|
||||||
- [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder.
|
- [x] fix the uncomplete utils in partseg by Mar/10, caused by error uplaoded folder.
|
||||||
|
@ -56,6 +66,9 @@ Stay tuned. More elite versions and voting results will be uploaded.
|
||||||
|
|
||||||
:point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026).
|
:point_right::point_right::point_right:**NOTE:** The codes/models/logs for submission version (without bug fixed) can be found here [commit:d2b8dbaa](http://github.com/13952522076/pointMLP-pytorch/tree/d2b8dbaa06eb6176b222dcf2ad248f8438582026).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -79,7 +92,8 @@ pip install cycler einops h5py pyyaml==5.4.1 scikit-learn==0.24.2 scipy tqdm mat
|
||||||
pip install pointnet2_ops_lib/.
|
pip install pointnet2_ops_lib/.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
## Useage
|
||||||
|
|
||||||
### Classification ModelNet40
|
### Classification ModelNet40
|
||||||
**Train**: The dataset will be automatically downloaded, run following command to train.
|
**Train**: The dataset will be automatically downloaded, run following command to train.
|
||||||
|
@ -147,5 +161,10 @@ Our implementation is mainly based on the following codebases. We gratefully tha
|
||||||
[Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
|
[Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
|
||||||
|
|
||||||
## LICENSE
|
## LICENSE
|
||||||
|
|
||||||
PointMLP is under the Apache-2.0 license.
|
PointMLP is under the Apache-2.0 license.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
11
analysis.py
11
analysis.py
|
@ -1,21 +1,20 @@
|
||||||
import fvcore.common
|
|
||||||
import fvcore.nn
|
|
||||||
import torch
|
import torch
|
||||||
|
import fvcore.nn
|
||||||
|
import fvcore.common
|
||||||
from fvcore.nn import FlopCountAnalysis
|
from fvcore.nn import FlopCountAnalysis
|
||||||
|
|
||||||
from classification_ScanObjectNN.models import pointMLPElite
|
from classification_ScanObjectNN.models import pointMLPElite
|
||||||
|
|
||||||
model = pointMLPElite()
|
model = pointMLPElite()
|
||||||
model.eval()
|
model.eval()
|
||||||
# model = deit_tiny_patch16_224()
|
# model = deit_tiny_patch16_224()
|
||||||
|
|
||||||
inputs = torch.randn((1, 3, 1024))
|
inputs = (torch.randn((1,3,1024)))
|
||||||
k = 1024.0
|
k = 1024.0
|
||||||
flops = FlopCountAnalysis(model, inputs).total()
|
flops = FlopCountAnalysis(model, inputs).total()
|
||||||
print(f"Flops : {flops}")
|
print(f"Flops : {flops}")
|
||||||
flops = flops / (k**3)
|
flops = flops/(k**3)
|
||||||
print(f"Flops : {flops:.1f}G")
|
print(f"Flops : {flops:.1f}G")
|
||||||
params = fvcore.nn.parameter_count(model)[""]
|
params = fvcore.nn.parameter_count(model)[""]
|
||||||
print(f"Params : {params}")
|
print(f"Params : {params}")
|
||||||
params = params / (k**2)
|
params = params/(k**2)
|
||||||
print(f"Params : {params:.1f}M")
|
print(f"Params : {params:.1f}M")
|
||||||
|
|
|
@ -1,37 +1,33 @@
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
|
import glob
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
|
|
||||||
def download():
|
def download():
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||||
if not os.path.exists(DATA_DIR):
|
if not os.path.exists(DATA_DIR):
|
||||||
os.mkdir(DATA_DIR)
|
os.mkdir(DATA_DIR)
|
||||||
if not os.path.exists(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048")):
|
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
|
||||||
www = "https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip"
|
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
|
||||||
zipfile = os.path.basename(www)
|
zipfile = os.path.basename(www)
|
||||||
os.system(f"wget {www} --no-check-certificate; unzip {zipfile}")
|
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
|
||||||
os.system(f"mv {zipfile[:-4]} {DATA_DIR}")
|
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
||||||
os.system("rm %s" % (zipfile))
|
os.system('rm %s' % (zipfile))
|
||||||
|
|
||||||
|
|
||||||
def load_data(partition):
|
def load_data(partition):
|
||||||
download()
|
download()
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||||
all_data = []
|
all_data = []
|
||||||
all_label = []
|
all_label = []
|
||||||
for h5_name in glob.glob(os.path.join(DATA_DIR, "modelnet40_ply_hdf5_2048", "ply_data_%s*.h5" % partition)):
|
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)):
|
||||||
# print(f"h5_name: {h5_name}")
|
# print(f"h5_name: {h5_name}")
|
||||||
f = h5py.File(h5_name, "r")
|
f = h5py.File(h5_name,'r')
|
||||||
data = f["data"][:].astype("float32")
|
data = f['data'][:].astype('float32')
|
||||||
label = f["label"][:].astype("int64")
|
label = f['label'][:].astype('int64')
|
||||||
f.close()
|
f.close()
|
||||||
all_data.append(data)
|
all_data.append(data)
|
||||||
all_label.append(label)
|
all_label.append(label)
|
||||||
|
@ -39,42 +35,40 @@ def load_data(partition):
|
||||||
all_label = np.concatenate(all_label, axis=0)
|
all_label = np.concatenate(all_label, axis=0)
|
||||||
return all_data, all_label
|
return all_data, all_label
|
||||||
|
|
||||||
|
|
||||||
def random_point_dropout(pc, max_dropout_ratio=0.875):
|
def random_point_dropout(pc, max_dropout_ratio=0.875):
|
||||||
"""batch_pc: BxNx3."""
|
''' batch_pc: BxNx3 '''
|
||||||
# for b in range(batch_pc.shape[0]):
|
# for b in range(batch_pc.shape[0]):
|
||||||
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
|
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
|
||||||
drop_idx = np.where(np.random.random(pc.shape[0]) <= dropout_ratio)[0]
|
drop_idx = np.where(np.random.random((pc.shape[0]))<=dropout_ratio)[0]
|
||||||
# print ('use random drop', len(drop_idx))
|
# print ('use random drop', len(drop_idx))
|
||||||
|
|
||||||
if len(drop_idx) > 0:
|
if len(drop_idx)>0:
|
||||||
pc[drop_idx, :] = pc[0, :] # set to the first point
|
pc[drop_idx,:] = pc[0,:] # set to the first point
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
def translate_pointcloud(pointcloud):
|
def translate_pointcloud(pointcloud):
|
||||||
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
|
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
|
||||||
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
||||||
|
|
||||||
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
|
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
|
||||||
|
return translated_pointcloud
|
||||||
|
|
||||||
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
|
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
|
||||||
N, C = pointcloud.shape
|
N, C = pointcloud.shape
|
||||||
pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
|
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
|
||||||
return pointcloud
|
return pointcloud
|
||||||
|
|
||||||
|
|
||||||
class ModelNet40(Dataset):
|
class ModelNet40(Dataset):
|
||||||
def __init__(self, num_points, partition="train"):
|
def __init__(self, num_points, partition='train'):
|
||||||
self.data, self.label = load_data(partition)
|
self.data, self.label = load_data(partition)
|
||||||
self.num_points = num_points
|
self.num_points = num_points
|
||||||
self.partition = partition
|
self.partition = partition
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
pointcloud = self.data[item][: self.num_points]
|
pointcloud = self.data[item][:self.num_points]
|
||||||
label = self.label[item]
|
label = self.label[item]
|
||||||
if self.partition == "train":
|
if self.partition == 'train':
|
||||||
# pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all
|
# pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all
|
||||||
pointcloud = translate_pointcloud(pointcloud)
|
pointcloud = translate_pointcloud(pointcloud)
|
||||||
np.random.shuffle(pointcloud)
|
np.random.shuffle(pointcloud)
|
||||||
|
@ -84,25 +78,19 @@ class ModelNet40(Dataset):
|
||||||
return self.data.shape[0]
|
return self.data.shape[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train = ModelNet40(1024)
|
train = ModelNet40(1024)
|
||||||
test = ModelNet40(1024, "test")
|
test = ModelNet40(1024, 'test')
|
||||||
# for data, label in train:
|
# for data, label in train:
|
||||||
# print(data.shape)
|
# print(data.shape)
|
||||||
# print(label.shape)
|
# print(label.shape)
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
train_loader = DataLoader(ModelNet40(partition='train', num_points=1024), num_workers=4,
|
||||||
train_loader = DataLoader(
|
batch_size=32, shuffle=True, drop_last=True)
|
||||||
ModelNet40(partition="train", num_points=1024),
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=32,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
for batch_idx, (data, label) in enumerate(train_loader):
|
for batch_idx, (data, label) in enumerate(train_loader):
|
||||||
print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}")
|
print(f"batch_idx: {batch_idx} | data shape: {data.shape} | ;lable shape: {label.shape}")
|
||||||
|
|
||||||
train_set = ModelNet40(partition="train", num_points=1024)
|
train_set = ModelNet40(partition='train', num_points=1024)
|
||||||
test_set = ModelNet40(partition="test", num_points=1024)
|
test_set = ModelNet40(partition='test', num_points=1024)
|
||||||
print(f"train_set size {train_set.__len__()}")
|
print(f"train_set size {train_set.__len__()}")
|
||||||
print(f"test_set size {test_set.__len__()}")
|
print(f"test_set size {test_set.__len__()}")
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def cal_loss(pred, gold, smoothing=True):
|
def cal_loss(pred, gold, smoothing=True):
|
||||||
"""Calculate cross entropy loss, apply label smoothing if needed."""
|
''' Calculate cross entropy loss, apply label smoothing if needed. '''
|
||||||
|
|
||||||
gold = gold.contiguous().view(-1)
|
gold = gold.contiguous().view(-1)
|
||||||
|
|
||||||
if smoothing:
|
if smoothing:
|
||||||
|
@ -16,6 +16,6 @@ def cal_loss(pred, gold, smoothing=True):
|
||||||
|
|
||||||
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.cross_entropy(pred, gold, reduction="mean")
|
loss = F.cross_entropy(pred, gold, reduction='mean')
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -1,46 +1,41 @@
|
||||||
"""Usage:
|
"""
|
||||||
python main.py --model PointMLP --msg demo.
|
Usage:
|
||||||
|
python main.py --model PointMLP --msg demo
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import models as models
|
import datetime
|
||||||
import numpy as np
|
|
||||||
import sklearn.metrics as metrics
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data.distributed
|
import torch.utils.data.distributed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import models as models
|
||||||
|
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
|
||||||
from data import ModelNet40
|
from data import ModelNet40
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
import sklearn.metrics as metrics
|
||||||
from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parameters"""
|
"""Parameters"""
|
||||||
parser = argparse.ArgumentParser("training")
|
parser = argparse.ArgumentParser('training')
|
||||||
parser.add_argument(
|
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
||||||
"-c",
|
help='path to save checkpoint (default: checkpoint)')
|
||||||
"--checkpoint",
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||||
type=str,
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
||||||
metavar="PATH",
|
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
|
||||||
help="path to save checkpoint (default: checkpoint)",
|
parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training')
|
||||||
)
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument("--msg", type=str, help="message after checkpoint")
|
parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training')
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
|
parser.add_argument('--min_lr', default=0.005, type=float, help='min lr')
|
||||||
parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]")
|
parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate')
|
||||||
parser.add_argument("--epoch", default=300, type=int, help="number of epoch in training")
|
parser.add_argument('--seed', type=int, help='random seed')
|
||||||
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
|
parser.add_argument('--workers', default=8, type=int, help='workers')
|
||||||
parser.add_argument("--learning_rate", default=0.1, type=float, help="learning rate in training")
|
|
||||||
parser.add_argument("--min_lr", default=0.005, type=float, help="min lr")
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=2e-4, help="decay rate")
|
|
||||||
parser.add_argument("--seed", type=int, help="random seed")
|
|
||||||
parser.add_argument("--workers", default=8, type=int, help="workers")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +46,7 @@ def main():
|
||||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
assert torch.cuda.is_available(), "Please ensure codes are executed in cuda."
|
assert torch.cuda.is_available(), "Please ensure codes are executed in cuda."
|
||||||
device = "cuda"
|
device = 'cuda'
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
@ -60,19 +55,19 @@ def main():
|
||||||
torch.set_printoptions(10)
|
torch.set_printoptions(10)
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
os.environ["PYTHONHASHSEED"] = str(args.seed)
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||||
time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
|
time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
if args.msg is None:
|
if args.msg is None:
|
||||||
message = time_str
|
message = time_str
|
||||||
else:
|
else:
|
||||||
message = "-" + args.msg
|
message = "-" + args.msg
|
||||||
args.checkpoint = "checkpoints/" + args.model + message + "-" + str(args.seed)
|
args.checkpoint = 'checkpoints/' + args.model + message + '-' + str(args.seed)
|
||||||
if not os.path.isdir(args.checkpoint):
|
if not os.path.isdir(args.checkpoint):
|
||||||
mkdir_p(args.checkpoint)
|
mkdir_p(args.checkpoint)
|
||||||
|
|
||||||
screen_logger = logging.getLogger("Model")
|
screen_logger = logging.getLogger("Model")
|
||||||
screen_logger.setLevel(logging.INFO)
|
screen_logger.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(message)s")
|
formatter = logging.Formatter('%(message)s')
|
||||||
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
|
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
|
||||||
file_handler.setLevel(logging.INFO)
|
file_handler.setLevel(logging.INFO)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
@ -84,19 +79,19 @@ def main():
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
printf(f"args: {args}")
|
printf(f"args: {args}")
|
||||||
printf("==> Building model..")
|
printf('==> Building model..')
|
||||||
net = models.__dict__[args.model]()
|
net = models.__dict__[args.model]()
|
||||||
criterion = cal_loss
|
criterion = cal_loss
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
# criterion = criterion.to(device)
|
# criterion = criterion.to(device)
|
||||||
if device == "cuda":
|
if device == 'cuda':
|
||||||
net = torch.nn.DataParallel(net)
|
net = torch.nn.DataParallel(net)
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
best_test_acc = 0.0 # best test accuracy
|
best_test_acc = 0. # best test accuracy
|
||||||
best_train_acc = 0.0
|
best_train_acc = 0.
|
||||||
best_test_acc_avg = 0.0
|
best_test_acc_avg = 0.
|
||||||
best_train_acc_avg = 0.0
|
best_train_acc_avg = 0.
|
||||||
best_test_loss = float("inf")
|
best_test_loss = float("inf")
|
||||||
best_train_loss = float("inf")
|
best_train_loss = float("inf")
|
||||||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||||
|
@ -104,49 +99,30 @@ def main():
|
||||||
|
|
||||||
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
||||||
save_args(args)
|
save_args(args)
|
||||||
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
|
||||||
logger.set_names(
|
logger.set_names(["Epoch-Num", 'Learning-Rate',
|
||||||
[
|
'Train-Loss', 'Train-acc-B', 'Train-acc',
|
||||||
"Epoch-Num",
|
'Valid-Loss', 'Valid-acc-B', 'Valid-acc'])
|
||||||
"Learning-Rate",
|
|
||||||
"Train-Loss",
|
|
||||||
"Train-acc-B",
|
|
||||||
"Train-acc",
|
|
||||||
"Valid-Loss",
|
|
||||||
"Valid-acc-B",
|
|
||||||
"Valid-acc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
printf(f"Resuming last checkpoint from {args.checkpoint}")
|
printf(f"Resuming last checkpoint from {args.checkpoint}")
|
||||||
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
|
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
net.load_state_dict(checkpoint["net"])
|
net.load_state_dict(checkpoint['net'])
|
||||||
start_epoch = checkpoint["epoch"]
|
start_epoch = checkpoint['epoch']
|
||||||
best_test_acc = checkpoint["best_test_acc"]
|
best_test_acc = checkpoint['best_test_acc']
|
||||||
best_train_acc = checkpoint["best_train_acc"]
|
best_train_acc = checkpoint['best_train_acc']
|
||||||
best_test_acc_avg = checkpoint["best_test_acc_avg"]
|
best_test_acc_avg = checkpoint['best_test_acc_avg']
|
||||||
best_train_acc_avg = checkpoint["best_train_acc_avg"]
|
best_train_acc_avg = checkpoint['best_train_acc_avg']
|
||||||
best_test_loss = checkpoint["best_test_loss"]
|
best_test_loss = checkpoint['best_test_loss']
|
||||||
best_train_loss = checkpoint["best_train_loss"]
|
best_train_loss = checkpoint['best_train_loss']
|
||||||
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
|
||||||
optimizer_dict = checkpoint["optimizer"]
|
optimizer_dict = checkpoint['optimizer']
|
||||||
|
|
||||||
printf("==> Preparing data..")
|
printf('==> Preparing data..')
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=args.workers,
|
||||||
ModelNet40(partition="train", num_points=args.num_points),
|
batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||||
num_workers=args.workers,
|
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=args.workers,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(
|
|
||||||
ModelNet40(partition="test", num_points=args.num_points),
|
|
||||||
num_workers=args.workers,
|
|
||||||
batch_size=args.batch_size // 2,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
||||||
if optimizer_dict is not None:
|
if optimizer_dict is not None:
|
||||||
|
@ -154,7 +130,7 @@ def main():
|
||||||
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1)
|
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.min_lr, last_epoch=start_epoch - 1)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epoch):
|
for epoch in range(start_epoch, args.epoch):
|
||||||
printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"]))
|
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
|
||||||
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
|
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
|
||||||
test_out = validate(net, test_loader, criterion, device)
|
test_out = validate(net, test_loader, criterion, device)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
@ -173,46 +149,31 @@ def main():
|
||||||
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
||||||
|
|
||||||
save_model(
|
save_model(
|
||||||
net,
|
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
|
||||||
epoch,
|
|
||||||
path=args.checkpoint,
|
|
||||||
acc=test_out["acc"],
|
|
||||||
is_best=is_best,
|
|
||||||
best_test_acc=best_test_acc, # best test accuracy
|
best_test_acc=best_test_acc, # best test accuracy
|
||||||
best_train_acc=best_train_acc,
|
best_train_acc=best_train_acc,
|
||||||
best_test_acc_avg=best_test_acc_avg,
|
best_test_acc_avg=best_test_acc_avg,
|
||||||
best_train_acc_avg=best_train_acc_avg,
|
best_train_acc_avg=best_train_acc_avg,
|
||||||
best_test_loss=best_test_loss,
|
best_test_loss=best_test_loss,
|
||||||
best_train_loss=best_train_loss,
|
best_train_loss=best_train_loss,
|
||||||
optimizer=optimizer.state_dict(),
|
optimizer=optimizer.state_dict()
|
||||||
)
|
|
||||||
logger.append(
|
|
||||||
[
|
|
||||||
epoch,
|
|
||||||
optimizer.param_groups[0]["lr"],
|
|
||||||
train_out["loss"],
|
|
||||||
train_out["acc_avg"],
|
|
||||||
train_out["acc"],
|
|
||||||
test_out["loss"],
|
|
||||||
test_out["acc_avg"],
|
|
||||||
test_out["acc"],
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
logger.append([epoch, optimizer.param_groups[0]['lr'],
|
||||||
|
train_out["loss"], train_out["acc_avg"], train_out["acc"],
|
||||||
|
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
|
||||||
printf(
|
printf(
|
||||||
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s",
|
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s")
|
||||||
)
|
|
||||||
printf(
|
printf(
|
||||||
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
|
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
|
||||||
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n",
|
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n")
|
||||||
)
|
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
printf("++++++++" * 2 + "Final results" + "++++++++" * 2)
|
printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2)
|
||||||
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
|
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
|
||||||
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
|
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
|
||||||
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
|
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
|
||||||
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
|
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
|
||||||
printf("++++++++" * 5)
|
printf(f"++++++++" * 5)
|
||||||
|
|
||||||
|
|
||||||
def train(net, trainloader, optimizer, criterion, device):
|
def train(net, trainloader, optimizer, criterion, device):
|
||||||
|
@ -241,21 +202,17 @@ def train(net, trainloader, optimizer, criterion, device):
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
|
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(trainloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
train_true = np.concatenate(train_true)
|
train_true = np.concatenate(train_true)
|
||||||
train_pred = np.concatenate(train_pred)
|
train_pred = np.concatenate(train_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (train_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (train_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,23 +236,19 @@ def validate(net, testloader, criterion, device):
|
||||||
test_pred.append(preds.detach().cpu().numpy())
|
test_pred.append(preds.detach().cpu().numpy())
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(testloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
test_true = np.concatenate(test_true)
|
test_true = np.concatenate(test_true)
|
||||||
test_pred = np.concatenate(test_pred)
|
test_pred = np.concatenate(test_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1 +1,3 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from .pointmlp import pointMLP, pointMLPElite
|
from .pointmlp import pointMLP, pointMLPElite
|
||||||
|
|
|
@ -1,32 +1,35 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# from torch import einsum
|
# from torch import einsum
|
||||||
# from einops import rearrange, repeat
|
# from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
from pointnet2_ops import pointnet2_utils
|
from pointnet2_ops import pointnet2_utils
|
||||||
|
|
||||||
|
|
||||||
def get_activation(activation):
|
def get_activation(activation):
|
||||||
if activation.lower() == "gelu":
|
if activation.lower() == 'gelu':
|
||||||
return nn.GELU()
|
return nn.GELU()
|
||||||
elif activation.lower() == "rrelu":
|
elif activation.lower() == 'rrelu':
|
||||||
return nn.RReLU(inplace=True)
|
return nn.RReLU(inplace=True)
|
||||||
elif activation.lower() == "selu":
|
elif activation.lower() == 'selu':
|
||||||
return nn.SELU(inplace=True)
|
return nn.SELU(inplace=True)
|
||||||
elif activation.lower() == "silu":
|
elif activation.lower() == 'silu':
|
||||||
return nn.SiLU(inplace=True)
|
return nn.SiLU(inplace=True)
|
||||||
elif activation.lower() == "hardswish":
|
elif activation.lower() == 'hardswish':
|
||||||
return nn.Hardswish(inplace=True)
|
return nn.Hardswish(inplace=True)
|
||||||
elif activation.lower() == "leakyrelu":
|
elif activation.lower() == 'leakyrelu':
|
||||||
return nn.LeakyReLU(inplace=True)
|
return nn.LeakyReLU(inplace=True)
|
||||||
else:
|
else:
|
||||||
return nn.ReLU(inplace=True)
|
return nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
|
||||||
def square_distance(src, dst):
|
def square_distance(src, dst):
|
||||||
"""Calculate Euclid distance between each two points.
|
"""
|
||||||
src^T * dst = xn * xm + yn * ym + zn * zm;
|
Calculate Euclid distance between each two points.
|
||||||
|
src^T * dst = xn * xm + yn * ym + zn * zm;
|
||||||
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
||||||
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
||||||
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
||||||
|
@ -35,23 +38,23 @@ def square_distance(src, dst):
|
||||||
src: source points, [B, N, C]
|
src: source points, [B, N, C]
|
||||||
dst: target points, [B, M, C]
|
dst: target points, [B, M, C]
|
||||||
Output:
|
Output:
|
||||||
dist: per-point square distance, [B, N, M].
|
dist: per-point square distance, [B, N, M]
|
||||||
"""
|
"""
|
||||||
B, N, _ = src.shape
|
B, N, _ = src.shape
|
||||||
_, M, _ = dst.shape
|
_, M, _ = dst.shape
|
||||||
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
||||||
dist += torch.sum(src**2, -1).view(B, N, 1)
|
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
||||||
dist += torch.sum(dst**2, -1).view(B, 1, M)
|
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
def index_points(points, idx):
|
def index_points(points, idx):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
points: input points data, [B, N, C]
|
points: input points data, [B, N, C]
|
||||||
idx: sample index data, [B, S].
|
idx: sample index data, [B, S]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
new_points:, indexed points data, [B, S, C].
|
new_points:, indexed points data, [B, S, C]
|
||||||
"""
|
"""
|
||||||
device = points.device
|
device = points.device
|
||||||
B = points.shape[0]
|
B = points.shape[0]
|
||||||
|
@ -60,15 +63,17 @@ def index_points(points, idx):
|
||||||
repeat_shape = list(idx.shape)
|
repeat_shape = list(idx.shape)
|
||||||
repeat_shape[0] = 1
|
repeat_shape[0] = 1
|
||||||
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
||||||
return points[batch_indices, idx, :]
|
new_points = points[batch_indices, idx, :]
|
||||||
|
return new_points
|
||||||
|
|
||||||
|
|
||||||
def farthest_point_sample(xyz, npoint):
|
def farthest_point_sample(xyz, npoint):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
xyz: pointcloud data, [B, N, 3]
|
xyz: pointcloud data, [B, N, 3]
|
||||||
npoint: number of samples
|
npoint: number of samples
|
||||||
Return:
|
Return:
|
||||||
centroids: sampled pointcloud index, [B, npoint].
|
centroids: sampled pointcloud index, [B, npoint]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
|
@ -86,21 +91,21 @@ def farthest_point_sample(xyz, npoint):
|
||||||
|
|
||||||
|
|
||||||
def query_ball_point(radius, nsample, xyz, new_xyz):
|
def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
radius: local region radius
|
radius: local region radius
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, 3]
|
xyz: all points, [B, N, 3]
|
||||||
new_xyz: query points, [B, S, 3].
|
new_xyz: query points, [B, S, 3]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
_, S, _ = new_xyz.shape
|
_, S, _ = new_xyz.shape
|
||||||
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
group_idx[sqrdists > radius**2] = N
|
group_idx[sqrdists > radius ** 2] = N
|
||||||
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
||||||
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
||||||
mask = group_idx == N
|
mask = group_idx == N
|
||||||
|
@ -109,13 +114,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
|
|
||||||
|
|
||||||
def knn_point(nsample, xyz, new_xyz):
|
def knn_point(nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, C]
|
xyz: all points, [B, N, C]
|
||||||
new_xyz: query points, [B, S, C].
|
new_xyz: query points, [B, S, C]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
||||||
|
@ -124,12 +129,13 @@ def knn_point(nsample, xyz, new_xyz):
|
||||||
|
|
||||||
class LocalGrouper(nn.Module):
|
class LocalGrouper(nn.Module):
|
||||||
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
|
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
|
||||||
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
"""
|
||||||
|
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
||||||
:param groups: groups number
|
:param groups: groups number
|
||||||
:param kneighbors: k-nerighbors
|
:param kneighbors: k-nerighbors
|
||||||
:param kwargs: others.
|
:param kwargs: others
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(LocalGrouper, self).__init__()
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.kneighbors = kneighbors
|
self.kneighbors = kneighbors
|
||||||
self.use_xyz = use_xyz
|
self.use_xyz = use_xyz
|
||||||
|
@ -138,11 +144,11 @@ class LocalGrouper(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize not in ["center", "anchor"]:
|
if self.normalize not in ["center", "anchor"]:
|
||||||
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
add_channel = 3 if self.use_xyz else 0
|
add_channel=3 if self.use_xyz else 0
|
||||||
self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
|
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel]))
|
||||||
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
||||||
|
|
||||||
def forward(self, xyz, points):
|
def forward(self, xyz, points):
|
||||||
|
@ -161,33 +167,29 @@ class LocalGrouper(nn.Module):
|
||||||
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
||||||
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
||||||
if self.use_xyz:
|
if self.use_xyz:
|
||||||
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
|
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
if self.normalize == "center":
|
if self.normalize =="center":
|
||||||
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
||||||
if self.normalize == "anchor":
|
if self.normalize =="anchor":
|
||||||
mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
|
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
|
||||||
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
||||||
std = (
|
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
||||||
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
|
grouped_points = (grouped_points-mean)/(std + 1e-5)
|
||||||
.unsqueeze(dim=-1)
|
grouped_points = self.affine_alpha*grouped_points + self.affine_beta
|
||||||
.unsqueeze(dim=-1)
|
|
||||||
)
|
|
||||||
grouped_points = (grouped_points - mean) / (std + 1e-5)
|
|
||||||
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
|
|
||||||
|
|
||||||
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
||||||
return new_xyz, new_points
|
return new_xyz, new_points
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLU1D(nn.Module):
|
class ConvBNReLU1D(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
|
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLU1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -195,43 +197,30 @@ class ConvBNReLU1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLURes1D(nn.Module):
|
class ConvBNReLURes1D(nn.Module):
|
||||||
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
|
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLURes1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net1 = nn.Sequential(
|
self.net1 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
|
||||||
in_channels=channel,
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=int(channel * res_expansion),
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(int(channel * res_expansion)),
|
nn.BatchNorm1d(int(channel * res_expansion)),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
if groups > 1:
|
if groups > 1:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=channel,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
self.act,
|
self.act,
|
||||||
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=channel, out_channels=channel,
|
||||||
|
kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, bias=bias),
|
||||||
out_channels=channel,
|
nn.BatchNorm1d(channel)
|
||||||
kernel_size=kernel_size,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -239,34 +228,21 @@ class ConvBNReLURes1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PreExtraction(nn.Module):
|
class PreExtraction(nn.Module):
|
||||||
def __init__(
|
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
|
||||||
self,
|
activation='relu', use_xyz=True):
|
||||||
channels,
|
"""
|
||||||
out_channels,
|
input: [b,g,k,d]: output:[b,d,g]
|
||||||
blocks=1,
|
|
||||||
groups=1,
|
|
||||||
res_expansion=1,
|
|
||||||
bias=True,
|
|
||||||
activation="relu",
|
|
||||||
use_xyz=True,
|
|
||||||
):
|
|
||||||
"""input: [b,g,k,d]: output:[b,d,g]
|
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PreExtraction, self).__init__()
|
||||||
in_channels = 3 + 2 * channels if use_xyz else 2 * channels
|
in_channels = 3+2*channels if use_xyz else 2*channels
|
||||||
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(
|
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
|
||||||
out_channels,
|
bias=bias, activation=activation)
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -278,20 +254,22 @@ class PreExtraction(nn.Module):
|
||||||
batch_size, _, _ = x.size()
|
batch_size, _, _ = x.size()
|
||||||
x = self.operation(x) # [b, d, k]
|
x = self.operation(x) # [b, d, k]
|
||||||
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
||||||
return x.reshape(b, n, -1).permute(0, 2, 1)
|
x = x.reshape(b, n, -1).permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PosExtraction(nn.Module):
|
class PosExtraction(nn.Module):
|
||||||
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
|
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
|
||||||
"""input[b,d,g]; output[b,d,g]
|
"""
|
||||||
|
input[b,d,g]; output[b,d,g]
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PosExtraction, self).__init__()
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
|
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -300,32 +278,17 @@ class PosExtraction(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(
|
def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
self,
|
activation="relu", bias=True, use_xyz=True, normalize="center",
|
||||||
points=1024,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
class_num=40,
|
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
|
||||||
embed_dim=64,
|
super(Model, self).__init__()
|
||||||
groups=1,
|
|
||||||
res_expansion=1.0,
|
|
||||||
activation="relu",
|
|
||||||
bias=True,
|
|
||||||
use_xyz=True,
|
|
||||||
normalize="center",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[32, 32, 32, 32],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.stages = len(pre_blocks)
|
self.stages = len(pre_blocks)
|
||||||
self.class_num = class_num
|
self.class_num = class_num
|
||||||
self.points = points
|
self.points = points
|
||||||
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
|
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
|
||||||
assert (
|
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
|
||||||
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
|
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
||||||
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
|
||||||
self.local_grouper_list = nn.ModuleList()
|
self.local_grouper_list = nn.ModuleList()
|
||||||
self.pre_blocks_list = nn.ModuleList()
|
self.pre_blocks_list = nn.ModuleList()
|
||||||
self.pos_blocks_list = nn.ModuleList()
|
self.pos_blocks_list = nn.ModuleList()
|
||||||
|
@ -342,26 +305,13 @@ class Model(nn.Module):
|
||||||
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
||||||
self.local_grouper_list.append(local_grouper)
|
self.local_grouper_list.append(local_grouper)
|
||||||
# append pre_block_list
|
# append pre_block_list
|
||||||
pre_block_module = PreExtraction(
|
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
|
||||||
last_channel,
|
res_expansion=res_expansion,
|
||||||
out_channel,
|
bias=bias, activation=activation, use_xyz=use_xyz)
|
||||||
pre_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
use_xyz=use_xyz,
|
|
||||||
)
|
|
||||||
self.pre_blocks_list.append(pre_block_module)
|
self.pre_blocks_list.append(pre_block_module)
|
||||||
# append pos_block_list
|
# append pos_block_list
|
||||||
pos_block_module = PosExtraction(
|
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
|
||||||
out_channel,
|
res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
pos_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
self.pos_blocks_list.append(pos_block_module)
|
self.pos_blocks_list.append(pos_block_module)
|
||||||
|
|
||||||
last_channel = out_channel
|
last_channel = out_channel
|
||||||
|
@ -376,7 +326,7 @@ class Model(nn.Module):
|
||||||
nn.BatchNorm1d(256),
|
nn.BatchNorm1d(256),
|
||||||
self.act,
|
self.act,
|
||||||
nn.Dropout(0.5),
|
nn.Dropout(0.5),
|
||||||
nn.Linear(256, self.class_num),
|
nn.Linear(256, self.class_num)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -390,59 +340,29 @@ class Model(nn.Module):
|
||||||
x = self.pos_blocks_list[i](x) # [b,d,g]
|
x = self.pos_blocks_list[i](x) # [b,d,g]
|
||||||
|
|
||||||
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
|
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
|
||||||
return self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pointMLP(num_classes=40, **kwargs) -> Model:
|
def pointMLP(num_classes=40, **kwargs) -> Model:
|
||||||
return Model(
|
return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
points=1024,
|
activation="relu", bias=False, use_xyz=False, normalize="anchor",
|
||||||
class_num=num_classes,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
embed_dim=64,
|
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
|
||||||
groups=1,
|
|
||||||
res_expansion=1.0,
|
|
||||||
activation="relu",
|
|
||||||
bias=False,
|
|
||||||
use_xyz=False,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[24, 24, 24, 24],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
||||||
return Model(
|
return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25,
|
||||||
points=1024,
|
activation="relu", bias=False, use_xyz=False, normalize="anchor",
|
||||||
class_num=num_classes,
|
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1],
|
||||||
embed_dim=32,
|
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
|
||||||
groups=1,
|
|
||||||
res_expansion=0.25,
|
|
||||||
activation="relu",
|
|
||||||
bias=False,
|
|
||||||
use_xyz=False,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 1],
|
|
||||||
pre_blocks=[1, 1, 2, 1],
|
|
||||||
pos_blocks=[1, 1, 2, 1],
|
|
||||||
k_neighbors=[24, 24, 24, 24],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
data = torch.rand(2, 3, 1024).cuda()
|
|
||||||
print(data.shape)
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data = torch.rand(2, 3, 1024)
|
||||||
print("===> testing pointMLP ...")
|
print("===> testing pointMLP ...")
|
||||||
model = pointMLP().cuda()
|
model = pointMLP()
|
||||||
out = model(data)
|
out = model(data)
|
||||||
print(out.shape)
|
print(out.shape)
|
||||||
|
|
||||||
print("===> testing pointMLPElite ...")
|
|
||||||
model = pointMLPElite().cuda()
|
|
||||||
out = model(data)
|
|
||||||
print(out.shape)
|
|
||||||
|
|
|
@ -1,81 +1,71 @@
|
||||||
"""python test.py --model pointMLP --msg 20220209053148-404."""
|
"""
|
||||||
|
python test.py --model pointMLP --msg 20220209053148-404
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
|
import datetime
|
||||||
import models as models
|
|
||||||
import numpy as np
|
|
||||||
import sklearn.metrics as metrics
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data.distributed
|
import torch.utils.data.distributed
|
||||||
from data import ModelNet40
|
|
||||||
from helper import cal_loss
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils import progress_bar
|
import models as models
|
||||||
|
from utils import progress_bar, IOStream
|
||||||
|
from data import ModelNet40
|
||||||
|
import sklearn.metrics as metrics
|
||||||
|
from helper import cal_loss
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
|
model_names = sorted(name for name in models.__dict__
|
||||||
|
if callable(models.__dict__[name]))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parameters"""
|
"""Parameters"""
|
||||||
parser = argparse.ArgumentParser("training")
|
parser = argparse.ArgumentParser('training')
|
||||||
parser.add_argument(
|
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
||||||
"-c",
|
help='path to save checkpoint (default: checkpoint)')
|
||||||
"--checkpoint",
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||||
type=str,
|
parser.add_argument('--batch_size', type=int, default=16, help='batch size in training')
|
||||||
metavar="PATH",
|
parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]')
|
||||||
help="path to save checkpoint (default: checkpoint)",
|
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||||
)
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument("--msg", type=str, help="message after checkpoint")
|
|
||||||
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
|
|
||||||
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
|
|
||||||
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
|
|
||||||
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
print(f"args: {args}")
|
print(f"args: {args}")
|
||||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = 'cuda'
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = 'cpu'
|
||||||
print(f"==> Using device: {device}")
|
print(f"==> Using device: {device}")
|
||||||
# if args.msg is None:
|
if args.msg is None:
|
||||||
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
# else:
|
else:
|
||||||
# message = "-"+args.msg
|
message = "-"+args.msg
|
||||||
# args.checkpoint = 'checkpoints/' + args.model + message
|
args.checkpoint = 'checkpoints/' + args.model + message
|
||||||
if args.checkpoint is not None:
|
|
||||||
print(f"==> Using checkpoint: {args.checkpoint}")
|
|
||||||
|
|
||||||
print("==> Preparing data..")
|
print('==> Preparing data..')
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
||||||
ModelNet40(partition="test", num_points=args.num_points),
|
batch_size=args.batch_size, shuffle=False, drop_last=False)
|
||||||
num_workers=4,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
# Model
|
# Model
|
||||||
print("==> Building model..")
|
print('==> Building model..')
|
||||||
net = models.__dict__[args.model]()
|
net = models.__dict__[args.model]()
|
||||||
criterion = cal_loss
|
criterion = cal_loss
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
||||||
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
# criterion = criterion.to(device)
|
# criterion = criterion.to(device)
|
||||||
if device == "cuda":
|
if device == 'cuda':
|
||||||
net = torch.nn.DataParallel(net)
|
net = torch.nn.DataParallel(net)
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
net.load_state_dict(checkpoint["net"])
|
net.load_state_dict(checkpoint['net'])
|
||||||
|
|
||||||
test_out = validate(net, test_loader, criterion, device)
|
test_out = validate(net, test_loader, criterion, device)
|
||||||
print(f"Vanilla out: {test_out}")
|
print(f"Vanilla out: {test_out}")
|
||||||
|
@ -101,23 +91,19 @@ def validate(net, testloader, criterion, device):
|
||||||
test_pred.append(preds.detach().cpu().numpy())
|
test_pred.append(preds.detach().cpu().numpy())
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(testloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
test_true = np.concatenate(test_true)
|
test_true = np.concatenate(test_true)
|
||||||
test_pred = np.concatenate(test_pred)
|
test_pred = np.concatenate(test_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Useful utils."""
|
"""Useful utils
|
||||||
from .logger import *
|
"""
|
||||||
from .misc import *
|
from .misc import *
|
||||||
|
from .logger import *
|
||||||
from .progress.progress.bar import Bar as Bar
|
from .progress.progress.bar import Bar as Bar
|
||||||
|
|
|
@ -1,50 +1,48 @@
|
||||||
# A simple torch style logger
|
# A simple torch style logger
|
||||||
# (C) Wei YANG 2017
|
# (C) Wei YANG 2017
|
||||||
|
from __future__ import absolute_import
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__all__ = ["Logger", "LoggerMonitor", "savefig"]
|
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
|
||||||
|
|
||||||
|
|
||||||
def savefig(fname, dpi=None):
|
def savefig(fname, dpi=None):
|
||||||
dpi = 150 if dpi is None else dpi
|
dpi = 150 if dpi == None else dpi
|
||||||
plt.savefig(fname, dpi=dpi)
|
plt.savefig(fname, dpi=dpi)
|
||||||
|
|
||||||
|
|
||||||
def plot_overlap(logger, names=None):
|
def plot_overlap(logger, names=None):
|
||||||
names = logger.names if names is None else names
|
names = logger.names if names == None else names
|
||||||
numbers = logger.numbers
|
numbers = logger.numbers
|
||||||
for _, name in enumerate(names):
|
for _, name in enumerate(names):
|
||||||
x = np.arange(len(numbers[name]))
|
x = np.arange(len(numbers[name]))
|
||||||
plt.plot(x, np.asarray(numbers[name]))
|
plt.plot(x, np.asarray(numbers[name]))
|
||||||
return [logger.title + "(" + name + ")" for name in names]
|
return [logger.title + '(' + name + ')' for name in names]
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
|
||||||
"""Save training process to log file with simple plot function."""
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
'''Save training process to log file with simple plot function.'''
|
||||||
def __init__(self, fpath, title=None, resume=False):
|
def __init__(self, fpath, title=None, resume=False):
|
||||||
self.file = None
|
self.file = None
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
self.title = "" if title is None else title
|
self.title = '' if title == None else title
|
||||||
if fpath is not None:
|
if fpath is not None:
|
||||||
if resume:
|
if resume:
|
||||||
self.file = open(fpath)
|
self.file = open(fpath, 'r')
|
||||||
name = self.file.readline()
|
name = self.file.readline()
|
||||||
self.names = name.rstrip().split("\t")
|
self.names = name.rstrip().split('\t')
|
||||||
self.numbers = {}
|
self.numbers = {}
|
||||||
for _, name in enumerate(self.names):
|
for _, name in enumerate(self.names):
|
||||||
self.numbers[name] = []
|
self.numbers[name] = []
|
||||||
|
|
||||||
for numbers in self.file:
|
for numbers in self.file:
|
||||||
numbers = numbers.rstrip().split("\t")
|
numbers = numbers.rstrip().split('\t')
|
||||||
for i in range(0, len(numbers)):
|
for i in range(0, len(numbers)):
|
||||||
self.numbers[self.names[i]].append(numbers[i])
|
self.numbers[self.names[i]].append(numbers[i])
|
||||||
self.file.close()
|
self.file.close()
|
||||||
self.file = open(fpath, "a")
|
self.file = open(fpath, 'a')
|
||||||
else:
|
else:
|
||||||
self.file = open(fpath, "w")
|
self.file = open(fpath, 'w')
|
||||||
|
|
||||||
def set_names(self, names):
|
def set_names(self, names):
|
||||||
if self.resume:
|
if self.resume:
|
||||||
|
@ -54,39 +52,38 @@ class Logger:
|
||||||
self.names = names
|
self.names = names
|
||||||
for _, name in enumerate(self.names):
|
for _, name in enumerate(self.names):
|
||||||
self.file.write(name)
|
self.file.write(name)
|
||||||
self.file.write("\t")
|
self.file.write('\t')
|
||||||
self.numbers[name] = []
|
self.numbers[name] = []
|
||||||
self.file.write("\n")
|
self.file.write('\n')
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
|
|
||||||
def append(self, numbers):
|
def append(self, numbers):
|
||||||
assert len(self.names) == len(numbers), "Numbers do not match names"
|
assert len(self.names) == len(numbers), 'Numbers do not match names'
|
||||||
for index, num in enumerate(numbers):
|
for index, num in enumerate(numbers):
|
||||||
self.file.write(f"{num:.6f}")
|
self.file.write("{0:.6f}".format(num))
|
||||||
self.file.write("\t")
|
self.file.write('\t')
|
||||||
self.numbers[self.names[index]].append(num)
|
self.numbers[self.names[index]].append(num)
|
||||||
self.file.write("\n")
|
self.file.write('\n')
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def plot(self, names=None):
|
def plot(self, names=None):
|
||||||
names = self.names if names is None else names
|
names = self.names if names == None else names
|
||||||
numbers = self.numbers
|
numbers = self.numbers
|
||||||
for _, name in enumerate(names):
|
for _, name in enumerate(names):
|
||||||
x = np.arange(len(numbers[name]))
|
x = np.arange(len(numbers[name]))
|
||||||
plt.plot(x, np.asarray(numbers[name]))
|
plt.plot(x, np.asarray(numbers[name]))
|
||||||
plt.legend([self.title + "(" + name + ")" for name in names])
|
plt.legend([self.title + '(' + name + ')' for name in names])
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.file is not None:
|
if self.file is not None:
|
||||||
self.file.close()
|
self.file.close()
|
||||||
|
|
||||||
|
class LoggerMonitor(object):
|
||||||
class LoggerMonitor:
|
'''Load and visualize multiple logs.'''
|
||||||
"""Load and visualize multiple logs."""
|
def __init__ (self, paths):
|
||||||
|
'''paths is a distionary with {name:filepath} pair'''
|
||||||
def __init__(self, paths):
|
|
||||||
"""Paths is a distionary with {name:filepath} pair."""
|
|
||||||
self.loggers = []
|
self.loggers = []
|
||||||
for title, path in paths.items():
|
for title, path in paths.items():
|
||||||
logger = Logger(path, title=title, resume=True)
|
logger = Logger(path, title=title, resume=True)
|
||||||
|
@ -98,11 +95,10 @@ class LoggerMonitor:
|
||||||
legend_text = []
|
legend_text = []
|
||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
legend_text += plot_overlap(logger, names)
|
legend_text += plot_overlap(logger, names)
|
||||||
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
|
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
if __name__ == "__main__":
|
|
||||||
# # Example
|
# # Example
|
||||||
# logger = Logger('test.txt')
|
# logger = Logger('test.txt')
|
||||||
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
||||||
|
@ -119,13 +115,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Example: logger monitor
|
# Example: logger monitor
|
||||||
paths = {
|
paths = {
|
||||||
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
|
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
|
||||||
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
|
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
|
||||||
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
|
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
|
||||||
}
|
}
|
||||||
|
|
||||||
field = ["Valid Acc."]
|
field = ['Valid Acc.']
|
||||||
|
|
||||||
monitor = LoggerMonitor(paths)
|
monitor = LoggerMonitor(paths)
|
||||||
monitor.plot(names=field)
|
monitor.plot(names=field)
|
||||||
savefig("test.eps")
|
savefig('test.eps')
|
|
@ -1,56 +1,48 @@
|
||||||
"""Some helper functions for PyTorch, including:
|
'''Some helper functions for PyTorch, including:
|
||||||
- get_mean_and_std: calculate the mean and std value of dataset.
|
- get_mean_and_std: calculate the mean and std value of dataset.
|
||||||
- msr_init: net parameter initialization.
|
- msr_init: net parameter initialization.
|
||||||
- progress_bar: progress bar mimic xlua.progress.
|
- progress_bar: progress bar mimic xlua.progress.
|
||||||
"""
|
'''
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import shutil
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_mean_and_std",
|
import torch.nn as nn
|
||||||
"init_params",
|
import torch.nn.init as init
|
||||||
"mkdir_p",
|
from torch.autograd import Variable
|
||||||
"AverageMeter",
|
|
||||||
"progress_bar",
|
__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter',
|
||||||
"save_model",
|
'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"]
|
||||||
"save_args",
|
|
||||||
"set_seed",
|
|
||||||
"IOStream",
|
|
||||||
"cal_loss",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_std(dataset):
|
def get_mean_and_std(dataset):
|
||||||
"""Compute the mean and std value of dataset."""
|
'''Compute the mean and std value of dataset.'''
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
|
dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
|
||||||
|
|
||||||
mean = torch.zeros(3)
|
mean = torch.zeros(3)
|
||||||
std = torch.zeros(3)
|
std = torch.zeros(3)
|
||||||
print("==> Computing mean and std..")
|
print('==> Computing mean and std..')
|
||||||
for inputs, _targets in dataloader:
|
for inputs, targets in dataloader:
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
mean[i] += inputs[:, i, :, :].mean()
|
mean[i] += inputs[:,i,:,:].mean()
|
||||||
std[i] += inputs[:, i, :, :].std()
|
std[i] += inputs[:,i,:,:].std()
|
||||||
mean.div_(len(dataset))
|
mean.div_(len(dataset))
|
||||||
std.div_(len(dataset))
|
std.div_(len(dataset))
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
def init_params(net):
|
def init_params(net):
|
||||||
"""Init layer parameters."""
|
'''Init layer parameters.'''
|
||||||
for m in net.modules():
|
for m in net.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
init.kaiming_normal(m.weight, mode="fan_out")
|
init.kaiming_normal(m.weight, mode='fan_out')
|
||||||
if m.bias:
|
if m.bias:
|
||||||
init.constant(m.bias, 0)
|
init.constant(m.bias, 0)
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
@ -61,9 +53,8 @@ def init_params(net):
|
||||||
if m.bias:
|
if m.bias:
|
||||||
init.constant(m.bias, 0)
|
init.constant(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
def mkdir_p(path):
|
def mkdir_p(path):
|
||||||
"""Make dir if not exist."""
|
'''make dir if not exist'''
|
||||||
try:
|
try:
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
except OSError as exc: # Python >2.5
|
except OSError as exc: # Python >2.5
|
||||||
|
@ -72,12 +63,10 @@ def mkdir_p(path):
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
class AverageMeter:
|
|
||||||
"""Computes and stores the average and current value
|
"""Computes and stores the average and current value
|
||||||
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
|
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -94,26 +83,25 @@ class AverageMeter:
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
TOTAL_BAR_LENGTH = 65.0
|
|
||||||
|
TOTAL_BAR_LENGTH = 65.
|
||||||
last_time = time.time()
|
last_time = time.time()
|
||||||
begin_time = last_time
|
begin_time = last_time
|
||||||
|
|
||||||
|
|
||||||
def progress_bar(current, total, msg=None):
|
def progress_bar(current, total, msg=None):
|
||||||
global last_time, begin_time
|
global last_time, begin_time
|
||||||
if current == 0:
|
if current == 0:
|
||||||
begin_time = time.time() # Reset for new bar.
|
begin_time = time.time() # Reset for new bar.
|
||||||
|
|
||||||
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
cur_len = int(TOTAL_BAR_LENGTH*current/total)
|
||||||
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
||||||
|
|
||||||
sys.stdout.write(" [")
|
sys.stdout.write(' [')
|
||||||
for _i in range(cur_len):
|
for i in range(cur_len):
|
||||||
sys.stdout.write("=")
|
sys.stdout.write('=')
|
||||||
sys.stdout.write(">")
|
sys.stdout.write('>')
|
||||||
for _i in range(rest_len):
|
for i in range(rest_len):
|
||||||
sys.stdout.write(".")
|
sys.stdout.write('.')
|
||||||
sys.stdout.write("]")
|
sys.stdout.write(']')
|
||||||
|
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
step_time = cur_time - last_time
|
step_time = cur_time - last_time
|
||||||
|
@ -121,12 +109,12 @@ def progress_bar(current, total, msg=None):
|
||||||
tot_time = cur_time - begin_time
|
tot_time = cur_time - begin_time
|
||||||
|
|
||||||
L = []
|
L = []
|
||||||
L.append(" Step: %s" % format_time(step_time))
|
L.append(' Step: %s' % format_time(step_time))
|
||||||
L.append(" | Tot: %s" % format_time(tot_time))
|
L.append(' | Tot: %s' % format_time(tot_time))
|
||||||
if msg:
|
if msg:
|
||||||
L.append(" | " + msg)
|
L.append(' | ' + msg)
|
||||||
|
|
||||||
msg = "".join(L)
|
msg = ''.join(L)
|
||||||
sys.stdout.write(msg)
|
sys.stdout.write(msg)
|
||||||
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
||||||
# sys.stdout.write(' ')
|
# sys.stdout.write(' ')
|
||||||
|
@ -134,74 +122,76 @@ def progress_bar(current, total, msg=None):
|
||||||
# Go back to the center of the bar.
|
# Go back to the center of the bar.
|
||||||
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
||||||
# sys.stdout.write('\b')
|
# sys.stdout.write('\b')
|
||||||
sys.stdout.write(" %d/%d " % (current + 1, total))
|
sys.stdout.write(' %d/%d ' % (current+1, total))
|
||||||
|
|
||||||
if current < total - 1:
|
if current < total-1:
|
||||||
sys.stdout.write("\r")
|
sys.stdout.write('\r')
|
||||||
else:
|
else:
|
||||||
sys.stdout.write("\n")
|
sys.stdout.write('\n')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
days = int(seconds / 3600 / 24)
|
days = int(seconds / 3600/24)
|
||||||
seconds = seconds - days * 3600 * 24
|
seconds = seconds - days*3600*24
|
||||||
hours = int(seconds / 3600)
|
hours = int(seconds / 3600)
|
||||||
seconds = seconds - hours * 3600
|
seconds = seconds - hours*3600
|
||||||
minutes = int(seconds / 60)
|
minutes = int(seconds / 60)
|
||||||
seconds = seconds - minutes * 60
|
seconds = seconds - minutes*60
|
||||||
secondsf = int(seconds)
|
secondsf = int(seconds)
|
||||||
seconds = seconds - secondsf
|
seconds = seconds - secondsf
|
||||||
millis = int(seconds * 1000)
|
millis = int(seconds*1000)
|
||||||
|
|
||||||
f = ""
|
f = ''
|
||||||
i = 1
|
i = 1
|
||||||
if days > 0:
|
if days > 0:
|
||||||
f += str(days) + "D"
|
f += str(days) + 'D'
|
||||||
i += 1
|
i += 1
|
||||||
if hours > 0 and i <= 2:
|
if hours > 0 and i <= 2:
|
||||||
f += str(hours) + "h"
|
f += str(hours) + 'h'
|
||||||
i += 1
|
i += 1
|
||||||
if minutes > 0 and i <= 2:
|
if minutes > 0 and i <= 2:
|
||||||
f += str(minutes) + "m"
|
f += str(minutes) + 'm'
|
||||||
i += 1
|
i += 1
|
||||||
if secondsf > 0 and i <= 2:
|
if secondsf > 0 and i <= 2:
|
||||||
f += str(secondsf) + "s"
|
f += str(secondsf) + 's'
|
||||||
i += 1
|
i += 1
|
||||||
if millis > 0 and i <= 2:
|
if millis > 0 and i <= 2:
|
||||||
f += str(millis) + "ms"
|
f += str(millis) + 'ms'
|
||||||
i += 1
|
i += 1
|
||||||
if f == "":
|
if f == '':
|
||||||
f = "0ms"
|
f = '0ms'
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def save_model(net, epoch, path, acc, is_best, **kwargs):
|
def save_model(net, epoch, path, acc, is_best, **kwargs):
|
||||||
state = {
|
state = {
|
||||||
"net": net.state_dict(),
|
'net': net.state_dict(),
|
||||||
"epoch": epoch,
|
'epoch': epoch,
|
||||||
"acc": acc,
|
'acc': acc
|
||||||
}
|
}
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
state[key] = value
|
state[key] = value
|
||||||
filepath = os.path.join(path, "last_checkpoint.pth")
|
filepath = os.path.join(path, "last_checkpoint.pth")
|
||||||
torch.save(state, filepath)
|
torch.save(state, filepath)
|
||||||
if is_best:
|
if is_best:
|
||||||
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
|
shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_args(args):
|
def save_args(args):
|
||||||
file = open(os.path.join(args.checkpoint, "args.txt"), "w")
|
file = open(os.path.join(args.checkpoint, 'args.txt'), "w")
|
||||||
for k, v in vars(args).items():
|
for k, v in vars(args).items():
|
||||||
file.write(f"{k}:\t {v}\n")
|
file.write(f"{k}:\t {v}\n")
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed=None):
|
def set_seed(seed=None):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
return
|
return
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
os.environ["PYTHONHASHSEED"] = "%s" % seed
|
os.environ['PYTHONHASHSEED'] = ("%s" % seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
@ -210,14 +200,15 @@ def set_seed(seed=None):
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# create a file and write the text into it
|
# create a file and write the text into it
|
||||||
class IOStream:
|
class IOStream():
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.f = open(path, "a")
|
self.f = open(path, 'a')
|
||||||
|
|
||||||
def cprint(self, text):
|
def cprint(self, text):
|
||||||
print(text)
|
print(text)
|
||||||
self.f.write(text + "\n")
|
self.f.write(text+'\n')
|
||||||
self.f.flush()
|
self.f.flush()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -225,7 +216,8 @@ class IOStream:
|
||||||
|
|
||||||
|
|
||||||
def cal_loss(pred, gold, smoothing=True):
|
def cal_loss(pred, gold, smoothing=True):
|
||||||
"""Calculate cross entropy loss, apply label smoothing if needed."""
|
''' Calculate cross entropy loss, apply label smoothing if needed. '''
|
||||||
|
|
||||||
gold = gold.contiguous().view(-1)
|
gold = gold.contiguous().view(-1)
|
||||||
|
|
||||||
if smoothing:
|
if smoothing:
|
||||||
|
@ -238,6 +230,6 @@ def cal_loss(pred, gold, smoothing=True):
|
||||||
|
|
||||||
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.cross_entropy(pred, gold, reduction="mean")
|
loss = F.cross_entropy(pred, gold, reduction='mean')
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -19,12 +20,13 @@ from math import ceil
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
__version__ = "1.3"
|
|
||||||
|
__version__ = '1.3'
|
||||||
|
|
||||||
|
|
||||||
class Infinite:
|
class Infinite(object):
|
||||||
file = stderr
|
file = stderr
|
||||||
sma_window = 10 # Simple Moving Average window
|
sma_window = 10 # Simple Moving Average window
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.index = 0
|
self.index = 0
|
||||||
|
@ -36,7 +38,7 @@ class Infinite:
|
||||||
setattr(self, key, val)
|
setattr(self, key, val)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if key.startswith("_"):
|
if key.startswith('_'):
|
||||||
return None
|
return None
|
||||||
return getattr(self, key, None)
|
return getattr(self, key, None)
|
||||||
|
|
||||||
|
@ -81,8 +83,8 @@ class Infinite:
|
||||||
|
|
||||||
class Progress(Infinite):
|
class Progress(Infinite):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super(Progress, self).__init__(*args, **kwargs)
|
||||||
self.max = kwargs.get("max", 100)
|
self.max = kwargs.get('max', 100)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eta(self):
|
def eta(self):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,18 +14,19 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Progress
|
from . import Progress
|
||||||
from .helpers import WritelnMixin
|
from .helpers import WritelnMixin
|
||||||
|
|
||||||
|
|
||||||
class Bar(WritelnMixin, Progress):
|
class Bar(WritelnMixin, Progress):
|
||||||
width = 32
|
width = 32
|
||||||
message = ""
|
message = ''
|
||||||
suffix = "%(index)d/%(max)d"
|
suffix = '%(index)d/%(max)d'
|
||||||
bar_prefix = " |"
|
bar_prefix = ' |'
|
||||||
bar_suffix = "| "
|
bar_suffix = '| '
|
||||||
empty_fill = " "
|
empty_fill = ' '
|
||||||
fill = "#"
|
fill = '#'
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -34,50 +37,52 @@ class Bar(WritelnMixin, Progress):
|
||||||
bar = self.fill * filled_length
|
bar = self.fill * filled_length
|
||||||
empty = self.empty_fill * empty_length
|
empty = self.empty_fill * empty_length
|
||||||
suffix = self.suffix % self
|
suffix = self.suffix % self
|
||||||
line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix])
|
line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
|
||||||
|
suffix])
|
||||||
self.writeln(line)
|
self.writeln(line)
|
||||||
|
|
||||||
|
|
||||||
class ChargingBar(Bar):
|
class ChargingBar(Bar):
|
||||||
suffix = "%(percent)d%%"
|
suffix = '%(percent)d%%'
|
||||||
bar_prefix = " "
|
bar_prefix = ' '
|
||||||
bar_suffix = " "
|
bar_suffix = ' '
|
||||||
empty_fill = "∙"
|
empty_fill = '∙'
|
||||||
fill = "█"
|
fill = '█'
|
||||||
|
|
||||||
|
|
||||||
class FillingSquaresBar(ChargingBar):
|
class FillingSquaresBar(ChargingBar):
|
||||||
empty_fill = "▢"
|
empty_fill = '▢'
|
||||||
fill = "▣"
|
fill = '▣'
|
||||||
|
|
||||||
|
|
||||||
class FillingCirclesBar(ChargingBar):
|
class FillingCirclesBar(ChargingBar):
|
||||||
empty_fill = "◯"
|
empty_fill = '◯'
|
||||||
fill = "◉"
|
fill = '◉'
|
||||||
|
|
||||||
|
|
||||||
class IncrementalBar(Bar):
|
class IncrementalBar(Bar):
|
||||||
phases = (" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█")
|
phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
nphases = len(self.phases)
|
nphases = len(self.phases)
|
||||||
filled_len = self.width * self.progress
|
filled_len = self.width * self.progress
|
||||||
nfull = int(filled_len) # Number of full chars
|
nfull = int(filled_len) # Number of full chars
|
||||||
phase = int((filled_len - nfull) * nphases) # Phase of last char
|
phase = int((filled_len - nfull) * nphases) # Phase of last char
|
||||||
nempty = self.width - nfull # Number of empty chars
|
nempty = self.width - nfull # Number of empty chars
|
||||||
|
|
||||||
message = self.message % self
|
message = self.message % self
|
||||||
bar = self.phases[-1] * nfull
|
bar = self.phases[-1] * nfull
|
||||||
current = self.phases[phase] if phase > 0 else ""
|
current = self.phases[phase] if phase > 0 else ''
|
||||||
empty = self.empty_fill * max(0, nempty - len(current))
|
empty = self.empty_fill * max(0, nempty - len(current))
|
||||||
suffix = self.suffix % self
|
suffix = self.suffix % self
|
||||||
line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix])
|
line = ''.join([message, self.bar_prefix, bar, current, empty,
|
||||||
|
self.bar_suffix, suffix])
|
||||||
self.writeln(line)
|
self.writeln(line)
|
||||||
|
|
||||||
|
|
||||||
class PixelBar(IncrementalBar):
|
class PixelBar(IncrementalBar):
|
||||||
phases = ("⡀", "⡄", "⡆", "⡇", "⣇", "⣧", "⣷", "⣿")
|
phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')
|
||||||
|
|
||||||
|
|
||||||
class ShadyBar(IncrementalBar):
|
class ShadyBar(IncrementalBar):
|
||||||
phases = (" ", "░", "▒", "▓", "█")
|
phases = (' ', '░', '▒', '▓', '█')
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,12 +14,13 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Infinite, Progress
|
from . import Infinite, Progress
|
||||||
from .helpers import WriteMixin
|
from .helpers import WriteMixin
|
||||||
|
|
||||||
|
|
||||||
class Counter(WriteMixin, Infinite):
|
class Counter(WriteMixin, Infinite):
|
||||||
message = ""
|
message = ''
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -32,7 +35,7 @@ class Countdown(WriteMixin, Progress):
|
||||||
|
|
||||||
|
|
||||||
class Stack(WriteMixin, Progress):
|
class Stack(WriteMixin, Progress):
|
||||||
phases = (" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█")
|
phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -42,4 +45,4 @@ class Stack(WriteMixin, Progress):
|
||||||
|
|
||||||
|
|
||||||
class Pie(Stack):
|
class Pie(Stack):
|
||||||
phases = ("○", "◔", "◑", "◕", "●")
|
phases = ('○', '◔', '◑', '◕', '●')
|
||||||
|
|
|
@ -12,76 +12,78 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
HIDE_CURSOR = "\x1b[?25l"
|
|
||||||
SHOW_CURSOR = "\x1b[?25h"
|
|
||||||
|
|
||||||
|
|
||||||
class WriteMixin:
|
HIDE_CURSOR = '\x1b[?25l'
|
||||||
|
SHOW_CURSOR = '\x1b[?25h'
|
||||||
|
|
||||||
|
|
||||||
|
class WriteMixin(object):
|
||||||
hide_cursor = False
|
hide_cursor = False
|
||||||
|
|
||||||
def __init__(self, message=None, **kwargs):
|
def __init__(self, message=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super(WriteMixin, self).__init__(**kwargs)
|
||||||
self._width = 0
|
self._width = 0
|
||||||
if message:
|
if message:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
if self.hide_cursor:
|
if self.hide_cursor:
|
||||||
print(HIDE_CURSOR, end="", file=self.file)
|
print(HIDE_CURSOR, end='', file=self.file)
|
||||||
print(self.message, end="", file=self.file)
|
print(self.message, end='', file=self.file)
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def write(self, s):
|
def write(self, s):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
b = "\b" * self._width
|
b = '\b' * self._width
|
||||||
c = s.ljust(self._width)
|
c = s.ljust(self._width)
|
||||||
print(b + c, end="", file=self.file)
|
print(b + c, end='', file=self.file)
|
||||||
self._width = max(self._width, len(s))
|
self._width = max(self._width, len(s))
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.file.isatty() and self.hide_cursor:
|
if self.file.isatty() and self.hide_cursor:
|
||||||
print(SHOW_CURSOR, end="", file=self.file)
|
print(SHOW_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
|
|
||||||
class WritelnMixin:
|
class WritelnMixin(object):
|
||||||
hide_cursor = False
|
hide_cursor = False
|
||||||
|
|
||||||
def __init__(self, message=None, **kwargs):
|
def __init__(self, message=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super(WritelnMixin, self).__init__(**kwargs)
|
||||||
if message:
|
if message:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
if self.file.isatty() and self.hide_cursor:
|
if self.file.isatty() and self.hide_cursor:
|
||||||
print(HIDE_CURSOR, end="", file=self.file)
|
print(HIDE_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
def clearln(self):
|
def clearln(self):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
print("\r\x1b[K", end="", file=self.file)
|
print('\r\x1b[K', end='', file=self.file)
|
||||||
|
|
||||||
def writeln(self, line):
|
def writeln(self, line):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
self.clearln()
|
self.clearln()
|
||||||
print(line, end="", file=self.file)
|
print(line, end='', file=self.file)
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
print(file=self.file)
|
print(file=self.file)
|
||||||
if self.hide_cursor:
|
if self.hide_cursor:
|
||||||
print(SHOW_CURSOR, end="", file=self.file)
|
print(SHOW_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
|
|
||||||
from signal import SIGINT, signal
|
from signal import signal, SIGINT
|
||||||
from sys import exit
|
from sys import exit
|
||||||
|
|
||||||
|
|
||||||
class SigIntMixin:
|
class SigIntMixin(object):
|
||||||
"""Registers a signal handler that calls finish on SIGINT."""
|
"""Registers a signal handler that calls finish on SIGINT"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super(SigIntMixin, self).__init__(*args, **kwargs)
|
||||||
signal(SIGINT, self._sigint_handler)
|
signal(SIGINT, self._sigint_handler)
|
||||||
|
|
||||||
def _sigint_handler(self, signum, frame):
|
def _sigint_handler(self, signum, frame):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,13 +14,14 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Infinite
|
from . import Infinite
|
||||||
from .helpers import WriteMixin
|
from .helpers import WriteMixin
|
||||||
|
|
||||||
|
|
||||||
class Spinner(WriteMixin, Infinite):
|
class Spinner(WriteMixin, Infinite):
|
||||||
message = ""
|
message = ''
|
||||||
phases = ("-", "\\", "|", "/")
|
phases = ('-', '\\', '|', '/')
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -27,16 +30,15 @@ class Spinner(WriteMixin, Infinite):
|
||||||
|
|
||||||
|
|
||||||
class PieSpinner(Spinner):
|
class PieSpinner(Spinner):
|
||||||
phases = ["◷", "◶", "◵", "◴"]
|
phases = ['◷', '◶', '◵', '◴']
|
||||||
|
|
||||||
|
|
||||||
class MoonSpinner(Spinner):
|
class MoonSpinner(Spinner):
|
||||||
phases = ["◑", "◒", "◐", "◓"]
|
phases = ['◑', '◒', '◐', '◓']
|
||||||
|
|
||||||
|
|
||||||
class LineSpinner(Spinner):
|
class LineSpinner(Spinner):
|
||||||
phases = ["⎺", "⎻", "⎼", "⎽", "⎼", "⎻"]
|
phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']
|
||||||
|
|
||||||
|
|
||||||
class PixelSpinner(Spinner):
|
class PixelSpinner(Spinner):
|
||||||
phases = ["⣾", "⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽"]
|
phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']
|
||||||
|
|
|
@ -1,27 +1,29 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import progress
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
|
import progress
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="progress",
|
name='progress',
|
||||||
version=progress.__version__,
|
version=progress.__version__,
|
||||||
description="Easy to use progress bars",
|
description='Easy to use progress bars',
|
||||||
long_description=open("README.rst").read(),
|
long_description=open('README.rst').read(),
|
||||||
author="Giorgos Verigakis",
|
author='Giorgos Verigakis',
|
||||||
author_email="verigak@gmail.com",
|
author_email='verigak@gmail.com',
|
||||||
url="http://github.com/verigak/progress/",
|
url='http://github.com/verigak/progress/',
|
||||||
license="ISC",
|
license='ISC',
|
||||||
packages=["progress"],
|
packages=['progress'],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Environment :: Console",
|
'Environment :: Console',
|
||||||
"Intended Audience :: Developers",
|
'Intended Audience :: Developers',
|
||||||
"License :: OSI Approved :: ISC License (ISCL)",
|
'License :: OSI Approved :: ISC License (ISCL)',
|
||||||
"Programming Language :: Python :: 2.6",
|
'Programming Language :: Python :: 2.6',
|
||||||
"Programming Language :: Python :: 2.7",
|
'Programming Language :: Python :: 2.7',
|
||||||
"Programming Language :: Python :: 3.3",
|
'Programming Language :: Python :: 3.3',
|
||||||
"Programming Language :: Python :: 3.4",
|
'Programming Language :: Python :: 3.4',
|
||||||
"Programming Language :: Python :: 3.5",
|
'Programming Language :: Python :: 3.5',
|
||||||
"Programming Language :: Python :: 3.6",
|
'Programming Language :: Python :: 3.6',
|
||||||
],
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar
|
from progress.bar import (Bar, ChargingBar, FillingSquaresBar,
|
||||||
from progress.counter import Countdown, Counter, Pie, Stack
|
FillingCirclesBar, IncrementalBar, PixelBar,
|
||||||
from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner
|
ShadyBar)
|
||||||
|
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
|
||||||
|
PixelSpinner)
|
||||||
|
from progress.counter import Counter, Countdown, Stack, Pie
|
||||||
|
|
||||||
|
|
||||||
def sleep():
|
def sleep():
|
||||||
|
@ -16,29 +20,29 @@ def sleep():
|
||||||
|
|
||||||
|
|
||||||
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
|
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
|
||||||
suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]"
|
suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
|
||||||
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
||||||
for _i in bar.iter(range(200)):
|
for i in bar.iter(range(200)):
|
||||||
sleep()
|
sleep()
|
||||||
|
|
||||||
for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
|
for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
|
||||||
suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]"
|
suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'
|
||||||
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
||||||
for _i in bar.iter(range(200)):
|
for i in bar.iter(range(200)):
|
||||||
sleep()
|
sleep()
|
||||||
|
|
||||||
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
|
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
|
||||||
for _i in spin(spin.__name__ + " ").iter(range(100)):
|
for i in spin(spin.__name__ + ' ').iter(range(100)):
|
||||||
sleep()
|
sleep()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
for singleton in (Counter, Countdown, Stack, Pie):
|
for singleton in (Counter, Countdown, Stack, Pie):
|
||||||
for _i in singleton(singleton.__name__ + " ").iter(range(100)):
|
for i in singleton(singleton.__name__ + ' ').iter(range(100)):
|
||||||
sleep()
|
sleep()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
bar = IncrementalBar("Random", suffix="%(index)d")
|
bar = IncrementalBar('Random', suffix='%(index)d')
|
||||||
for _i in range(100):
|
for i in range(100):
|
||||||
bar.goto(random.randint(0, 100))
|
bar.goto(random.randint(0, 100))
|
||||||
sleep()
|
sleep()
|
||||||
bar.finish()
|
bar.finish()
|
||||||
|
|
|
@ -1,52 +1,47 @@
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
|
import datetime
|
||||||
import models as models
|
|
||||||
import numpy as np
|
|
||||||
import sklearn.metrics as metrics
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data.distributed
|
import torch.utils.data.distributed
|
||||||
from data import ModelNet40
|
|
||||||
from helper import cal_loss
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils import IOStream, progress_bar
|
import models as models
|
||||||
|
from utils import progress_bar, IOStream
|
||||||
|
from data import ModelNet40
|
||||||
|
import sklearn.metrics as metrics
|
||||||
|
from helper import cal_loss
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
|
model_names = sorted(name for name in models.__dict__
|
||||||
|
if callable(models.__dict__[name]))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parameters"""
|
"""Parameters"""
|
||||||
parser = argparse.ArgumentParser("training")
|
parser = argparse.ArgumentParser('training')
|
||||||
parser.add_argument(
|
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
||||||
"-c",
|
help='path to save checkpoint (default: checkpoint)')
|
||||||
"--checkpoint",
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||||
type=str,
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
||||||
metavar="PATH",
|
parser.add_argument('--model', default='model31A', help='model name [default: pointnet_cls]')
|
||||||
help="path to save checkpoint (default: checkpoint)",
|
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||||
)
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument("--msg", type=str, help="message after checkpoint")
|
parser.add_argument('--seed', type=int, help='random seed (default: 1)')
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
|
|
||||||
parser.add_argument("--model", default="model31A", help="model name [default: pointnet_cls]")
|
|
||||||
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
|
|
||||||
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
|
|
||||||
parser.add_argument("--seed", type=int, help="random seed (default: 1)")
|
|
||||||
|
|
||||||
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
|
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
|
||||||
parser.add_argument("--NUM_PEPEAT", type=int, default=300)
|
parser.add_argument('--NUM_PEPEAT', type=int, default=300)
|
||||||
parser.add_argument("--NUM_VOTE", type=int, default=10)
|
parser.add_argument('--NUM_VOTE', type=int, default=10)
|
||||||
|
|
||||||
parser.add_argument("--validate", action="store_true", help="Validate the original testing result.")
|
parser.add_argument('--validate', action='store_true', help='Validate the original testing result.')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
class PointcloudScale: # input random scaling
|
class PointcloudScale(object): # input random scaling
|
||||||
def __init__(self, scale_low=2.0 / 3.0, scale_high=3.0 / 2.0):
|
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
|
||||||
self.scale_low = scale_low
|
self.scale_low = scale_low
|
||||||
self.scale_high = scale_high
|
self.scale_high = scale_high
|
||||||
|
|
||||||
|
@ -73,52 +68,45 @@ def main():
|
||||||
torch.set_printoptions(10)
|
torch.set_printoptions(10)
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
os.environ["PYTHONHASHSEED"] = str(args.seed)
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = 'cuda'
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = 'cpu'
|
||||||
print(f"==> Using device: {device}")
|
print(f"==> Using device: {device}")
|
||||||
if args.msg is None:
|
if args.msg is None:
|
||||||
message = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
|
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
else:
|
else:
|
||||||
message = "-" + args.msg
|
message = "-" + args.msg
|
||||||
args.checkpoint = "checkpoints/" + args.model + message
|
args.checkpoint = 'checkpoints/' + args.model + message
|
||||||
|
|
||||||
print("==> Preparing data..")
|
print('==> Preparing data..')
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
||||||
ModelNet40(partition="test", num_points=args.num_points),
|
batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
|
||||||
num_workers=4,
|
|
||||||
batch_size=args.batch_size // 2,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
# Model
|
# Model
|
||||||
print("==> Building model..")
|
print('==> Building model..')
|
||||||
net = models.__dict__[args.model]()
|
net = models.__dict__[args.model]()
|
||||||
criterion = cal_loss
|
criterion = cal_loss
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
checkpoint_path = os.path.join(args.checkpoint, "best_checkpoint.pth")
|
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
# criterion = criterion.to(device)
|
# criterion = criterion.to(device)
|
||||||
if device == "cuda":
|
if device == 'cuda':
|
||||||
net = torch.nn.DataParallel(net)
|
net = torch.nn.DataParallel(net)
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
net.load_state_dict(checkpoint["net"])
|
net.load_state_dict(checkpoint['net'])
|
||||||
|
|
||||||
if args.validate:
|
if args.validate:
|
||||||
test_out = validate(net, test_loader, criterion, device)
|
test_out = validate(net, test_loader, criterion, device)
|
||||||
print(f"Vanilla out: {test_out}")
|
print(f"Vanilla out: {test_out}")
|
||||||
print(
|
print(f"Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n"
|
||||||
"Note 1: Please also load the random seed parameter (if forgot, see out.txt).\n"
|
f"Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n"
|
||||||
"Note 2: This result may vary little on different GPUs (and number of GPUs), we tested 2080Ti, P100, and V100.\n"
|
f"[note : Original result is achieved with V100 GPUs.]\n\n\n")
|
||||||
"[note : Original result is achieved with V100 GPUs.]\n\n\n",
|
|
||||||
)
|
|
||||||
# Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu.
|
# Interestingly, we get original best_test_acc on 4 V100 gpus, but this model is trained on one V100 gpu.
|
||||||
# On different GPUs, and different number of GPUs, both OA and mean_acc vary a little.
|
# On different GPUs, and different number of GPUs, both OA and mean_acc vary a little.
|
||||||
# Also, the batch size also affect the testing results, could not understand.
|
# Also, the batch size also affect the testing results, could not understand.
|
||||||
|
|
||||||
print("===> start voting evaluation...")
|
print(f"===> start voting evaluation...")
|
||||||
voting(net, test_loader, device, args)
|
voting(net, test_loader, device, args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,28 +130,23 @@ def validate(net, testloader, criterion, device):
|
||||||
test_pred.append(preds.detach().cpu().numpy())
|
test_pred.append(preds.detach().cpu().numpy())
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(testloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
test_true = np.concatenate(test_true)
|
test_true = np.concatenate(test_true)
|
||||||
test_pred = np.concatenate(test_pred)
|
test_pred = np.concatenate(test_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def voting(net, testloader, device, args):
|
def voting(net, testloader, device, args):
|
||||||
name = (
|
name = '/evaluate_voting' + str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) + 'seed_' + str(
|
||||||
"/evaluate_voting" + str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S")) + "seed_" + str(args.seed) + ".log"
|
args.seed) + '.log'
|
||||||
)
|
|
||||||
io = IOStream(args.checkpoint + name)
|
io = IOStream(args.checkpoint + name)
|
||||||
io.cprint(str(args))
|
io.cprint(str(args))
|
||||||
|
|
||||||
|
@ -178,7 +161,7 @@ def voting(net, testloader, device, args):
|
||||||
test_true = []
|
test_true = []
|
||||||
test_pred = []
|
test_pred = []
|
||||||
|
|
||||||
for _batch_idx, (data, label) in enumerate(testloader):
|
for batch_idx, (data, label) in enumerate(testloader):
|
||||||
data, label = data.to(device), label.to(device).squeeze()
|
data, label = data.to(device), label.to(device).squeeze()
|
||||||
pred = 0
|
pred = 0
|
||||||
for v in range(args.NUM_VOTE):
|
for v in range(args.NUM_VOTE):
|
||||||
|
@ -195,24 +178,19 @@ def voting(net, testloader, device, args):
|
||||||
test_pred.append(pred_choice.detach().cpu().numpy())
|
test_pred.append(pred_choice.detach().cpu().numpy())
|
||||||
test_true = np.concatenate(test_true)
|
test_true = np.concatenate(test_true)
|
||||||
test_pred = np.concatenate(test_pred)
|
test_pred = np.concatenate(test_pred)
|
||||||
test_acc = 100.0 * metrics.accuracy_score(test_true, test_pred)
|
test_acc = 100. * metrics.accuracy_score(test_true, test_pred)
|
||||||
test_mean_acc = 100.0 * metrics.balanced_accuracy_score(test_true, test_pred)
|
test_mean_acc = 100. * metrics.balanced_accuracy_score(test_true, test_pred)
|
||||||
if test_acc > best_acc:
|
if test_acc > best_acc:
|
||||||
best_acc = test_acc
|
best_acc = test_acc
|
||||||
if test_mean_acc > best_mean_acc:
|
if test_mean_acc > best_mean_acc:
|
||||||
best_mean_acc = test_mean_acc
|
best_mean_acc = test_mean_acc
|
||||||
outstr = "Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]" % (
|
outstr = 'Voting %d, test acc: %.3f, test mean acc: %.3f, [current best(all_acc: %.3f mean_acc: %.3f)]' % \
|
||||||
i,
|
(i, test_acc, test_mean_acc, best_acc, best_mean_acc)
|
||||||
test_acc,
|
|
||||||
test_mean_acc,
|
|
||||||
best_acc,
|
|
||||||
best_mean_acc,
|
|
||||||
)
|
|
||||||
io.cprint(outstr)
|
io.cprint(outstr)
|
||||||
|
|
||||||
final_outstr = "Final voting test acc: %.6f," % (best_acc * 100)
|
final_outstr = 'Final voting test acc: %.6f,' % (best_acc * 100)
|
||||||
io.cprint(final_outstr)
|
io.cprint(final_outstr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
"""ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip."""
|
"""
|
||||||
|
ScanObjectNN download: http://103.24.77.34/scanobjectnn/h5_files.zip
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -11,18 +14,18 @@ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
def download():
|
def download():
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||||
if not os.path.exists(DATA_DIR):
|
if not os.path.exists(DATA_DIR):
|
||||||
os.mkdir(DATA_DIR)
|
os.mkdir(DATA_DIR)
|
||||||
if not os.path.exists(os.path.join(DATA_DIR, "h5_files")):
|
if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')):
|
||||||
# note that this link only contains the hardest perturbed variant (PB_T50_RS).
|
# note that this link only contains the hardest perturbed variant (PB_T50_RS).
|
||||||
# for full versions, consider the following link.
|
# for full versions, consider the following link.
|
||||||
www = "https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip"
|
www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip'
|
||||||
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
|
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
|
||||||
zipfile = os.path.basename(www)
|
zipfile = os.path.basename(www)
|
||||||
os.system(f"wget {www} --no-check-certificate; unzip {zipfile}")
|
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
|
||||||
os.system(f"mv {zipfile[:-4]} {DATA_DIR}")
|
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
||||||
os.system("rm %s" % (zipfile))
|
os.system('rm %s' % (zipfile))
|
||||||
|
|
||||||
|
|
||||||
def load_scanobjectnn_data(partition):
|
def load_scanobjectnn_data(partition):
|
||||||
|
@ -31,10 +34,10 @@ def load_scanobjectnn_data(partition):
|
||||||
all_data = []
|
all_data = []
|
||||||
all_label = []
|
all_label = []
|
||||||
|
|
||||||
h5_name = BASE_DIR + "/data/h5_files/main_split/" + partition + "_objectdataset_augmentedrot_scale75.h5"
|
h5_name = BASE_DIR + '/data/h5_files/main_split/' + partition + '_objectdataset_augmentedrot_scale75.h5'
|
||||||
f = h5py.File(h5_name, mode="r")
|
f = h5py.File(h5_name, mode="r")
|
||||||
data = f["data"][:].astype("float32")
|
data = f['data'][:].astype('float32')
|
||||||
label = f["label"][:].astype("int64")
|
label = f['label'][:].astype('int64')
|
||||||
f.close()
|
f.close()
|
||||||
all_data.append(data)
|
all_data.append(data)
|
||||||
all_label.append(label)
|
all_label.append(label)
|
||||||
|
@ -44,22 +47,23 @@ def load_scanobjectnn_data(partition):
|
||||||
|
|
||||||
|
|
||||||
def translate_pointcloud(pointcloud):
|
def translate_pointcloud(pointcloud):
|
||||||
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
|
xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3])
|
||||||
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
||||||
|
|
||||||
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
|
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
|
||||||
|
return translated_pointcloud
|
||||||
|
|
||||||
|
|
||||||
class ScanObjectNN(Dataset):
|
class ScanObjectNN(Dataset):
|
||||||
def __init__(self, num_points, partition="training"):
|
def __init__(self, num_points, partition='training'):
|
||||||
self.data, self.label = load_scanobjectnn_data(partition)
|
self.data, self.label = load_scanobjectnn_data(partition)
|
||||||
self.num_points = num_points
|
self.num_points = num_points
|
||||||
self.partition = partition
|
self.partition = partition
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
pointcloud = self.data[item][: self.num_points]
|
pointcloud = self.data[item][:self.num_points]
|
||||||
label = self.label[item]
|
label = self.label[item]
|
||||||
if self.partition == "training":
|
if self.partition == 'training':
|
||||||
pointcloud = translate_pointcloud(pointcloud)
|
pointcloud = translate_pointcloud(pointcloud)
|
||||||
np.random.shuffle(pointcloud)
|
np.random.shuffle(pointcloud)
|
||||||
return pointcloud, label
|
return pointcloud, label
|
||||||
|
@ -68,9 +72,9 @@ class ScanObjectNN(Dataset):
|
||||||
return self.data.shape[0]
|
return self.data.shape[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train = ScanObjectNN(1024)
|
train = ScanObjectNN(1024)
|
||||||
test = ScanObjectNN(1024, "test")
|
test = ScanObjectNN(1024, 'test')
|
||||||
for data, label in train:
|
for data, label in train:
|
||||||
print(data.shape)
|
print(data.shape)
|
||||||
print(label)
|
print(label)
|
||||||
|
|
|
@ -1,50 +1,45 @@
|
||||||
"""for training with resume functions.
|
"""
|
||||||
|
for training with resume functions.
|
||||||
Usage:
|
Usage:
|
||||||
python main.py --model PointNet --msg demo
|
python main.py --model PointNet --msg demo
|
||||||
or
|
or
|
||||||
CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &.
|
CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import models as models
|
import datetime
|
||||||
import numpy as np
|
|
||||||
import sklearn.metrics as metrics
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data.distributed
|
import torch.utils.data.distributed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import models as models
|
||||||
|
from utils import Logger, mkdir_p, progress_bar, save_model, save_args, cal_loss
|
||||||
from ScanObjectNN import ScanObjectNN
|
from ScanObjectNN import ScanObjectNN
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
import sklearn.metrics as metrics
|
||||||
from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parameters"""
|
"""Parameters"""
|
||||||
parser = argparse.ArgumentParser("training")
|
parser = argparse.ArgumentParser('training')
|
||||||
parser.add_argument(
|
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
||||||
"-c",
|
help='path to save checkpoint (default: checkpoint)')
|
||||||
"--checkpoint",
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||||
type=str,
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
||||||
metavar="PATH",
|
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
|
||||||
help="path to save checkpoint (default: checkpoint)",
|
parser.add_argument('--num_classes', default=15, type=int, help='default value for classes of ScanObjectNN')
|
||||||
)
|
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
|
||||||
parser.add_argument("--msg", type=str, help="message after checkpoint")
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
|
parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate in training')
|
||||||
parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]")
|
parser.add_argument('--weight_decay', type=float, default=1e-4, help='decay rate')
|
||||||
parser.add_argument("--num_classes", default=15, type=int, help="default value for classes of ScanObjectNN")
|
parser.add_argument('--smoothing', action='store_true', default=False, help='loss smoothing')
|
||||||
parser.add_argument("--epoch", default=200, type=int, help="number of epoch in training")
|
parser.add_argument('--seed', type=int, help='random seed')
|
||||||
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
|
parser.add_argument('--workers', default=4, type=int, help='workers')
|
||||||
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate in training")
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="decay rate")
|
|
||||||
parser.add_argument("--smoothing", action="store_true", default=False, help="loss smoothing")
|
|
||||||
parser.add_argument("--seed", type=int, help="random seed")
|
|
||||||
parser.add_argument("--workers", default=4, type=int, help="workers")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,23 +49,23 @@ def main():
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = 'cuda'
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
torch.cuda.manual_seed(args.seed)
|
torch.cuda.manual_seed(args.seed)
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = 'cpu'
|
||||||
time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
|
time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
if args.msg is None:
|
if args.msg is None:
|
||||||
message = time_str
|
message = time_str
|
||||||
else:
|
else:
|
||||||
message = "-" + args.msg
|
message = "-" + args.msg
|
||||||
args.checkpoint = "checkpoints/" + args.model + message
|
args.checkpoint = 'checkpoints/' + args.model + message
|
||||||
if not os.path.isdir(args.checkpoint):
|
if not os.path.isdir(args.checkpoint):
|
||||||
mkdir_p(args.checkpoint)
|
mkdir_p(args.checkpoint)
|
||||||
|
|
||||||
screen_logger = logging.getLogger("Model")
|
screen_logger = logging.getLogger("Model")
|
||||||
screen_logger.setLevel(logging.INFO)
|
screen_logger.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(message)s")
|
formatter = logging.Formatter('%(message)s')
|
||||||
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
|
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
|
||||||
file_handler.setLevel(logging.INFO)
|
file_handler.setLevel(logging.INFO)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
@ -82,19 +77,19 @@ def main():
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
printf(f"args: {args}")
|
printf(f"args: {args}")
|
||||||
printf("==> Building model..")
|
printf('==> Building model..')
|
||||||
net = models.__dict__[args.model](num_classes=args.num_classes)
|
net = models.__dict__[args.model](num_classes=args.num_classes)
|
||||||
criterion = cal_loss
|
criterion = cal_loss
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
# criterion = criterion.to(device)
|
# criterion = criterion.to(device)
|
||||||
if device == "cuda":
|
if device == 'cuda':
|
||||||
net = torch.nn.DataParallel(net)
|
net = torch.nn.DataParallel(net)
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
best_test_acc = 0.0 # best test accuracy
|
best_test_acc = 0. # best test accuracy
|
||||||
best_train_acc = 0.0
|
best_train_acc = 0.
|
||||||
best_test_acc_avg = 0.0
|
best_test_acc_avg = 0.
|
||||||
best_train_acc_avg = 0.0
|
best_train_acc_avg = 0.
|
||||||
best_test_loss = float("inf")
|
best_test_loss = float("inf")
|
||||||
best_train_loss = float("inf")
|
best_train_loss = float("inf")
|
||||||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||||
|
@ -102,49 +97,30 @@ def main():
|
||||||
|
|
||||||
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
||||||
save_args(args)
|
save_args(args)
|
||||||
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
|
||||||
logger.set_names(
|
logger.set_names(["Epoch-Num", 'Learning-Rate',
|
||||||
[
|
'Train-Loss', 'Train-acc-B', 'Train-acc',
|
||||||
"Epoch-Num",
|
'Valid-Loss', 'Valid-acc-B', 'Valid-acc'])
|
||||||
"Learning-Rate",
|
|
||||||
"Train-Loss",
|
|
||||||
"Train-acc-B",
|
|
||||||
"Train-acc",
|
|
||||||
"Valid-Loss",
|
|
||||||
"Valid-acc-B",
|
|
||||||
"Valid-acc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
printf(f"Resuming last checkpoint from {args.checkpoint}")
|
printf(f"Resuming last checkpoint from {args.checkpoint}")
|
||||||
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
|
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
net.load_state_dict(checkpoint["net"])
|
net.load_state_dict(checkpoint['net'])
|
||||||
start_epoch = checkpoint["epoch"]
|
start_epoch = checkpoint['epoch']
|
||||||
best_test_acc = checkpoint["best_test_acc"]
|
best_test_acc = checkpoint['best_test_acc']
|
||||||
best_train_acc = checkpoint["best_train_acc"]
|
best_train_acc = checkpoint['best_train_acc']
|
||||||
best_test_acc_avg = checkpoint["best_test_acc_avg"]
|
best_test_acc_avg = checkpoint['best_test_acc_avg']
|
||||||
best_train_acc_avg = checkpoint["best_train_acc_avg"]
|
best_train_acc_avg = checkpoint['best_train_acc_avg']
|
||||||
best_test_loss = checkpoint["best_test_loss"]
|
best_test_loss = checkpoint['best_test_loss']
|
||||||
best_train_loss = checkpoint["best_train_loss"]
|
best_train_loss = checkpoint['best_train_loss']
|
||||||
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
|
||||||
optimizer_dict = checkpoint["optimizer"]
|
optimizer_dict = checkpoint['optimizer']
|
||||||
|
|
||||||
printf("==> Preparing data..")
|
printf('==> Preparing data..')
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(ScanObjectNN(partition='training', num_points=args.num_points), num_workers=args.workers,
|
||||||
ScanObjectNN(partition="training", num_points=args.num_points),
|
batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||||
num_workers=args.workers,
|
test_loader = DataLoader(ScanObjectNN(partition='test', num_points=args.num_points), num_workers=args.workers,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, shuffle=True, drop_last=False)
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(
|
|
||||||
ScanObjectNN(partition="test", num_points=args.num_points),
|
|
||||||
num_workers=args.workers,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
||||||
if optimizer_dict is not None:
|
if optimizer_dict is not None:
|
||||||
|
@ -152,7 +128,7 @@ def main():
|
||||||
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1)
|
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epoch):
|
for epoch in range(start_epoch, args.epoch):
|
||||||
printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"]))
|
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
|
||||||
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
|
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
|
||||||
test_out = validate(net, test_loader, criterion, device)
|
test_out = validate(net, test_loader, criterion, device)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
@ -171,46 +147,31 @@ def main():
|
||||||
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
||||||
|
|
||||||
save_model(
|
save_model(
|
||||||
net,
|
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
|
||||||
epoch,
|
|
||||||
path=args.checkpoint,
|
|
||||||
acc=test_out["acc"],
|
|
||||||
is_best=is_best,
|
|
||||||
best_test_acc=best_test_acc, # best test accuracy
|
best_test_acc=best_test_acc, # best test accuracy
|
||||||
best_train_acc=best_train_acc,
|
best_train_acc=best_train_acc,
|
||||||
best_test_acc_avg=best_test_acc_avg,
|
best_test_acc_avg=best_test_acc_avg,
|
||||||
best_train_acc_avg=best_train_acc_avg,
|
best_train_acc_avg=best_train_acc_avg,
|
||||||
best_test_loss=best_test_loss,
|
best_test_loss=best_test_loss,
|
||||||
best_train_loss=best_train_loss,
|
best_train_loss=best_train_loss,
|
||||||
optimizer=optimizer.state_dict(),
|
optimizer=optimizer.state_dict()
|
||||||
)
|
|
||||||
logger.append(
|
|
||||||
[
|
|
||||||
epoch,
|
|
||||||
optimizer.param_groups[0]["lr"],
|
|
||||||
train_out["loss"],
|
|
||||||
train_out["acc_avg"],
|
|
||||||
train_out["acc"],
|
|
||||||
test_out["loss"],
|
|
||||||
test_out["acc_avg"],
|
|
||||||
test_out["acc"],
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
logger.append([epoch, optimizer.param_groups[0]['lr'],
|
||||||
|
train_out["loss"], train_out["acc_avg"], train_out["acc"],
|
||||||
|
test_out["loss"], test_out["acc_avg"], test_out["acc"]])
|
||||||
printf(
|
printf(
|
||||||
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s",
|
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s")
|
||||||
)
|
|
||||||
printf(
|
printf(
|
||||||
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
|
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
|
||||||
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n",
|
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n")
|
||||||
)
|
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
printf("++++++++" * 2 + "Final results" + "++++++++" * 2)
|
printf(f"++++++++" * 2 + "Final results" + "++++++++" * 2)
|
||||||
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
|
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
|
||||||
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
|
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
|
||||||
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
|
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
|
||||||
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
|
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
|
||||||
printf("++++++++" * 5)
|
printf(f"++++++++" * 5)
|
||||||
|
|
||||||
|
|
||||||
def train(net, trainloader, optimizer, criterion, device):
|
def train(net, trainloader, optimizer, criterion, device):
|
||||||
|
@ -238,21 +199,17 @@ def train(net, trainloader, optimizer, criterion, device):
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
|
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(trainloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
train_true = np.concatenate(train_true)
|
train_true = np.concatenate(train_true)
|
||||||
train_pred = np.concatenate(train_pred)
|
train_pred = np.concatenate(train_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (train_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (train_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(train_true, train_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(train_true, train_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,23 +233,19 @@ def validate(net, testloader, criterion, device):
|
||||||
test_pred.append(preds.detach().cpu().numpy())
|
test_pred.append(preds.detach().cpu().numpy())
|
||||||
total += label.size(0)
|
total += label.size(0)
|
||||||
correct += preds.eq(label).sum().item()
|
correct += preds.eq(label).sum().item()
|
||||||
progress_bar(
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
batch_idx,
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||||
len(testloader),
|
|
||||||
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
||||||
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
||||||
)
|
|
||||||
|
|
||||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||||
test_true = np.concatenate(test_true)
|
test_true = np.concatenate(test_true)
|
||||||
test_pred = np.concatenate(test_pred)
|
test_pred = np.concatenate(test_pred)
|
||||||
return {
|
return {
|
||||||
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
||||||
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
||||||
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
||||||
"time": time_cost,
|
"time": time_cost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1 +1,3 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from .pointmlp import pointMLP, pointMLPElite
|
from .pointmlp import pointMLP, pointMLPElite
|
||||||
|
|
|
@ -1,32 +1,35 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# from torch import einsum
|
# from torch import einsum
|
||||||
# from einops import rearrange, repeat
|
# from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
from pointnet2_ops import pointnet2_utils
|
from pointnet2_ops import pointnet2_utils
|
||||||
|
|
||||||
|
|
||||||
def get_activation(activation):
|
def get_activation(activation):
|
||||||
if activation.lower() == "gelu":
|
if activation.lower() == 'gelu':
|
||||||
return nn.GELU()
|
return nn.GELU()
|
||||||
elif activation.lower() == "rrelu":
|
elif activation.lower() == 'rrelu':
|
||||||
return nn.RReLU(inplace=True)
|
return nn.RReLU(inplace=True)
|
||||||
elif activation.lower() == "selu":
|
elif activation.lower() == 'selu':
|
||||||
return nn.SELU(inplace=True)
|
return nn.SELU(inplace=True)
|
||||||
elif activation.lower() == "silu":
|
elif activation.lower() == 'silu':
|
||||||
return nn.SiLU(inplace=True)
|
return nn.SiLU(inplace=True)
|
||||||
elif activation.lower() == "hardswish":
|
elif activation.lower() == 'hardswish':
|
||||||
return nn.Hardswish(inplace=True)
|
return nn.Hardswish(inplace=True)
|
||||||
elif activation.lower() == "leakyrelu":
|
elif activation.lower() == 'leakyrelu':
|
||||||
return nn.LeakyReLU(inplace=True)
|
return nn.LeakyReLU(inplace=True)
|
||||||
else:
|
else:
|
||||||
return nn.ReLU(inplace=True)
|
return nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
|
||||||
def square_distance(src, dst):
|
def square_distance(src, dst):
|
||||||
"""Calculate Euclid distance between each two points.
|
"""
|
||||||
src^T * dst = xn * xm + yn * ym + zn * zm;
|
Calculate Euclid distance between each two points.
|
||||||
|
src^T * dst = xn * xm + yn * ym + zn * zm;
|
||||||
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
||||||
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
||||||
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
||||||
|
@ -35,23 +38,23 @@ def square_distance(src, dst):
|
||||||
src: source points, [B, N, C]
|
src: source points, [B, N, C]
|
||||||
dst: target points, [B, M, C]
|
dst: target points, [B, M, C]
|
||||||
Output:
|
Output:
|
||||||
dist: per-point square distance, [B, N, M].
|
dist: per-point square distance, [B, N, M]
|
||||||
"""
|
"""
|
||||||
B, N, _ = src.shape
|
B, N, _ = src.shape
|
||||||
_, M, _ = dst.shape
|
_, M, _ = dst.shape
|
||||||
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
||||||
dist += torch.sum(src**2, -1).view(B, N, 1)
|
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
||||||
dist += torch.sum(dst**2, -1).view(B, 1, M)
|
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
def index_points(points, idx):
|
def index_points(points, idx):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
points: input points data, [B, N, C]
|
points: input points data, [B, N, C]
|
||||||
idx: sample index data, [B, S].
|
idx: sample index data, [B, S]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
new_points:, indexed points data, [B, S, C].
|
new_points:, indexed points data, [B, S, C]
|
||||||
"""
|
"""
|
||||||
device = points.device
|
device = points.device
|
||||||
B = points.shape[0]
|
B = points.shape[0]
|
||||||
|
@ -60,15 +63,17 @@ def index_points(points, idx):
|
||||||
repeat_shape = list(idx.shape)
|
repeat_shape = list(idx.shape)
|
||||||
repeat_shape[0] = 1
|
repeat_shape[0] = 1
|
||||||
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
||||||
return points[batch_indices, idx, :]
|
new_points = points[batch_indices, idx, :]
|
||||||
|
return new_points
|
||||||
|
|
||||||
|
|
||||||
def farthest_point_sample(xyz, npoint):
|
def farthest_point_sample(xyz, npoint):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
xyz: pointcloud data, [B, N, 3]
|
xyz: pointcloud data, [B, N, 3]
|
||||||
npoint: number of samples
|
npoint: number of samples
|
||||||
Return:
|
Return:
|
||||||
centroids: sampled pointcloud index, [B, npoint].
|
centroids: sampled pointcloud index, [B, npoint]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
|
@ -86,21 +91,21 @@ def farthest_point_sample(xyz, npoint):
|
||||||
|
|
||||||
|
|
||||||
def query_ball_point(radius, nsample, xyz, new_xyz):
|
def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
radius: local region radius
|
radius: local region radius
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, 3]
|
xyz: all points, [B, N, 3]
|
||||||
new_xyz: query points, [B, S, 3].
|
new_xyz: query points, [B, S, 3]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
_, S, _ = new_xyz.shape
|
_, S, _ = new_xyz.shape
|
||||||
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
group_idx[sqrdists > radius**2] = N
|
group_idx[sqrdists > radius ** 2] = N
|
||||||
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
||||||
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
||||||
mask = group_idx == N
|
mask = group_idx == N
|
||||||
|
@ -109,13 +114,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
|
|
||||||
|
|
||||||
def knn_point(nsample, xyz, new_xyz):
|
def knn_point(nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, C]
|
xyz: all points, [B, N, C]
|
||||||
new_xyz: query points, [B, S, C].
|
new_xyz: query points, [B, S, C]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
||||||
|
@ -124,12 +129,13 @@ def knn_point(nsample, xyz, new_xyz):
|
||||||
|
|
||||||
class LocalGrouper(nn.Module):
|
class LocalGrouper(nn.Module):
|
||||||
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
|
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
|
||||||
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
"""
|
||||||
|
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
||||||
:param groups: groups number
|
:param groups: groups number
|
||||||
:param kneighbors: k-nerighbors
|
:param kneighbors: k-nerighbors
|
||||||
:param kwargs: others.
|
:param kwargs: others
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(LocalGrouper, self).__init__()
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.kneighbors = kneighbors
|
self.kneighbors = kneighbors
|
||||||
self.use_xyz = use_xyz
|
self.use_xyz = use_xyz
|
||||||
|
@ -138,11 +144,11 @@ class LocalGrouper(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize not in ["center", "anchor"]:
|
if self.normalize not in ["center", "anchor"]:
|
||||||
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
add_channel = 3 if self.use_xyz else 0
|
add_channel=3 if self.use_xyz else 0
|
||||||
self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
|
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel]))
|
||||||
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
||||||
|
|
||||||
def forward(self, xyz, points):
|
def forward(self, xyz, points):
|
||||||
|
@ -161,33 +167,29 @@ class LocalGrouper(nn.Module):
|
||||||
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
||||||
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
||||||
if self.use_xyz:
|
if self.use_xyz:
|
||||||
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
|
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
if self.normalize == "center":
|
if self.normalize =="center":
|
||||||
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
||||||
if self.normalize == "anchor":
|
if self.normalize =="anchor":
|
||||||
mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
|
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
|
||||||
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
||||||
std = (
|
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
||||||
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
|
grouped_points = (grouped_points-mean)/(std + 1e-5)
|
||||||
.unsqueeze(dim=-1)
|
grouped_points = self.affine_alpha*grouped_points + self.affine_beta
|
||||||
.unsqueeze(dim=-1)
|
|
||||||
)
|
|
||||||
grouped_points = (grouped_points - mean) / (std + 1e-5)
|
|
||||||
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
|
|
||||||
|
|
||||||
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
||||||
return new_xyz, new_points
|
return new_xyz, new_points
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLU1D(nn.Module):
|
class ConvBNReLU1D(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
|
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLU1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -195,43 +197,30 @@ class ConvBNReLU1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLURes1D(nn.Module):
|
class ConvBNReLURes1D(nn.Module):
|
||||||
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
|
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLURes1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net1 = nn.Sequential(
|
self.net1 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
|
||||||
in_channels=channel,
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=int(channel * res_expansion),
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(int(channel * res_expansion)),
|
nn.BatchNorm1d(int(channel * res_expansion)),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
if groups > 1:
|
if groups > 1:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=channel,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
self.act,
|
self.act,
|
||||||
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=channel, out_channels=channel,
|
||||||
|
kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, bias=bias),
|
||||||
out_channels=channel,
|
nn.BatchNorm1d(channel)
|
||||||
kernel_size=kernel_size,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -239,34 +228,21 @@ class ConvBNReLURes1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PreExtraction(nn.Module):
|
class PreExtraction(nn.Module):
|
||||||
def __init__(
|
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
|
||||||
self,
|
activation='relu', use_xyz=True):
|
||||||
channels,
|
"""
|
||||||
out_channels,
|
input: [b,g,k,d]: output:[b,d,g]
|
||||||
blocks=1,
|
|
||||||
groups=1,
|
|
||||||
res_expansion=1,
|
|
||||||
bias=True,
|
|
||||||
activation="relu",
|
|
||||||
use_xyz=True,
|
|
||||||
):
|
|
||||||
"""input: [b,g,k,d]: output:[b,d,g]
|
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PreExtraction, self).__init__()
|
||||||
in_channels = 3 + 2 * channels if use_xyz else 2 * channels
|
in_channels = 3+2*channels if use_xyz else 2*channels
|
||||||
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(
|
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
|
||||||
out_channels,
|
bias=bias, activation=activation)
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -278,20 +254,22 @@ class PreExtraction(nn.Module):
|
||||||
batch_size, _, _ = x.size()
|
batch_size, _, _ = x.size()
|
||||||
x = self.operation(x) # [b, d, k]
|
x = self.operation(x) # [b, d, k]
|
||||||
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
||||||
return x.reshape(b, n, -1).permute(0, 2, 1)
|
x = x.reshape(b, n, -1).permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PosExtraction(nn.Module):
|
class PosExtraction(nn.Module):
|
||||||
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
|
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
|
||||||
"""input[b,d,g]; output[b,d,g]
|
"""
|
||||||
|
input[b,d,g]; output[b,d,g]
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PosExtraction, self).__init__()
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
|
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -300,32 +278,17 @@ class PosExtraction(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(
|
def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
self,
|
activation="relu", bias=True, use_xyz=True, normalize="center",
|
||||||
points=1024,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
class_num=40,
|
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
|
||||||
embed_dim=64,
|
super(Model, self).__init__()
|
||||||
groups=1,
|
|
||||||
res_expansion=1.0,
|
|
||||||
activation="relu",
|
|
||||||
bias=True,
|
|
||||||
use_xyz=True,
|
|
||||||
normalize="center",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[32, 32, 32, 32],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.stages = len(pre_blocks)
|
self.stages = len(pre_blocks)
|
||||||
self.class_num = class_num
|
self.class_num = class_num
|
||||||
self.points = points
|
self.points = points
|
||||||
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
|
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
|
||||||
assert (
|
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
|
||||||
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
|
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
||||||
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
|
||||||
self.local_grouper_list = nn.ModuleList()
|
self.local_grouper_list = nn.ModuleList()
|
||||||
self.pre_blocks_list = nn.ModuleList()
|
self.pre_blocks_list = nn.ModuleList()
|
||||||
self.pos_blocks_list = nn.ModuleList()
|
self.pos_blocks_list = nn.ModuleList()
|
||||||
|
@ -342,26 +305,13 @@ class Model(nn.Module):
|
||||||
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
||||||
self.local_grouper_list.append(local_grouper)
|
self.local_grouper_list.append(local_grouper)
|
||||||
# append pre_block_list
|
# append pre_block_list
|
||||||
pre_block_module = PreExtraction(
|
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
|
||||||
last_channel,
|
res_expansion=res_expansion,
|
||||||
out_channel,
|
bias=bias, activation=activation, use_xyz=use_xyz)
|
||||||
pre_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
use_xyz=use_xyz,
|
|
||||||
)
|
|
||||||
self.pre_blocks_list.append(pre_block_module)
|
self.pre_blocks_list.append(pre_block_module)
|
||||||
# append pos_block_list
|
# append pos_block_list
|
||||||
pos_block_module = PosExtraction(
|
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
|
||||||
out_channel,
|
res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
pos_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
self.pos_blocks_list.append(pos_block_module)
|
self.pos_blocks_list.append(pos_block_module)
|
||||||
|
|
||||||
last_channel = out_channel
|
last_channel = out_channel
|
||||||
|
@ -376,7 +326,7 @@ class Model(nn.Module):
|
||||||
nn.BatchNorm1d(256),
|
nn.BatchNorm1d(256),
|
||||||
self.act,
|
self.act,
|
||||||
nn.Dropout(0.5),
|
nn.Dropout(0.5),
|
||||||
nn.Linear(256, self.class_num),
|
nn.Linear(256, self.class_num)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -390,52 +340,29 @@ class Model(nn.Module):
|
||||||
x = self.pos_blocks_list[i](x) # [b,d,g]
|
x = self.pos_blocks_list[i](x) # [b,d,g]
|
||||||
|
|
||||||
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
|
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
|
||||||
return self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pointMLP(num_classes=40, **kwargs) -> Model:
|
def pointMLP(num_classes=40, **kwargs) -> Model:
|
||||||
return Model(
|
return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
points=1024,
|
activation="relu", bias=False, use_xyz=False, normalize="anchor",
|
||||||
class_num=num_classes,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
embed_dim=64,
|
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
|
||||||
groups=1,
|
|
||||||
res_expansion=1.0,
|
|
||||||
activation="relu",
|
|
||||||
bias=False,
|
|
||||||
use_xyz=False,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[24, 24, 24, 24],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
||||||
return Model(
|
return Model(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25,
|
||||||
points=1024,
|
activation="relu", bias=False, use_xyz=False, normalize="anchor",
|
||||||
class_num=num_classes,
|
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1],
|
||||||
embed_dim=32,
|
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
|
||||||
groups=1,
|
|
||||||
res_expansion=0.25,
|
|
||||||
activation="relu",
|
|
||||||
bias=False,
|
|
||||||
use_xyz=False,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 1],
|
|
||||||
pre_blocks=[1, 1, 2, 1],
|
|
||||||
pos_blocks=[1, 1, 2, 1],
|
|
||||||
k_neighbors=[24, 24, 24, 24],
|
|
||||||
reducers=[2, 2, 2, 2],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
if __name__ == "__main__":
|
|
||||||
data = torch.rand(2, 3, 1024)
|
data = torch.rand(2, 3, 1024)
|
||||||
print("===> testing pointMLP ...")
|
print("===> testing pointMLP ...")
|
||||||
model = pointMLP()
|
model = pointMLP()
|
||||||
out = model(data)
|
out = model(data)
|
||||||
print(out.shape)
|
print(out.shape)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Useful utils."""
|
"""Useful utils
|
||||||
from .logger import *
|
"""
|
||||||
from .misc import *
|
from .misc import *
|
||||||
|
from .logger import *
|
||||||
from .progress.progress.bar import Bar as Bar
|
from .progress.progress.bar import Bar as Bar
|
||||||
|
|
|
@ -1,50 +1,48 @@
|
||||||
# A simple torch style logger
|
# A simple torch style logger
|
||||||
# (C) Wei YANG 2017
|
# (C) Wei YANG 2017
|
||||||
|
from __future__ import absolute_import
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__all__ = ["Logger", "LoggerMonitor", "savefig"]
|
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
|
||||||
|
|
||||||
|
|
||||||
def savefig(fname, dpi=None):
|
def savefig(fname, dpi=None):
|
||||||
dpi = 150 if dpi is None else dpi
|
dpi = 150 if dpi == None else dpi
|
||||||
plt.savefig(fname, dpi=dpi)
|
plt.savefig(fname, dpi=dpi)
|
||||||
|
|
||||||
|
|
||||||
def plot_overlap(logger, names=None):
|
def plot_overlap(logger, names=None):
|
||||||
names = logger.names if names is None else names
|
names = logger.names if names == None else names
|
||||||
numbers = logger.numbers
|
numbers = logger.numbers
|
||||||
for _, name in enumerate(names):
|
for _, name in enumerate(names):
|
||||||
x = np.arange(len(numbers[name]))
|
x = np.arange(len(numbers[name]))
|
||||||
plt.plot(x, np.asarray(numbers[name]))
|
plt.plot(x, np.asarray(numbers[name]))
|
||||||
return [logger.title + "(" + name + ")" for name in names]
|
return [logger.title + '(' + name + ')' for name in names]
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
|
||||||
"""Save training process to log file with simple plot function."""
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
'''Save training process to log file with simple plot function.'''
|
||||||
def __init__(self, fpath, title=None, resume=False):
|
def __init__(self, fpath, title=None, resume=False):
|
||||||
self.file = None
|
self.file = None
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
self.title = "" if title is None else title
|
self.title = '' if title == None else title
|
||||||
if fpath is not None:
|
if fpath is not None:
|
||||||
if resume:
|
if resume:
|
||||||
self.file = open(fpath)
|
self.file = open(fpath, 'r')
|
||||||
name = self.file.readline()
|
name = self.file.readline()
|
||||||
self.names = name.rstrip().split("\t")
|
self.names = name.rstrip().split('\t')
|
||||||
self.numbers = {}
|
self.numbers = {}
|
||||||
for _, name in enumerate(self.names):
|
for _, name in enumerate(self.names):
|
||||||
self.numbers[name] = []
|
self.numbers[name] = []
|
||||||
|
|
||||||
for numbers in self.file:
|
for numbers in self.file:
|
||||||
numbers = numbers.rstrip().split("\t")
|
numbers = numbers.rstrip().split('\t')
|
||||||
for i in range(0, len(numbers)):
|
for i in range(0, len(numbers)):
|
||||||
self.numbers[self.names[i]].append(numbers[i])
|
self.numbers[self.names[i]].append(numbers[i])
|
||||||
self.file.close()
|
self.file.close()
|
||||||
self.file = open(fpath, "a")
|
self.file = open(fpath, 'a')
|
||||||
else:
|
else:
|
||||||
self.file = open(fpath, "w")
|
self.file = open(fpath, 'w')
|
||||||
|
|
||||||
def set_names(self, names):
|
def set_names(self, names):
|
||||||
if self.resume:
|
if self.resume:
|
||||||
|
@ -54,39 +52,38 @@ class Logger:
|
||||||
self.names = names
|
self.names = names
|
||||||
for _, name in enumerate(self.names):
|
for _, name in enumerate(self.names):
|
||||||
self.file.write(name)
|
self.file.write(name)
|
||||||
self.file.write("\t")
|
self.file.write('\t')
|
||||||
self.numbers[name] = []
|
self.numbers[name] = []
|
||||||
self.file.write("\n")
|
self.file.write('\n')
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
|
|
||||||
def append(self, numbers):
|
def append(self, numbers):
|
||||||
assert len(self.names) == len(numbers), "Numbers do not match names"
|
assert len(self.names) == len(numbers), 'Numbers do not match names'
|
||||||
for index, num in enumerate(numbers):
|
for index, num in enumerate(numbers):
|
||||||
self.file.write(f"{num:.6f}")
|
self.file.write("{0:.6f}".format(num))
|
||||||
self.file.write("\t")
|
self.file.write('\t')
|
||||||
self.numbers[self.names[index]].append(num)
|
self.numbers[self.names[index]].append(num)
|
||||||
self.file.write("\n")
|
self.file.write('\n')
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def plot(self, names=None):
|
def plot(self, names=None):
|
||||||
names = self.names if names is None else names
|
names = self.names if names == None else names
|
||||||
numbers = self.numbers
|
numbers = self.numbers
|
||||||
for _, name in enumerate(names):
|
for _, name in enumerate(names):
|
||||||
x = np.arange(len(numbers[name]))
|
x = np.arange(len(numbers[name]))
|
||||||
plt.plot(x, np.asarray(numbers[name]))
|
plt.plot(x, np.asarray(numbers[name]))
|
||||||
plt.legend([self.title + "(" + name + ")" for name in names])
|
plt.legend([self.title + '(' + name + ')' for name in names])
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.file is not None:
|
if self.file is not None:
|
||||||
self.file.close()
|
self.file.close()
|
||||||
|
|
||||||
|
class LoggerMonitor(object):
|
||||||
class LoggerMonitor:
|
'''Load and visualize multiple logs.'''
|
||||||
"""Load and visualize multiple logs."""
|
def __init__ (self, paths):
|
||||||
|
'''paths is a distionary with {name:filepath} pair'''
|
||||||
def __init__(self, paths):
|
|
||||||
"""Paths is a distionary with {name:filepath} pair."""
|
|
||||||
self.loggers = []
|
self.loggers = []
|
||||||
for title, path in paths.items():
|
for title, path in paths.items():
|
||||||
logger = Logger(path, title=title, resume=True)
|
logger = Logger(path, title=title, resume=True)
|
||||||
|
@ -98,11 +95,10 @@ class LoggerMonitor:
|
||||||
legend_text = []
|
legend_text = []
|
||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
legend_text += plot_overlap(logger, names)
|
legend_text += plot_overlap(logger, names)
|
||||||
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
|
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
if __name__ == "__main__":
|
|
||||||
# # Example
|
# # Example
|
||||||
# logger = Logger('test.txt')
|
# logger = Logger('test.txt')
|
||||||
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
||||||
|
@ -119,13 +115,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Example: logger monitor
|
# Example: logger monitor
|
||||||
paths = {
|
paths = {
|
||||||
"resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt",
|
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
|
||||||
"resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt",
|
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
|
||||||
"resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt",
|
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
|
||||||
}
|
}
|
||||||
|
|
||||||
field = ["Valid Acc."]
|
field = ['Valid Acc.']
|
||||||
|
|
||||||
monitor = LoggerMonitor(paths)
|
monitor = LoggerMonitor(paths)
|
||||||
monitor.plot(names=field)
|
monitor.plot(names=field)
|
||||||
savefig("test.eps")
|
savefig('test.eps')
|
|
@ -1,56 +1,48 @@
|
||||||
"""Some helper functions for PyTorch, including:
|
'''Some helper functions for PyTorch, including:
|
||||||
- get_mean_and_std: calculate the mean and std value of dataset.
|
- get_mean_and_std: calculate the mean and std value of dataset.
|
||||||
- msr_init: net parameter initialization.
|
- msr_init: net parameter initialization.
|
||||||
- progress_bar: progress bar mimic xlua.progress.
|
- progress_bar: progress bar mimic xlua.progress.
|
||||||
"""
|
'''
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import shutil
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_mean_and_std",
|
import torch.nn as nn
|
||||||
"init_params",
|
import torch.nn.init as init
|
||||||
"mkdir_p",
|
from torch.autograd import Variable
|
||||||
"AverageMeter",
|
|
||||||
"progress_bar",
|
__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter',
|
||||||
"save_model",
|
'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"]
|
||||||
"save_args",
|
|
||||||
"set_seed",
|
|
||||||
"IOStream",
|
|
||||||
"cal_loss",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_std(dataset):
|
def get_mean_and_std(dataset):
|
||||||
"""Compute the mean and std value of dataset."""
|
'''Compute the mean and std value of dataset.'''
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
|
dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
|
||||||
|
|
||||||
mean = torch.zeros(3)
|
mean = torch.zeros(3)
|
||||||
std = torch.zeros(3)
|
std = torch.zeros(3)
|
||||||
print("==> Computing mean and std..")
|
print('==> Computing mean and std..')
|
||||||
for inputs, _targets in dataloader:
|
for inputs, targets in dataloader:
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
mean[i] += inputs[:, i, :, :].mean()
|
mean[i] += inputs[:,i,:,:].mean()
|
||||||
std[i] += inputs[:, i, :, :].std()
|
std[i] += inputs[:,i,:,:].std()
|
||||||
mean.div_(len(dataset))
|
mean.div_(len(dataset))
|
||||||
std.div_(len(dataset))
|
std.div_(len(dataset))
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
def init_params(net):
|
def init_params(net):
|
||||||
"""Init layer parameters."""
|
'''Init layer parameters.'''
|
||||||
for m in net.modules():
|
for m in net.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
init.kaiming_normal(m.weight, mode="fan_out")
|
init.kaiming_normal(m.weight, mode='fan_out')
|
||||||
if m.bias:
|
if m.bias:
|
||||||
init.constant(m.bias, 0)
|
init.constant(m.bias, 0)
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
@ -61,9 +53,8 @@ def init_params(net):
|
||||||
if m.bias:
|
if m.bias:
|
||||||
init.constant(m.bias, 0)
|
init.constant(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
def mkdir_p(path):
|
def mkdir_p(path):
|
||||||
"""Make dir if not exist."""
|
'''make dir if not exist'''
|
||||||
try:
|
try:
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
except OSError as exc: # Python >2.5
|
except OSError as exc: # Python >2.5
|
||||||
|
@ -72,12 +63,10 @@ def mkdir_p(path):
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
class AverageMeter:
|
|
||||||
"""Computes and stores the average and current value
|
"""Computes and stores the average and current value
|
||||||
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
|
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -94,26 +83,25 @@ class AverageMeter:
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
TOTAL_BAR_LENGTH = 65.0
|
|
||||||
|
TOTAL_BAR_LENGTH = 65.
|
||||||
last_time = time.time()
|
last_time = time.time()
|
||||||
begin_time = last_time
|
begin_time = last_time
|
||||||
|
|
||||||
|
|
||||||
def progress_bar(current, total, msg=None):
|
def progress_bar(current, total, msg=None):
|
||||||
global last_time, begin_time
|
global last_time, begin_time
|
||||||
if current == 0:
|
if current == 0:
|
||||||
begin_time = time.time() # Reset for new bar.
|
begin_time = time.time() # Reset for new bar.
|
||||||
|
|
||||||
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
cur_len = int(TOTAL_BAR_LENGTH*current/total)
|
||||||
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
||||||
|
|
||||||
sys.stdout.write(" [")
|
sys.stdout.write(' [')
|
||||||
for _i in range(cur_len):
|
for i in range(cur_len):
|
||||||
sys.stdout.write("=")
|
sys.stdout.write('=')
|
||||||
sys.stdout.write(">")
|
sys.stdout.write('>')
|
||||||
for _i in range(rest_len):
|
for i in range(rest_len):
|
||||||
sys.stdout.write(".")
|
sys.stdout.write('.')
|
||||||
sys.stdout.write("]")
|
sys.stdout.write(']')
|
||||||
|
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
step_time = cur_time - last_time
|
step_time = cur_time - last_time
|
||||||
|
@ -121,12 +109,12 @@ def progress_bar(current, total, msg=None):
|
||||||
tot_time = cur_time - begin_time
|
tot_time = cur_time - begin_time
|
||||||
|
|
||||||
L = []
|
L = []
|
||||||
L.append(" Step: %s" % format_time(step_time))
|
L.append(' Step: %s' % format_time(step_time))
|
||||||
L.append(" | Tot: %s" % format_time(tot_time))
|
L.append(' | Tot: %s' % format_time(tot_time))
|
||||||
if msg:
|
if msg:
|
||||||
L.append(" | " + msg)
|
L.append(' | ' + msg)
|
||||||
|
|
||||||
msg = "".join(L)
|
msg = ''.join(L)
|
||||||
sys.stdout.write(msg)
|
sys.stdout.write(msg)
|
||||||
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
||||||
# sys.stdout.write(' ')
|
# sys.stdout.write(' ')
|
||||||
|
@ -134,74 +122,76 @@ def progress_bar(current, total, msg=None):
|
||||||
# Go back to the center of the bar.
|
# Go back to the center of the bar.
|
||||||
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
||||||
# sys.stdout.write('\b')
|
# sys.stdout.write('\b')
|
||||||
sys.stdout.write(" %d/%d " % (current + 1, total))
|
sys.stdout.write(' %d/%d ' % (current+1, total))
|
||||||
|
|
||||||
if current < total - 1:
|
if current < total-1:
|
||||||
sys.stdout.write("\r")
|
sys.stdout.write('\r')
|
||||||
else:
|
else:
|
||||||
sys.stdout.write("\n")
|
sys.stdout.write('\n')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
days = int(seconds / 3600 / 24)
|
days = int(seconds / 3600/24)
|
||||||
seconds = seconds - days * 3600 * 24
|
seconds = seconds - days*3600*24
|
||||||
hours = int(seconds / 3600)
|
hours = int(seconds / 3600)
|
||||||
seconds = seconds - hours * 3600
|
seconds = seconds - hours*3600
|
||||||
minutes = int(seconds / 60)
|
minutes = int(seconds / 60)
|
||||||
seconds = seconds - minutes * 60
|
seconds = seconds - minutes*60
|
||||||
secondsf = int(seconds)
|
secondsf = int(seconds)
|
||||||
seconds = seconds - secondsf
|
seconds = seconds - secondsf
|
||||||
millis = int(seconds * 1000)
|
millis = int(seconds*1000)
|
||||||
|
|
||||||
f = ""
|
f = ''
|
||||||
i = 1
|
i = 1
|
||||||
if days > 0:
|
if days > 0:
|
||||||
f += str(days) + "D"
|
f += str(days) + 'D'
|
||||||
i += 1
|
i += 1
|
||||||
if hours > 0 and i <= 2:
|
if hours > 0 and i <= 2:
|
||||||
f += str(hours) + "h"
|
f += str(hours) + 'h'
|
||||||
i += 1
|
i += 1
|
||||||
if minutes > 0 and i <= 2:
|
if minutes > 0 and i <= 2:
|
||||||
f += str(minutes) + "m"
|
f += str(minutes) + 'm'
|
||||||
i += 1
|
i += 1
|
||||||
if secondsf > 0 and i <= 2:
|
if secondsf > 0 and i <= 2:
|
||||||
f += str(secondsf) + "s"
|
f += str(secondsf) + 's'
|
||||||
i += 1
|
i += 1
|
||||||
if millis > 0 and i <= 2:
|
if millis > 0 and i <= 2:
|
||||||
f += str(millis) + "ms"
|
f += str(millis) + 'ms'
|
||||||
i += 1
|
i += 1
|
||||||
if f == "":
|
if f == '':
|
||||||
f = "0ms"
|
f = '0ms'
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def save_model(net, epoch, path, acc, is_best, **kwargs):
|
def save_model(net, epoch, path, acc, is_best, **kwargs):
|
||||||
state = {
|
state = {
|
||||||
"net": net.state_dict(),
|
'net': net.state_dict(),
|
||||||
"epoch": epoch,
|
'epoch': epoch,
|
||||||
"acc": acc,
|
'acc': acc
|
||||||
}
|
}
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
state[key] = value
|
state[key] = value
|
||||||
filepath = os.path.join(path, "last_checkpoint.pth")
|
filepath = os.path.join(path, "last_checkpoint.pth")
|
||||||
torch.save(state, filepath)
|
torch.save(state, filepath)
|
||||||
if is_best:
|
if is_best:
|
||||||
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
|
shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_args(args):
|
def save_args(args):
|
||||||
file = open(os.path.join(args.checkpoint, "args.txt"), "w")
|
file = open(os.path.join(args.checkpoint, 'args.txt'), "w")
|
||||||
for k, v in vars(args).items():
|
for k, v in vars(args).items():
|
||||||
file.write(f"{k}:\t {v}\n")
|
file.write(f"{k}:\t {v}\n")
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed=None):
|
def set_seed(seed=None):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
return
|
return
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
os.environ["PYTHONHASHSEED"] = "%s" % seed
|
os.environ['PYTHONHASHSEED'] = ("%s" % seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
@ -210,14 +200,15 @@ def set_seed(seed=None):
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# create a file and write the text into it
|
# create a file and write the text into it
|
||||||
class IOStream:
|
class IOStream():
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.f = open(path, "a")
|
self.f = open(path, 'a')
|
||||||
|
|
||||||
def cprint(self, text):
|
def cprint(self, text):
|
||||||
print(text)
|
print(text)
|
||||||
self.f.write(text + "\n")
|
self.f.write(text+'\n')
|
||||||
self.f.flush()
|
self.f.flush()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -225,7 +216,8 @@ class IOStream:
|
||||||
|
|
||||||
|
|
||||||
def cal_loss(pred, gold, smoothing=True):
|
def cal_loss(pred, gold, smoothing=True):
|
||||||
"""Calculate cross entropy loss, apply label smoothing if needed."""
|
''' Calculate cross entropy loss, apply label smoothing if needed. '''
|
||||||
|
|
||||||
gold = gold.contiguous().view(-1)
|
gold = gold.contiguous().view(-1)
|
||||||
|
|
||||||
if smoothing:
|
if smoothing:
|
||||||
|
@ -238,6 +230,6 @@ def cal_loss(pred, gold, smoothing=True):
|
||||||
|
|
||||||
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.cross_entropy(pred, gold, reduction="mean")
|
loss = F.cross_entropy(pred, gold, reduction='mean')
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -19,12 +20,13 @@ from math import ceil
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
__version__ = "1.3"
|
|
||||||
|
__version__ = '1.3'
|
||||||
|
|
||||||
|
|
||||||
class Infinite:
|
class Infinite(object):
|
||||||
file = stderr
|
file = stderr
|
||||||
sma_window = 10 # Simple Moving Average window
|
sma_window = 10 # Simple Moving Average window
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.index = 0
|
self.index = 0
|
||||||
|
@ -36,7 +38,7 @@ class Infinite:
|
||||||
setattr(self, key, val)
|
setattr(self, key, val)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if key.startswith("_"):
|
if key.startswith('_'):
|
||||||
return None
|
return None
|
||||||
return getattr(self, key, None)
|
return getattr(self, key, None)
|
||||||
|
|
||||||
|
@ -81,8 +83,8 @@ class Infinite:
|
||||||
|
|
||||||
class Progress(Infinite):
|
class Progress(Infinite):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super(Progress, self).__init__(*args, **kwargs)
|
||||||
self.max = kwargs.get("max", 100)
|
self.max = kwargs.get('max', 100)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eta(self):
|
def eta(self):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,18 +14,19 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Progress
|
from . import Progress
|
||||||
from .helpers import WritelnMixin
|
from .helpers import WritelnMixin
|
||||||
|
|
||||||
|
|
||||||
class Bar(WritelnMixin, Progress):
|
class Bar(WritelnMixin, Progress):
|
||||||
width = 32
|
width = 32
|
||||||
message = ""
|
message = ''
|
||||||
suffix = "%(index)d/%(max)d"
|
suffix = '%(index)d/%(max)d'
|
||||||
bar_prefix = " |"
|
bar_prefix = ' |'
|
||||||
bar_suffix = "| "
|
bar_suffix = '| '
|
||||||
empty_fill = " "
|
empty_fill = ' '
|
||||||
fill = "#"
|
fill = '#'
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -34,50 +37,52 @@ class Bar(WritelnMixin, Progress):
|
||||||
bar = self.fill * filled_length
|
bar = self.fill * filled_length
|
||||||
empty = self.empty_fill * empty_length
|
empty = self.empty_fill * empty_length
|
||||||
suffix = self.suffix % self
|
suffix = self.suffix % self
|
||||||
line = "".join([message, self.bar_prefix, bar, empty, self.bar_suffix, suffix])
|
line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
|
||||||
|
suffix])
|
||||||
self.writeln(line)
|
self.writeln(line)
|
||||||
|
|
||||||
|
|
||||||
class ChargingBar(Bar):
|
class ChargingBar(Bar):
|
||||||
suffix = "%(percent)d%%"
|
suffix = '%(percent)d%%'
|
||||||
bar_prefix = " "
|
bar_prefix = ' '
|
||||||
bar_suffix = " "
|
bar_suffix = ' '
|
||||||
empty_fill = "∙"
|
empty_fill = '∙'
|
||||||
fill = "█"
|
fill = '█'
|
||||||
|
|
||||||
|
|
||||||
class FillingSquaresBar(ChargingBar):
|
class FillingSquaresBar(ChargingBar):
|
||||||
empty_fill = "▢"
|
empty_fill = '▢'
|
||||||
fill = "▣"
|
fill = '▣'
|
||||||
|
|
||||||
|
|
||||||
class FillingCirclesBar(ChargingBar):
|
class FillingCirclesBar(ChargingBar):
|
||||||
empty_fill = "◯"
|
empty_fill = '◯'
|
||||||
fill = "◉"
|
fill = '◉'
|
||||||
|
|
||||||
|
|
||||||
class IncrementalBar(Bar):
|
class IncrementalBar(Bar):
|
||||||
phases = (" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉", "█")
|
phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
nphases = len(self.phases)
|
nphases = len(self.phases)
|
||||||
filled_len = self.width * self.progress
|
filled_len = self.width * self.progress
|
||||||
nfull = int(filled_len) # Number of full chars
|
nfull = int(filled_len) # Number of full chars
|
||||||
phase = int((filled_len - nfull) * nphases) # Phase of last char
|
phase = int((filled_len - nfull) * nphases) # Phase of last char
|
||||||
nempty = self.width - nfull # Number of empty chars
|
nempty = self.width - nfull # Number of empty chars
|
||||||
|
|
||||||
message = self.message % self
|
message = self.message % self
|
||||||
bar = self.phases[-1] * nfull
|
bar = self.phases[-1] * nfull
|
||||||
current = self.phases[phase] if phase > 0 else ""
|
current = self.phases[phase] if phase > 0 else ''
|
||||||
empty = self.empty_fill * max(0, nempty - len(current))
|
empty = self.empty_fill * max(0, nempty - len(current))
|
||||||
suffix = self.suffix % self
|
suffix = self.suffix % self
|
||||||
line = "".join([message, self.bar_prefix, bar, current, empty, self.bar_suffix, suffix])
|
line = ''.join([message, self.bar_prefix, bar, current, empty,
|
||||||
|
self.bar_suffix, suffix])
|
||||||
self.writeln(line)
|
self.writeln(line)
|
||||||
|
|
||||||
|
|
||||||
class PixelBar(IncrementalBar):
|
class PixelBar(IncrementalBar):
|
||||||
phases = ("⡀", "⡄", "⡆", "⡇", "⣇", "⣧", "⣷", "⣿")
|
phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')
|
||||||
|
|
||||||
|
|
||||||
class ShadyBar(IncrementalBar):
|
class ShadyBar(IncrementalBar):
|
||||||
phases = (" ", "░", "▒", "▓", "█")
|
phases = (' ', '░', '▒', '▓', '█')
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,12 +14,13 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Infinite, Progress
|
from . import Infinite, Progress
|
||||||
from .helpers import WriteMixin
|
from .helpers import WriteMixin
|
||||||
|
|
||||||
|
|
||||||
class Counter(WriteMixin, Infinite):
|
class Counter(WriteMixin, Infinite):
|
||||||
message = ""
|
message = ''
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -32,7 +35,7 @@ class Countdown(WriteMixin, Progress):
|
||||||
|
|
||||||
|
|
||||||
class Stack(WriteMixin, Progress):
|
class Stack(WriteMixin, Progress):
|
||||||
phases = (" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█")
|
phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -42,4 +45,4 @@ class Stack(WriteMixin, Progress):
|
||||||
|
|
||||||
|
|
||||||
class Pie(Stack):
|
class Pie(Stack):
|
||||||
phases = ("○", "◔", "◑", "◕", "●")
|
phases = ('○', '◔', '◑', '◕', '●')
|
||||||
|
|
|
@ -12,76 +12,78 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
HIDE_CURSOR = "\x1b[?25l"
|
|
||||||
SHOW_CURSOR = "\x1b[?25h"
|
|
||||||
|
|
||||||
|
|
||||||
class WriteMixin:
|
HIDE_CURSOR = '\x1b[?25l'
|
||||||
|
SHOW_CURSOR = '\x1b[?25h'
|
||||||
|
|
||||||
|
|
||||||
|
class WriteMixin(object):
|
||||||
hide_cursor = False
|
hide_cursor = False
|
||||||
|
|
||||||
def __init__(self, message=None, **kwargs):
|
def __init__(self, message=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super(WriteMixin, self).__init__(**kwargs)
|
||||||
self._width = 0
|
self._width = 0
|
||||||
if message:
|
if message:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
if self.hide_cursor:
|
if self.hide_cursor:
|
||||||
print(HIDE_CURSOR, end="", file=self.file)
|
print(HIDE_CURSOR, end='', file=self.file)
|
||||||
print(self.message, end="", file=self.file)
|
print(self.message, end='', file=self.file)
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def write(self, s):
|
def write(self, s):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
b = "\b" * self._width
|
b = '\b' * self._width
|
||||||
c = s.ljust(self._width)
|
c = s.ljust(self._width)
|
||||||
print(b + c, end="", file=self.file)
|
print(b + c, end='', file=self.file)
|
||||||
self._width = max(self._width, len(s))
|
self._width = max(self._width, len(s))
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.file.isatty() and self.hide_cursor:
|
if self.file.isatty() and self.hide_cursor:
|
||||||
print(SHOW_CURSOR, end="", file=self.file)
|
print(SHOW_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
|
|
||||||
class WritelnMixin:
|
class WritelnMixin(object):
|
||||||
hide_cursor = False
|
hide_cursor = False
|
||||||
|
|
||||||
def __init__(self, message=None, **kwargs):
|
def __init__(self, message=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super(WritelnMixin, self).__init__(**kwargs)
|
||||||
if message:
|
if message:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
if self.file.isatty() and self.hide_cursor:
|
if self.file.isatty() and self.hide_cursor:
|
||||||
print(HIDE_CURSOR, end="", file=self.file)
|
print(HIDE_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
def clearln(self):
|
def clearln(self):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
print("\r\x1b[K", end="", file=self.file)
|
print('\r\x1b[K', end='', file=self.file)
|
||||||
|
|
||||||
def writeln(self, line):
|
def writeln(self, line):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
self.clearln()
|
self.clearln()
|
||||||
print(line, end="", file=self.file)
|
print(line, end='', file=self.file)
|
||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.file.isatty():
|
if self.file.isatty():
|
||||||
print(file=self.file)
|
print(file=self.file)
|
||||||
if self.hide_cursor:
|
if self.hide_cursor:
|
||||||
print(SHOW_CURSOR, end="", file=self.file)
|
print(SHOW_CURSOR, end='', file=self.file)
|
||||||
|
|
||||||
|
|
||||||
from signal import SIGINT, signal
|
from signal import signal, SIGINT
|
||||||
from sys import exit
|
from sys import exit
|
||||||
|
|
||||||
|
|
||||||
class SigIntMixin:
|
class SigIntMixin(object):
|
||||||
"""Registers a signal handler that calls finish on SIGINT."""
|
"""Registers a signal handler that calls finish on SIGINT"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super(SigIntMixin, self).__init__(*args, **kwargs)
|
||||||
signal(SIGINT, self._sigint_handler)
|
signal(SIGINT, self._sigint_handler)
|
||||||
|
|
||||||
def _sigint_handler(self, signum, frame):
|
def _sigint_handler(self, signum, frame):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
|
||||||
#
|
#
|
||||||
# Permission to use, copy, modify, and distribute this software for any
|
# Permission to use, copy, modify, and distribute this software for any
|
||||||
|
@ -12,13 +14,14 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
from . import Infinite
|
from . import Infinite
|
||||||
from .helpers import WriteMixin
|
from .helpers import WriteMixin
|
||||||
|
|
||||||
|
|
||||||
class Spinner(WriteMixin, Infinite):
|
class Spinner(WriteMixin, Infinite):
|
||||||
message = ""
|
message = ''
|
||||||
phases = ("-", "\\", "|", "/")
|
phases = ('-', '\\', '|', '/')
|
||||||
hide_cursor = True
|
hide_cursor = True
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
@ -27,16 +30,15 @@ class Spinner(WriteMixin, Infinite):
|
||||||
|
|
||||||
|
|
||||||
class PieSpinner(Spinner):
|
class PieSpinner(Spinner):
|
||||||
phases = ["◷", "◶", "◵", "◴"]
|
phases = ['◷', '◶', '◵', '◴']
|
||||||
|
|
||||||
|
|
||||||
class MoonSpinner(Spinner):
|
class MoonSpinner(Spinner):
|
||||||
phases = ["◑", "◒", "◐", "◓"]
|
phases = ['◑', '◒', '◐', '◓']
|
||||||
|
|
||||||
|
|
||||||
class LineSpinner(Spinner):
|
class LineSpinner(Spinner):
|
||||||
phases = ["⎺", "⎻", "⎼", "⎽", "⎼", "⎻"]
|
phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']
|
||||||
|
|
||||||
|
|
||||||
class PixelSpinner(Spinner):
|
class PixelSpinner(Spinner):
|
||||||
phases = ["⣾", "⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽"]
|
phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']
|
||||||
|
|
|
@ -1,27 +1,29 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import progress
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
|
import progress
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="progress",
|
name='progress',
|
||||||
version=progress.__version__,
|
version=progress.__version__,
|
||||||
description="Easy to use progress bars",
|
description='Easy to use progress bars',
|
||||||
long_description=open("README.rst").read(),
|
long_description=open('README.rst').read(),
|
||||||
author="Giorgos Verigakis",
|
author='Giorgos Verigakis',
|
||||||
author_email="verigak@gmail.com",
|
author_email='verigak@gmail.com',
|
||||||
url="http://github.com/verigak/progress/",
|
url='http://github.com/verigak/progress/',
|
||||||
license="ISC",
|
license='ISC',
|
||||||
packages=["progress"],
|
packages=['progress'],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Environment :: Console",
|
'Environment :: Console',
|
||||||
"Intended Audience :: Developers",
|
'Intended Audience :: Developers',
|
||||||
"License :: OSI Approved :: ISC License (ISCL)",
|
'License :: OSI Approved :: ISC License (ISCL)',
|
||||||
"Programming Language :: Python :: 2.6",
|
'Programming Language :: Python :: 2.6',
|
||||||
"Programming Language :: Python :: 2.7",
|
'Programming Language :: Python :: 2.7',
|
||||||
"Programming Language :: Python :: 3.3",
|
'Programming Language :: Python :: 3.3',
|
||||||
"Programming Language :: Python :: 3.4",
|
'Programming Language :: Python :: 3.4',
|
||||||
"Programming Language :: Python :: 3.5",
|
'Programming Language :: Python :: 3.5',
|
||||||
"Programming Language :: Python :: 3.6",
|
'Programming Language :: Python :: 3.6',
|
||||||
],
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from progress.bar import Bar, ChargingBar, FillingCirclesBar, FillingSquaresBar, IncrementalBar, PixelBar, ShadyBar
|
from progress.bar import (Bar, ChargingBar, FillingSquaresBar,
|
||||||
from progress.counter import Countdown, Counter, Pie, Stack
|
FillingCirclesBar, IncrementalBar, PixelBar,
|
||||||
from progress.spinner import LineSpinner, MoonSpinner, PieSpinner, PixelSpinner, Spinner
|
ShadyBar)
|
||||||
|
from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
|
||||||
|
PixelSpinner)
|
||||||
|
from progress.counter import Counter, Countdown, Stack, Pie
|
||||||
|
|
||||||
|
|
||||||
def sleep():
|
def sleep():
|
||||||
|
@ -16,29 +20,29 @@ def sleep():
|
||||||
|
|
||||||
|
|
||||||
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
|
for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
|
||||||
suffix = "%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]"
|
suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
|
||||||
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
||||||
for _i in bar.iter(range(200)):
|
for i in bar.iter(range(200)):
|
||||||
sleep()
|
sleep()
|
||||||
|
|
||||||
for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
|
for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
|
||||||
suffix = "%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]"
|
suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'
|
||||||
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
bar = bar_cls(bar_cls.__name__, suffix=suffix)
|
||||||
for _i in bar.iter(range(200)):
|
for i in bar.iter(range(200)):
|
||||||
sleep()
|
sleep()
|
||||||
|
|
||||||
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
|
for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
|
||||||
for _i in spin(spin.__name__ + " ").iter(range(100)):
|
for i in spin(spin.__name__ + ' ').iter(range(100)):
|
||||||
sleep()
|
sleep()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
for singleton in (Counter, Countdown, Stack, Pie):
|
for singleton in (Counter, Countdown, Stack, Pie):
|
||||||
for _i in singleton(singleton.__name__ + " ").iter(range(100)):
|
for i in singleton(singleton.__name__ + ' ').iter(range(100)):
|
||||||
sleep()
|
sleep()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
bar = IncrementalBar("Random", suffix="%(index)d")
|
bar = IncrementalBar('Random', suffix='%(index)d')
|
||||||
for _i in range(100):
|
for i in range(100):
|
||||||
bar.goto(random.randint(0, 100))
|
bar.goto(random.randint(0, 100))
|
||||||
sleep()
|
sleep()
|
||||||
bar.finish()
|
bar.finish()
|
||||||
|
|
|
@ -1,33 +1,23 @@
|
||||||
name: pointmlp
|
name: pointmlp
|
||||||
|
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
- nvidia
|
- nvidia
|
||||||
- conda-forge
|
- conda-forge
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
#---# basic python
|
# - cudatoolkit=10.2.89
|
||||||
- pytorch
|
- cudatoolkit=11.1
|
||||||
- tqdm
|
- cycler=0.10.0
|
||||||
- numpy
|
- einops=0.3.0
|
||||||
- scipy
|
- h5py=3.2.1
|
||||||
- scikit-learn
|
- matplotlib=3.4.2
|
||||||
#---# file readers
|
- numpy=1.20.2
|
||||||
- h5py
|
- numpy-base=1.20.2
|
||||||
- pyyaml
|
- pytorch=1.8.1
|
||||||
#---# tooling (linting, typing...)
|
- pyyaml=5.4.1
|
||||||
- ruff
|
- scikit-learn=0.24.2
|
||||||
- mypy
|
- scipy=1.6.3
|
||||||
- black
|
- torchvision=0.9.1
|
||||||
- isort
|
- tqdm=4.61.1
|
||||||
#---# visu
|
|
||||||
- matplotlib
|
|
||||||
#---# pytorch
|
|
||||||
- cudatoolkit
|
|
||||||
- cycler
|
|
||||||
- einops
|
|
||||||
- torchvision
|
|
||||||
|
|
||||||
- pip
|
- pip
|
||||||
- pip:
|
- pip:
|
||||||
- pointnet2_ops_lib/.
|
- pointnet2_ops_lib/.
|
||||||
|
|
BIN
overview.pdf
Normal file
BIN
overview.pdf
Normal file
Binary file not shown.
BIN
overview.png
Normal file
BIN
overview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
|
@ -1,46 +1,30 @@
|
||||||
import argparse
|
from __future__ import print_function
|
||||||
import os
|
import os
|
||||||
import random
|
import argparse
|
||||||
from collections import defaultdict
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
|
||||||
|
from util.data_util import PartNormalDataset
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
import model as models
|
import model as models
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from util.util import to_categorical, compute_overall_iou, IOStream
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from util.data_util import PartNormalDataset
|
from collections import defaultdict
|
||||||
from util.util import IOStream, compute_overall_iou, to_categorical
|
from torch.autograd import Variable
|
||||||
|
import random
|
||||||
|
|
||||||
classes_str = [
|
|
||||||
"aero",
|
classes_str = ['aero','bag','cap','car','chair','ear','guitar','knife','lamp','lapt','moto','mug','Pistol','rock','stake','table']
|
||||||
"bag",
|
|
||||||
"cap",
|
|
||||||
"car",
|
|
||||||
"chair",
|
|
||||||
"ear",
|
|
||||||
"guitar",
|
|
||||||
"knife",
|
|
||||||
"lamp",
|
|
||||||
"lapt",
|
|
||||||
"moto",
|
|
||||||
"mug",
|
|
||||||
"Pistol",
|
|
||||||
"rock",
|
|
||||||
"stake",
|
|
||||||
"table",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _init_():
|
def _init_():
|
||||||
if not os.path.exists("checkpoints"):
|
if not os.path.exists('checkpoints'):
|
||||||
os.makedirs("checkpoints")
|
os.makedirs('checkpoints')
|
||||||
if not os.path.exists("checkpoints/" + args.exp_name):
|
if not os.path.exists('checkpoints/' + args.exp_name):
|
||||||
os.makedirs("checkpoints/" + args.exp_name)
|
os.makedirs('checkpoints/' + args.exp_name)
|
||||||
|
|
||||||
|
|
||||||
def weight_init(m):
|
def weight_init(m):
|
||||||
|
@ -65,6 +49,7 @@ def weight_init(m):
|
||||||
|
|
||||||
|
|
||||||
def train(args, io):
|
def train(args, io):
|
||||||
|
|
||||||
# ============= Model ===================
|
# ============= Model ===================
|
||||||
num_part = 50
|
num_part = 50
|
||||||
device = torch.device("cuda" if args.cuda else "cpu")
|
device = torch.device("cuda" if args.cuda else "cpu")
|
||||||
|
@ -76,19 +61,16 @@ def train(args, io):
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
|
||||||
"""Resume or not"""
|
'''Resume or not'''
|
||||||
if args.resume:
|
if args.resume:
|
||||||
state_dict = torch.load(
|
state_dict = torch.load("checkpoints/%s/best_insiou_model.pth" % args.exp_name,
|
||||||
"checkpoints/%s/best_insiou_model.pth" % args.exp_name,
|
map_location=torch.device('cpu'))['model']
|
||||||
map_location=torch.device("cpu"),
|
|
||||||
)["model"]
|
|
||||||
for k in state_dict.keys():
|
for k in state_dict.keys():
|
||||||
if "module" not in k:
|
if 'module' not in k:
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
new_state_dict["module." + k] = state_dict[k]
|
new_state_dict['module.' + k] = state_dict[k]
|
||||||
state_dict = new_state_dict
|
state_dict = new_state_dict
|
||||||
break
|
break
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
@ -99,37 +81,27 @@ def train(args, io):
|
||||||
print("Training from scratch...")
|
print("Training from scratch...")
|
||||||
|
|
||||||
# =========== Dataloader =================
|
# =========== Dataloader =================
|
||||||
train_data = PartNormalDataset(npoints=2048, split="trainval", normalize=False)
|
train_data = PartNormalDataset(npoints=2048, split='trainval', normalize=False)
|
||||||
print("The number of training data is:%d", len(train_data))
|
print("The number of training data is:%d", len(train_data))
|
||||||
|
|
||||||
test_data = PartNormalDataset(npoints=2048, split="test", normalize=False)
|
test_data = PartNormalDataset(npoints=2048, split='test', normalize=False)
|
||||||
print("The number of test data is:%d", len(test_data))
|
print("The number of test data is:%d", len(test_data))
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
|
||||||
train_data,
|
drop_last=True)
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=args.workers,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers,
|
||||||
test_data,
|
drop_last=False)
|
||||||
batch_size=args.test_batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=args.workers,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============= Optimizer ================
|
# ============= Optimizer ================
|
||||||
if args.use_sgd:
|
if args.use_sgd:
|
||||||
print("Use SGD")
|
print("Use SGD")
|
||||||
opt = optim.SGD(model.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=0)
|
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=0)
|
||||||
else:
|
else:
|
||||||
print("Use Adam")
|
print("Use Adam")
|
||||||
opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
opt = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
if args.scheduler == "cos":
|
if args.scheduler == 'cos':
|
||||||
print("Use CosLR")
|
print("Use CosLR")
|
||||||
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100)
|
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr if args.use_sgd else args.lr / 100)
|
||||||
else:
|
else:
|
||||||
|
@ -144,33 +116,28 @@ def train(args, io):
|
||||||
num_classes = 16
|
num_classes = 16
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
|
|
||||||
train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io)
|
train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io)
|
||||||
|
|
||||||
test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io)
|
test_metrics, total_per_cat_iou = test_epoch(test_loader, model, epoch, num_part, num_classes, io)
|
||||||
|
|
||||||
# 1. when get the best accuracy, save the model:
|
# 1. when get the best accuracy, save the model:
|
||||||
if test_metrics["accuracy"] > best_acc:
|
if test_metrics['accuracy'] > best_acc:
|
||||||
best_acc = test_metrics["accuracy"]
|
best_acc = test_metrics['accuracy']
|
||||||
io.cprint("Max Acc:%.5f" % best_acc)
|
io.cprint('Max Acc:%.5f' % best_acc)
|
||||||
state = {
|
state = {
|
||||||
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
||||||
"optimizer": opt.state_dict(),
|
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_acc': best_acc}
|
||||||
"epoch": epoch,
|
torch.save(state, 'checkpoints/%s/best_acc_model.pth' % args.exp_name)
|
||||||
"test_acc": best_acc,
|
|
||||||
}
|
|
||||||
torch.save(state, "checkpoints/%s/best_acc_model.pth" % args.exp_name)
|
|
||||||
|
|
||||||
# 2. when get the best instance_iou, save the model:
|
# 2. when get the best instance_iou, save the model:
|
||||||
if test_metrics["shape_avg_iou"] > best_instance_iou:
|
if test_metrics['shape_avg_iou'] > best_instance_iou:
|
||||||
best_instance_iou = test_metrics["shape_avg_iou"]
|
best_instance_iou = test_metrics['shape_avg_iou']
|
||||||
io.cprint("Max instance iou:%.5f" % best_instance_iou)
|
io.cprint('Max instance iou:%.5f' % best_instance_iou)
|
||||||
state = {
|
state = {
|
||||||
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
||||||
"optimizer": opt.state_dict(),
|
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_instance_iou': best_instance_iou}
|
||||||
"epoch": epoch,
|
torch.save(state, 'checkpoints/%s/best_insiou_model.pth' % args.exp_name)
|
||||||
"test_instance_iou": best_instance_iou,
|
|
||||||
}
|
|
||||||
torch.save(state, "checkpoints/%s/best_insiou_model.pth" % args.exp_name)
|
|
||||||
|
|
||||||
# 3. when get the best class_iou, save the model:
|
# 3. when get the best class_iou, save the model:
|
||||||
# first we need to calculate the average per-class iou
|
# first we need to calculate the average per-class iou
|
||||||
|
@ -182,28 +149,22 @@ def train(args, io):
|
||||||
best_class_iou = avg_class_iou
|
best_class_iou = avg_class_iou
|
||||||
# print the iou of each class:
|
# print the iou of each class:
|
||||||
for cat_idx in range(16):
|
for cat_idx in range(16):
|
||||||
io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx]))
|
io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx]))
|
||||||
io.cprint("Max class iou:%.5f" % best_class_iou)
|
io.cprint('Max class iou:%.5f' % best_class_iou)
|
||||||
state = {
|
state = {
|
||||||
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
||||||
"optimizer": opt.state_dict(),
|
'optimizer': opt.state_dict(), 'epoch': epoch, 'test_class_iou': best_class_iou}
|
||||||
"epoch": epoch,
|
torch.save(state, 'checkpoints/%s/best_clsiou_model.pth' % args.exp_name)
|
||||||
"test_class_iou": best_class_iou,
|
|
||||||
}
|
|
||||||
torch.save(state, "checkpoints/%s/best_clsiou_model.pth" % args.exp_name)
|
|
||||||
|
|
||||||
# report best acc, ins_iou, cls_iou
|
# report best acc, ins_iou, cls_iou
|
||||||
io.cprint("Final Max Acc:%.5f" % best_acc)
|
io.cprint('Final Max Acc:%.5f' % best_acc)
|
||||||
io.cprint("Final Max instance iou:%.5f" % best_instance_iou)
|
io.cprint('Final Max instance iou:%.5f' % best_instance_iou)
|
||||||
io.cprint("Final Max class iou:%.5f" % best_class_iou)
|
io.cprint('Final Max class iou:%.5f' % best_class_iou)
|
||||||
# save last model
|
# save last model
|
||||||
state = {
|
state = {
|
||||||
"model": model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
'model': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
|
||||||
"optimizer": opt.state_dict(),
|
'optimizer': opt.state_dict(), 'epoch': args.epochs - 1, 'test_iou': best_instance_iou}
|
||||||
"epoch": args.epochs - 1,
|
torch.save(state, 'checkpoints/%s/model_ep%d.pth' % (args.exp_name, args.epochs))
|
||||||
"test_iou": best_instance_iou,
|
|
||||||
}
|
|
||||||
torch.save(state, "checkpoints/%s/model_ep%d.pth" % (args.exp_name, args.epochs))
|
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io):
|
def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classes, io):
|
||||||
|
@ -214,41 +175,22 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
|
||||||
metrics = defaultdict(lambda: list())
|
metrics = defaultdict(lambda: list())
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
for _batch_id, (points, label, target, norm_plt) in tqdm(
|
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
|
||||||
enumerate(train_loader),
|
|
||||||
total=len(train_loader),
|
|
||||||
smoothing=0.9,
|
|
||||||
):
|
|
||||||
batch_size, num_point, _ = points.size()
|
batch_size, num_point, _ = points.size()
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \
|
||||||
Variable(points.float()),
|
Variable(norm_plt.float())
|
||||||
Variable(label.long()),
|
|
||||||
Variable(target.long()),
|
|
||||||
Variable(norm_plt.float()),
|
|
||||||
)
|
|
||||||
points = points.transpose(2, 1)
|
points = points.transpose(2, 1)
|
||||||
norm_plt = norm_plt.transpose(2, 1)
|
norm_plt = norm_plt.transpose(2, 1)
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \
|
||||||
points.cuda(non_blocking=True),
|
target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
|
||||||
label.squeeze(1).cuda(non_blocking=True),
|
|
||||||
target.cuda(non_blocking=True),
|
|
||||||
norm_plt.cuda(non_blocking=True),
|
|
||||||
)
|
|
||||||
# target: b,n
|
# target: b,n
|
||||||
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50
|
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # seg_pred: b,n,50
|
||||||
loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0])
|
loss = F.nll_loss(seg_pred.contiguous().view(-1, num_part), target.view(-1, 1)[:, 0])
|
||||||
|
|
||||||
# instance iou without considering the class average at each batch_size:
|
# instance iou without considering the class average at each batch_size:
|
||||||
batch_shapeious = compute_overall_iou(
|
batch_shapeious = compute_overall_iou(seg_pred, target, num_part) # list of of current batch_iou:[iou1,iou2,...,iou#b_size]
|
||||||
seg_pred,
|
|
||||||
target,
|
|
||||||
num_part,
|
|
||||||
) # list of of current batch_iou:[iou1,iou2,...,iou#b_size]
|
|
||||||
# total iou of current batch in each process:
|
# total iou of current batch in each process:
|
||||||
batch_shapeious = seg_pred.new_tensor(
|
batch_shapeious = seg_pred.new_tensor([np.sum(batch_shapeious)], dtype=torch.float64) # same device with seg_pred!!!
|
||||||
[np.sum(batch_shapeious)],
|
|
||||||
dtype=torch.float64,
|
|
||||||
) # same device with seg_pred!!!
|
|
||||||
|
|
||||||
# Loss backward
|
# Loss backward
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
|
@ -258,37 +200,33 @@ def train_epoch(train_loader, model, opt, scheduler, epoch, num_part, num_classe
|
||||||
|
|
||||||
# accuracy
|
# accuracy
|
||||||
seg_pred = seg_pred.contiguous().view(-1, num_part) # b*n,50
|
seg_pred = seg_pred.contiguous().view(-1, num_part) # b*n,50
|
||||||
target = target.view(-1, 1)[:, 0] # b*n
|
target = target.view(-1, 1)[:, 0] # b*n
|
||||||
pred_choice = seg_pred.contiguous().data.max(1)[1] # b*n
|
pred_choice = seg_pred.contiguous().data.max(1)[1] # b*n
|
||||||
correct = pred_choice.eq(target.contiguous().data).sum() # torch.int64: total number of correct-predict pts
|
correct = pred_choice.eq(target.contiguous().data).sum() # torch.int64: total number of correct-predict pts
|
||||||
|
|
||||||
# sum
|
# sum
|
||||||
shape_ious += batch_shapeious.item() # count the sum of ious in each iteration
|
shape_ious += batch_shapeious.item() # count the sum of ious in each iteration
|
||||||
count += batch_size # count the total number of samples in each iteration
|
count += batch_size # count the total number of samples in each iteration
|
||||||
train_loss += loss.item() * batch_size
|
train_loss += loss.item() * batch_size
|
||||||
accuracy.append(correct.item() / (batch_size * num_point)) # append the accuracy of each iteration
|
accuracy.append(correct.item()/(batch_size * num_point)) # append the accuracy of each iteration
|
||||||
|
|
||||||
# Note: We do not need to calculate per_class iou during training
|
# Note: We do not need to calculate per_class iou during training
|
||||||
|
|
||||||
if args.scheduler == "cos":
|
if args.scheduler == 'cos':
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
elif args.scheduler == "step":
|
elif args.scheduler == 'step':
|
||||||
if opt.param_groups[0]["lr"] > 0.9e-5:
|
if opt.param_groups[0]['lr'] > 0.9e-5:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
if opt.param_groups[0]["lr"] < 0.9e-5:
|
if opt.param_groups[0]['lr'] < 0.9e-5:
|
||||||
for param_group in opt.param_groups:
|
for param_group in opt.param_groups:
|
||||||
param_group["lr"] = 0.9e-5
|
param_group['lr'] = 0.9e-5
|
||||||
io.cprint("Learning rate: %f" % opt.param_groups[0]["lr"])
|
io.cprint('Learning rate: %f' % opt.param_groups[0]['lr'])
|
||||||
|
|
||||||
metrics["accuracy"] = np.mean(accuracy)
|
metrics['accuracy'] = np.mean(accuracy)
|
||||||
metrics["shape_avg_iou"] = shape_ious * 1.0 / count
|
metrics['shape_avg_iou'] = shape_ious * 1.0 / count
|
||||||
|
|
||||||
outstr = "Train %d, loss: %f, train acc: %f, train ins_iou: %f" % (
|
outstr = 'Train %d, loss: %f, train acc: %f, train ins_iou: %f' % (epoch+1, train_loss * 1.0 / count,
|
||||||
epoch + 1,
|
metrics['accuracy'], metrics['shape_avg_iou'])
|
||||||
train_loss * 1.0 / count,
|
|
||||||
metrics["accuracy"],
|
|
||||||
metrics["shape_avg_iou"],
|
|
||||||
)
|
|
||||||
io.cprint(outstr)
|
io.cprint(outstr)
|
||||||
|
|
||||||
|
|
||||||
|
@ -303,26 +241,14 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# label_size: b, means each sample has one corresponding class
|
# label_size: b, means each sample has one corresponding class
|
||||||
for _batch_id, (points, label, target, norm_plt) in tqdm(
|
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
|
||||||
enumerate(test_loader),
|
|
||||||
total=len(test_loader),
|
|
||||||
smoothing=0.9,
|
|
||||||
):
|
|
||||||
batch_size, num_point, _ = points.size()
|
batch_size, num_point, _ = points.size()
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), \
|
||||||
Variable(points.float()),
|
Variable(norm_plt.float())
|
||||||
Variable(label.long()),
|
|
||||||
Variable(target.long()),
|
|
||||||
Variable(norm_plt.float()),
|
|
||||||
)
|
|
||||||
points = points.transpose(2, 1)
|
points = points.transpose(2, 1)
|
||||||
norm_plt = norm_plt.transpose(2, 1)
|
norm_plt = norm_plt.transpose(2, 1)
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze(1).cuda(non_blocking=True), \
|
||||||
points.cuda(non_blocking=True),
|
target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
|
||||||
label.squeeze(1).cuda(non_blocking=True),
|
|
||||||
target.cuda(non_blocking=True),
|
|
||||||
norm_plt.cuda(non_blocking=True),
|
|
||||||
)
|
|
||||||
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
|
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
|
||||||
|
|
||||||
# instance iou without considering the class average at each batch_size:
|
# instance iou without considering the class average at each batch_size:
|
||||||
|
@ -355,19 +281,13 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
|
||||||
|
|
||||||
for cat_idx in range(16):
|
for cat_idx in range(16):
|
||||||
if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending
|
if final_total_per_cat_seen[cat_idx] > 0: # indicating this cat is included during previous iou appending
|
||||||
final_total_per_cat_iou[cat_idx] = (
|
final_total_per_cat_iou[cat_idx] = final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx] # avg class iou across all samples
|
||||||
final_total_per_cat_iou[cat_idx] / final_total_per_cat_seen[cat_idx]
|
|
||||||
) # avg class iou across all samples
|
|
||||||
|
|
||||||
metrics["accuracy"] = np.mean(accuracy)
|
metrics['accuracy'] = np.mean(accuracy)
|
||||||
metrics["shape_avg_iou"] = shape_ious * 1.0 / count
|
metrics['shape_avg_iou'] = shape_ious * 1.0 / count
|
||||||
|
|
||||||
outstr = "Test %d, loss: %f, test acc: %f test ins_iou: %f" % (
|
outstr = 'Test %d, loss: %f, test acc: %f test ins_iou: %f' % (epoch + 1, test_loss * 1.0 / count,
|
||||||
epoch + 1,
|
metrics['accuracy'], metrics['shape_avg_iou'])
|
||||||
test_loss * 1.0 / count,
|
|
||||||
metrics["accuracy"],
|
|
||||||
metrics["shape_avg_iou"],
|
|
||||||
)
|
|
||||||
|
|
||||||
io.cprint(outstr)
|
io.cprint(outstr)
|
||||||
|
|
||||||
|
@ -376,16 +296,11 @@ def test_epoch(test_loader, model, epoch, num_part, num_classes, io):
|
||||||
|
|
||||||
def test(args, io):
|
def test(args, io):
|
||||||
# Dataloader
|
# Dataloader
|
||||||
test_data = PartNormalDataset(npoints=2048, split="test", normalize=False)
|
test_data = PartNormalDataset(npoints=2048, split='test', normalize=False)
|
||||||
print("The number of test data is:%d", len(test_data))
|
print("The number of test data is:%d", len(test_data))
|
||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers,
|
||||||
test_data,
|
drop_last=False)
|
||||||
batch_size=args.test_batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=args.workers,
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to load models
|
# Try to load models
|
||||||
num_part = 50
|
num_part = 50
|
||||||
|
@ -395,15 +310,12 @@ def test(args, io):
|
||||||
io.cprint(str(model))
|
io.cprint(str(model))
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
state_dict = torch.load("checkpoints/%s/best_%s_model.pth" % (args.exp_name, args.model_type),
|
||||||
state_dict = torch.load(
|
map_location=torch.device('cpu'))['model']
|
||||||
f"checkpoints/{args.exp_name}/best_{args.model_type}_model.pth",
|
|
||||||
map_location=torch.device("cpu"),
|
|
||||||
)["model"]
|
|
||||||
|
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for layer in state_dict:
|
for layer in state_dict:
|
||||||
new_state_dict[layer.replace("module.", "")] = state_dict[layer]
|
new_state_dict[layer.replace('module.', '')] = state_dict[layer]
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -412,29 +324,16 @@ def test(args, io):
|
||||||
metrics = defaultdict(lambda: list())
|
metrics = defaultdict(lambda: list())
|
||||||
hist_acc = []
|
hist_acc = []
|
||||||
shape_ious = []
|
shape_ious = []
|
||||||
total_per_cat_iou = np.zeros(16).astype(np.float32)
|
total_per_cat_iou = np.zeros((16)).astype(np.float32)
|
||||||
total_per_cat_seen = np.zeros(16).astype(np.int32)
|
total_per_cat_seen = np.zeros((16)).astype(np.int32)
|
||||||
|
|
||||||
for _batch_id, (points, label, target, norm_plt) in tqdm(
|
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
|
||||||
enumerate(test_loader),
|
|
||||||
total=len(test_loader),
|
|
||||||
smoothing=0.9,
|
|
||||||
):
|
|
||||||
batch_size, num_point, _ = points.size()
|
batch_size, num_point, _ = points.size()
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = Variable(points.float()), Variable(label.long()), Variable(target.long()), Variable(norm_plt.float())
|
||||||
Variable(points.float()),
|
|
||||||
Variable(label.long()),
|
|
||||||
Variable(target.long()),
|
|
||||||
Variable(norm_plt.float()),
|
|
||||||
)
|
|
||||||
points = points.transpose(2, 1)
|
points = points.transpose(2, 1)
|
||||||
norm_plt = norm_plt.transpose(2, 1)
|
norm_plt = norm_plt.transpose(2, 1)
|
||||||
points, label, target, norm_plt = (
|
points, label, target, norm_plt = points.cuda(non_blocking=True), label.squeeze().cuda(
|
||||||
points.cuda(non_blocking=True),
|
non_blocking=True), target.cuda(non_blocking=True), norm_plt.cuda(non_blocking=True)
|
||||||
label.squeeze().cuda(non_blocking=True),
|
|
||||||
target.cuda(non_blocking=True),
|
|
||||||
norm_plt.cuda(non_blocking=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
|
seg_pred = model(points, norm_plt, to_categorical(label, num_classes)) # b,n,50
|
||||||
|
@ -454,11 +353,11 @@ def test(args, io):
|
||||||
target = target.view(-1, 1)[:, 0]
|
target = target.view(-1, 1)[:, 0]
|
||||||
pred_choice = seg_pred.data.max(1)[1]
|
pred_choice = seg_pred.data.max(1)[1]
|
||||||
correct = pred_choice.eq(target.data).cpu().sum()
|
correct = pred_choice.eq(target.data).cpu().sum()
|
||||||
metrics["accuracy"].append(correct.item() / (batch_size * num_point))
|
metrics['accuracy'].append(correct.item() / (batch_size * num_point))
|
||||||
|
|
||||||
hist_acc += metrics["accuracy"]
|
hist_acc += metrics['accuracy']
|
||||||
metrics["accuracy"] = np.mean(hist_acc)
|
metrics['accuracy'] = np.mean(hist_acc)
|
||||||
metrics["shape_avg_iou"] = np.mean(shape_ious)
|
metrics['shape_avg_iou'] = np.mean(shape_ious)
|
||||||
for cat_idx in range(16):
|
for cat_idx in range(16):
|
||||||
if total_per_cat_seen[cat_idx] > 0:
|
if total_per_cat_seen[cat_idx] > 0:
|
||||||
total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx]
|
total_per_cat_iou[cat_idx] = total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx]
|
||||||
|
@ -467,51 +366,57 @@ def test(args, io):
|
||||||
class_iou = 0
|
class_iou = 0
|
||||||
for cat_idx in range(16):
|
for cat_idx in range(16):
|
||||||
class_iou += total_per_cat_iou[cat_idx]
|
class_iou += total_per_cat_iou[cat_idx]
|
||||||
io.cprint(classes_str[cat_idx] + " iou: " + str(total_per_cat_iou[cat_idx])) # print the iou of each class
|
io.cprint(classes_str[cat_idx] + ' iou: ' + str(total_per_cat_iou[cat_idx])) # print the iou of each class
|
||||||
avg_class_iou = class_iou / 16
|
avg_class_iou = class_iou / 16
|
||||||
outstr = "Test :: test acc: {:f} test class mIOU: {:f}, test instance mIOU: {:f}".format(
|
outstr = 'Test :: test acc: %f test class mIOU: %f, test instance mIOU: %f' % (metrics['accuracy'], avg_class_iou, metrics['shape_avg_iou'])
|
||||||
metrics["accuracy"],
|
|
||||||
avg_class_iou,
|
|
||||||
metrics["shape_avg_iou"],
|
|
||||||
)
|
|
||||||
io.cprint(outstr)
|
io.cprint(outstr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Training settings
|
# Training settings
|
||||||
parser = argparse.ArgumentParser(description="3D Shape Part Segmentation")
|
parser = argparse.ArgumentParser(description='3D Shape Part Segmentation')
|
||||||
parser.add_argument("--model", type=str, default="PointMLP1")
|
parser.add_argument('--model', type=str, default='PointMLP1')
|
||||||
parser.add_argument("--exp_name", type=str, default="demo1", metavar="N", help="Name of the experiment")
|
parser.add_argument('--exp_name', type=str, default='demo1', metavar='N',
|
||||||
parser.add_argument("--batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)")
|
help='Name of the experiment')
|
||||||
parser.add_argument("--test_batch_size", type=int, default=32, metavar="batch_size", help="Size of batch)")
|
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
|
||||||
parser.add_argument("--epochs", type=int, default=350, metavar="N", help="number of episode to train")
|
help='Size of batch)')
|
||||||
parser.add_argument("--use_sgd", type=bool, default=False, help="Use SGD")
|
parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size',
|
||||||
parser.add_argument("--scheduler", type=str, default="step", help="lr scheduler")
|
help='Size of batch)')
|
||||||
parser.add_argument("--step", type=int, default=40, help="lr decay step")
|
parser.add_argument('--epochs', type=int, default=350, metavar='N',
|
||||||
parser.add_argument("--lr", type=float, default=0.003, metavar="LR", help="learning rate")
|
help='number of episode to train')
|
||||||
parser.add_argument("--momentum", type=float, default=0.9, metavar="M", help="SGD momentum (default: 0.9)")
|
parser.add_argument('--use_sgd', type=bool, default=False,
|
||||||
parser.add_argument("--no_cuda", type=bool, default=False, help="enables CUDA training")
|
help='Use SGD')
|
||||||
parser.add_argument("--manual_seed", type=int, metavar="S", help="random seed (default: 1)")
|
parser.add_argument('--scheduler', type=str, default='step',
|
||||||
parser.add_argument("--eval", type=bool, default=False, help="evaluate the model")
|
help='lr scheduler')
|
||||||
parser.add_argument("--num_points", type=int, default=2048, help="num of points to use")
|
parser.add_argument('--step', type=int, default=40,
|
||||||
parser.add_argument("--workers", type=int, default=12)
|
help='lr decay step')
|
||||||
parser.add_argument("--resume", type=bool, default=False, help="Resume training or not")
|
parser.add_argument('--lr', type=float, default=0.003, metavar='LR',
|
||||||
parser.add_argument(
|
help='learning rate')
|
||||||
"--model_type",
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||||
type=str,
|
help='SGD momentum (default: 0.9)')
|
||||||
default="insiou",
|
parser.add_argument('--no_cuda', type=bool, default=False,
|
||||||
help="choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)",
|
help='enables CUDA training')
|
||||||
)
|
parser.add_argument('--manual_seed', type=int, metavar='S',
|
||||||
|
help='random seed (default: 1)')
|
||||||
|
parser.add_argument('--eval', type=bool, default=False,
|
||||||
|
help='evaluate the model')
|
||||||
|
parser.add_argument('--num_points', type=int, default=2048,
|
||||||
|
help='num of points to use')
|
||||||
|
parser.add_argument('--workers', type=int, default=12)
|
||||||
|
parser.add_argument('--resume', type=bool, default=False,
|
||||||
|
help='Resume training or not')
|
||||||
|
parser.add_argument('--model_type', type=str, default='insiou',
|
||||||
|
help='choose to test the best insiou/clsiou/acc model (options: insiou, clsiou, acc)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_name = args.model + "_" + args.exp_name
|
args.exp_name = args.model+"_"+args.exp_name
|
||||||
|
|
||||||
_init_()
|
_init_()
|
||||||
|
|
||||||
if not args.eval:
|
if not args.eval:
|
||||||
io = IOStream("checkpoints/" + args.exp_name + "/%s_train.log" % (args.exp_name))
|
io = IOStream('checkpoints/' + args.exp_name + '/%s_train.log' % (args.exp_name))
|
||||||
else:
|
else:
|
||||||
io = IOStream("checkpoints/" + args.exp_name + "/%s_test.log" % (args.exp_name))
|
io = IOStream('checkpoints/' + args.exp_name + '/%s_test.log' % (args.exp_name))
|
||||||
io.cprint(str(args))
|
io.cprint(str(args))
|
||||||
|
|
||||||
if args.manual_seed is not None:
|
if args.manual_seed is not None:
|
||||||
|
@ -522,12 +427,12 @@ if __name__ == "__main__":
|
||||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
|
||||||
if args.cuda:
|
if args.cuda:
|
||||||
io.cprint("Using GPU")
|
io.cprint('Using GPU')
|
||||||
if args.manual_seed is not None:
|
if args.manual_seed is not None:
|
||||||
torch.cuda.manual_seed(args.manual_seed)
|
torch.cuda.manual_seed(args.manual_seed)
|
||||||
torch.cuda.manual_seed_all(args.manual_seed)
|
torch.cuda.manual_seed_all(args.manual_seed)
|
||||||
else:
|
else:
|
||||||
io.cprint("Using CPU")
|
io.cprint('Using CPU')
|
||||||
|
|
||||||
if not args.eval:
|
if not args.eval:
|
||||||
train(args, io)
|
train(args, io)
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
from .pointMLP import pointMLP
|
from .pointMLP import pointMLP
|
||||||
|
|
|
@ -1,32 +1,34 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
from pointnet2_ops import pointnet2_utils
|
from pointnet2_ops import pointnet2_utils
|
||||||
|
|
||||||
|
|
||||||
def get_activation(activation):
|
def get_activation(activation):
|
||||||
if activation.lower() == "gelu":
|
if activation.lower() == 'gelu':
|
||||||
return nn.GELU()
|
return nn.GELU()
|
||||||
elif activation.lower() == "rrelu":
|
elif activation.lower() == 'rrelu':
|
||||||
return nn.RReLU(inplace=True)
|
return nn.RReLU(inplace=True)
|
||||||
elif activation.lower() == "selu":
|
elif activation.lower() == 'selu':
|
||||||
return nn.SELU(inplace=True)
|
return nn.SELU(inplace=True)
|
||||||
elif activation.lower() == "silu":
|
elif activation.lower() == 'silu':
|
||||||
return nn.SiLU(inplace=True)
|
return nn.SiLU(inplace=True)
|
||||||
elif activation.lower() == "hardswish":
|
elif activation.lower() == 'hardswish':
|
||||||
return nn.Hardswish(inplace=True)
|
return nn.Hardswish(inplace=True)
|
||||||
elif activation.lower() == "leakyrelu":
|
elif activation.lower() == 'leakyrelu':
|
||||||
return nn.LeakyReLU(inplace=True)
|
return nn.LeakyReLU(inplace=True)
|
||||||
elif activation.lower() == "leakyrelu0.2":
|
elif activation.lower() == 'leakyrelu0.2':
|
||||||
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
else:
|
else:
|
||||||
return nn.ReLU(inplace=True)
|
return nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
|
||||||
def square_distance(src, dst):
|
def square_distance(src, dst):
|
||||||
"""Calculate Euclid distance between each two points.
|
"""
|
||||||
src^T * dst = xn * xm + yn * ym + zn * zm;
|
Calculate Euclid distance between each two points.
|
||||||
|
src^T * dst = xn * xm + yn * ym + zn * zm;
|
||||||
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
||||||
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
||||||
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
||||||
|
@ -35,23 +37,23 @@ def square_distance(src, dst):
|
||||||
src: source points, [B, N, C]
|
src: source points, [B, N, C]
|
||||||
dst: target points, [B, M, C]
|
dst: target points, [B, M, C]
|
||||||
Output:
|
Output:
|
||||||
dist: per-point square distance, [B, N, M].
|
dist: per-point square distance, [B, N, M]
|
||||||
"""
|
"""
|
||||||
B, N, _ = src.shape
|
B, N, _ = src.shape
|
||||||
_, M, _ = dst.shape
|
_, M, _ = dst.shape
|
||||||
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
||||||
dist += torch.sum(src**2, -1).view(B, N, 1)
|
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
||||||
dist += torch.sum(dst**2, -1).view(B, 1, M)
|
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
def index_points(points, idx):
|
def index_points(points, idx):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
points: input points data, [B, N, C]
|
points: input points data, [B, N, C]
|
||||||
idx: sample index data, [B, S].
|
idx: sample index data, [B, S]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
new_points:, indexed points data, [B, S, C].
|
new_points:, indexed points data, [B, S, C]
|
||||||
"""
|
"""
|
||||||
device = points.device
|
device = points.device
|
||||||
B = points.shape[0]
|
B = points.shape[0]
|
||||||
|
@ -60,15 +62,17 @@ def index_points(points, idx):
|
||||||
repeat_shape = list(idx.shape)
|
repeat_shape = list(idx.shape)
|
||||||
repeat_shape[0] = 1
|
repeat_shape[0] = 1
|
||||||
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
||||||
return points[batch_indices, idx, :]
|
new_points = points[batch_indices, idx, :]
|
||||||
|
return new_points
|
||||||
|
|
||||||
|
|
||||||
def farthest_point_sample(xyz, npoint):
|
def farthest_point_sample(xyz, npoint):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
xyz: pointcloud data, [B, N, 3]
|
xyz: pointcloud data, [B, N, 3]
|
||||||
npoint: number of samples
|
npoint: number of samples
|
||||||
Return:
|
Return:
|
||||||
centroids: sampled pointcloud index, [B, npoint].
|
centroids: sampled pointcloud index, [B, npoint]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
|
@ -86,21 +90,21 @@ def farthest_point_sample(xyz, npoint):
|
||||||
|
|
||||||
|
|
||||||
def query_ball_point(radius, nsample, xyz, new_xyz):
|
def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
radius: local region radius
|
radius: local region radius
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, 3]
|
xyz: all points, [B, N, 3]
|
||||||
new_xyz: query points, [B, S, 3].
|
new_xyz: query points, [B, S, 3]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
device = xyz.device
|
device = xyz.device
|
||||||
B, N, C = xyz.shape
|
B, N, C = xyz.shape
|
||||||
_, S, _ = new_xyz.shape
|
_, S, _ = new_xyz.shape
|
||||||
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
group_idx[sqrdists > radius**2] = N
|
group_idx[sqrdists > radius ** 2] = N
|
||||||
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
||||||
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
||||||
mask = group_idx == N
|
mask = group_idx == N
|
||||||
|
@ -109,13 +113,13 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||||
|
|
||||||
|
|
||||||
def knn_point(nsample, xyz, new_xyz):
|
def knn_point(nsample, xyz, new_xyz):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
nsample: max sample number in local region
|
nsample: max sample number in local region
|
||||||
xyz: all points, [B, N, C]
|
xyz: all points, [B, N, C]
|
||||||
new_xyz: query points, [B, S, C].
|
new_xyz: query points, [B, S, C]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
group_idx: grouped points index, [B, S, nsample].
|
group_idx: grouped points index, [B, S, nsample]
|
||||||
"""
|
"""
|
||||||
sqrdists = square_distance(new_xyz, xyz)
|
sqrdists = square_distance(new_xyz, xyz)
|
||||||
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
||||||
|
@ -124,12 +128,13 @@ def knn_point(nsample, xyz, new_xyz):
|
||||||
|
|
||||||
class LocalGrouper(nn.Module):
|
class LocalGrouper(nn.Module):
|
||||||
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs):
|
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs):
|
||||||
"""Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
"""
|
||||||
|
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
|
||||||
:param groups: groups number
|
:param groups: groups number
|
||||||
:param kneighbors: k-nerighbors
|
:param kneighbors: k-nerighbors
|
||||||
:param kwargs: others.
|
:param kwargs: others
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(LocalGrouper, self).__init__()
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.kneighbors = kneighbors
|
self.kneighbors = kneighbors
|
||||||
self.use_xyz = use_xyz
|
self.use_xyz = use_xyz
|
||||||
|
@ -138,11 +143,11 @@ class LocalGrouper(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize not in ["center", "anchor"]:
|
if self.normalize not in ["center", "anchor"]:
|
||||||
print("Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
add_channel = 3 if self.use_xyz else 0
|
add_channel=3 if self.use_xyz else 0
|
||||||
self.affine_alpha = nn.Parameter(torch.ones([1, 1, 1, channel + add_channel]))
|
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel]))
|
||||||
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
|
||||||
|
|
||||||
def forward(self, xyz, points):
|
def forward(self, xyz, points):
|
||||||
|
@ -161,33 +166,29 @@ class LocalGrouper(nn.Module):
|
||||||
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
|
||||||
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
grouped_points = index_points(points, idx) # [B, npoint, k, d]
|
||||||
if self.use_xyz:
|
if self.use_xyz:
|
||||||
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) # [B, npoint, k, d+3]
|
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
|
||||||
if self.normalize is not None:
|
if self.normalize is not None:
|
||||||
if self.normalize == "center":
|
if self.normalize =="center":
|
||||||
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
mean = torch.mean(grouped_points, dim=2, keepdim=True)
|
||||||
if self.normalize == "anchor":
|
if self.normalize =="anchor":
|
||||||
mean = torch.cat([new_points, new_xyz], dim=-1) if self.use_xyz else new_points
|
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
|
||||||
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
|
||||||
std = (
|
std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
||||||
torch.std((grouped_points - mean).reshape(B, -1), dim=-1, keepdim=True)
|
grouped_points = (grouped_points-mean)/(std + 1e-5)
|
||||||
.unsqueeze(dim=-1)
|
grouped_points = self.affine_alpha*grouped_points + self.affine_beta
|
||||||
.unsqueeze(dim=-1)
|
|
||||||
)
|
|
||||||
grouped_points = (grouped_points - mean) / (std + 1e-5)
|
|
||||||
grouped_points = self.affine_alpha * grouped_points + self.affine_beta
|
|
||||||
|
|
||||||
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
|
||||||
return new_xyz, new_points
|
return new_xyz, new_points
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLU1D(nn.Module):
|
class ConvBNReLU1D(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation="relu"):
|
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLU1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -195,43 +196,30 @@ class ConvBNReLU1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ConvBNReLURes1D(nn.Module):
|
class ConvBNReLURes1D(nn.Module):
|
||||||
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
|
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(ConvBNReLURes1D, self).__init__()
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
self.net1 = nn.Sequential(
|
self.net1 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
|
||||||
in_channels=channel,
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=int(channel * res_expansion),
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(int(channel * res_expansion)),
|
nn.BatchNorm1d(int(channel * res_expansion)),
|
||||||
self.act,
|
self.act
|
||||||
)
|
)
|
||||||
if groups > 1:
|
if groups > 1:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, groups=groups, bias=bias),
|
||||||
out_channels=channel,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
self.act,
|
self.act,
|
||||||
nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias),
|
nn.Conv1d(in_channels=channel, out_channels=channel,
|
||||||
|
kernel_size=kernel_size, bias=bias),
|
||||||
nn.BatchNorm1d(channel),
|
nn.BatchNorm1d(channel),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.net2 = nn.Sequential(
|
self.net2 = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
|
||||||
in_channels=int(channel * res_expansion),
|
kernel_size=kernel_size, bias=bias),
|
||||||
out_channels=channel,
|
nn.BatchNorm1d(channel)
|
||||||
kernel_size=kernel_size,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.BatchNorm1d(channel),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -239,34 +227,21 @@ class ConvBNReLURes1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PreExtraction(nn.Module):
|
class PreExtraction(nn.Module):
|
||||||
def __init__(
|
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
|
||||||
self,
|
activation='relu', use_xyz=True):
|
||||||
channels,
|
"""
|
||||||
out_channels,
|
input: [b,g,k,d]: output:[b,d,g]
|
||||||
blocks=1,
|
|
||||||
groups=1,
|
|
||||||
res_expansion=1,
|
|
||||||
bias=True,
|
|
||||||
activation="relu",
|
|
||||||
use_xyz=True,
|
|
||||||
):
|
|
||||||
"""input: [b,g,k,d]: output:[b,d,g]
|
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PreExtraction, self).__init__()
|
||||||
in_channels = 3 + 2 * channels if use_xyz else 2 * channels
|
in_channels = 3+2*channels if use_xyz else 2*channels
|
||||||
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(
|
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
|
||||||
out_channels,
|
bias=bias, activation=activation)
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -278,20 +253,22 @@ class PreExtraction(nn.Module):
|
||||||
batch_size, _, _ = x.size()
|
batch_size, _, _ = x.size()
|
||||||
x = self.operation(x) # [b, d, k]
|
x = self.operation(x) # [b, d, k]
|
||||||
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
|
||||||
return x.reshape(b, n, -1).permute(0, 2, 1)
|
x = x.reshape(b, n, -1).permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PosExtraction(nn.Module):
|
class PosExtraction(nn.Module):
|
||||||
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation="relu"):
|
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
|
||||||
"""input[b,d,g]; output[b,d,g]
|
"""
|
||||||
|
input[b,d,g]; output[b,d,g]
|
||||||
:param channels:
|
:param channels:
|
||||||
:param blocks:
|
:param blocks:
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super(PosExtraction, self).__init__()
|
||||||
operation = []
|
operation = []
|
||||||
for _ in range(blocks):
|
for _ in range(blocks):
|
||||||
operation.append(
|
operation.append(
|
||||||
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation),
|
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
)
|
)
|
||||||
self.operation = nn.Sequential(*operation)
|
self.operation = nn.Sequential(*operation)
|
||||||
|
|
||||||
|
@ -300,27 +277,22 @@ class PosExtraction(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PointNetFeaturePropagation(nn.Module):
|
class PointNetFeaturePropagation(nn.Module):
|
||||||
def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation="relu"):
|
def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
|
||||||
super().__init__()
|
super(PointNetFeaturePropagation, self).__init__()
|
||||||
self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias)
|
self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias)
|
||||||
self.extraction = PosExtraction(
|
self.extraction = PosExtraction(out_channel, blocks, groups=groups,
|
||||||
out_channel,
|
res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
blocks,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, xyz1, xyz2, points1, points2):
|
def forward(self, xyz1, xyz2, points1, points2):
|
||||||
"""Input:
|
"""
|
||||||
|
Input:
|
||||||
xyz1: input points position data, [B, N, 3]
|
xyz1: input points position data, [B, N, 3]
|
||||||
xyz2: sampled input points position data, [B, S, 3]
|
xyz2: sampled input points position data, [B, S, 3]
|
||||||
points1: input points data, [B, D', N]
|
points1: input points data, [B, D', N]
|
||||||
points2: input points data, [B, D'', S].
|
points2: input points data, [B, D'', S]
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
new_points: upsampled points data, [B, D''', N].
|
new_points: upsampled points data, [B, D''', N]
|
||||||
"""
|
"""
|
||||||
# xyz1 = xyz1.permute(0, 2, 1)
|
# xyz1 = xyz1.permute(0, 2, 1)
|
||||||
# xyz2 = xyz2.permute(0, 2, 1)
|
# xyz2 = xyz2.permute(0, 2, 1)
|
||||||
|
@ -349,40 +321,26 @@ class PointNetFeaturePropagation(nn.Module):
|
||||||
|
|
||||||
new_points = new_points.permute(0, 2, 1)
|
new_points = new_points.permute(0, 2, 1)
|
||||||
new_points = self.fuse(new_points)
|
new_points = self.fuse(new_points)
|
||||||
return self.extraction(new_points)
|
new_points = self.extraction(new_points)
|
||||||
|
return new_points
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PointMLP(nn.Module):
|
class PointMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
self,
|
activation="relu", bias=True, use_xyz=True, normalize="anchor",
|
||||||
num_classes=50,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
points=2048,
|
k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4],
|
||||||
embed_dim=64,
|
de_dims=[512, 256, 128, 128], de_blocks=[2,2,2,2],
|
||||||
groups=1,
|
gmp_dim=64,cls_dim=64, **kwargs):
|
||||||
res_expansion=1.0,
|
super(PointMLP, self).__init__()
|
||||||
activation="relu",
|
|
||||||
bias=True,
|
|
||||||
use_xyz=True,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[32, 32, 32, 32],
|
|
||||||
reducers=[4, 4, 4, 4],
|
|
||||||
de_dims=[512, 256, 128, 128],
|
|
||||||
de_blocks=[2, 2, 2, 2],
|
|
||||||
gmp_dim=64,
|
|
||||||
cls_dim=64,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.stages = len(pre_blocks)
|
self.stages = len(pre_blocks)
|
||||||
self.class_num = num_classes
|
self.class_num = num_classes
|
||||||
self.points = points
|
self.points = points
|
||||||
self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation)
|
self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation)
|
||||||
assert (
|
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
|
||||||
len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion)
|
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
||||||
), "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
|
|
||||||
self.local_grouper_list = nn.ModuleList()
|
self.local_grouper_list = nn.ModuleList()
|
||||||
self.pre_blocks_list = nn.ModuleList()
|
self.pre_blocks_list = nn.ModuleList()
|
||||||
self.pos_blocks_list = nn.ModuleList()
|
self.pos_blocks_list = nn.ModuleList()
|
||||||
|
@ -401,47 +359,29 @@ class PointMLP(nn.Module):
|
||||||
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
|
||||||
self.local_grouper_list.append(local_grouper)
|
self.local_grouper_list.append(local_grouper)
|
||||||
# append pre_block_list
|
# append pre_block_list
|
||||||
pre_block_module = PreExtraction(
|
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
|
||||||
last_channel,
|
res_expansion=res_expansion,
|
||||||
out_channel,
|
bias=bias, activation=activation, use_xyz=use_xyz)
|
||||||
pre_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
use_xyz=use_xyz,
|
|
||||||
)
|
|
||||||
self.pre_blocks_list.append(pre_block_module)
|
self.pre_blocks_list.append(pre_block_module)
|
||||||
# append pos_block_list
|
# append pos_block_list
|
||||||
pos_block_module = PosExtraction(
|
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
|
||||||
out_channel,
|
res_expansion=res_expansion, bias=bias, activation=activation)
|
||||||
pos_block_num,
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
self.pos_blocks_list.append(pos_block_module)
|
self.pos_blocks_list.append(pos_block_module)
|
||||||
|
|
||||||
last_channel = out_channel
|
last_channel = out_channel
|
||||||
en_dims.append(last_channel)
|
en_dims.append(last_channel)
|
||||||
|
|
||||||
|
|
||||||
### Building Decoder #####
|
### Building Decoder #####
|
||||||
self.decode_list = nn.ModuleList()
|
self.decode_list = nn.ModuleList()
|
||||||
en_dims.reverse()
|
en_dims.reverse()
|
||||||
de_dims.insert(0, en_dims[0])
|
de_dims.insert(0,en_dims[0])
|
||||||
assert len(en_dims) == len(de_dims) == len(de_blocks) + 1
|
assert len(en_dims) ==len(de_dims) == len(de_blocks)+1
|
||||||
for i in range(len(en_dims) - 1):
|
for i in range(len(en_dims)-1):
|
||||||
self.decode_list.append(
|
self.decode_list.append(
|
||||||
PointNetFeaturePropagation(
|
PointNetFeaturePropagation(de_dims[i]+en_dims[i+1], de_dims[i+1],
|
||||||
de_dims[i] + en_dims[i + 1],
|
blocks=de_blocks[i], groups=groups, res_expansion=res_expansion,
|
||||||
de_dims[i + 1],
|
bias=bias, activation=activation)
|
||||||
blocks=de_blocks[i],
|
|
||||||
groups=groups,
|
|
||||||
res_expansion=res_expansion,
|
|
||||||
bias=bias,
|
|
||||||
activation=activation,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.act = get_activation(activation)
|
self.act = get_activation(activation)
|
||||||
|
@ -449,26 +389,26 @@ class PointMLP(nn.Module):
|
||||||
# class label mapping
|
# class label mapping
|
||||||
self.cls_map = nn.Sequential(
|
self.cls_map = nn.Sequential(
|
||||||
ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation),
|
ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation),
|
||||||
ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation),
|
ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation)
|
||||||
)
|
)
|
||||||
# global max pooling mapping
|
# global max pooling mapping
|
||||||
self.gmp_map_list = nn.ModuleList()
|
self.gmp_map_list = nn.ModuleList()
|
||||||
for en_dim in en_dims:
|
for en_dim in en_dims:
|
||||||
self.gmp_map_list.append(ConvBNReLU1D(en_dim, gmp_dim, bias=bias, activation=activation))
|
self.gmp_map_list.append(ConvBNReLU1D(en_dim, gmp_dim, bias=bias, activation=activation))
|
||||||
self.gmp_map_end = ConvBNReLU1D(gmp_dim * len(en_dims), gmp_dim, bias=bias, activation=activation)
|
self.gmp_map_end = ConvBNReLU1D(gmp_dim*len(en_dims), gmp_dim, bias=bias, activation=activation)
|
||||||
|
|
||||||
# classifier
|
# classifier
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.Conv1d(gmp_dim + cls_dim + de_dims[-1], 128, 1, bias=bias),
|
nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias),
|
||||||
nn.BatchNorm1d(128),
|
nn.BatchNorm1d(128),
|
||||||
nn.Dropout(),
|
nn.Dropout(),
|
||||||
nn.Conv1d(128, num_classes, 1, bias=bias),
|
nn.Conv1d(128, num_classes, 1, bias=bias)
|
||||||
)
|
)
|
||||||
self.en_dims = en_dims
|
self.en_dims = en_dims
|
||||||
|
|
||||||
def forward(self, x, norm_plt, cls_label):
|
def forward(self, x, norm_plt, cls_label):
|
||||||
xyz = x.permute(0, 2, 1)
|
xyz = x.permute(0, 2, 1)
|
||||||
x = torch.cat([x, norm_plt], dim=1)
|
x = torch.cat([x,norm_plt],dim=1)
|
||||||
x = self.embedding(x) # B,D,N
|
x = self.embedding(x) # B,D,N
|
||||||
|
|
||||||
xyz_list = [xyz] # [B, N, 3]
|
xyz_list = [xyz] # [B, N, 3]
|
||||||
|
@ -488,55 +428,37 @@ class PointMLP(nn.Module):
|
||||||
x_list.reverse()
|
x_list.reverse()
|
||||||
x = x_list[0]
|
x = x_list[0]
|
||||||
for i in range(len(self.decode_list)):
|
for i in range(len(self.decode_list)):
|
||||||
x = self.decode_list[i](xyz_list[i + 1], xyz_list[i], x_list[i + 1], x)
|
x = self.decode_list[i](xyz_list[i+1], xyz_list[i], x_list[i+1],x)
|
||||||
|
|
||||||
# here is the global context
|
# here is the global context
|
||||||
gmp_list = []
|
gmp_list = []
|
||||||
for i in range(len(x_list)):
|
for i in range(len(x_list)):
|
||||||
gmp_list.append(F.adaptive_max_pool1d(self.gmp_map_list[i](x_list[i]), 1))
|
gmp_list.append(F.adaptive_max_pool1d(self.gmp_map_list[i](x_list[i]), 1))
|
||||||
global_context = self.gmp_map_end(torch.cat(gmp_list, dim=1)) # [b, gmp_dim, 1]
|
global_context = self.gmp_map_end(torch.cat(gmp_list, dim=1)) # [b, gmp_dim, 1]
|
||||||
|
|
||||||
# here is the cls_token
|
#here is the cls_token
|
||||||
cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1]
|
cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1]
|
||||||
x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1)
|
x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
x = F.log_softmax(x, dim=1)
|
x = F.log_softmax(x, dim=1)
|
||||||
return x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def pointMLP(num_classes=50, **kwargs) -> PointMLP:
|
def pointMLP(num_classes=50, **kwargs) -> PointMLP:
|
||||||
return PointMLP(
|
return PointMLP(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0,
|
||||||
num_classes=num_classes,
|
activation="relu", bias=True, use_xyz=True, normalize="anchor",
|
||||||
points=2048,
|
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
|
||||||
embed_dim=64,
|
k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4],
|
||||||
groups=1,
|
de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4],
|
||||||
res_expansion=1.0,
|
gmp_dim=64,cls_dim=64, **kwargs)
|
||||||
activation="relu",
|
|
||||||
bias=True,
|
|
||||||
use_xyz=True,
|
|
||||||
normalize="anchor",
|
|
||||||
dim_expansion=[2, 2, 2, 2],
|
|
||||||
pre_blocks=[2, 2, 2, 2],
|
|
||||||
pos_blocks=[2, 2, 2, 2],
|
|
||||||
k_neighbors=[32, 32, 32, 32],
|
|
||||||
reducers=[4, 4, 4, 4],
|
|
||||||
de_dims=[512, 256, 128, 128],
|
|
||||||
de_blocks=[4, 4, 4, 4],
|
|
||||||
gmp_dim=64,
|
|
||||||
cls_dim=64,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
data = torch.rand(2, 3, 2048).cuda()
|
data = torch.rand(2, 3, 2048)
|
||||||
norm = torch.rand(2, 3, 2048).cuda()
|
norm = torch.rand(2, 3, 2048)
|
||||||
cls_label = torch.rand([2, 16]).cuda()
|
cls_label = torch.rand([2, 16])
|
||||||
print(f"data shape: {data.shape}")
|
print("===> testing modelD ...")
|
||||||
print(f"norm shape: {norm.shape}")
|
model = pointMLP(50)
|
||||||
print(f"cls_label shape: {cls_label.shape}")
|
out = model(data, cls_label) # [2,2048,50]
|
||||||
|
print(out.shape)
|
||||||
print("===> testing pointMLP (segmentation) ...")
|
|
||||||
model = pointMLP(50).cuda()
|
|
||||||
out = model(data, norm, cls_label) # [2,2048,50]
|
|
||||||
print(f"out shape: {out.shape}")
|
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
import glob
|
import glob
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
import os
|
||||||
|
import json
|
||||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
|
|
||||||
def load_data(partition):
|
def load_data(partition):
|
||||||
all_data = []
|
all_data = []
|
||||||
all_label = []
|
all_label = []
|
||||||
for h5_name in glob.glob("./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5" % partition):
|
for h5_name in glob.glob('./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5' % partition):
|
||||||
f = h5py.File(h5_name)
|
f = h5py.File(h5_name)
|
||||||
data = f["data"][:].astype("float32")
|
data = f['data'][:].astype('float32')
|
||||||
label = f["label"][:].astype("int64")
|
label = f['label'][:].astype('int64')
|
||||||
f.close()
|
f.close()
|
||||||
all_data.append(data)
|
all_data.append(data)
|
||||||
all_label.append(label)
|
all_label.append(label)
|
||||||
|
@ -27,34 +25,36 @@ def load_data(partition):
|
||||||
def pc_normalize(pc):
|
def pc_normalize(pc):
|
||||||
centroid = np.mean(pc, axis=0)
|
centroid = np.mean(pc, axis=0)
|
||||||
pc = pc - centroid
|
pc = pc - centroid
|
||||||
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
||||||
return pc / m
|
pc = pc / m
|
||||||
|
return pc
|
||||||
|
|
||||||
|
|
||||||
def translate_pointcloud(pointcloud):
|
def translate_pointcloud(pointcloud):
|
||||||
xyz1 = np.random.uniform(low=2.0 / 3.0, high=3.0 / 2.0, size=[3])
|
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
|
||||||
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
|
||||||
|
|
||||||
return np.add(np.multiply(pointcloud, xyz1), xyz2).astype("float32")
|
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
|
||||||
|
return translated_pointcloud
|
||||||
|
|
||||||
|
|
||||||
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
|
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
|
||||||
N, C = pointcloud.shape
|
N, C = pointcloud.shape
|
||||||
pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
|
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
|
||||||
return pointcloud
|
return pointcloud
|
||||||
|
|
||||||
|
|
||||||
# =========== ModelNet40 =================
|
# =========== ModelNet40 =================
|
||||||
class ModelNet40(Dataset):
|
class ModelNet40(Dataset):
|
||||||
def __init__(self, num_points, partition="train"):
|
def __init__(self, num_points, partition='train'):
|
||||||
self.data, self.label = load_data(partition)
|
self.data, self.label = load_data(partition)
|
||||||
self.num_points = num_points
|
self.num_points = num_points
|
||||||
self.partition = partition # Here the new given partition will cover the 'train'
|
self.partition = partition # Here the new given partition will cover the 'train'
|
||||||
|
|
||||||
def __getitem__(self, item): # indice of the pts or label
|
def __getitem__(self, item): # indice of the pts or label
|
||||||
pointcloud = self.data[item][: self.num_points]
|
pointcloud = self.data[item][:self.num_points]
|
||||||
label = self.label[item]
|
label = self.label[item]
|
||||||
if self.partition == "train":
|
if self.partition == 'train':
|
||||||
# pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model
|
# pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model
|
||||||
pointcloud = translate_pointcloud(pointcloud)
|
pointcloud = translate_pointcloud(pointcloud)
|
||||||
np.random.shuffle(pointcloud) # shuffle the order of pts
|
np.random.shuffle(pointcloud) # shuffle the order of pts
|
||||||
|
@ -66,46 +66,46 @@ class ModelNet40(Dataset):
|
||||||
|
|
||||||
# =========== ShapeNet Part =================
|
# =========== ShapeNet Part =================
|
||||||
class PartNormalDataset(Dataset):
|
class PartNormalDataset(Dataset):
|
||||||
def __init__(self, npoints=2500, split="train", normalize=False):
|
def __init__(self, npoints=2500, split='train', normalize=False):
|
||||||
self.npoints = npoints
|
self.npoints = npoints
|
||||||
self.root = "./data/shapenetcore_partanno_segmentation_benchmark_v0_normal"
|
self.root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal'
|
||||||
self.catfile = os.path.join(self.root, "synsetoffset2category.txt")
|
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
|
||||||
self.cat = {}
|
self.cat = {}
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
|
||||||
with open(self.catfile) as f:
|
with open(self.catfile, 'r') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
ls = line.strip().split()
|
ls = line.strip().split()
|
||||||
self.cat[ls[0]] = ls[1]
|
self.cat[ls[0]] = ls[1]
|
||||||
self.cat = {k: v for k, v in self.cat.items()}
|
self.cat = {k: v for k, v in self.cat.items()}
|
||||||
|
|
||||||
self.meta = {}
|
self.meta = {}
|
||||||
with open(os.path.join(self.root, "train_test_split", "shuffled_train_file_list.json")) as f:
|
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
|
||||||
train_ids = set([str(d.split("/")[2]) for d in json.load(f)])
|
train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
|
||||||
with open(os.path.join(self.root, "train_test_split", "shuffled_val_file_list.json")) as f:
|
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
|
||||||
val_ids = set([str(d.split("/")[2]) for d in json.load(f)])
|
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
|
||||||
with open(os.path.join(self.root, "train_test_split", "shuffled_test_file_list.json")) as f:
|
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
|
||||||
test_ids = set([str(d.split("/")[2]) for d in json.load(f)])
|
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
|
||||||
for item in self.cat:
|
for item in self.cat:
|
||||||
self.meta[item] = []
|
self.meta[item] = []
|
||||||
dir_point = os.path.join(self.root, self.cat[item])
|
dir_point = os.path.join(self.root, self.cat[item])
|
||||||
fns = sorted(os.listdir(dir_point))
|
fns = sorted(os.listdir(dir_point))
|
||||||
|
|
||||||
if split == "trainval":
|
if split == 'trainval':
|
||||||
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
|
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
|
||||||
elif split == "train":
|
elif split == 'train':
|
||||||
fns = [fn for fn in fns if fn[0:-4] in train_ids]
|
fns = [fn for fn in fns if fn[0:-4] in train_ids]
|
||||||
elif split == "val":
|
elif split == 'val':
|
||||||
fns = [fn for fn in fns if fn[0:-4] in val_ids]
|
fns = [fn for fn in fns if fn[0:-4] in val_ids]
|
||||||
elif split == "test":
|
elif split == 'test':
|
||||||
fns = [fn for fn in fns if fn[0:-4] in test_ids]
|
fns = [fn for fn in fns if fn[0:-4] in test_ids]
|
||||||
else:
|
else:
|
||||||
print("Unknown split: %s. Exiting.." % (split))
|
print('Unknown split: %s. Exiting..' % (split))
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
token = os.path.splitext(os.path.basename(fn))[0]
|
token = (os.path.splitext(os.path.basename(fn))[0])
|
||||||
self.meta[item].append(os.path.join(dir_point, token + ".txt"))
|
self.meta[item].append(os.path.join(dir_point, token + '.txt'))
|
||||||
|
|
||||||
self.datapath = []
|
self.datapath = []
|
||||||
for item in self.cat:
|
for item in self.cat:
|
||||||
|
@ -114,24 +114,11 @@ class PartNormalDataset(Dataset):
|
||||||
|
|
||||||
self.classes = dict(zip(self.cat, range(len(self.cat))))
|
self.classes = dict(zip(self.cat, range(len(self.cat))))
|
||||||
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
|
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
|
||||||
self.seg_classes = {
|
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
|
||||||
"Earphone": [16, 17, 18],
|
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
|
||||||
"Motorbike": [30, 31, 32, 33, 34, 35],
|
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
|
||||||
"Rocket": [41, 42, 43],
|
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
|
||||||
"Car": [8, 9, 10, 11],
|
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
|
||||||
"Laptop": [28, 29],
|
|
||||||
"Cap": [6, 7],
|
|
||||||
"Skateboard": [44, 45, 46],
|
|
||||||
"Mug": [36, 37],
|
|
||||||
"Guitar": [19, 20, 21],
|
|
||||||
"Bag": [4, 5],
|
|
||||||
"Lamp": [24, 25, 26, 27],
|
|
||||||
"Table": [47, 48, 49],
|
|
||||||
"Airplane": [0, 1, 2, 3],
|
|
||||||
"Pistol": [38, 39, 40],
|
|
||||||
"Chair": [12, 13, 14, 15],
|
|
||||||
"Knife": [22, 23],
|
|
||||||
}
|
|
||||||
|
|
||||||
self.cache = {} # from index to (point_set, cls, seg) tuple
|
self.cache = {} # from index to (point_set, cls, seg) tuple
|
||||||
self.cache_size = 20000
|
self.cache_size = 20000
|
||||||
|
@ -169,9 +156,9 @@ class PartNormalDataset(Dataset):
|
||||||
return len(self.datapath)
|
return len(self.datapath)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
train = PartNormalDataset(npoints=2048, split="trainval", normalize=False)
|
train = PartNormalDataset(npoints=2048, split='trainval', normalize=False)
|
||||||
test = PartNormalDataset(npoints=2048, split="test", normalize=False)
|
test = PartNormalDataset(npoints=2048, split='test', normalize=False)
|
||||||
for data, label, _, _ in train:
|
for data, label, _, _ in train:
|
||||||
print(data.shape)
|
print(data.shape)
|
||||||
print(label.shape)
|
print(label.shape)
|
||||||
|
|
|
@ -4,8 +4,9 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def cal_loss(pred, gold, smoothing=True):
|
def cal_loss(pred, gold, smoothing=True):
|
||||||
"""Calculate cross entropy loss, apply label smoothing if needed."""
|
''' Calculate cross entropy loss, apply label smoothing if needed. '''
|
||||||
gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
|
|
||||||
|
gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
|
||||||
|
|
||||||
if smoothing:
|
if smoothing:
|
||||||
eps = 0.2
|
eps = 0.2
|
||||||
|
@ -17,19 +18,19 @@ def cal_loss(pred, gold, smoothing=True):
|
||||||
|
|
||||||
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
||||||
else:
|
else:
|
||||||
loss = F.cross_entropy(pred, gold, reduction="mean")
|
loss = F.cross_entropy(pred, gold, reduction='mean')
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# create a file and write the text into it:
|
# create a file and write the text into it:
|
||||||
class IOStream:
|
class IOStream():
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.f = open(path, "a")
|
self.f = open(path, 'a')
|
||||||
|
|
||||||
def cprint(self, text):
|
def cprint(self, text):
|
||||||
print(text)
|
print(text)
|
||||||
self.f.write(text + "\n")
|
self.f.write(text+'\n')
|
||||||
self.f.flush()
|
self.f.flush()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -37,24 +38,22 @@ class IOStream:
|
||||||
|
|
||||||
|
|
||||||
def to_categorical(y, num_classes):
|
def to_categorical(y, num_classes):
|
||||||
"""1-hot encodes a tensor."""
|
""" 1-hot encodes a tensor """
|
||||||
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
|
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
|
||||||
if y.is_cuda:
|
if (y.is_cuda):
|
||||||
return new_y.cuda(non_blocking=True)
|
return new_y.cuda(non_blocking=True)
|
||||||
return new_y
|
return new_y
|
||||||
|
|
||||||
|
|
||||||
def compute_overall_iou(pred, target, num_classes):
|
def compute_overall_iou(pred, target, num_classes):
|
||||||
shape_ious = []
|
shape_ious = []
|
||||||
pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample
|
pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample
|
||||||
pred_np = pred.cpu().data.numpy()
|
pred_np = pred.cpu().data.numpy()
|
||||||
|
|
||||||
target_np = target.cpu().data.numpy()
|
target_np = target.cpu().data.numpy()
|
||||||
for shape_idx in range(pred.size(0)): # sample_idx
|
for shape_idx in range(pred.size(0)): # sample_idx
|
||||||
part_ious = []
|
part_ious = []
|
||||||
for part in range(
|
for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
|
||||||
num_classes,
|
|
||||||
): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
|
|
||||||
# for target, each point has a class no matter which category owns this point! also 50 classes!!!
|
# for target, each point has a class no matter which category owns this point! also 50 classes!!!
|
||||||
# only return 1 when both belongs to this class, which means correct:
|
# only return 1 when both belongs to this class, which means correct:
|
||||||
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
|
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
|
||||||
|
@ -64,9 +63,7 @@ def compute_overall_iou(pred, target, num_classes):
|
||||||
F = np.sum(target_np[shape_idx] == part)
|
F = np.sum(target_np[shape_idx] == part)
|
||||||
|
|
||||||
if F != 0:
|
if F != 0:
|
||||||
iou = I / float(U) # iou across all points for this class
|
iou = I / float(U) # iou across all points for this class
|
||||||
part_ious.append(iou) # append the iou of this class
|
part_ious.append(iou) # append the iou of this class
|
||||||
shape_ious.append(
|
shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!)
|
||||||
np.mean(part_ious),
|
return shape_ious # [batch_size]
|
||||||
) # each time append an average iou across all classes of this sample (sample_level!)
|
|
||||||
return shape_ious # [batch_size]
|
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from pointnet2_ops import pointnet2_utils
|
from pointnet2_ops import pointnet2_utils
|
||||||
|
|
||||||
|
|
||||||
def build_shared_mlp(mlp_spec: list[int], bn: bool = True):
|
def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(1, len(mlp_spec)):
|
for i in range(1, len(mlp_spec)):
|
||||||
layers.append(
|
layers.append(
|
||||||
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn),
|
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
|
||||||
)
|
)
|
||||||
if bn:
|
if bn:
|
||||||
layers.append(nn.BatchNorm2d(mlp_spec[i]))
|
layers.append(nn.BatchNorm2d(mlp_spec[i]))
|
||||||
|
@ -20,37 +21,36 @@ def build_shared_mlp(mlp_spec: list[int], bn: bool = True):
|
||||||
|
|
||||||
class _PointnetSAModuleBase(nn.Module):
|
class _PointnetSAModuleBase(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super(_PointnetSAModuleBase, self).__init__()
|
||||||
self.npoint = None
|
self.npoint = None
|
||||||
self.groupers = None
|
self.groupers = None
|
||||||
self.mlps = None
|
self.mlps = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, xyz: torch.Tensor, features: Optional[torch.Tensor]
|
||||||
xyz: torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
features: torch.Tensor | None,
|
r"""
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
Parameters
|
||||||
r"""Parameters
|
|
||||||
----------
|
----------
|
||||||
xyz : torch.Tensor
|
xyz : torch.Tensor
|
||||||
(B, N, 3) tensor of the xyz coordinates of the features
|
(B, N, 3) tensor of the xyz coordinates of the features
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
(B, C, N) tensor of the descriptors of the the features
|
(B, C, N) tensor of the descriptors of the the features
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
new_xyz : torch.Tensor
|
new_xyz : torch.Tensor
|
||||||
(B, npoint, 3) tensor of the new features' xyz
|
(B, npoint, 3) tensor of the new features' xyz
|
||||||
new_features : torch.Tensor
|
new_features : torch.Tensor
|
||||||
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
|
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
new_features_list = []
|
new_features_list = []
|
||||||
|
|
||||||
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
||||||
new_xyz = (
|
new_xyz = (
|
||||||
pointnet2_utils.gather_operation(
|
pointnet2_utils.gather_operation(
|
||||||
xyz_flipped,
|
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
||||||
pointnet2_utils.furthest_point_sample(xyz, self.npoint),
|
|
||||||
)
|
)
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
|
@ -60,15 +60,12 @@ class _PointnetSAModuleBase(nn.Module):
|
||||||
|
|
||||||
for i in range(len(self.groupers)):
|
for i in range(len(self.groupers)):
|
||||||
new_features = self.groupers[i](
|
new_features = self.groupers[i](
|
||||||
xyz,
|
xyz, new_xyz, features
|
||||||
new_xyz,
|
|
||||||
features,
|
|
||||||
) # (B, C, npoint, nsample)
|
) # (B, C, npoint, nsample)
|
||||||
|
|
||||||
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
||||||
new_features = F.max_pool2d(
|
new_features = F.max_pool2d(
|
||||||
new_features,
|
new_features, kernel_size=[1, new_features.size(3)]
|
||||||
kernel_size=[1, new_features.size(3)],
|
|
||||||
) # (B, mlp[-1], npoint, 1)
|
) # (B, mlp[-1], npoint, 1)
|
||||||
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
||||||
|
|
||||||
|
@ -78,7 +75,7 @@ class _PointnetSAModuleBase(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
||||||
r"""Pointnet set abstrction layer with multiscale grouping.
|
r"""Pointnet set abstrction layer with multiscale grouping
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -96,7 +93,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
||||||
|
|
||||||
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
|
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
|
||||||
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
|
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
|
||||||
super().__init__()
|
super(PointnetSAModuleMSG, self).__init__()
|
||||||
|
|
||||||
assert len(radii) == len(nsamples) == len(mlps)
|
assert len(radii) == len(nsamples) == len(mlps)
|
||||||
|
|
||||||
|
@ -109,7 +106,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
||||||
self.groupers.append(
|
self.groupers.append(
|
||||||
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
||||||
if npoint is not None
|
if npoint is not None
|
||||||
else pointnet2_utils.GroupAll(use_xyz),
|
else pointnet2_utils.GroupAll(use_xyz)
|
||||||
)
|
)
|
||||||
mlp_spec = mlps[i]
|
mlp_spec = mlps[i]
|
||||||
if use_xyz:
|
if use_xyz:
|
||||||
|
@ -119,7 +116,7 @@ class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
||||||
|
|
||||||
|
|
||||||
class PointnetSAModule(PointnetSAModuleMSG):
|
class PointnetSAModule(PointnetSAModuleMSG):
|
||||||
r"""Pointnet set abstrction layer.
|
r"""Pointnet set abstrction layer
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -136,16 +133,10 @@ class PointnetSAModule(PointnetSAModuleMSG):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
|
||||||
mlp,
|
|
||||||
npoint=None,
|
|
||||||
radius=None,
|
|
||||||
nsample=None,
|
|
||||||
bn=True,
|
|
||||||
use_xyz=True,
|
|
||||||
):
|
):
|
||||||
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
|
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
|
||||||
super().__init__(
|
super(PointnetSAModule, self).__init__(
|
||||||
mlps=[mlp],
|
mlps=[mlp],
|
||||||
npoint=npoint,
|
npoint=npoint,
|
||||||
radii=[radius],
|
radii=[radius],
|
||||||
|
@ -156,7 +147,7 @@ class PointnetSAModule(PointnetSAModuleMSG):
|
||||||
|
|
||||||
|
|
||||||
class PointnetFPModule(nn.Module):
|
class PointnetFPModule(nn.Module):
|
||||||
r"""Propigates the features of one set to another.
|
r"""Propigates the features of one set to another
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -168,12 +159,13 @@ class PointnetFPModule(nn.Module):
|
||||||
|
|
||||||
def __init__(self, mlp, bn=True):
|
def __init__(self, mlp, bn=True):
|
||||||
# type: (PointnetFPModule, List[int], bool) -> None
|
# type: (PointnetFPModule, List[int], bool) -> None
|
||||||
super().__init__()
|
super(PointnetFPModule, self).__init__()
|
||||||
self.mlp = build_shared_mlp(mlp, bn=bn)
|
self.mlp = build_shared_mlp(mlp, bn=bn)
|
||||||
|
|
||||||
def forward(self, unknown, known, unknow_feats, known_feats):
|
def forward(self, unknown, known, unknow_feats, known_feats):
|
||||||
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
|
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
|
||||||
r"""Parameters
|
r"""
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
unknown : torch.Tensor
|
unknown : torch.Tensor
|
||||||
(B, n, 3) tensor of the xyz positions of the unknown features
|
(B, n, 3) tensor of the xyz positions of the unknown features
|
||||||
|
@ -184,11 +176,12 @@ class PointnetFPModule(nn.Module):
|
||||||
known_feats : torch.Tensor
|
known_feats : torch.Tensor
|
||||||
(B, C2, m) tensor of features to be propigated
|
(B, C2, m) tensor of features to be propigated
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
new_features : torch.Tensor
|
new_features : torch.Tensor
|
||||||
(B, mlp[-1], n) tensor of the features of the unknown features
|
(B, mlp[-1], n) tensor of the features of the unknown features
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if known is not None:
|
if known is not None:
|
||||||
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
||||||
dist_recip = 1.0 / (dist + 1e-8)
|
dist_recip = 1.0 / (dist + 1e-8)
|
||||||
|
@ -196,19 +189,16 @@ class PointnetFPModule(nn.Module):
|
||||||
weight = dist_recip / norm
|
weight = dist_recip / norm
|
||||||
|
|
||||||
interpolated_feats = pointnet2_utils.three_interpolate(
|
interpolated_feats = pointnet2_utils.three_interpolate(
|
||||||
known_feats,
|
known_feats, idx, weight
|
||||||
idx,
|
|
||||||
weight,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
interpolated_feats = known_feats.expand(
|
interpolated_feats = known_feats.expand(
|
||||||
*(known_feats.size()[0:2] + [unknown.size(1)]),
|
*(known_feats.size()[0:2] + [unknown.size(1)])
|
||||||
)
|
)
|
||||||
|
|
||||||
if unknow_feats is not None:
|
if unknow_feats is not None:
|
||||||
new_features = torch.cat(
|
new_features = torch.cat(
|
||||||
[interpolated_feats, unknow_feats],
|
[interpolated_feats, unknow_feats], dim=1
|
||||||
dim=1,
|
|
||||||
) # (B, C2 + C1, n)
|
) # (B, C2 + C1, n)
|
||||||
else:
|
else:
|
||||||
new_features = interpolated_feats
|
new_features = interpolated_feats
|
||||||
|
|
|
@ -1,24 +1,22 @@
|
||||||
import warnings
|
|
||||||
from typing import *
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import warnings
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
from typing import *
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pointnet2_ops._ext as _ext
|
import pointnet2_ops._ext as _ext
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
|
import glob
|
||||||
|
import os.path as osp
|
||||||
|
import os
|
||||||
|
|
||||||
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
|
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
|
||||||
|
|
||||||
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
|
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
|
||||||
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
||||||
osp.join(_ext_src_root, "src", "*.cu"),
|
osp.join(_ext_src_root, "src", "*.cu")
|
||||||
)
|
)
|
||||||
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
||||||
|
|
||||||
|
@ -37,8 +35,9 @@ class FurthestPointSampling(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, xyz, npoint):
|
def forward(ctx, xyz, npoint):
|
||||||
# type: (Any, torch.Tensor, int) -> torch.Tensor
|
# type: (Any, torch.Tensor, int) -> torch.Tensor
|
||||||
r"""Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
r"""
|
||||||
minimum distance.
|
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
||||||
|
minimum distance
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -47,7 +46,7 @@ class FurthestPointSampling(Function):
|
||||||
npoint : int32
|
npoint : int32
|
||||||
number of features in the sampled set
|
number of features in the sampled set
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, npoint) tensor containing the set
|
(B, npoint) tensor containing the set
|
||||||
|
@ -70,7 +69,9 @@ class GatherOperation(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, features, idx):
|
def forward(ctx, features, idx):
|
||||||
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
||||||
r"""Parameters
|
r"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
(B, C, N) tensor
|
(B, C, N) tensor
|
||||||
|
@ -78,11 +79,12 @@ class GatherOperation(Function):
|
||||||
idx : torch.Tensor
|
idx : torch.Tensor
|
||||||
(B, npoint) tensor of the features to gather
|
(B, npoint) tensor of the features to gather
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, C, npoint) tensor
|
(B, C, npoint) tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctx.save_for_backward(idx, features)
|
ctx.save_for_backward(idx, features)
|
||||||
|
|
||||||
return _ext.gather_points(features, idx)
|
return _ext.gather_points(features, idx)
|
||||||
|
@ -103,15 +105,16 @@ class ThreeNN(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, unknown, known):
|
def forward(ctx, unknown, known):
|
||||||
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
||||||
r"""Find the three nearest neighbors of unknown in known
|
r"""
|
||||||
|
Find the three nearest neighbors of unknown in known
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
unknown : torch.Tensor
|
unknown : torch.Tensor
|
||||||
(B, n, 3) tensor of known features
|
(B, n, 3) tensor of known features
|
||||||
known : torch.Tensor
|
known : torch.Tensor
|
||||||
(B, m, 3) tensor of unknown features.
|
(B, m, 3) tensor of unknown features
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
dist : torch.Tensor
|
dist : torch.Tensor
|
||||||
(B, n, 3) l2 distance to the three nearest neighbors
|
(B, n, 3) l2 distance to the three nearest neighbors
|
||||||
|
@ -137,7 +140,8 @@ class ThreeInterpolate(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, features, idx, weight):
|
def forward(ctx, features, idx, weight):
|
||||||
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
|
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
|
||||||
r"""Performs weight linear interpolation on 3 features
|
r"""
|
||||||
|
Performs weight linear interpolation on 3 features
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
|
@ -145,9 +149,9 @@ class ThreeInterpolate(Function):
|
||||||
idx : torch.Tensor
|
idx : torch.Tensor
|
||||||
(B, n, 3) three nearest neighbors of the target features in features
|
(B, n, 3) three nearest neighbors of the target features in features
|
||||||
weight : torch.Tensor
|
weight : torch.Tensor
|
||||||
(B, n, 3) weights.
|
(B, n, 3) weights
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, c, n) tensor of the interpolated features
|
(B, c, n) tensor of the interpolated features
|
||||||
|
@ -159,12 +163,13 @@ class ThreeInterpolate(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||||
r"""Parameters
|
r"""
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
grad_out : torch.Tensor
|
grad_out : torch.Tensor
|
||||||
(B, c, n) tensor with gradients of ouputs
|
(B, c, n) tensor with gradients of ouputs
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
grad_features : torch.Tensor
|
grad_features : torch.Tensor
|
||||||
(B, c, m) tensor with gradients of features
|
(B, c, m) tensor with gradients of features
|
||||||
|
@ -177,10 +182,7 @@ class ThreeInterpolate(Function):
|
||||||
m = features.size(2)
|
m = features.size(2)
|
||||||
|
|
||||||
grad_features = _ext.three_interpolate_grad(
|
grad_features = _ext.three_interpolate_grad(
|
||||||
grad_out.contiguous(),
|
grad_out.contiguous(), idx, weight, m
|
||||||
idx,
|
|
||||||
weight,
|
|
||||||
m,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
||||||
|
@ -193,14 +195,16 @@ class GroupingOperation(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, features, idx):
|
def forward(ctx, features, idx):
|
||||||
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
||||||
r"""Parameters
|
r"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
(B, C, N) tensor of features to group
|
(B, C, N) tensor of features to group
|
||||||
idx : torch.Tensor
|
idx : torch.Tensor
|
||||||
(B, npoint, nsample) tensor containing the indicies of features to group with
|
(B, npoint, nsample) tensor containing the indicies of features to group with
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, C, npoint, nsample) tensor
|
(B, C, npoint, nsample) tensor
|
||||||
|
@ -212,12 +216,14 @@ class GroupingOperation(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
||||||
r"""Parameters
|
r"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
grad_out : torch.Tensor
|
grad_out : torch.Tensor
|
||||||
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, C, N) gradient of the features
|
(B, C, N) gradient of the features
|
||||||
|
@ -238,7 +244,9 @@ class BallQuery(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, radius, nsample, xyz, new_xyz):
|
def forward(ctx, radius, nsample, xyz, new_xyz):
|
||||||
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
|
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
|
||||||
r"""Parameters
|
r"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
radius : float
|
radius : float
|
||||||
radius of the balls
|
radius of the balls
|
||||||
|
@ -249,7 +257,7 @@ class BallQuery(Function):
|
||||||
new_xyz : torch.Tensor
|
new_xyz : torch.Tensor
|
||||||
(B, npoint, 3) centers of the ball query
|
(B, npoint, 3) centers of the ball query
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
||||||
|
@ -269,7 +277,8 @@ ball_query = BallQuery.apply
|
||||||
|
|
||||||
|
|
||||||
class QueryAndGroup(nn.Module):
|
class QueryAndGroup(nn.Module):
|
||||||
r"""Groups with a ball query of radius.
|
r"""
|
||||||
|
Groups with a ball query of radius
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
---------
|
---------
|
||||||
|
@ -281,12 +290,13 @@ class QueryAndGroup(nn.Module):
|
||||||
|
|
||||||
def __init__(self, radius, nsample, use_xyz=True):
|
def __init__(self, radius, nsample, use_xyz=True):
|
||||||
# type: (QueryAndGroup, float, int, bool) -> None
|
# type: (QueryAndGroup, float, int, bool) -> None
|
||||||
super().__init__()
|
super(QueryAndGroup, self).__init__()
|
||||||
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
||||||
|
|
||||||
def forward(self, xyz, new_xyz, features=None):
|
def forward(self, xyz, new_xyz, features=None):
|
||||||
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
|
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
|
||||||
r"""Parameters
|
r"""
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
xyz : torch.Tensor
|
xyz : torch.Tensor
|
||||||
xyz coordinates of the features (B, N, 3)
|
xyz coordinates of the features (B, N, 3)
|
||||||
|
@ -295,11 +305,12 @@ class QueryAndGroup(nn.Module):
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Descriptors of the features (B, C, N)
|
Descriptors of the features (B, C, N)
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
new_features : torch.Tensor
|
new_features : torch.Tensor
|
||||||
(B, 3 + C, npoint, nsample) tensor
|
(B, 3 + C, npoint, nsample) tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
||||||
xyz_trans = xyz.transpose(1, 2).contiguous()
|
xyz_trans = xyz.transpose(1, 2).contiguous()
|
||||||
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
||||||
|
@ -309,20 +320,22 @@ class QueryAndGroup(nn.Module):
|
||||||
grouped_features = grouping_operation(features, idx)
|
grouped_features = grouping_operation(features, idx)
|
||||||
if self.use_xyz:
|
if self.use_xyz:
|
||||||
new_features = torch.cat(
|
new_features = torch.cat(
|
||||||
[grouped_xyz, grouped_features],
|
[grouped_xyz, grouped_features], dim=1
|
||||||
dim=1,
|
|
||||||
) # (B, C + 3, npoint, nsample)
|
) # (B, C + 3, npoint, nsample)
|
||||||
else:
|
else:
|
||||||
new_features = grouped_features
|
new_features = grouped_features
|
||||||
else:
|
else:
|
||||||
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
assert (
|
||||||
|
self.use_xyz
|
||||||
|
), "Cannot have not features and not use xyz as a feature!"
|
||||||
new_features = grouped_xyz
|
new_features = grouped_xyz
|
||||||
|
|
||||||
return new_features
|
return new_features
|
||||||
|
|
||||||
|
|
||||||
class GroupAll(nn.Module):
|
class GroupAll(nn.Module):
|
||||||
r"""Groups all features.
|
r"""
|
||||||
|
Groups all features
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
---------
|
---------
|
||||||
|
@ -330,12 +343,13 @@ class GroupAll(nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_xyz=True):
|
def __init__(self, use_xyz=True):
|
||||||
# type: (GroupAll, bool) -> None
|
# type: (GroupAll, bool) -> None
|
||||||
super().__init__()
|
super(GroupAll, self).__init__()
|
||||||
self.use_xyz = use_xyz
|
self.use_xyz = use_xyz
|
||||||
|
|
||||||
def forward(self, xyz, new_xyz, features=None):
|
def forward(self, xyz, new_xyz, features=None):
|
||||||
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
|
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
|
||||||
r"""Parameters
|
r"""
|
||||||
|
Parameters
|
||||||
----------
|
----------
|
||||||
xyz : torch.Tensor
|
xyz : torch.Tensor
|
||||||
xyz coordinates of the features (B, N, 3)
|
xyz coordinates of the features (B, N, 3)
|
||||||
|
@ -344,18 +358,18 @@ class GroupAll(nn.Module):
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Descriptors of the features (B, C, N)
|
Descriptors of the features (B, C, N)
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
-------
|
-------
|
||||||
new_features : torch.Tensor
|
new_features : torch.Tensor
|
||||||
(B, C + 3, 1, N) tensor
|
(B, C + 3, 1, N) tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
||||||
if features is not None:
|
if features is not None:
|
||||||
grouped_features = features.unsqueeze(2)
|
grouped_features = features.unsqueeze(2)
|
||||||
if self.use_xyz:
|
if self.use_xyz:
|
||||||
new_features = torch.cat(
|
new_features = torch.cat(
|
||||||
[grouped_xyz, grouped_features],
|
[grouped_xyz, grouped_features], dim=1
|
||||||
dim=1,
|
|
||||||
) # (B, 3 + C, 1, N)
|
) # (B, 3 + C, 1, N)
|
||||||
else:
|
else:
|
||||||
new_features = grouped_features
|
new_features = grouped_features
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
this_dir = osp.dirname(osp.abspath(__file__))
|
this_dir = osp.dirname(osp.abspath(__file__))
|
||||||
_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
|
_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
|
||||||
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
||||||
osp.join(_ext_src_root, "src", "*.cu"),
|
osp.join(_ext_src_root, "src", "*.cu")
|
||||||
)
|
)
|
||||||
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ setup(
|
||||||
"nvcc": ["-O3", "-Xfatbin", "-compress-all"],
|
"nvcc": ["-O3", "-Xfatbin", "-compress-all"],
|
||||||
},
|
},
|
||||||
include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
|
include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
|
||||||
),
|
)
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
[tool.ruff]
|
|
||||||
line-length = 120
|
|
||||||
ignore-init-module-imports = true
|
|
||||||
ignore = [
|
|
||||||
"G004", # Logging statement uses f-string
|
|
||||||
"EM102", # Exception must not use an f-string literal, assign to variable first
|
|
||||||
"D100", # Missing docstring in public module
|
|
||||||
"D104", # Missing docstring in public package
|
|
||||||
"N812", # Lowercase imported as non lowercase
|
|
||||||
]
|
|
||||||
select = [
|
|
||||||
"A", # flake8-builtins
|
|
||||||
"B", # flake8-bugbear
|
|
||||||
"C90", # mccabe
|
|
||||||
"COM", # flake8-commas
|
|
||||||
"D", # pydocstyle
|
|
||||||
"EM", # flake8-errmsg
|
|
||||||
"E", # pycodestyle errors
|
|
||||||
"F", # Pyflakes
|
|
||||||
"G", # flake8-logging-format
|
|
||||||
"I", # isort
|
|
||||||
"N", # pep8-naming
|
|
||||||
"PIE", # flake8-pie
|
|
||||||
"PTH", # flake8-use-pathlib
|
|
||||||
"TD", # flake8-todo
|
|
||||||
"FIX", # flake8-fixme
|
|
||||||
"RET", # flake8-return
|
|
||||||
"RUF", # ruff
|
|
||||||
"S", # flake8-bandit
|
|
||||||
"TCH", # flake8-type-checking
|
|
||||||
"TID", # flake8-tidy-imports
|
|
||||||
"UP", # pyupgrade
|
|
||||||
"W", # pycodestyle warnings
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.pydocstyle]
|
|
||||||
convention = "google"
|
|
||||||
|
|
||||||
[tool.ruff.isort]
|
|
||||||
known-first-party = ["pointnet2_ops"]
|
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
|
||||||
"__init__.py" = ["F401"]
|
|
||||||
"src/aube/main.py" = ["E402", "F401"]
|
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
exclude = '''
|
|
||||||
/(
|
|
||||||
\.git
|
|
||||||
\.venv
|
|
||||||
)/
|
|
||||||
'''
|
|
||||||
include = '\.pyi?$'
|
|
||||||
line-length = 120
|
|
||||||
target-version = ["py310"]
|
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
multi_line_output = 3
|
|
||||||
profile = "black"
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
python_version = "3.10"
|
|
||||||
warn_return_any = true
|
|
||||||
warn_unused_configs = true
|
|
12
requirements.txt
Normal file
12
requirements.txt
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
cudatoolkit
|
||||||
|
cycler
|
||||||
|
einops
|
||||||
|
h5py
|
||||||
|
matplotlib==3.4.2
|
||||||
|
pytorch
|
||||||
|
pyyaml==5.4.1
|
||||||
|
scikit-learn==0.24.2
|
||||||
|
scipy
|
||||||
|
tqdm
|
Loading…
Reference in a new issue