Compare commits
No commits in common. "403e04de8a9b2395454fa5c67bafee7fa1115c50" and "b84169e724478f136a9ee0eaa541218086bd5f2e" have entirely different histories.
403e04de8a
...
b84169e724
12
README.md
12
README.md
|
@ -27,17 +27,13 @@
|
||||||
* Setup the environment
|
* Setup the environment
|
||||||
Install from conda file
|
Install from conda file
|
||||||
```
|
```
|
||||||
mamba env create -f environment.yml
|
conda env create --name lion_env --file=env.yaml
|
||||||
# mamba env update -f environment.yml
|
conda activate lion_env
|
||||||
conda activate LION
|
|
||||||
|
|
||||||
# Install some other packages (use proxy)
|
# Install some other packages
|
||||||
pip install git+https://github.com/openai/CLIP.git
|
pip install git+https://github.com/openai/CLIP.git
|
||||||
|
|
||||||
# build some packages first (optional)
|
# build some packages first (optional)
|
||||||
export CUDA_HOME=/usr/local/cuda # just in case rosetta cucks you
|
|
||||||
module load compilers
|
|
||||||
module load mpfr
|
|
||||||
python build_pkg.py
|
python build_pkg.py
|
||||||
```
|
```
|
||||||
Tested with conda version 22.9.0
|
Tested with conda version 22.9.0
|
||||||
|
@ -48,7 +44,7 @@
|
||||||
|
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. Download checkpoints from [HuggingFace Hub](https://huggingface.co/xiaohui2022/lion_ckpt)
|
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. (Note: the checkpoint is not released yet, the files loaded in the `demo.py` file is not available at this point)
|
||||||
|
|
||||||
## Released checkpoint and samples
|
## Released checkpoint and samples
|
||||||
* will be release soon
|
* will be release soon
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ./exp/tmp/2022_0407_0300_45.sh
|
bash_name: ../exp/tmp/2022_0407_0300_45.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
log_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
log_name: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
log_name: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
save_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ./exp/tmp/2022_0407_1347_21.sh
|
bash_name: ../exp/tmp/2022_0407_1347_21.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
log_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
log_name: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
log_name: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
save_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
@ -195,7 +195,7 @@ sde:
|
||||||
update_q_ema: false
|
update_q_ema: false
|
||||||
use_adam: true
|
use_adam: true
|
||||||
use_adamax: false
|
use_adamax: false
|
||||||
vae_checkpoint: ./exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
vae_checkpoint: ../exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
||||||
warmup_epochs: 20
|
warmup_epochs: 20
|
||||||
weight_decay: 0.0003
|
weight_decay: 0.0003
|
||||||
weight_decay_norm_dae: 0.0
|
weight_decay_norm_dae: 0.0
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
bash_name: ./exp/tmp/2022_0416_1418_42.sh
|
bash_name: ../exp/tmp/2022_0416_1418_42.sh
|
||||||
clipforge:
|
clipforge:
|
||||||
clip_model: ViT-B/32
|
clip_model: ViT-B/32
|
||||||
enable: 0
|
enable: 0
|
||||||
|
@ -105,13 +105,13 @@ latent_pts:
|
||||||
weight_kl_feat: 1.0
|
weight_kl_feat: 1.0
|
||||||
weight_kl_glb: 1.0
|
weight_kl_glb: 1.0
|
||||||
weight_kl_pt: 1.0
|
weight_kl_pt: 1.0
|
||||||
log_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
log_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
log_name: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
log_name: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
model_config: default
|
model_config: default
|
||||||
ngpu: 8
|
ngpu: 8
|
||||||
num_ref: 0
|
num_ref: 0
|
||||||
num_val_samples: 24
|
num_val_samples: 24
|
||||||
save_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
save_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||||
sde:
|
sde:
|
||||||
attn_mhead: 0
|
attn_mhead: 0
|
||||||
attn_mhead_local: -1
|
attn_mhead_local: -1
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
||||||
import os
|
import os
|
||||||
|
import open3d as o3d
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -173,7 +174,7 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
# obj_fname = os.path.join(sub_path, x)
|
# obj_fname = os.path.join(sub_path, x)
|
||||||
if self.clip_forge_enable:
|
if self.clip_forge_enable:
|
||||||
synset_id = subd
|
synset_id = subd
|
||||||
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016')
|
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
|
||||||
|
|
||||||
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
||||||
#if not (os.path.exists(render_img_path)): continue
|
#if not (os.path.exists(render_img_path)): continue
|
||||||
|
|
311
env.yaml
Normal file
311
env.yaml
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
name: lion_env
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- anaconda
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=4.5=1_gnu
|
||||||
|
- argon2-cffi=20.1.0=py38h27cfd23_1
|
||||||
|
- async_generator=1.10=pyhd3eb1b0_0
|
||||||
|
- attrs=21.4.0=pyhd3eb1b0_0
|
||||||
|
- backcall=0.2.0=pyhd3eb1b0_0
|
||||||
|
- blas=1.0=mkl
|
||||||
|
- bleach=4.1.0=pyhd3eb1b0_0
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2020.10.14=0
|
||||||
|
- certifi=2020.6.20=py38_0
|
||||||
|
- cffi=1.15.0=py38hd667e15_1
|
||||||
|
- cmake=3.18.2=ha30ef3c_0
|
||||||
|
- cudatoolkit=11.1.74=h6bb024c_0
|
||||||
|
- debugpy=1.5.1=py38h295c915_0
|
||||||
|
- decorator=5.1.1=pyhd3eb1b0_0
|
||||||
|
- defusedxml=0.7.1=pyhd3eb1b0_0
|
||||||
|
- entrypoints=0.3=py38_0
|
||||||
|
- expat=2.2.10=he6710b0_2
|
||||||
|
- ffmpeg=4.3=hf484d3e_0
|
||||||
|
- freetype=2.11.0=h70c0345_0
|
||||||
|
- giflib=5.2.1=h7b6447c_0
|
||||||
|
- gmp=6.2.1=h2531618_2
|
||||||
|
- gnutls=3.6.15=he1e5248_0
|
||||||
|
- importlib_metadata=4.8.2=hd3eb1b0_0
|
||||||
|
- intel-openmp=2021.4.0=h06a4308_3561
|
||||||
|
- ipykernel=6.4.1=py38h06a4308_1
|
||||||
|
- ipython=7.31.1=py38h06a4308_0
|
||||||
|
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
||||||
|
- ipywidgets=7.6.5=pyhd3eb1b0_1
|
||||||
|
- jedi=0.18.1=py38h06a4308_0
|
||||||
|
- jpeg=9d=h7f8727e_0
|
||||||
|
- jupyter_client=7.1.2=pyhd3eb1b0_0
|
||||||
|
- jupyter_core=4.9.1=py38h06a4308_0
|
||||||
|
- jupyterlab_pygments=0.1.2=py_0
|
||||||
|
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
|
||||||
|
- krb5=1.18.2=h173b8e3_0
|
||||||
|
- lame=3.100=h7b6447c_0
|
||||||
|
- lcms2=2.12=h3be6417_0
|
||||||
|
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||||
|
- libcurl=7.71.1=h20c2e04_1
|
||||||
|
- libedit=3.1.20191231=h14c3975_1
|
||||||
|
- libffi=3.3=he6710b0_2
|
||||||
|
- libgcc-ng=9.3.0=h5101ec6_17
|
||||||
|
- libgomp=9.3.0=h5101ec6_17
|
||||||
|
- libiconv=1.15=h63c8f33_5
|
||||||
|
- libidn2=2.3.2=h7f8727e_0
|
||||||
|
- libpng=1.6.37=hbc83047_0
|
||||||
|
- libsodium=1.0.18=h7b6447c_0
|
||||||
|
- libssh2=1.9.0=h1ba5d50_1
|
||||||
|
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||||
|
- libtasn1=4.16.0=h27cfd23_0
|
||||||
|
- libtiff=4.2.0=h85742a9_0
|
||||||
|
- libunistring=0.9.10=h27cfd23_0
|
||||||
|
- libuv=1.40.0=h7b6447c_0
|
||||||
|
- libwebp=1.2.0=h89dd481_0
|
||||||
|
- libwebp-base=1.2.0=h27cfd23_0
|
||||||
|
- lz4-c=1.9.3=h295c915_1
|
||||||
|
- markupsafe=2.0.1=py38h27cfd23_0
|
||||||
|
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
||||||
|
- mistune=0.8.4=py38h7b6447c_1000
|
||||||
|
- mkl=2021.4.0=h06a4308_640
|
||||||
|
- mkl-service=2.4.0=py38h7f8727e_0
|
||||||
|
- mkl_fft=1.3.1=py38hd3c417c_0
|
||||||
|
- mkl_random=1.2.2=py38h51133e4_0
|
||||||
|
- nbclient=0.5.3=pyhd3eb1b0_0
|
||||||
|
- nbconvert=6.3.0=py38h06a4308_0
|
||||||
|
- ncurses=6.3=h7f8727e_2
|
||||||
|
- nest-asyncio=1.5.1=pyhd3eb1b0_0
|
||||||
|
- nettle=3.7.3=hbbd107a_1
|
||||||
|
- notebook=6.4.6=py38h06a4308_0
|
||||||
|
- numpy=1.21.2=py38h20f2e39_0
|
||||||
|
- numpy-base=1.21.2=py38h79a1101_0
|
||||||
|
- olefile=0.46=pyhd3eb1b0_0
|
||||||
|
- openh264=2.1.1=h4ff587b_0
|
||||||
|
- openssl=1.1.1m=h7f8727e_0
|
||||||
|
- packaging=21.3=pyhd3eb1b0_0
|
||||||
|
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
||||||
|
- parso=0.8.3=pyhd3eb1b0_0
|
||||||
|
- pexpect=4.8.0=pyhd3eb1b0_3
|
||||||
|
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
||||||
|
- pillow=8.4.0=py38h5aabda8_0
|
||||||
|
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
||||||
|
- prometheus_client=0.13.1=pyhd3eb1b0_0
|
||||||
|
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
||||||
|
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
||||||
|
- pycparser=2.21=pyhd3eb1b0_0
|
||||||
|
- pygments=2.11.2=pyhd3eb1b0_0
|
||||||
|
- python=3.8.12=h12debd9_0
|
||||||
|
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||||
|
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
|
||||||
|
- python_abi=3.8=2_cp38
|
||||||
|
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
|
||||||
|
- pytorch-mutex=1.0=cuda
|
||||||
|
- pyzmq=22.3.0=py38h295c915_2
|
||||||
|
- readline=8.1.2=h7f8727e_1
|
||||||
|
- rhash=1.4.0=h1ba5d50_0
|
||||||
|
- send2trash=1.8.0=pyhd3eb1b0_1
|
||||||
|
- six=1.16.0=pyhd3eb1b0_0
|
||||||
|
- sqlite=3.37.2=hc218d9a_0
|
||||||
|
- terminado=0.9.4=py38h06a4308_0
|
||||||
|
- testpath=0.5.0=pyhd3eb1b0_0
|
||||||
|
- tk=8.6.11=h1ccaba5_0
|
||||||
|
- torchaudio=0.10.2=py38_cu111
|
||||||
|
- torchvision=0.11.3=py38_cu111
|
||||||
|
- tornado=6.1=py38h27cfd23_0
|
||||||
|
- traitlets=5.1.1=pyhd3eb1b0_0
|
||||||
|
- wcwidth=0.2.5=pyhd3eb1b0_0
|
||||||
|
- webencodings=0.5.1=py38_1
|
||||||
|
- wheel=0.37.1=pyhd3eb1b0_0
|
||||||
|
- widgetsnbextension=3.5.1=py38_0
|
||||||
|
- xz=5.2.5=h7b6447c_0
|
||||||
|
- zeromq=4.3.4=h2531618_0
|
||||||
|
- zipp=3.7.0=pyhd3eb1b0_0
|
||||||
|
- zlib=1.2.11=h7f8727e_4
|
||||||
|
- zstd=1.4.5=h9ceee32_0
|
||||||
|
- pip:
|
||||||
|
- about-time==3.1.1
|
||||||
|
- absl-py==1.0.0
|
||||||
|
- addict==2.4.0
|
||||||
|
- aiohttp==3.8.1
|
||||||
|
- aiosignal==1.2.0
|
||||||
|
- alive-progress==2.2.0
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- anyio==3.5.0
|
||||||
|
- astunparse==1.6.3
|
||||||
|
- async-timeout==4.0.2
|
||||||
|
- babel==2.9.1
|
||||||
|
- cachetools==5.0.0
|
||||||
|
- calmsize==0.1.3
|
||||||
|
- ccimport==0.3.7
|
||||||
|
- cftime==1.6.0
|
||||||
|
- charset-normalizer==2.0.11
|
||||||
|
- click==8.0.3
|
||||||
|
- colorama==0.4.4
|
||||||
|
- comet-ml==3.31.21
|
||||||
|
- commonmark==0.9.1
|
||||||
|
- configobj==5.0.6
|
||||||
|
- crc32c==2.2.post0
|
||||||
|
- cumm-cu111==0.2.8
|
||||||
|
- cupy-cuda111==10.2.0
|
||||||
|
- cycler==0.11.0
|
||||||
|
- cython==0.29.20
|
||||||
|
- dataclasses==0.6
|
||||||
|
- deepspeed==0.6.5
|
||||||
|
- deprecation==2.1.0
|
||||||
|
- diffusers==0.11.1
|
||||||
|
- docker-pycreds==0.4.0
|
||||||
|
- drjit==0.2.1
|
||||||
|
- dulwich==0.20.32
|
||||||
|
- easydict==1.9
|
||||||
|
- einops==0.4.0
|
||||||
|
- everett==3.0.0
|
||||||
|
- fastrlock==0.8
|
||||||
|
- filelock==3.9.0
|
||||||
|
- fire==0.4.0
|
||||||
|
- flatbuffers==2.0
|
||||||
|
- flatten-dict==0.4.2
|
||||||
|
- fonttools==4.29.1
|
||||||
|
- freetype-py==2.3.0
|
||||||
|
- frozenlist==1.3.0
|
||||||
|
- fsspec==2022.2.0
|
||||||
|
- ftfy==6.1.1
|
||||||
|
- future==0.18.2
|
||||||
|
- fvcore==0.1.5.post20220512
|
||||||
|
- gast==0.5.3
|
||||||
|
- gitdb==4.0.9
|
||||||
|
- gitpython==3.1.26
|
||||||
|
- google-auth==2.6.0
|
||||||
|
- google-auth-oauthlib==0.4.6
|
||||||
|
- google-pasta==0.2.0
|
||||||
|
- grapheme==0.6.0
|
||||||
|
- grpcio==1.43.0
|
||||||
|
- h5py==3.6.0
|
||||||
|
- hjson==3.0.2
|
||||||
|
- huggingface-hub==0.11.1
|
||||||
|
- idna==3.3
|
||||||
|
- imageio==2.15.0
|
||||||
|
- imageio-ffmpeg==0.4.5
|
||||||
|
- importlib-metadata==4.10.1
|
||||||
|
- importlib-resources==5.4.0
|
||||||
|
- iopath==0.1.10
|
||||||
|
- jinja2==3.1.1
|
||||||
|
- joblib==1.1.0
|
||||||
|
- json5==0.9.6
|
||||||
|
- jsonschema==4.4.0
|
||||||
|
- jupyter-packaging==0.12.0
|
||||||
|
- jupyter-server==1.15.6
|
||||||
|
- jupyterlab==3.3.2
|
||||||
|
- jupyterlab-server==2.11.2
|
||||||
|
- keras==2.8.0
|
||||||
|
- keras-preprocessing==1.1.2
|
||||||
|
- kiwisolver==1.3.2
|
||||||
|
- kornia==0.6.6
|
||||||
|
- lark==1.1.2
|
||||||
|
- libclang==14.0.1
|
||||||
|
- llvmlite==0.39.0
|
||||||
|
- loguru==0.6.0
|
||||||
|
- markdown==3.3.6
|
||||||
|
- matplotlib==3.5.1
|
||||||
|
- matplotlib2tikz==0.7.6
|
||||||
|
- meshio==5.3.4
|
||||||
|
- mitsuba==3.0.1
|
||||||
|
- mrcfile==1.3.0
|
||||||
|
- multidict==6.0.2
|
||||||
|
- multipledispatch==0.6.0
|
||||||
|
- mypy-extensions==0.4.3
|
||||||
|
- nbclassic==0.3.7
|
||||||
|
- nbformat==5.2.0
|
||||||
|
- nestargs==0.5.0
|
||||||
|
- netcdf4==1.5.8
|
||||||
|
- networkx==2.6.3
|
||||||
|
- ninja==1.10.2.3
|
||||||
|
- notebook-shim==0.1.0
|
||||||
|
- numba==0.56.0
|
||||||
|
- nvidia-ml-py3==7.352.0
|
||||||
|
- oauthlib==3.2.0
|
||||||
|
- omegaconf==2.2.2
|
||||||
|
- open3d==0.15.2
|
||||||
|
- opencv-python==4.5.5.64
|
||||||
|
- openexr==1.3.7
|
||||||
|
- opt-einsum==3.3.0
|
||||||
|
- pandas==1.4.0
|
||||||
|
- pathtools==0.1.2
|
||||||
|
- pccm==0.3.4
|
||||||
|
- pip==22.3.1
|
||||||
|
- plyfile==0.7.4
|
||||||
|
- portalocker==2.5.1
|
||||||
|
- progressbar2==4.0.0
|
||||||
|
- promise==2.3
|
||||||
|
- protobuf==3.19.4
|
||||||
|
- psutil==5.9.0
|
||||||
|
- py-cpuinfo==8.0.0
|
||||||
|
- pyasn1==0.4.8
|
||||||
|
- pyasn1-modules==0.2.8
|
||||||
|
- pybind11==2.10.0
|
||||||
|
- pydeprecate==0.3.1
|
||||||
|
- pyglet==1.5.23
|
||||||
|
- pykeops==1.5
|
||||||
|
- pymcubes==0.1.2
|
||||||
|
- pyopengl==3.1.0
|
||||||
|
- pyparsing==3.0.7
|
||||||
|
- pyquaternion==0.9.9
|
||||||
|
- pyrr==0.10.3
|
||||||
|
- pyrsistent==0.18.1
|
||||||
|
- python-swiftclient==4.0.0
|
||||||
|
- python-utils==3.3.3
|
||||||
|
- pytorch-lightning==1.5.1
|
||||||
|
- pytorch3d==0.3.0
|
||||||
|
- pytz==2021.3
|
||||||
|
- pywavelets==1.2.0
|
||||||
|
- pyyaml==6.0
|
||||||
|
- regex==2022.3.15
|
||||||
|
- requests==2.27.1
|
||||||
|
- requests-oauthlib==1.3.1
|
||||||
|
- requests-toolbelt==0.9.1
|
||||||
|
- rich==12.3.0
|
||||||
|
- rsa==4.8
|
||||||
|
- ruamel-yaml==0.17.20
|
||||||
|
- ruamel-yaml-clib==0.2.6
|
||||||
|
- scikit-image==0.19.1
|
||||||
|
- scikit-learn==1.0.2
|
||||||
|
- scipy==1.8.0
|
||||||
|
- seaborn==0.11.2
|
||||||
|
- semantic-version==2.9.0
|
||||||
|
- sentry-sdk==1.5.4
|
||||||
|
- sharedarray==3.2.1
|
||||||
|
- shortuuid==1.0.8
|
||||||
|
- simple-parsing==0.0.18
|
||||||
|
- simplejson==3.18.0
|
||||||
|
- sklearn==0.0
|
||||||
|
- smmap==5.0.0
|
||||||
|
- sniffio==1.2.0
|
||||||
|
- tabulate==0.8.9
|
||||||
|
- tensorboard==2.8.0
|
||||||
|
- tensorboard-data-server==0.6.1
|
||||||
|
- tensorboard-plugin-wit==1.8.1
|
||||||
|
- tensorboardx==2.4.1
|
||||||
|
- tensorflow-gpu==2.8.0
|
||||||
|
- tensorflow-io-gcs-filesystem==0.25.0
|
||||||
|
- termcolor==1.1.0
|
||||||
|
- tf-estimator-nightly==2.8.0.dev2021122109
|
||||||
|
- tflearn==0.5.0
|
||||||
|
- tfrecord==1.14.1
|
||||||
|
- threadpoolctl==3.1.0
|
||||||
|
- tifffile==2022.2.2
|
||||||
|
- tikzplotlib==0.10.1
|
||||||
|
- tomlkit==0.10.0
|
||||||
|
- torchmetrics==0.7.2
|
||||||
|
- tqdm==4.62.3
|
||||||
|
- trimesh==3.10.1
|
||||||
|
- typing-extensions==4.2.0
|
||||||
|
- typing-inspect==0.7.1
|
||||||
|
- urllib3==1.26.8
|
||||||
|
- wandb==0.12.10
|
||||||
|
- webcolors==1.11.1
|
||||||
|
- websocket-client==1.2.3
|
||||||
|
- werkzeug==2.0.3
|
||||||
|
- wrapt==1.13.3
|
||||||
|
- wurlitzer==3.0.2
|
||||||
|
- yacs==0.1.8
|
||||||
|
- yarl==1.7.2
|
||||||
|
- yaspin==2.1.0
|
|
@ -1,44 +0,0 @@
|
||||||
name: LION
|
|
||||||
channels:
|
|
||||||
- pytorch
|
|
||||||
- nvidia
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
|
||||||
- open3d-admin
|
|
||||||
- comet_ml
|
|
||||||
dependencies:
|
|
||||||
- pip
|
|
||||||
|
|
||||||
- python
|
|
||||||
- pytorch
|
|
||||||
- torchvision
|
|
||||||
- cudatoolkit
|
|
||||||
- matplotlib
|
|
||||||
- tqdm
|
|
||||||
- trimesh
|
|
||||||
- scipy
|
|
||||||
- scikit-learn
|
|
||||||
|
|
||||||
- ftfy
|
|
||||||
- tqdm
|
|
||||||
- regex
|
|
||||||
|
|
||||||
- tabulate
|
|
||||||
- h5py
|
|
||||||
- wandb
|
|
||||||
- pyyaml
|
|
||||||
- open3d
|
|
||||||
- cupy
|
|
||||||
- scikit-image
|
|
||||||
- loguru
|
|
||||||
- einops
|
|
||||||
- comet_ml
|
|
||||||
- diffusers
|
|
||||||
|
|
||||||
- pip:
|
|
||||||
- pykeops
|
|
||||||
- nestargs
|
|
||||||
- flatten_dict
|
|
||||||
- point-cloud-utils
|
|
||||||
- calmsize
|
|
||||||
# - git+https://github.com/openai/CLIP.git
|
|
|
@ -160,7 +160,7 @@ class Prior(nn.Module):
|
||||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
||||||
# self.is_active = None
|
# self.is_active = None
|
||||||
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
||||||
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
||||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
||||||
# self.is_active = None
|
# self.is_active = None
|
||||||
# else:
|
# else:
|
||||||
|
|
2
script/train_vae.sh
Executable file → Normal file
2
script/train_vae.sh
Executable file → Normal file
|
@ -6,7 +6,7 @@ fi
|
||||||
DATA=" ddpm.input_dim 3 data.cates car "
|
DATA=" ddpm.input_dim 3 data.cates car "
|
||||||
NGPU=$1 #
|
NGPU=$1 #
|
||||||
num_node=1
|
num_node=1
|
||||||
BS=6
|
BS=32
|
||||||
total_bs=$(( $NGPU * $BS ))
|
total_bs=$(( $NGPU * $BS ))
|
||||||
if (( $total_bs > 128 )); then
|
if (( $total_bs > 128 )); then
|
||||||
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
||||||
|
|
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||||
|
#include <THC/THC.h>
|
||||||
|
|
||||||
#define CHECK_INPUT(x)
|
#define CHECK_INPUT(x)
|
||||||
|
|
||||||
|
@ -172,9 +173,9 @@ at::Tensor ApproxMatchForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -184,7 +185,7 @@ at::Tensor ApproxMatchForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
C10_CUDA_CHECK(cudaGetLastError());
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
return match;
|
return match;
|
||||||
}
|
}
|
||||||
|
@ -259,9 +260,9 @@ at::Tensor MatchCostForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -270,7 +271,7 @@ at::Tensor MatchCostForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
C10_CUDA_CHECK(cudaGetLastError());
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
@ -376,9 +377,9 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -389,7 +390,7 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
C10_CUDA_CHECK(cudaGetLastError());
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
return std::vector<at::Tensor>({grad1, grad2});
|
return std::vector<at::Tensor>({grad1, grad2});
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,7 +105,7 @@ def main(args, config):
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser('encoder decoder examiner')
|
parser = argparse.ArgumentParser('encoder decoder examiner')
|
||||||
# experimental results
|
# experimental results
|
||||||
parser.add_argument('--exp_root', type=str, default='./exp',
|
parser.add_argument('--exp_root', type=str, default='../exp',
|
||||||
help='location of the results')
|
help='location of the results')
|
||||||
# parser.add_argument('--save', type=str, default='exp',
|
# parser.add_argument('--save', type=str, default='exp',
|
||||||
# help='id used for storing intermediate results')
|
# help='id used for storing intermediate results')
|
||||||
|
@ -176,7 +176,7 @@ def get_args():
|
||||||
config.merge_from_list(args.opt)
|
config.merge_from_list(args.opt)
|
||||||
|
|
||||||
# Create log_name
|
# Create log_name
|
||||||
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', './exp/')
|
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
|
||||||
if config.exp_name == '' or config.exp_name == 'none':
|
if config.exp_name == '' or config.exp_name == 'none':
|
||||||
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
||||||
cfg_file_name = exp_helper.get_expname(config)
|
cfg_file_name = exp_helper.get_expname(config)
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from comet_ml import Experiment, OfflineExperiment
|
from comet_ml import Experiment, OfflineExperiment
|
||||||
|
## import open3d as o3d
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -323,7 +324,7 @@ def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda',
|
||||||
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
|
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
|
||||||
results['jsd'] = jsd
|
results['jsd'] = jsd
|
||||||
msg = print_results(results, **print_kwargs)
|
msg = print_results(results, **print_kwargs)
|
||||||
# with open('./exp/eval_out.txt', 'a') as f:
|
# with open('../exp/eval_out.txt', 'a') as f:
|
||||||
# run_time = time.strftime('%m%d-%H%M-%S')
|
# run_time = time.strftime('%m%d-%H%M-%S')
|
||||||
# f.write('<< date: %s >>\n' % run_time)
|
# f.write('<< date: %s >>\n' % run_time)
|
||||||
# f.write('%s\n%s\n' % (exp.url, msg))
|
# f.write('%s\n%s\n' % (exp.url, msg))
|
||||||
|
|
|
@ -491,7 +491,7 @@ def common_init(rank, seed, save_dir, comet_key=''):
|
||||||
|
|
||||||
if os.path.exists('.wandb_api'):
|
if os.path.exists('.wandb_api'):
|
||||||
wb_args = json.load(open('.wandb_api', 'r'))
|
wb_args = json.load(open('.wandb_api', 'r'))
|
||||||
wb_dir = './exp/wandb/' if not os.path.exists(
|
wb_dir = '../exp/wandb/' if not os.path.exists(
|
||||||
'/workspace/result') else '/workspace/result/wandb/'
|
'/workspace/result') else '/workspace/result/wandb/'
|
||||||
if not os.path.exists(wb_dir):
|
if not os.path.exists(wb_dir):
|
||||||
os.makedirs(wb_dir)
|
os.makedirs(wb_dir)
|
||||||
|
@ -1130,26 +1130,24 @@ def init_processes(rank, size, fn, args, config):
|
||||||
""" Initialize the distributed environment. """
|
""" Initialize the distributed environment. """
|
||||||
os.environ['MASTER_ADDR'] = args.master_address
|
os.environ['MASTER_ADDR'] = args.master_address
|
||||||
os.environ['MASTER_PORT'] = '6020'
|
os.environ['MASTER_PORT'] = '6020'
|
||||||
logger.info('set MASTER_PORT: {}, MASTER_PORT: {}', os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
if args.num_proc_node == 1:
|
||||||
|
import socket
|
||||||
# if args.num_proc_node == 1: # try to solve the port occupied issue
|
import errno
|
||||||
# import socket
|
a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
# import errno
|
for p in range(6010, 6030):
|
||||||
# a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
location = (args.master_address, p) # "127.0.0.1", p)
|
||||||
# for p in range(6010, 6030):
|
try:
|
||||||
# location = (args.master_address, p) # "127.0.0.1", p)
|
a_socket.bind((args.master_address, p))
|
||||||
# try:
|
logger.debug('set port as {}', p)
|
||||||
# a_socket.bind((args.master_address, p))
|
os.environ['MASTER_PORT'] = '%d' % p
|
||||||
# logger.debug('set port as {}', p)
|
a_socket.close()
|
||||||
# os.environ['MASTER_PORT'] = '%d' % p
|
break
|
||||||
# a_socket.close()
|
except socket.error as e:
|
||||||
# break
|
a = 0
|
||||||
# except socket.error as e:
|
# if e.errno == errno.EADDRINUSE:
|
||||||
# a = 0
|
# # logger.debug("Port {} is already in use", p)
|
||||||
# # if e.errno == errno.EADDRINUSE:
|
# else:
|
||||||
# # # logger.debug("Port {} is already in use", p)
|
# logger.debug(e)
|
||||||
# # else:
|
|
||||||
# # logger.debug(e)
|
|
||||||
|
|
||||||
logger.info('init_process: rank={}, world_size={}', rank, size)
|
logger.info('init_process: rank={}, world_size={}', rank, size)
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
|
Loading…
Reference in a new issue