diff --git a/script/train_vae.sh b/script/train_vae.sh old mode 100644 new mode 100755 index f5627e0..fe4468e --- a/script/train_vae.sh +++ b/script/train_vae.sh @@ -6,7 +6,7 @@ fi DATA=" ddpm.input_dim 3 data.cates car " NGPU=$1 # num_node=1 -BS=32 +BS=6 total_bs=$(( $NGPU * $BS )) if (( $total_bs > 128 )); then echo "[WARNING] total batch_size larger than 128 may lead to unstable training, please reduce the size"