Compare commits

...

10 commits

Author SHA1 Message Date
Laurent FAINSIN 403e04de8a fix: add link to checkpoints 2023-07-18 14:41:59 +02:00
Laurent FAINSIN 8f37d2838b fix: outdated/broken instructions 2023-07-18 14:40:41 +02:00
Laurent FAINSIN f79b9c697b "../exp" -> "./exp" 2023-04-07 13:33:06 +02:00
Laurent FAINSIN 7225eabb15 fix: change batch size to fit on bigviz 2023-04-07 13:32:44 +02:00
Laurent FAINSIN cb3c32d6b1 fix: fucking syntax error, you kidding me ? 2023-04-07 13:32:24 +02:00
Laurent FAINSIN 55902e13bc fix: update README instructions for rosetta 2023-04-07 13:31:51 +02:00
Laurent FAINSIN c67e7faf35 fix: missed some deps 2023-04-07 13:31:31 +02:00
Laurent FAINSIN c0d5c6754b fix: apply same patches as PVD 2023-04-07 10:15:59 +02:00
Laurent FAINSIN 231ff196c7 fix: update caveman conda env 2023-04-07 10:15:37 +02:00
xzeng 0467d21990 try to fix job hang issue
reference: https://github.com/nv-tlabs/LION/issues/32#issuecomment-1496997294
2023-04-05 03:21:20 -04:00
13 changed files with 104 additions and 368 deletions

View file

@ -27,13 +27,17 @@
* Setup the environment
Install from conda file
```
conda env create --name lion_env --file=env.yaml
conda activate lion_env
mamba env create -f environment.yml
# mamba env update -f environment.yml
conda activate LION
# Install some other packages
# Install some other packages (use proxy)
pip install git+https://github.com/openai/CLIP.git
# build some packages first (optional)
export CUDA_HOME=/usr/local/cuda # just in case rosetta cucks you
module load compilers
module load mpfr
python build_pkg.py
```
Tested with conda version 22.9.0
@ -44,7 +48,7 @@
## Demo
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. (Note: the checkpoint is not released yet, the files loaded in the `demo.py` file is not available at this point)
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. Download checkpoints from [HuggingFace Hub](https://huggingface.co/xiaohui2022/lion_ckpt)
## Released checkpoint and samples
* will be release soon

View file

@ -1,4 +1,4 @@
bash_name: ../exp/tmp/2022_0407_0300_45.sh
bash_name: ./exp/tmp/2022_0407_0300_45.sh
clipforge:
clip_model: ViT-B/32
enable: 0
@ -105,13 +105,13 @@ latent_pts:
weight_kl_feat: 1.0
weight_kl_glb: 1.0
weight_kl_pt: 1.0
log_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
log_name: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
log_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
log_name: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
model_config: default
ngpu: 8
num_ref: 0
num_val_samples: 24
save_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
save_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
sde:
attn_mhead: 0
attn_mhead_local: -1

View file

@ -1,4 +1,4 @@
bash_name: ../exp/tmp/2022_0407_1347_21.sh
bash_name: ./exp/tmp/2022_0407_1347_21.sh
clipforge:
clip_model: ViT-B/32
enable: 0
@ -105,13 +105,13 @@ latent_pts:
weight_kl_feat: 1.0
weight_kl_glb: 1.0
weight_kl_pt: 1.0
log_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
log_name: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
log_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
log_name: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
model_config: default
ngpu: 8
num_ref: 0
num_val_samples: 24
save_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
save_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
sde:
attn_mhead: 0
attn_mhead_local: -1
@ -195,7 +195,7 @@ sde:
update_q_ema: false
use_adam: true
use_adamax: false
vae_checkpoint: ../exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
vae_checkpoint: ./exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
warmup_epochs: 20
weight_decay: 0.0003
weight_decay_norm_dae: 0.0

View file

@ -1,4 +1,4 @@
bash_name: ../exp/tmp/2022_0416_1418_42.sh
bash_name: ./exp/tmp/2022_0416_1418_42.sh
clipforge:
clip_model: ViT-B/32
enable: 0
@ -105,13 +105,13 @@ latent_pts:
weight_kl_feat: 1.0
weight_kl_glb: 1.0
weight_kl_pt: 1.0
log_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
log_name: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
log_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
log_name: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
model_config: default
ngpu: 8
num_ref: 0
num_val_samples: 24
save_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
save_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
sde:
attn_mhead: 0
attn_mhead_local: -1

View file

@ -8,7 +8,6 @@
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
import os
import open3d as o3d
import time
import torch
import numpy as np
@ -174,7 +173,7 @@ class ShapeNet15kPointClouds(Dataset):
# obj_fname = os.path.join(sub_path, x)
if self.clip_forge_enable:
synset_id = subd
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016')
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
#if not (os.path.exists(render_img_path)): continue

311
env.yaml
View file

@ -1,311 +0,0 @@
name: lion_env
channels:
- pytorch
- nvidia
- anaconda
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- argon2-cffi=20.1.0=py38h27cfd23_1
- async_generator=1.10=pyhd3eb1b0_0
- attrs=21.4.0=pyhd3eb1b0_0
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bleach=4.1.0=pyhd3eb1b0_0
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2020.10.14=0
- certifi=2020.6.20=py38_0
- cffi=1.15.0=py38hd667e15_1
- cmake=3.18.2=ha30ef3c_0
- cudatoolkit=11.1.74=h6bb024c_0
- debugpy=1.5.1=py38h295c915_0
- decorator=5.1.1=pyhd3eb1b0_0
- defusedxml=0.7.1=pyhd3eb1b0_0
- entrypoints=0.3=py38_0
- expat=2.2.10=he6710b0_2
- ffmpeg=4.3=hf484d3e_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- importlib_metadata=4.8.2=hd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=6.4.1=py38h06a4308_1
- ipython=7.31.1=py38h06a4308_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- ipywidgets=7.6.5=pyhd3eb1b0_1
- jedi=0.18.1=py38h06a4308_0
- jpeg=9d=h7f8727e_0
- jupyter_client=7.1.2=pyhd3eb1b0_0
- jupyter_core=4.9.1=py38h06a4308_0
- jupyterlab_pygments=0.1.2=py_0
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
- krb5=1.18.2=h173b8e3_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libcurl=7.71.1=h20c2e04_1
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libssh2=1.9.0=h1ba5d50_1
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.0=h89dd481_0
- libwebp-base=1.2.0=h27cfd23_0
- lz4-c=1.9.3=h295c915_1
- markupsafe=2.0.1=py38h27cfd23_0
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mistune=0.8.4=py38h7b6447c_1000
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- nbclient=0.5.3=pyhd3eb1b0_0
- nbconvert=6.3.0=py38h06a4308_0
- ncurses=6.3=h7f8727e_2
- nest-asyncio=1.5.1=pyhd3eb1b0_0
- nettle=3.7.3=hbbd107a_1
- notebook=6.4.6=py38h06a4308_0
- numpy=1.21.2=py38h20f2e39_0
- numpy-base=1.21.2=py38h79a1101_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1m=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- pandocfilters=1.5.0=pyhd3eb1b0_0
- parso=0.8.3=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h5aabda8_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
- prometheus_client=0.13.1=pyhd3eb1b0_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pycparser=2.21=pyhd3eb1b0_0
- pygments=2.11.2=pyhd3eb1b0_0
- python=3.8.12=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
- pytorch-mutex=1.0=cuda
- pyzmq=22.3.0=py38h295c915_2
- readline=8.1.2=h7f8727e_1
- rhash=1.4.0=h1ba5d50_0
- send2trash=1.8.0=pyhd3eb1b0_1
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.37.2=hc218d9a_0
- terminado=0.9.4=py38h06a4308_0
- testpath=0.5.0=pyhd3eb1b0_0
- tk=8.6.11=h1ccaba5_0
- torchaudio=0.10.2=py38_cu111
- torchvision=0.11.3=py38_cu111
- tornado=6.1=py38h27cfd23_0
- traitlets=5.1.1=pyhd3eb1b0_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- webencodings=0.5.1=py38_1
- wheel=0.37.1=pyhd3eb1b0_0
- widgetsnbextension=3.5.1=py38_0
- xz=5.2.5=h7b6447c_0
- zeromq=4.3.4=h2531618_0
- zipp=3.7.0=pyhd3eb1b0_0
- zlib=1.2.11=h7f8727e_4
- zstd=1.4.5=h9ceee32_0
- pip:
- about-time==3.1.1
- absl-py==1.0.0
- addict==2.4.0
- aiohttp==3.8.1
- aiosignal==1.2.0
- alive-progress==2.2.0
- antlr4-python3-runtime==4.9.3
- anyio==3.5.0
- astunparse==1.6.3
- async-timeout==4.0.2
- babel==2.9.1
- cachetools==5.0.0
- calmsize==0.1.3
- ccimport==0.3.7
- cftime==1.6.0
- charset-normalizer==2.0.11
- click==8.0.3
- colorama==0.4.4
- comet-ml==3.31.21
- commonmark==0.9.1
- configobj==5.0.6
- crc32c==2.2.post0
- cumm-cu111==0.2.8
- cupy-cuda111==10.2.0
- cycler==0.11.0
- cython==0.29.20
- dataclasses==0.6
- deepspeed==0.6.5
- deprecation==2.1.0
- diffusers==0.11.1
- docker-pycreds==0.4.0
- drjit==0.2.1
- dulwich==0.20.32
- easydict==1.9
- einops==0.4.0
- everett==3.0.0
- fastrlock==0.8
- filelock==3.9.0
- fire==0.4.0
- flatbuffers==2.0
- flatten-dict==0.4.2
- fonttools==4.29.1
- freetype-py==2.3.0
- frozenlist==1.3.0
- fsspec==2022.2.0
- ftfy==6.1.1
- future==0.18.2
- fvcore==0.1.5.post20220512
- gast==0.5.3
- gitdb==4.0.9
- gitpython==3.1.26
- google-auth==2.6.0
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grapheme==0.6.0
- grpcio==1.43.0
- h5py==3.6.0
- hjson==3.0.2
- huggingface-hub==0.11.1
- idna==3.3
- imageio==2.15.0
- imageio-ffmpeg==0.4.5
- importlib-metadata==4.10.1
- importlib-resources==5.4.0
- iopath==0.1.10
- jinja2==3.1.1
- joblib==1.1.0
- json5==0.9.6
- jsonschema==4.4.0
- jupyter-packaging==0.12.0
- jupyter-server==1.15.6
- jupyterlab==3.3.2
- jupyterlab-server==2.11.2
- keras==2.8.0
- keras-preprocessing==1.1.2
- kiwisolver==1.3.2
- kornia==0.6.6
- lark==1.1.2
- libclang==14.0.1
- llvmlite==0.39.0
- loguru==0.6.0
- markdown==3.3.6
- matplotlib==3.5.1
- matplotlib2tikz==0.7.6
- meshio==5.3.4
- mitsuba==3.0.1
- mrcfile==1.3.0
- multidict==6.0.2
- multipledispatch==0.6.0
- mypy-extensions==0.4.3
- nbclassic==0.3.7
- nbformat==5.2.0
- nestargs==0.5.0
- netcdf4==1.5.8
- networkx==2.6.3
- ninja==1.10.2.3
- notebook-shim==0.1.0
- numba==0.56.0
- nvidia-ml-py3==7.352.0
- oauthlib==3.2.0
- omegaconf==2.2.2
- open3d==0.15.2
- opencv-python==4.5.5.64
- openexr==1.3.7
- opt-einsum==3.3.0
- pandas==1.4.0
- pathtools==0.1.2
- pccm==0.3.4
- pip==22.3.1
- plyfile==0.7.4
- portalocker==2.5.1
- progressbar2==4.0.0
- promise==2.3
- protobuf==3.19.4
- psutil==5.9.0
- py-cpuinfo==8.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pybind11==2.10.0
- pydeprecate==0.3.1
- pyglet==1.5.23
- pykeops==1.5
- pymcubes==0.1.2
- pyopengl==3.1.0
- pyparsing==3.0.7
- pyquaternion==0.9.9
- pyrr==0.10.3
- pyrsistent==0.18.1
- python-swiftclient==4.0.0
- python-utils==3.3.3
- pytorch-lightning==1.5.1
- pytorch3d==0.3.0
- pytz==2021.3
- pywavelets==1.2.0
- pyyaml==6.0
- regex==2022.3.15
- requests==2.27.1
- requests-oauthlib==1.3.1
- requests-toolbelt==0.9.1
- rich==12.3.0
- rsa==4.8
- ruamel-yaml==0.17.20
- ruamel-yaml-clib==0.2.6
- scikit-image==0.19.1
- scikit-learn==1.0.2
- scipy==1.8.0
- seaborn==0.11.2
- semantic-version==2.9.0
- sentry-sdk==1.5.4
- sharedarray==3.2.1
- shortuuid==1.0.8
- simple-parsing==0.0.18
- simplejson==3.18.0
- sklearn==0.0
- smmap==5.0.0
- sniffio==1.2.0
- tabulate==0.8.9
- tensorboard==2.8.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorboardx==2.4.1
- tensorflow-gpu==2.8.0
- tensorflow-io-gcs-filesystem==0.25.0
- termcolor==1.1.0
- tf-estimator-nightly==2.8.0.dev2021122109
- tflearn==0.5.0
- tfrecord==1.14.1
- threadpoolctl==3.1.0
- tifffile==2022.2.2
- tikzplotlib==0.10.1
- tomlkit==0.10.0
- torchmetrics==0.7.2
- tqdm==4.62.3
- trimesh==3.10.1
- typing-extensions==4.2.0
- typing-inspect==0.7.1
- urllib3==1.26.8
- wandb==0.12.10
- webcolors==1.11.1
- websocket-client==1.2.3
- werkzeug==2.0.3
- wrapt==1.13.3
- wurlitzer==3.0.2
- yacs==0.1.8
- yarl==1.7.2
- yaspin==2.1.0

44
environment.yml Normal file
View file

@ -0,0 +1,44 @@
name: LION
channels:
- pytorch
- nvidia
- conda-forge
- defaults
- open3d-admin
- comet_ml
dependencies:
- pip
- python
- pytorch
- torchvision
- cudatoolkit
- matplotlib
- tqdm
- trimesh
- scipy
- scikit-learn
- ftfy
- tqdm
- regex
- tabulate
- h5py
- wandb
- pyyaml
- open3d
- cupy
- scikit-image
- loguru
- einops
- comet_ml
- diffusers
- pip:
- pykeops
- nestargs
- flatten_dict
- point-cloud-utils
- calmsize
# - git+https://github.com/openai/CLIP.git

View file

@ -160,7 +160,7 @@ class Prior(nn.Module):
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
# self.is_active = None
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
# self.is_active = None
# else:

2
script/train_vae.sh Normal file → Executable file
View file

@ -6,7 +6,7 @@ fi
DATA=" ddpm.input_dim 3 data.cates car "
NGPU=$1 #
num_node=1
BS=32
BS=6
total_bs=$(( $NGPU * $BS ))
if (( $total_bs > 128 )); then
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"

View file

@ -11,7 +11,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#include <THC/THC.h>
#define CHECK_INPUT(x)
@ -173,9 +172,9 @@ at::Tensor ApproxMatchForward(
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
@ -185,7 +184,7 @@ at::Tensor ApproxMatchForward(
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return match;
}
@ -260,9 +259,9 @@ at::Tensor MatchCostForward(
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
@ -271,7 +270,7 @@ at::Tensor MatchCostForward(
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return cost;
}
@ -377,9 +376,9 @@ std::vector<at::Tensor> MatchCostBackward(
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
@ -390,7 +389,7 @@ std::vector<at::Tensor> MatchCostBackward(
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2});
}

View file

@ -105,7 +105,7 @@ def main(args, config):
def get_args():
parser = argparse.ArgumentParser('encoder decoder examiner')
# experimental results
parser.add_argument('--exp_root', type=str, default='../exp',
parser.add_argument('--exp_root', type=str, default='./exp',
help='location of the results')
# parser.add_argument('--save', type=str, default='exp',
# help='id used for storing intermediate results')
@ -176,7 +176,7 @@ def get_args():
config.merge_from_list(args.opt)
# Create log_name
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', './exp/')
if config.exp_name == '' or config.exp_name == 'none':
config.hash = io_helper.hash_str('%s' % config) + 'h'
cfg_file_name = exp_helper.get_expname(config)

View file

@ -8,7 +8,6 @@
import os
import json
from comet_ml import Experiment, OfflineExperiment
## import open3d as o3d
import time
import numpy as np
import torch
@ -324,7 +323,7 @@ def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda',
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
results['jsd'] = jsd
msg = print_results(results, **print_kwargs)
# with open('../exp/eval_out.txt', 'a') as f:
# with open('./exp/eval_out.txt', 'a') as f:
# run_time = time.strftime('%m%d-%H%M-%S')
# f.write('<< date: %s >>\n' % run_time)
# f.write('%s\n%s\n' % (exp.url, msg))

View file

@ -491,7 +491,7 @@ def common_init(rank, seed, save_dir, comet_key=''):
if os.path.exists('.wandb_api'):
wb_args = json.load(open('.wandb_api', 'r'))
wb_dir = '../exp/wandb/' if not os.path.exists(
wb_dir = './exp/wandb/' if not os.path.exists(
'/workspace/result') else '/workspace/result/wandb/'
if not os.path.exists(wb_dir):
os.makedirs(wb_dir)
@ -1130,24 +1130,26 @@ def init_processes(rank, size, fn, args, config):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = args.master_address
os.environ['MASTER_PORT'] = '6020'
if args.num_proc_node == 1:
import socket
import errno
a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for p in range(6010, 6030):
location = (args.master_address, p) # "127.0.0.1", p)
try:
a_socket.bind((args.master_address, p))
logger.debug('set port as {}', p)
os.environ['MASTER_PORT'] = '%d' % p
a_socket.close()
break
except socket.error as e:
a = 0
# if e.errno == errno.EADDRINUSE:
# # logger.debug("Port {} is already in use", p)
# else:
# logger.debug(e)
logger.info('set MASTER_PORT: {}, MASTER_PORT: {}', os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
# if args.num_proc_node == 1: # try to solve the port occupied issue
# import socket
# import errno
# a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# for p in range(6010, 6030):
# location = (args.master_address, p) # "127.0.0.1", p)
# try:
# a_socket.bind((args.master_address, p))
# logger.debug('set port as {}', p)
# os.environ['MASTER_PORT'] = '%d' % p
# a_socket.close()
# break
# except socket.error as e:
# a = 0
# # if e.errno == errno.EADDRINUSE:
# # # logger.debug("Port {} is already in use", p)
# # else:
# # logger.debug(e)
logger.info('init_process: rank={}, world_size={}', rank, size)
torch.cuda.set_device(args.local_rank)