From a0496fc843758f50de6c58322feeed47437ff4b4 Mon Sep 17 00:00:00 2001 From: xzeng Date: Tue, 24 Jan 2023 11:06:55 -0500 Subject: [PATCH] add scripts --- script/eval.sh | 1 + script/train_prior.sh | 34 ++++++++++++++++++++++++++++++++++ script/train_vae.sh | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+) create mode 100644 script/eval.sh create mode 100644 script/train_prior.sh create mode 100644 script/train_vae.sh diff --git a/script/eval.sh b/script/eval.sh new file mode 100644 index 0000000..cb7bfdd --- /dev/null +++ b/script/eval.sh @@ -0,0 +1 @@ +python train_dist.py --skip_nll 1 --eval_generation --pretrained $1 ddpm.model_var_type "fixedlarge" data.batch_size_test 3 ddpm.ema 1 num_val_samples 3 diff --git a/script/train_prior.sh b/script/train_prior.sh new file mode 100644 index 0000000..9562605 --- /dev/null +++ b/script/train_prior.sh @@ -0,0 +1,34 @@ +loss="mse_sum" +NGPU=$1 ## 1 #8 +num_node=2 +mem=32 +BS=10 +lr=2e-4 +ENT="python train_dist.py --num_process_per_node $NGPU " +train_vae=False +cmt="lion" +ckpt="./lion_ckpt/unconditional/car/checkpoints/vae_only.pt" + +$ENT \ + --config "./lion_ckpt/unconditional/car/cfg.yml" \ + latent_pts.pvd_mse_loss 1 \ + vis_latent_point 1 \ + num_val_samples 24 \ + ddpm.ema 1 \ + ddpm.use_bn False ddpm.use_gn True \ + ddpm.time_dim 64 \ + ddpm.beta_T 0.02 \ + sde.vae_checkpoint $ckpt \ + sde.learning_rate_dae $lr sde.learning_rate_min_dae $lr \ + trainer.epochs 18000 \ + sde.num_channels_dae 2048 \ + sde.dropout 0.3 \ + latent_pts.style_prior 'models.score_sde.resnet.PriorSEDrop' \ + sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \ + sde.train_vae $train_vae \ + sde.embedding_scale 1.0 \ + viz.save_freq 1000 \ + viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \ + data.batch_size $BS \ + trainer.type 'trainers.train_2prior' \ + cmt $cmt diff --git a/script/train_vae.sh b/script/train_vae.sh new file mode 100644 index 0000000..c4c6d80 --- /dev/null +++ b/script/train_vae.sh @@ -0,0 +1,39 @@ +DATA=" ddpm.input_dim 3 data.cates car " +NGPU=$1 # +num_node=1 +mem=40 +BS=32 + +ENT="python train_dist.py --num_process_per_node $NGPU " +kl=0.5 +lr=1e-3 +latent=1 +skip_weight=0.01 +sigma_offset=6.0 +loss='l1_sum' + +$ENT ddpm.num_steps 1 ddpm.ema 0 \ + trainer.opt.vae_lr_warmup_epochs 0 \ + latent_pts.ada_mlp_init_scale 0.1 \ + sde.kl_const_coeff_vada 1e-7 \ + trainer.anneal_kl 1 sde.kl_max_coeff_vada $kl \ + sde.kl_anneal_portion_vada 0.5 \ + shapelatent.log_sigma_offset $sigma_offset latent_pts.skip_weight $skip_weight \ + trainer.opt.beta2 0.99 \ + data.num_workers 4 \ + ddpm.loss_weight_emd 1.0 \ + trainer.epochs 8000 data.random_subsample 1 \ + viz.viz_freq -400 viz.log_freq -1 viz.val_freq 200 \ + data.batch_size $BS viz.save_freq 2000 \ + trainer.type 'trainers.hvae_trainer' \ + model_config default shapelatent.model 'models.vae_adain' \ + shapelatent.decoder_type 'models.latent_points_ada.LatentPointDecPVC' \ + shapelatent.encoder_type 'models.latent_points_ada.PointTransPVC' \ + latent_pts.style_encoder 'models.shapelatent_modules.PointNetPlusEncoder' \ + shapelatent.prior_type normal \ + shapelatent.latent_dim $latent trainer.opt.lr $lr \ + shapelatent.kl_weight ${kl} \ + shapelatent.decoder_num_points 2048 \ + data.tr_max_sample_points 2048 data.te_max_sample_points 2048 \ + ddpm.loss_type $loss cmt "lion" \ + $DATA viz.viz_order [2,0,1] data.recenter_per_shape False data.normalize_global True