diff --git a/script/train_prior.sh b/script/train_prior.sh index 2ac1698..a867dc9 100644 --- a/script/train_prior.sh +++ b/script/train_prior.sh @@ -32,7 +32,7 @@ $ENT \ sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \ sde.train_vae $train_vae \ sde.embedding_scale 1.0 \ - viz.save_freq 1000 \ + viz.save_freq 1 \ viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \ data.batch_size $BS \ trainer.type 'trainers.train_2prior' \ diff --git a/train_dist.py b/train_dist.py index 2786179..88422b0 100644 --- a/train_dist.py +++ b/train_dist.py @@ -76,7 +76,9 @@ def main(args, config): else: raise NotImplementedError elif args.pretrained is not None: - trainer.load_vae(args.pretrained) + logger.info('Resuming training from {}; if you dont want resume training, edit the cmt to change the exp name', + args.pretrained) + trainer.resume(args.pretrained) if not args.eval_generation: trainer.train_epochs()