update readme; minor fix and add log msg

This commit is contained in:
xzeng 2023-01-25 17:02:07 -05:00
parent 031ccd75d6
commit 110f73f4c0
4 changed files with 19 additions and 3 deletions

View file

@ -48,11 +48,12 @@ run `python demo.py`, will load the released text2shape model on hugging face an
* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path. * Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path.
### train VAE ### train VAE
* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4`) * run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4` on A100)
* if want to use comet to log the experiment, add `.comet_api` file under the current folder, write the api key as `{"api_key": "${COMET_API_KEY}"}` in the `.comet_api` file
### train diffusion prior ### train diffusion prior
* require the vae checkpoint * require the vae checkpoint
* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node) * run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node on V100)
### evaluate a trained prior ### evaluate a trained prior
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/` * download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`

View file

@ -1,3 +1,8 @@
if [ -z "$1" ]
then
echo "Require NGPU input; "
exit
fi
loss="mse_sum" loss="mse_sum"
NGPU=$1 ## 1 #8 NGPU=$1 ## 1 #8
num_node=2 num_node=2

View file

@ -1,8 +1,17 @@
if [ -z "$1" ]
then
echo "Require NGPU input; "
exit
fi
DATA=" ddpm.input_dim 3 data.cates car " DATA=" ddpm.input_dim 3 data.cates car "
NGPU=$1 # NGPU=$1 #
num_node=1 num_node=1
mem=40
BS=32 BS=32
total_bs=$(( $NGPU * $BS ))
if (( $total_bs > 128 )); then
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
exit
fi
ENT="python train_dist.py --num_process_per_node $NGPU " ENT="python train_dist.py --num_process_per_node $NGPU "
kl=0.5 kl=0.5

View file

@ -12,6 +12,7 @@ import wandb as WB
import os import os
import math import math
import shutil import shutil
import json
import time import time
import sys import sys
import types import types