update readme; minor fix and add log msg
This commit is contained in:
parent
031ccd75d6
commit
110f73f4c0
|
@ -48,11 +48,12 @@ run `python demo.py`, will load the released text2shape model on hugging face an
|
||||||
* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path.
|
* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path.
|
||||||
|
|
||||||
### train VAE
|
### train VAE
|
||||||
* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4`)
|
* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4` on A100)
|
||||||
|
* if want to use comet to log the experiment, add `.comet_api` file under the current folder, write the api key as `{"api_key": "${COMET_API_KEY}"}` in the `.comet_api` file
|
||||||
|
|
||||||
### train diffusion prior
|
### train diffusion prior
|
||||||
* require the vae checkpoint
|
* require the vae checkpoint
|
||||||
* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node)
|
* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node on V100)
|
||||||
|
|
||||||
### evaluate a trained prior
|
### evaluate a trained prior
|
||||||
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
|
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
if [ -z "$1" ]
|
||||||
|
then
|
||||||
|
echo "Require NGPU input; "
|
||||||
|
exit
|
||||||
|
fi
|
||||||
loss="mse_sum"
|
loss="mse_sum"
|
||||||
NGPU=$1 ## 1 #8
|
NGPU=$1 ## 1 #8
|
||||||
num_node=2
|
num_node=2
|
||||||
|
|
|
@ -1,8 +1,17 @@
|
||||||
|
if [ -z "$1" ]
|
||||||
|
then
|
||||||
|
echo "Require NGPU input; "
|
||||||
|
exit
|
||||||
|
fi
|
||||||
DATA=" ddpm.input_dim 3 data.cates car "
|
DATA=" ddpm.input_dim 3 data.cates car "
|
||||||
NGPU=$1 #
|
NGPU=$1 #
|
||||||
num_node=1
|
num_node=1
|
||||||
mem=40
|
|
||||||
BS=32
|
BS=32
|
||||||
|
total_bs=$(( $NGPU * $BS ))
|
||||||
|
if (( $total_bs > 128 )); then
|
||||||
|
echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
ENT="python train_dist.py --num_process_per_node $NGPU "
|
ENT="python train_dist.py --num_process_per_node $NGPU "
|
||||||
kl=0.5
|
kl=0.5
|
||||||
|
|
|
@ -12,6 +12,7 @@ import wandb as WB
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import shutil
|
import shutil
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
|
|
Loading…
Reference in a new issue