Compare commits
10 commits
b84169e724
...
403e04de8a
Author | SHA1 | Date | |
---|---|---|---|
403e04de8a | |||
8f37d2838b | |||
f79b9c697b | |||
7225eabb15 | |||
cb3c32d6b1 | |||
55902e13bc | |||
c67e7faf35 | |||
c0d5c6754b | |||
231ff196c7 | |||
0467d21990 |
12
README.md
12
README.md
|
@ -27,13 +27,17 @@
|
||||||
* Setup the environment
|
* Setup the environment
|
||||||
Install from conda file
|
Install from conda file
|
||||||
```
|
```
|
||||||
conda env create --name lion_env --file=env.yaml
|
mamba env create -f environment.yml
|
||||||
conda activate lion_env
|
# mamba env update -f environment.yml
|
||||||
|
conda activate LION
|
||||||
|
|
||||||
# Install some other packages
|
# Install some other packages (use proxy)
|
||||||
pip install git+https://github.com/openai/CLIP.git
|
pip install git+https://github.com/openai/CLIP.git
|
||||||
|
|
||||||
# build some packages first (optional)
|
# build some packages first (optional)
|
||||||
|
export CUDA_HOME=/usr/local/cuda # just in case rosetta cucks you
|
||||||
|
module load compilers
|
||||||
|
module load mpfr
|
||||||
python build_pkg.py
|
python build_pkg.py
|
||||||
```
|
```
|
||||||
Tested with conda version 22.9.0
|
Tested with conda version 22.9.0
|
||||||
|
@ -44,7 +48,7 @@
|
||||||
|
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. (Note: the checkpoint is not released yet, the files loaded in the `demo.py` file is not available at this point)
|
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. Download checkpoints from [HuggingFace Hub](https://huggingface.co/xiaohui2022/lion_ckpt)
|
||||||
|
|
||||||
## Released checkpoint and samples
|
## Released checkpoint and samples
|
||||||
* will be release soon
|
* will be release soon
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ../exp/tmp/2022_0407_0300_45.sh
|
bash_name: ./exp/tmp/2022_0407_0300_45.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
log_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
log_name: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
log_name: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
save_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ../exp/tmp/2022_0407_1347_21.sh
|
bash_name: ./exp/tmp/2022_0407_1347_21.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
log_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
log_name: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
log_name: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
save_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
@ -195,7 +195,7 @@ sde:
|
||||||
update_q_ema: false
|
update_q_ema: false
|
||||||
use_adam: true
|
use_adam: true
|
||||||
use_adamax: false
|
use_adamax: false
|
||||||
vae_checkpoint: ../exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
vae_checkpoint: ./exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
||||||
warmup_epochs: 20
|
warmup_epochs: 20
|
||||||
weight_decay: 0.0003
|
weight_decay: 0.0003
|
||||||
weight_decay_norm_dae: 0.0
|
weight_decay_norm_dae: 0.0
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ../exp/tmp/2022_0416_1418_42.sh
|
bash_name: ./exp/tmp/2022_0416_1418_42.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
log_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
log_name: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
log_name: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
save_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
|
|
||||||
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
||||||
import os
|
import os
|
||||||
import open3d as o3d
|
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -174,7 +173,7 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
# obj_fname = os.path.join(sub_path, x)
|
# obj_fname = os.path.join(sub_path, x)
|
||||||
if self.clip_forge_enable:
|
if self.clip_forge_enable:
|
||||||
synset_id = subd
|
synset_id = subd
|
||||||
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
|
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016')
|
||||||
|
|
||||||
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
||||||
#if not (os.path.exists(render_img_path)): continue
|
#if not (os.path.exists(render_img_path)): continue
|
||||||
|
|
311
env.yaml
311
env.yaml
|
@ -1,311 +0,0 @@
|
||||||
name: lion_env
|
|
||||||
channels:
|
|
||||||
- pytorch
|
|
||||||
- nvidia
|
|
||||||
- anaconda
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex=0.1=main
|
|
||||||
- _openmp_mutex=4.5=1_gnu
|
|
||||||
- argon2-cffi=20.1.0=py38h27cfd23_1
|
|
||||||
- async_generator=1.10=pyhd3eb1b0_0
|
|
||||||
- attrs=21.4.0=pyhd3eb1b0_0
|
|
||||||
- backcall=0.2.0=pyhd3eb1b0_0
|
|
||||||
- blas=1.0=mkl
|
|
||||||
- bleach=4.1.0=pyhd3eb1b0_0
|
|
||||||
- bzip2=1.0.8=h7b6447c_0
|
|
||||||
- ca-certificates=2020.10.14=0
|
|
||||||
- certifi=2020.6.20=py38_0
|
|
||||||
- cffi=1.15.0=py38hd667e15_1
|
|
||||||
- cmake=3.18.2=ha30ef3c_0
|
|
||||||
- cudatoolkit=11.1.74=h6bb024c_0
|
|
||||||
- debugpy=1.5.1=py38h295c915_0
|
|
||||||
- decorator=5.1.1=pyhd3eb1b0_0
|
|
||||||
- defusedxml=0.7.1=pyhd3eb1b0_0
|
|
||||||
- entrypoints=0.3=py38_0
|
|
||||||
- expat=2.2.10=he6710b0_2
|
|
||||||
- ffmpeg=4.3=hf484d3e_0
|
|
||||||
- freetype=2.11.0=h70c0345_0
|
|
||||||
- giflib=5.2.1=h7b6447c_0
|
|
||||||
- gmp=6.2.1=h2531618_2
|
|
||||||
- gnutls=3.6.15=he1e5248_0
|
|
||||||
- importlib_metadata=4.8.2=hd3eb1b0_0
|
|
||||||
- intel-openmp=2021.4.0=h06a4308_3561
|
|
||||||
- ipykernel=6.4.1=py38h06a4308_1
|
|
||||||
- ipython=7.31.1=py38h06a4308_0
|
|
||||||
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
|
||||||
- ipywidgets=7.6.5=pyhd3eb1b0_1
|
|
||||||
- jedi=0.18.1=py38h06a4308_0
|
|
||||||
- jpeg=9d=h7f8727e_0
|
|
||||||
- jupyter_client=7.1.2=pyhd3eb1b0_0
|
|
||||||
- jupyter_core=4.9.1=py38h06a4308_0
|
|
||||||
- jupyterlab_pygments=0.1.2=py_0
|
|
||||||
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
|
|
||||||
- krb5=1.18.2=h173b8e3_0
|
|
||||||
- lame=3.100=h7b6447c_0
|
|
||||||
- lcms2=2.12=h3be6417_0
|
|
||||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
|
||||||
- libcurl=7.71.1=h20c2e04_1
|
|
||||||
- libedit=3.1.20191231=h14c3975_1
|
|
||||||
- libffi=3.3=he6710b0_2
|
|
||||||
- libgcc-ng=9.3.0=h5101ec6_17
|
|
||||||
- libgomp=9.3.0=h5101ec6_17
|
|
||||||
- libiconv=1.15=h63c8f33_5
|
|
||||||
- libidn2=2.3.2=h7f8727e_0
|
|
||||||
- libpng=1.6.37=hbc83047_0
|
|
||||||
- libsodium=1.0.18=h7b6447c_0
|
|
||||||
- libssh2=1.9.0=h1ba5d50_1
|
|
||||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
|
||||||
- libtasn1=4.16.0=h27cfd23_0
|
|
||||||
- libtiff=4.2.0=h85742a9_0
|
|
||||||
- libunistring=0.9.10=h27cfd23_0
|
|
||||||
- libuv=1.40.0=h7b6447c_0
|
|
||||||
- libwebp=1.2.0=h89dd481_0
|
|
||||||
- libwebp-base=1.2.0=h27cfd23_0
|
|
||||||
- lz4-c=1.9.3=h295c915_1
|
|
||||||
- markupsafe=2.0.1=py38h27cfd23_0
|
|
||||||
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
|
||||||
- mistune=0.8.4=py38h7b6447c_1000
|
|
||||||
- mkl=2021.4.0=h06a4308_640
|
|
||||||
- mkl-service=2.4.0=py38h7f8727e_0
|
|
||||||
- mkl_fft=1.3.1=py38hd3c417c_0
|
|
||||||
- mkl_random=1.2.2=py38h51133e4_0
|
|
||||||
- nbclient=0.5.3=pyhd3eb1b0_0
|
|
||||||
- nbconvert=6.3.0=py38h06a4308_0
|
|
||||||
- ncurses=6.3=h7f8727e_2
|
|
||||||
- nest-asyncio=1.5.1=pyhd3eb1b0_0
|
|
||||||
- nettle=3.7.3=hbbd107a_1
|
|
||||||
- notebook=6.4.6=py38h06a4308_0
|
|
||||||
- numpy=1.21.2=py38h20f2e39_0
|
|
||||||
- numpy-base=1.21.2=py38h79a1101_0
|
|
||||||
- olefile=0.46=pyhd3eb1b0_0
|
|
||||||
- openh264=2.1.1=h4ff587b_0
|
|
||||||
- openssl=1.1.1m=h7f8727e_0
|
|
||||||
- packaging=21.3=pyhd3eb1b0_0
|
|
||||||
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
|
||||||
- parso=0.8.3=pyhd3eb1b0_0
|
|
||||||
- pexpect=4.8.0=pyhd3eb1b0_3
|
|
||||||
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
|
||||||
- pillow=8.4.0=py38h5aabda8_0
|
|
||||||
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
|
||||||
- prometheus_client=0.13.1=pyhd3eb1b0_0
|
|
||||||
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
|
||||||
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
|
||||||
- pycparser=2.21=pyhd3eb1b0_0
|
|
||||||
- pygments=2.11.2=pyhd3eb1b0_0
|
|
||||||
- python=3.8.12=h12debd9_0
|
|
||||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
|
||||||
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
|
|
||||||
- python_abi=3.8=2_cp38
|
|
||||||
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
|
|
||||||
- pytorch-mutex=1.0=cuda
|
|
||||||
- pyzmq=22.3.0=py38h295c915_2
|
|
||||||
- readline=8.1.2=h7f8727e_1
|
|
||||||
- rhash=1.4.0=h1ba5d50_0
|
|
||||||
- send2trash=1.8.0=pyhd3eb1b0_1
|
|
||||||
- six=1.16.0=pyhd3eb1b0_0
|
|
||||||
- sqlite=3.37.2=hc218d9a_0
|
|
||||||
- terminado=0.9.4=py38h06a4308_0
|
|
||||||
- testpath=0.5.0=pyhd3eb1b0_0
|
|
||||||
- tk=8.6.11=h1ccaba5_0
|
|
||||||
- torchaudio=0.10.2=py38_cu111
|
|
||||||
- torchvision=0.11.3=py38_cu111
|
|
||||||
- tornado=6.1=py38h27cfd23_0
|
|
||||||
- traitlets=5.1.1=pyhd3eb1b0_0
|
|
||||||
- wcwidth=0.2.5=pyhd3eb1b0_0
|
|
||||||
- webencodings=0.5.1=py38_1
|
|
||||||
- wheel=0.37.1=pyhd3eb1b0_0
|
|
||||||
- widgetsnbextension=3.5.1=py38_0
|
|
||||||
- xz=5.2.5=h7b6447c_0
|
|
||||||
- zeromq=4.3.4=h2531618_0
|
|
||||||
- zipp=3.7.0=pyhd3eb1b0_0
|
|
||||||
- zlib=1.2.11=h7f8727e_4
|
|
||||||
- zstd=1.4.5=h9ceee32_0
|
|
||||||
- pip:
|
|
||||||
- about-time==3.1.1
|
|
||||||
- absl-py==1.0.0
|
|
||||||
- addict==2.4.0
|
|
||||||
- aiohttp==3.8.1
|
|
||||||
- aiosignal==1.2.0
|
|
||||||
- alive-progress==2.2.0
|
|
||||||
- antlr4-python3-runtime==4.9.3
|
|
||||||
- anyio==3.5.0
|
|
||||||
- astunparse==1.6.3
|
|
||||||
- async-timeout==4.0.2
|
|
||||||
- babel==2.9.1
|
|
||||||
- cachetools==5.0.0
|
|
||||||
- calmsize==0.1.3
|
|
||||||
- ccimport==0.3.7
|
|
||||||
- cftime==1.6.0
|
|
||||||
- charset-normalizer==2.0.11
|
|
||||||
- click==8.0.3
|
|
||||||
- colorama==0.4.4
|
|
||||||
- comet-ml==3.31.21
|
|
||||||
- commonmark==0.9.1
|
|
||||||
- configobj==5.0.6
|
|
||||||
- crc32c==2.2.post0
|
|
||||||
- cumm-cu111==0.2.8
|
|
||||||
- cupy-cuda111==10.2.0
|
|
||||||
- cycler==0.11.0
|
|
||||||
- cython==0.29.20
|
|
||||||
- dataclasses==0.6
|
|
||||||
- deepspeed==0.6.5
|
|
||||||
- deprecation==2.1.0
|
|
||||||
- diffusers==0.11.1
|
|
||||||
- docker-pycreds==0.4.0
|
|
||||||
- drjit==0.2.1
|
|
||||||
- dulwich==0.20.32
|
|
||||||
- easydict==1.9
|
|
||||||
- einops==0.4.0
|
|
||||||
- everett==3.0.0
|
|
||||||
- fastrlock==0.8
|
|
||||||
- filelock==3.9.0
|
|
||||||
- fire==0.4.0
|
|
||||||
- flatbuffers==2.0
|
|
||||||
- flatten-dict==0.4.2
|
|
||||||
- fonttools==4.29.1
|
|
||||||
- freetype-py==2.3.0
|
|
||||||
- frozenlist==1.3.0
|
|
||||||
- fsspec==2022.2.0
|
|
||||||
- ftfy==6.1.1
|
|
||||||
- future==0.18.2
|
|
||||||
- fvcore==0.1.5.post20220512
|
|
||||||
- gast==0.5.3
|
|
||||||
- gitdb==4.0.9
|
|
||||||
- gitpython==3.1.26
|
|
||||||
- google-auth==2.6.0
|
|
||||||
- google-auth-oauthlib==0.4.6
|
|
||||||
- google-pasta==0.2.0
|
|
||||||
- grapheme==0.6.0
|
|
||||||
- grpcio==1.43.0
|
|
||||||
- h5py==3.6.0
|
|
||||||
- hjson==3.0.2
|
|
||||||
- huggingface-hub==0.11.1
|
|
||||||
- idna==3.3
|
|
||||||
- imageio==2.15.0
|
|
||||||
- imageio-ffmpeg==0.4.5
|
|
||||||
- importlib-metadata==4.10.1
|
|
||||||
- importlib-resources==5.4.0
|
|
||||||
- iopath==0.1.10
|
|
||||||
- jinja2==3.1.1
|
|
||||||
- joblib==1.1.0
|
|
||||||
- json5==0.9.6
|
|
||||||
- jsonschema==4.4.0
|
|
||||||
- jupyter-packaging==0.12.0
|
|
||||||
- jupyter-server==1.15.6
|
|
||||||
- jupyterlab==3.3.2
|
|
||||||
- jupyterlab-server==2.11.2
|
|
||||||
- keras==2.8.0
|
|
||||||
- keras-preprocessing==1.1.2
|
|
||||||
- kiwisolver==1.3.2
|
|
||||||
- kornia==0.6.6
|
|
||||||
- lark==1.1.2
|
|
||||||
- libclang==14.0.1
|
|
||||||
- llvmlite==0.39.0
|
|
||||||
- loguru==0.6.0
|
|
||||||
- markdown==3.3.6
|
|
||||||
- matplotlib==3.5.1
|
|
||||||
- matplotlib2tikz==0.7.6
|
|
||||||
- meshio==5.3.4
|
|
||||||
- mitsuba==3.0.1
|
|
||||||
- mrcfile==1.3.0
|
|
||||||
- multidict==6.0.2
|
|
||||||
- multipledispatch==0.6.0
|
|
||||||
- mypy-extensions==0.4.3
|
|
||||||
- nbclassic==0.3.7
|
|
||||||
- nbformat==5.2.0
|
|
||||||
- nestargs==0.5.0
|
|
||||||
- netcdf4==1.5.8
|
|
||||||
- networkx==2.6.3
|
|
||||||
- ninja==1.10.2.3
|
|
||||||
- notebook-shim==0.1.0
|
|
||||||
- numba==0.56.0
|
|
||||||
- nvidia-ml-py3==7.352.0
|
|
||||||
- oauthlib==3.2.0
|
|
||||||
- omegaconf==2.2.2
|
|
||||||
- open3d==0.15.2
|
|
||||||
- opencv-python==4.5.5.64
|
|
||||||
- openexr==1.3.7
|
|
||||||
- opt-einsum==3.3.0
|
|
||||||
- pandas==1.4.0
|
|
||||||
- pathtools==0.1.2
|
|
||||||
- pccm==0.3.4
|
|
||||||
- pip==22.3.1
|
|
||||||
- plyfile==0.7.4
|
|
||||||
- portalocker==2.5.1
|
|
||||||
- progressbar2==4.0.0
|
|
||||||
- promise==2.3
|
|
||||||
- protobuf==3.19.4
|
|
||||||
- psutil==5.9.0
|
|
||||||
- py-cpuinfo==8.0.0
|
|
||||||
- pyasn1==0.4.8
|
|
||||||
- pyasn1-modules==0.2.8
|
|
||||||
- pybind11==2.10.0
|
|
||||||
- pydeprecate==0.3.1
|
|
||||||
- pyglet==1.5.23
|
|
||||||
- pykeops==1.5
|
|
||||||
- pymcubes==0.1.2
|
|
||||||
- pyopengl==3.1.0
|
|
||||||
- pyparsing==3.0.7
|
|
||||||
- pyquaternion==0.9.9
|
|
||||||
- pyrr==0.10.3
|
|
||||||
- pyrsistent==0.18.1
|
|
||||||
- python-swiftclient==4.0.0
|
|
||||||
- python-utils==3.3.3
|
|
||||||
- pytorch-lightning==1.5.1
|
|
||||||
- pytorch3d==0.3.0
|
|
||||||
- pytz==2021.3
|
|
||||||
- pywavelets==1.2.0
|
|
||||||
- pyyaml==6.0
|
|
||||||
- regex==2022.3.15
|
|
||||||
- requests==2.27.1
|
|
||||||
- requests-oauthlib==1.3.1
|
|
||||||
- requests-toolbelt==0.9.1
|
|
||||||
- rich==12.3.0
|
|
||||||
- rsa==4.8
|
|
||||||
- ruamel-yaml==0.17.20
|
|
||||||
- ruamel-yaml-clib==0.2.6
|
|
||||||
- scikit-image==0.19.1
|
|
||||||
- scikit-learn==1.0.2
|
|
||||||
- scipy==1.8.0
|
|
||||||
- seaborn==0.11.2
|
|
||||||
- semantic-version==2.9.0
|
|
||||||
- sentry-sdk==1.5.4
|
|
||||||
- sharedarray==3.2.1
|
|
||||||
- shortuuid==1.0.8
|
|
||||||
- simple-parsing==0.0.18
|
|
||||||
- simplejson==3.18.0
|
|
||||||
- sklearn==0.0
|
|
||||||
- smmap==5.0.0
|
|
||||||
- sniffio==1.2.0
|
|
||||||
- tabulate==0.8.9
|
|
||||||
- tensorboard==2.8.0
|
|
||||||
- tensorboard-data-server==0.6.1
|
|
||||||
- tensorboard-plugin-wit==1.8.1
|
|
||||||
- tensorboardx==2.4.1
|
|
||||||
- tensorflow-gpu==2.8.0
|
|
||||||
- tensorflow-io-gcs-filesystem==0.25.0
|
|
||||||
- termcolor==1.1.0
|
|
||||||
- tf-estimator-nightly==2.8.0.dev2021122109
|
|
||||||
- tflearn==0.5.0
|
|
||||||
- tfrecord==1.14.1
|
|
||||||
- threadpoolctl==3.1.0
|
|
||||||
- tifffile==2022.2.2
|
|
||||||
- tikzplotlib==0.10.1
|
|
||||||
- tomlkit==0.10.0
|
|
||||||
- torchmetrics==0.7.2
|
|
||||||
- tqdm==4.62.3
|
|
||||||
- trimesh==3.10.1
|
|
||||||
- typing-extensions==4.2.0
|
|
||||||
- typing-inspect==0.7.1
|
|
||||||
- urllib3==1.26.8
|
|
||||||
- wandb==0.12.10
|
|
||||||
- webcolors==1.11.1
|
|
||||||
- websocket-client==1.2.3
|
|
||||||
- werkzeug==2.0.3
|
|
||||||
- wrapt==1.13.3
|
|
||||||
- wurlitzer==3.0.2
|
|
||||||
- yacs==0.1.8
|
|
||||||
- yarl==1.7.2
|
|
||||||
- yaspin==2.1.0
|
|
44
environment.yml
Normal file
44
environment.yml
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
name: LION
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
- open3d-admin
|
||||||
|
- comet_ml
|
||||||
|
dependencies:
|
||||||
|
- pip
|
||||||
|
|
||||||
|
- python
|
||||||
|
- pytorch
|
||||||
|
- torchvision
|
||||||
|
- cudatoolkit
|
||||||
|
- matplotlib
|
||||||
|
- tqdm
|
||||||
|
- trimesh
|
||||||
|
- scipy
|
||||||
|
- scikit-learn
|
||||||
|
|
||||||
|
- ftfy
|
||||||
|
- tqdm
|
||||||
|
- regex
|
||||||
|
|
||||||
|
- tabulate
|
||||||
|
- h5py
|
||||||
|
- wandb
|
||||||
|
- pyyaml
|
||||||
|
- open3d
|
||||||
|
- cupy
|
||||||
|
- scikit-image
|
||||||
|
- loguru
|
||||||
|
- einops
|
||||||
|
- comet_ml
|
||||||
|
- diffusers
|
||||||
|
|
||||||
|
- pip:
|
||||||
|
- pykeops
|
||||||
|
- nestargs
|
||||||
|
- flatten_dict
|
||||||
|
- point-cloud-utils
|
||||||
|
- calmsize
|
||||||
|
# - git+https://github.com/openai/CLIP.git
|
|
@ -160,7 +160,7 @@ class Prior(nn.Module):
|
||||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
||||||
# self.is_active = None
|
# self.is_active = None
|
||||||
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
||||||
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
||||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
||||||
# self.is_active = None
|
# self.is_active = None
|
||||||
# else:
|
# else:
|
||||||
|
|
2
script/train_vae.sh
Normal file → Executable file
2
script/train_vae.sh
Normal file → Executable file
|
@ -6,7 +6,7 @@ fi
|
||||||
DATA=" ddpm.input_dim 3 data.cates car "
|
DATA=" ddpm.input_dim 3 data.cates car "
|
||||||
NGPU=$1 #
|
NGPU=$1 #
|
||||||
num_node=1
|
num_node=1
|
||||||
BS=32
|
BS=6
|
||||||
total_bs=$(( $NGPU * $BS ))
|
total_bs=$(( $NGPU * $BS ))
|
||||||
if (( $total_bs > 128 )); then
|
if (( $total_bs > 128 )); then
|
||||||
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
||||||
|
|
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
|
@ -11,7 +11,6 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||||
#include <THC/THC.h>
|
|
||||||
|
|
||||||
#define CHECK_INPUT(x)
|
#define CHECK_INPUT(x)
|
||||||
|
|
||||||
|
@ -173,9 +172,9 @@ at::Tensor ApproxMatchForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -185,7 +184,7 @@ at::Tensor ApproxMatchForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return match;
|
return match;
|
||||||
}
|
}
|
||||||
|
@ -260,9 +259,9 @@ at::Tensor MatchCostForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -271,7 +270,7 @@ at::Tensor MatchCostForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
@ -377,9 +376,9 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -390,7 +389,7 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return std::vector<at::Tensor>({grad1, grad2});
|
return std::vector<at::Tensor>({grad1, grad2});
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,7 +105,7 @@ def main(args, config):
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser('encoder decoder examiner')
|
parser = argparse.ArgumentParser('encoder decoder examiner')
|
||||||
# experimental results
|
# experimental results
|
||||||
parser.add_argument('--exp_root', type=str, default='../exp',
|
parser.add_argument('--exp_root', type=str, default='./exp',
|
||||||
help='location of the results')
|
help='location of the results')
|
||||||
# parser.add_argument('--save', type=str, default='exp',
|
# parser.add_argument('--save', type=str, default='exp',
|
||||||
# help='id used for storing intermediate results')
|
# help='id used for storing intermediate results')
|
||||||
|
@ -176,7 +176,7 @@ def get_args():
|
||||||
config.merge_from_list(args.opt)
|
config.merge_from_list(args.opt)
|
||||||
|
|
||||||
# Create log_name
|
# Create log_name
|
||||||
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
|
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', './exp/')
|
||||||
if config.exp_name == '' or config.exp_name == 'none':
|
if config.exp_name == '' or config.exp_name == 'none':
|
||||||
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
||||||
cfg_file_name = exp_helper.get_expname(config)
|
cfg_file_name = exp_helper.get_expname(config)
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from comet_ml import Experiment, OfflineExperiment
|
from comet_ml import Experiment, OfflineExperiment
|
||||||
## import open3d as o3d
|
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -324,7 +323,7 @@ def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda',
|
||||||
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
|
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
|
||||||
results['jsd'] = jsd
|
results['jsd'] = jsd
|
||||||
msg = print_results(results, **print_kwargs)
|
msg = print_results(results, **print_kwargs)
|
||||||
# with open('../exp/eval_out.txt', 'a') as f:
|
# with open('./exp/eval_out.txt', 'a') as f:
|
||||||
# run_time = time.strftime('%m%d-%H%M-%S')
|
# run_time = time.strftime('%m%d-%H%M-%S')
|
||||||
# f.write('<< date: %s >>\n' % run_time)
|
# f.write('<< date: %s >>\n' % run_time)
|
||||||
# f.write('%s\n%s\n' % (exp.url, msg))
|
# f.write('%s\n%s\n' % (exp.url, msg))
|
||||||
|
|
|
@ -491,7 +491,7 @@ def common_init(rank, seed, save_dir, comet_key=''):
|
||||||
|
|
||||||
if os.path.exists('.wandb_api'):
|
if os.path.exists('.wandb_api'):
|
||||||
wb_args = json.load(open('.wandb_api', 'r'))
|
wb_args = json.load(open('.wandb_api', 'r'))
|
||||||
wb_dir = '../exp/wandb/' if not os.path.exists(
|
wb_dir = './exp/wandb/' if not os.path.exists(
|
||||||
'/workspace/result') else '/workspace/result/wandb/'
|
'/workspace/result') else '/workspace/result/wandb/'
|
||||||
if not os.path.exists(wb_dir):
|
if not os.path.exists(wb_dir):
|
||||||
os.makedirs(wb_dir)
|
os.makedirs(wb_dir)
|
||||||
|
@ -1130,24 +1130,26 @@ def init_processes(rank, size, fn, args, config):
|
||||||
""" Initialize the distributed environment. """
|
""" Initialize the distributed environment. """
|
||||||
os.environ['MASTER_ADDR'] = args.master_address
|
os.environ['MASTER_ADDR'] = args.master_address
|
||||||
os.environ['MASTER_PORT'] = '6020'
|
os.environ['MASTER_PORT'] = '6020'
|
||||||
if args.num_proc_node == 1:
|
logger.info('set MASTER_PORT: {}, MASTER_PORT: {}', os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
||||||
import socket
|
|
||||||
import errno
|
# if args.num_proc_node == 1: # try to solve the port occupied issue
|
||||||
a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
# import socket
|
||||||
for p in range(6010, 6030):
|
# import errno
|
||||||
location = (args.master_address, p) # "127.0.0.1", p)
|
# a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
try:
|
# for p in range(6010, 6030):
|
||||||
a_socket.bind((args.master_address, p))
|
# location = (args.master_address, p) # "127.0.0.1", p)
|
||||||
logger.debug('set port as {}', p)
|
# try:
|
||||||
os.environ['MASTER_PORT'] = '%d' % p
|
# a_socket.bind((args.master_address, p))
|
||||||
a_socket.close()
|
# logger.debug('set port as {}', p)
|
||||||
break
|
# os.environ['MASTER_PORT'] = '%d' % p
|
||||||
except socket.error as e:
|
# a_socket.close()
|
||||||
a = 0
|
# break
|
||||||
# if e.errno == errno.EADDRINUSE:
|
# except socket.error as e:
|
||||||
# # logger.debug("Port {} is already in use", p)
|
# a = 0
|
||||||
# else:
|
# # if e.errno == errno.EADDRINUSE:
|
||||||
# logger.debug(e)
|
# # # logger.debug("Port {} is already in use", p)
|
||||||
|
# # else:
|
||||||
|
# # logger.debug(e)
|
||||||
|
|
||||||
logger.info('init_process: rank={}, world_size={}', rank, size)
|
logger.info('init_process: rank={}, world_size={}', rank, size)
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
|
Loading…
Reference in a new issue