Compare commits
No commits in common. "403e04de8a9b2395454fa5c67bafee7fa1115c50" and "b84169e724478f136a9ee0eaa541218086bd5f2e" have entirely different histories.
403e04de8a
...
b84169e724
12
README.md
12
README.md
|
@ -27,17 +27,13 @@
|
|||
* Setup the environment
|
||||
Install from conda file
|
||||
```
|
||||
mamba env create -f environment.yml
|
||||
# mamba env update -f environment.yml
|
||||
conda activate LION
|
||||
conda env create --name lion_env --file=env.yaml
|
||||
conda activate lion_env
|
||||
|
||||
# Install some other packages (use proxy)
|
||||
# Install some other packages
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
|
||||
# build some packages first (optional)
|
||||
export CUDA_HOME=/usr/local/cuda # just in case rosetta cucks you
|
||||
module load compilers
|
||||
module load mpfr
|
||||
python build_pkg.py
|
||||
```
|
||||
Tested with conda version 22.9.0
|
||||
|
@ -48,7 +44,7 @@
|
|||
|
||||
|
||||
## Demo
|
||||
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. Download checkpoints from [HuggingFace Hub](https://huggingface.co/xiaohui2022/lion_ckpt)
|
||||
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud. (Note: the checkpoint is not released yet, the files loaded in the `demo.py` file is not available at this point)
|
||||
|
||||
## Released checkpoint and samples
|
||||
* will be release soon
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
bash_name: ./exp/tmp/2022_0407_0300_45.sh
|
||||
bash_name: ../exp/tmp/2022_0407_0300_45.sh
|
||||
clipforge:
|
||||
clip_model: ViT-B/32
|
||||
enable: 0
|
||||
|
@ -105,13 +105,13 @@ latent_pts:
|
|||
weight_kl_feat: 1.0
|
||||
weight_kl_glb: 1.0
|
||||
weight_kl_pt: 1.0
|
||||
log_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
log_name: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
log_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
log_name: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
model_config: default
|
||||
ngpu: 8
|
||||
num_ref: 0
|
||||
num_val_samples: 24
|
||||
save_dir: ./exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
save_dir: ../exp/0407/airplane/cb8fb3h_train_l2e-4GlobalP2048_vae_adainB20l1E3W8
|
||||
sde:
|
||||
attn_mhead: 0
|
||||
attn_mhead_local: -1
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
bash_name: ./exp/tmp/2022_0407_1347_21.sh
|
||||
bash_name: ../exp/tmp/2022_0407_1347_21.sh
|
||||
clipforge:
|
||||
clip_model: ViT-B/32
|
||||
enable: 0
|
||||
|
@ -105,13 +105,13 @@ latent_pts:
|
|||
weight_kl_feat: 1.0
|
||||
weight_kl_glb: 1.0
|
||||
weight_kl_pt: 1.0
|
||||
log_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
log_name: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
log_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
log_name: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
model_config: default
|
||||
ngpu: 8
|
||||
num_ref: 0
|
||||
num_val_samples: 24
|
||||
save_dir: ./exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
save_dir: ../exp/0407/car/fbb941h_train_l2e-4GlobalP2048D03_vae_adainB20l1E3W8
|
||||
sde:
|
||||
attn_mhead: 0
|
||||
attn_mhead_local: -1
|
||||
|
@ -195,7 +195,7 @@ sde:
|
|||
update_q_ema: false
|
||||
use_adam: true
|
||||
use_adamax: false
|
||||
vae_checkpoint: ./exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
||||
vae_checkpoint: ../exp/0326/car/f91abeh_hvae_kl0.5N32H1Anneall1_sumWlrInitScale_vae_adainB32l1E3W4/checkpoints/epoch_7999_iters_151999.pt
|
||||
warmup_epochs: 20
|
||||
weight_decay: 0.0003
|
||||
weight_decay_norm_dae: 0.0
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
bash_name: ./exp/tmp/2022_0416_1418_42.sh
|
||||
bash_name: ../exp/tmp/2022_0416_1418_42.sh
|
||||
clipforge:
|
||||
clip_model: ViT-B/32
|
||||
enable: 0
|
||||
|
@ -105,13 +105,13 @@ latent_pts:
|
|||
weight_kl_feat: 1.0
|
||||
weight_kl_glb: 1.0
|
||||
weight_kl_pt: 1.0
|
||||
log_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
log_name: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
log_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
log_name: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
model_config: default
|
||||
ngpu: 8
|
||||
num_ref: 0
|
||||
num_val_samples: 24
|
||||
save_dir: ./exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
save_dir: ../exp/0416/chair/afc967h_train_l2e-4GlobalP2048D04_vae_adainB20l1E3W8
|
||||
sde:
|
||||
attn_mhead: 0
|
||||
attn_mhead_local: -1
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
||||
import os
|
||||
import open3d as o3d
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -173,7 +174,7 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
# obj_fname = os.path.join(sub_path, x)
|
||||
if self.clip_forge_enable:
|
||||
synset_id = subd
|
||||
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016')
|
||||
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
|
||||
|
||||
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
||||
#if not (os.path.exists(render_img_path)): continue
|
||||
|
|
311
env.yaml
Normal file
311
env.yaml
Normal file
|
@ -0,0 +1,311 @@
|
|||
name: lion_env
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- anaconda
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- argon2-cffi=20.1.0=py38h27cfd23_1
|
||||
- async_generator=1.10=pyhd3eb1b0_0
|
||||
- attrs=21.4.0=pyhd3eb1b0_0
|
||||
- backcall=0.2.0=pyhd3eb1b0_0
|
||||
- blas=1.0=mkl
|
||||
- bleach=4.1.0=pyhd3eb1b0_0
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2020.10.14=0
|
||||
- certifi=2020.6.20=py38_0
|
||||
- cffi=1.15.0=py38hd667e15_1
|
||||
- cmake=3.18.2=ha30ef3c_0
|
||||
- cudatoolkit=11.1.74=h6bb024c_0
|
||||
- debugpy=1.5.1=py38h295c915_0
|
||||
- decorator=5.1.1=pyhd3eb1b0_0
|
||||
- defusedxml=0.7.1=pyhd3eb1b0_0
|
||||
- entrypoints=0.3=py38_0
|
||||
- expat=2.2.10=he6710b0_2
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
- giflib=5.2.1=h7b6447c_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- importlib_metadata=4.8.2=hd3eb1b0_0
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- ipykernel=6.4.1=py38h06a4308_1
|
||||
- ipython=7.31.1=py38h06a4308_0
|
||||
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
||||
- ipywidgets=7.6.5=pyhd3eb1b0_1
|
||||
- jedi=0.18.1=py38h06a4308_0
|
||||
- jpeg=9d=h7f8727e_0
|
||||
- jupyter_client=7.1.2=pyhd3eb1b0_0
|
||||
- jupyter_core=4.9.1=py38h06a4308_0
|
||||
- jupyterlab_pygments=0.1.2=py_0
|
||||
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
|
||||
- krb5=1.18.2=h173b8e3_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libcurl=7.71.1=h20c2e04_1
|
||||
- libedit=3.1.20191231=h14c3975_1
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libsodium=1.0.18=h7b6447c_0
|
||||
- libssh2=1.9.0=h1ba5d50_1
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp=1.2.0=h89dd481_0
|
||||
- libwebp-base=1.2.0=h27cfd23_0
|
||||
- lz4-c=1.9.3=h295c915_1
|
||||
- markupsafe=2.0.1=py38h27cfd23_0
|
||||
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
||||
- mistune=0.8.4=py38h7b6447c_1000
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py38h7f8727e_0
|
||||
- mkl_fft=1.3.1=py38hd3c417c_0
|
||||
- mkl_random=1.2.2=py38h51133e4_0
|
||||
- nbclient=0.5.3=pyhd3eb1b0_0
|
||||
- nbconvert=6.3.0=py38h06a4308_0
|
||||
- ncurses=6.3=h7f8727e_2
|
||||
- nest-asyncio=1.5.1=pyhd3eb1b0_0
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- notebook=6.4.6=py38h06a4308_0
|
||||
- numpy=1.21.2=py38h20f2e39_0
|
||||
- numpy-base=1.21.2=py38h79a1101_0
|
||||
- olefile=0.46=pyhd3eb1b0_0
|
||||
- openh264=2.1.1=h4ff587b_0
|
||||
- openssl=1.1.1m=h7f8727e_0
|
||||
- packaging=21.3=pyhd3eb1b0_0
|
||||
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
||||
- parso=0.8.3=pyhd3eb1b0_0
|
||||
- pexpect=4.8.0=pyhd3eb1b0_3
|
||||
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
||||
- pillow=8.4.0=py38h5aabda8_0
|
||||
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
||||
- prometheus_client=0.13.1=pyhd3eb1b0_0
|
||||
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
||||
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
||||
- pycparser=2.21=pyhd3eb1b0_0
|
||||
- pygments=2.11.2=pyhd3eb1b0_0
|
||||
- python=3.8.12=h12debd9_0
|
||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
|
||||
- python_abi=3.8=2_cp38
|
||||
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pyzmq=22.3.0=py38h295c915_2
|
||||
- readline=8.1.2=h7f8727e_1
|
||||
- rhash=1.4.0=h1ba5d50_0
|
||||
- send2trash=1.8.0=pyhd3eb1b0_1
|
||||
- six=1.16.0=pyhd3eb1b0_0
|
||||
- sqlite=3.37.2=hc218d9a_0
|
||||
- terminado=0.9.4=py38h06a4308_0
|
||||
- testpath=0.5.0=pyhd3eb1b0_0
|
||||
- tk=8.6.11=h1ccaba5_0
|
||||
- torchaudio=0.10.2=py38_cu111
|
||||
- torchvision=0.11.3=py38_cu111
|
||||
- tornado=6.1=py38h27cfd23_0
|
||||
- traitlets=5.1.1=pyhd3eb1b0_0
|
||||
- wcwidth=0.2.5=pyhd3eb1b0_0
|
||||
- webencodings=0.5.1=py38_1
|
||||
- wheel=0.37.1=pyhd3eb1b0_0
|
||||
- widgetsnbextension=3.5.1=py38_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- zeromq=4.3.4=h2531618_0
|
||||
- zipp=3.7.0=pyhd3eb1b0_0
|
||||
- zlib=1.2.11=h7f8727e_4
|
||||
- zstd=1.4.5=h9ceee32_0
|
||||
- pip:
|
||||
- about-time==3.1.1
|
||||
- absl-py==1.0.0
|
||||
- addict==2.4.0
|
||||
- aiohttp==3.8.1
|
||||
- aiosignal==1.2.0
|
||||
- alive-progress==2.2.0
|
||||
- antlr4-python3-runtime==4.9.3
|
||||
- anyio==3.5.0
|
||||
- astunparse==1.6.3
|
||||
- async-timeout==4.0.2
|
||||
- babel==2.9.1
|
||||
- cachetools==5.0.0
|
||||
- calmsize==0.1.3
|
||||
- ccimport==0.3.7
|
||||
- cftime==1.6.0
|
||||
- charset-normalizer==2.0.11
|
||||
- click==8.0.3
|
||||
- colorama==0.4.4
|
||||
- comet-ml==3.31.21
|
||||
- commonmark==0.9.1
|
||||
- configobj==5.0.6
|
||||
- crc32c==2.2.post0
|
||||
- cumm-cu111==0.2.8
|
||||
- cupy-cuda111==10.2.0
|
||||
- cycler==0.11.0
|
||||
- cython==0.29.20
|
||||
- dataclasses==0.6
|
||||
- deepspeed==0.6.5
|
||||
- deprecation==2.1.0
|
||||
- diffusers==0.11.1
|
||||
- docker-pycreds==0.4.0
|
||||
- drjit==0.2.1
|
||||
- dulwich==0.20.32
|
||||
- easydict==1.9
|
||||
- einops==0.4.0
|
||||
- everett==3.0.0
|
||||
- fastrlock==0.8
|
||||
- filelock==3.9.0
|
||||
- fire==0.4.0
|
||||
- flatbuffers==2.0
|
||||
- flatten-dict==0.4.2
|
||||
- fonttools==4.29.1
|
||||
- freetype-py==2.3.0
|
||||
- frozenlist==1.3.0
|
||||
- fsspec==2022.2.0
|
||||
- ftfy==6.1.1
|
||||
- future==0.18.2
|
||||
- fvcore==0.1.5.post20220512
|
||||
- gast==0.5.3
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.26
|
||||
- google-auth==2.6.0
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- google-pasta==0.2.0
|
||||
- grapheme==0.6.0
|
||||
- grpcio==1.43.0
|
||||
- h5py==3.6.0
|
||||
- hjson==3.0.2
|
||||
- huggingface-hub==0.11.1
|
||||
- idna==3.3
|
||||
- imageio==2.15.0
|
||||
- imageio-ffmpeg==0.4.5
|
||||
- importlib-metadata==4.10.1
|
||||
- importlib-resources==5.4.0
|
||||
- iopath==0.1.10
|
||||
- jinja2==3.1.1
|
||||
- joblib==1.1.0
|
||||
- json5==0.9.6
|
||||
- jsonschema==4.4.0
|
||||
- jupyter-packaging==0.12.0
|
||||
- jupyter-server==1.15.6
|
||||
- jupyterlab==3.3.2
|
||||
- jupyterlab-server==2.11.2
|
||||
- keras==2.8.0
|
||||
- keras-preprocessing==1.1.2
|
||||
- kiwisolver==1.3.2
|
||||
- kornia==0.6.6
|
||||
- lark==1.1.2
|
||||
- libclang==14.0.1
|
||||
- llvmlite==0.39.0
|
||||
- loguru==0.6.0
|
||||
- markdown==3.3.6
|
||||
- matplotlib==3.5.1
|
||||
- matplotlib2tikz==0.7.6
|
||||
- meshio==5.3.4
|
||||
- mitsuba==3.0.1
|
||||
- mrcfile==1.3.0
|
||||
- multidict==6.0.2
|
||||
- multipledispatch==0.6.0
|
||||
- mypy-extensions==0.4.3
|
||||
- nbclassic==0.3.7
|
||||
- nbformat==5.2.0
|
||||
- nestargs==0.5.0
|
||||
- netcdf4==1.5.8
|
||||
- networkx==2.6.3
|
||||
- ninja==1.10.2.3
|
||||
- notebook-shim==0.1.0
|
||||
- numba==0.56.0
|
||||
- nvidia-ml-py3==7.352.0
|
||||
- oauthlib==3.2.0
|
||||
- omegaconf==2.2.2
|
||||
- open3d==0.15.2
|
||||
- opencv-python==4.5.5.64
|
||||
- openexr==1.3.7
|
||||
- opt-einsum==3.3.0
|
||||
- pandas==1.4.0
|
||||
- pathtools==0.1.2
|
||||
- pccm==0.3.4
|
||||
- pip==22.3.1
|
||||
- plyfile==0.7.4
|
||||
- portalocker==2.5.1
|
||||
- progressbar2==4.0.0
|
||||
- promise==2.3
|
||||
- protobuf==3.19.4
|
||||
- psutil==5.9.0
|
||||
- py-cpuinfo==8.0.0
|
||||
- pyasn1==0.4.8
|
||||
- pyasn1-modules==0.2.8
|
||||
- pybind11==2.10.0
|
||||
- pydeprecate==0.3.1
|
||||
- pyglet==1.5.23
|
||||
- pykeops==1.5
|
||||
- pymcubes==0.1.2
|
||||
- pyopengl==3.1.0
|
||||
- pyparsing==3.0.7
|
||||
- pyquaternion==0.9.9
|
||||
- pyrr==0.10.3
|
||||
- pyrsistent==0.18.1
|
||||
- python-swiftclient==4.0.0
|
||||
- python-utils==3.3.3
|
||||
- pytorch-lightning==1.5.1
|
||||
- pytorch3d==0.3.0
|
||||
- pytz==2021.3
|
||||
- pywavelets==1.2.0
|
||||
- pyyaml==6.0
|
||||
- regex==2022.3.15
|
||||
- requests==2.27.1
|
||||
- requests-oauthlib==1.3.1
|
||||
- requests-toolbelt==0.9.1
|
||||
- rich==12.3.0
|
||||
- rsa==4.8
|
||||
- ruamel-yaml==0.17.20
|
||||
- ruamel-yaml-clib==0.2.6
|
||||
- scikit-image==0.19.1
|
||||
- scikit-learn==1.0.2
|
||||
- scipy==1.8.0
|
||||
- seaborn==0.11.2
|
||||
- semantic-version==2.9.0
|
||||
- sentry-sdk==1.5.4
|
||||
- sharedarray==3.2.1
|
||||
- shortuuid==1.0.8
|
||||
- simple-parsing==0.0.18
|
||||
- simplejson==3.18.0
|
||||
- sklearn==0.0
|
||||
- smmap==5.0.0
|
||||
- sniffio==1.2.0
|
||||
- tabulate==0.8.9
|
||||
- tensorboard==2.8.0
|
||||
- tensorboard-data-server==0.6.1
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- tensorboardx==2.4.1
|
||||
- tensorflow-gpu==2.8.0
|
||||
- tensorflow-io-gcs-filesystem==0.25.0
|
||||
- termcolor==1.1.0
|
||||
- tf-estimator-nightly==2.8.0.dev2021122109
|
||||
- tflearn==0.5.0
|
||||
- tfrecord==1.14.1
|
||||
- threadpoolctl==3.1.0
|
||||
- tifffile==2022.2.2
|
||||
- tikzplotlib==0.10.1
|
||||
- tomlkit==0.10.0
|
||||
- torchmetrics==0.7.2
|
||||
- tqdm==4.62.3
|
||||
- trimesh==3.10.1
|
||||
- typing-extensions==4.2.0
|
||||
- typing-inspect==0.7.1
|
||||
- urllib3==1.26.8
|
||||
- wandb==0.12.10
|
||||
- webcolors==1.11.1
|
||||
- websocket-client==1.2.3
|
||||
- werkzeug==2.0.3
|
||||
- wrapt==1.13.3
|
||||
- wurlitzer==3.0.2
|
||||
- yacs==0.1.8
|
||||
- yarl==1.7.2
|
||||
- yaspin==2.1.0
|
|
@ -1,44 +0,0 @@
|
|||
name: LION
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
- open3d-admin
|
||||
- comet_ml
|
||||
dependencies:
|
||||
- pip
|
||||
|
||||
- python
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit
|
||||
- matplotlib
|
||||
- tqdm
|
||||
- trimesh
|
||||
- scipy
|
||||
- scikit-learn
|
||||
|
||||
- ftfy
|
||||
- tqdm
|
||||
- regex
|
||||
|
||||
- tabulate
|
||||
- h5py
|
||||
- wandb
|
||||
- pyyaml
|
||||
- open3d
|
||||
- cupy
|
||||
- scikit-image
|
||||
- loguru
|
||||
- einops
|
||||
- comet_ml
|
||||
- diffusers
|
||||
|
||||
- pip:
|
||||
- pykeops
|
||||
- nestargs
|
||||
- flatten_dict
|
||||
- point-cloud-utils
|
||||
- calmsize
|
||||
# - git+https://github.com/openai/CLIP.git
|
|
@ -160,7 +160,7 @@ class Prior(nn.Module):
|
|||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
||||
# self.is_active = None
|
||||
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
||||
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
||||
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
||||
# self.is_active = None
|
||||
# else:
|
||||
|
|
2
script/train_vae.sh
Executable file → Normal file
2
script/train_vae.sh
Executable file → Normal file
|
@ -6,7 +6,7 @@ fi
|
|||
DATA=" ddpm.input_dim 3 data.cates car "
|
||||
NGPU=$1 #
|
||||
num_node=1
|
||||
BS=6
|
||||
BS=32
|
||||
total_bs=$(( $NGPU * $BS ))
|
||||
if (( $total_bs > 128 )); then
|
||||
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
||||
|
|
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||
#include <THC/THC.h>
|
||||
|
||||
#define CHECK_INPUT(x)
|
||||
|
||||
|
@ -172,9 +173,9 @@ at::Tensor ApproxMatchForward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -184,7 +185,7 @@ at::Tensor ApproxMatchForward(
|
|||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
@ -259,9 +260,9 @@ at::Tensor MatchCostForward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -270,7 +271,7 @@ at::Tensor MatchCostForward(
|
|||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
@ -376,9 +377,9 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -389,7 +390,7 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ def main(args, config):
|
|||
def get_args():
|
||||
parser = argparse.ArgumentParser('encoder decoder examiner')
|
||||
# experimental results
|
||||
parser.add_argument('--exp_root', type=str, default='./exp',
|
||||
parser.add_argument('--exp_root', type=str, default='../exp',
|
||||
help='location of the results')
|
||||
# parser.add_argument('--save', type=str, default='exp',
|
||||
# help='id used for storing intermediate results')
|
||||
|
@ -176,7 +176,7 @@ def get_args():
|
|||
config.merge_from_list(args.opt)
|
||||
|
||||
# Create log_name
|
||||
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', './exp/')
|
||||
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
|
||||
if config.exp_name == '' or config.exp_name == 'none':
|
||||
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
||||
cfg_file_name = exp_helper.get_expname(config)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
import os
|
||||
import json
|
||||
from comet_ml import Experiment, OfflineExperiment
|
||||
## import open3d as o3d
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -323,7 +324,7 @@ def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda',
|
|||
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
|
||||
results['jsd'] = jsd
|
||||
msg = print_results(results, **print_kwargs)
|
||||
# with open('./exp/eval_out.txt', 'a') as f:
|
||||
# with open('../exp/eval_out.txt', 'a') as f:
|
||||
# run_time = time.strftime('%m%d-%H%M-%S')
|
||||
# f.write('<< date: %s >>\n' % run_time)
|
||||
# f.write('%s\n%s\n' % (exp.url, msg))
|
||||
|
|
|
@ -491,7 +491,7 @@ def common_init(rank, seed, save_dir, comet_key=''):
|
|||
|
||||
if os.path.exists('.wandb_api'):
|
||||
wb_args = json.load(open('.wandb_api', 'r'))
|
||||
wb_dir = './exp/wandb/' if not os.path.exists(
|
||||
wb_dir = '../exp/wandb/' if not os.path.exists(
|
||||
'/workspace/result') else '/workspace/result/wandb/'
|
||||
if not os.path.exists(wb_dir):
|
||||
os.makedirs(wb_dir)
|
||||
|
@ -1130,26 +1130,24 @@ def init_processes(rank, size, fn, args, config):
|
|||
""" Initialize the distributed environment. """
|
||||
os.environ['MASTER_ADDR'] = args.master_address
|
||||
os.environ['MASTER_PORT'] = '6020'
|
||||
logger.info('set MASTER_PORT: {}, MASTER_PORT: {}', os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
||||
|
||||
# if args.num_proc_node == 1: # try to solve the port occupied issue
|
||||
# import socket
|
||||
# import errno
|
||||
# a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# for p in range(6010, 6030):
|
||||
# location = (args.master_address, p) # "127.0.0.1", p)
|
||||
# try:
|
||||
# a_socket.bind((args.master_address, p))
|
||||
# logger.debug('set port as {}', p)
|
||||
# os.environ['MASTER_PORT'] = '%d' % p
|
||||
# a_socket.close()
|
||||
# break
|
||||
# except socket.error as e:
|
||||
# a = 0
|
||||
# # if e.errno == errno.EADDRINUSE:
|
||||
# # # logger.debug("Port {} is already in use", p)
|
||||
# # else:
|
||||
# # logger.debug(e)
|
||||
if args.num_proc_node == 1:
|
||||
import socket
|
||||
import errno
|
||||
a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
for p in range(6010, 6030):
|
||||
location = (args.master_address, p) # "127.0.0.1", p)
|
||||
try:
|
||||
a_socket.bind((args.master_address, p))
|
||||
logger.debug('set port as {}', p)
|
||||
os.environ['MASTER_PORT'] = '%d' % p
|
||||
a_socket.close()
|
||||
break
|
||||
except socket.error as e:
|
||||
a = 0
|
||||
# if e.errno == errno.EADDRINUSE:
|
||||
# # logger.debug("Port {} is already in use", p)
|
||||
# else:
|
||||
# logger.debug(e)
|
||||
|
||||
logger.info('init_process: rank={}, world_size={}', rank, size)
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
|
|
Loading…
Reference in a new issue