From b2cf1e19ac95085273ad9543e21cda0cefd96e5a Mon Sep 17 00:00:00 2001 From: xzeng Date: Mon, 3 Apr 2023 17:03:27 -0400 Subject: [PATCH] add scripts to train with clip feat --- datasets/data_path.py | 3 +++ datasets/pointflow_datasets.py | 46 ++++++++++++++++++++++++++++++++-- script/train_prior_clip.sh | 44 ++++++++++++++++++++++++++++++++ trainers/train_2prior.py | 2 ++ 4 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 script/train_prior_clip.sh diff --git a/datasets/data_path.py b/datasets/data_path.py index 7c4f94e..849c2c9 100644 --- a/datasets/data_path.py +++ b/datasets/data_path.py @@ -14,6 +14,9 @@ def get_path(dataname=None): './data/ShapeNetCore.v2.PC15k/' ] + dataset_path['clip_forge_image'] = [ + './data/shapenet_render/' + ] if dataname is None: return dataset_path diff --git a/datasets/pointflow_datasets.py b/datasets/pointflow_datasets.py index 9ab7c5e..a5aa700 100644 --- a/datasets/pointflow_datasets.py +++ b/datasets/pointflow_datasets.py @@ -18,6 +18,7 @@ from torch.utils import data import random import tqdm from datasets.data_path import get_path +from PIL import Image OVERFIT = 0 # taken from https://github.com/optas/latent_3d_points/blob/ @@ -101,7 +102,16 @@ class ShapeNet15kPointClouds(Dataset): all_points_mean=None, all_points_std=None, input_dim=3, + clip_forge_enable=0, clip_model=None ): + self.clip_forge_enable = clip_forge_enable + if clip_forge_enable: + import clip + _, self.clip_preprocess = clip.load(clip_model) + if self.clip_forge_enable: + self.img_path = [] + img_path = get_path('clip_forge_image') + self.normalize_shape_box = normalize_shape_box root_dir = get_path('pointflow') self.root_dir = root_dir @@ -146,6 +156,7 @@ class ShapeNet15kPointClouds(Dataset): print("Directory missing : %s " % (sub_path)) raise ValueError('check the data path') continue + if True: all_mids = [] assert(os.path.exists(sub_path)), f'path missing: {sub_path}' @@ -161,6 +172,15 @@ class ShapeNet15kPointClouds(Dataset): all_mids = sorted(all_mids) for mid in all_mids: # obj_fname = os.path.join(sub_path, x) + if self.clip_forge_enable: + synset_id = subd + render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016' + + #render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1]) + #if not (os.path.exists(render_img_path)): continue + self.img_path.append(render_img_path) + assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}' + obj_fname = os.path.join(root_dir, subd, mid + ".npy") point_cloud = np.load(obj_fname) # (15k, 3) self.all_points.append(point_cloud[np.newaxis, ...]) @@ -177,6 +197,8 @@ class ShapeNet15kPointClouds(Dataset): self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx] self.all_points = [self.all_points[i] for i in self.shuffle_idx] self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx] + if self.clip_forge_enable: + self.img_path = [self.img_path[i] for i in self.shuffle_idx] # Normalization self.all_points = np.concatenate(self.all_points) # (N, 15000, 3) @@ -248,7 +270,7 @@ class ShapeNet15kPointClouds(Dataset): # TODO: why do we need this?? self.train_points = self.all_points[:, :min( - 10000, self.all_points.shape[1])] + 10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape self.tr_sample_size = min(10000, tr_sample_size) self.te_sample_size = min(5000, te_sample_size) assert self.scale == 1, "Scale (!= 1) is deprecated" @@ -300,7 +322,7 @@ class ShapeNet15kPointClouds(Dataset): cate_idx = self.cate_idx_lst[idx] sid, mid = self.all_cate_mids[idx] input_pts = tr_out - + output.update( { 'idx': idx, @@ -314,6 +336,22 @@ class ShapeNet15kPointClouds(Dataset): 'mid': mid, 'display_axis_order': self.display_axis_order }) + + # read image + if self.clip_forge_enable: + img_path = self.img_path[idx] + img_list = os.listdir(img_path) + img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p] + assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}' + # subset 5 image + img_idx = np.random.choice(len(img_list), 5) + img_list = [img_list[o] for o in img_idx] + img_list = [Image.open(img).convert('RGB') for img in img_list] + img_list = [self.clip_preprocess(img) for img in img_list] + img_list = torch.stack(img_list, dim=0) # B,3,H,W + all_img = img_list + output['tr_img'] = all_img + return output @@ -352,6 +390,8 @@ def get_datasets(cfg, args): normalize_global=cfg.normalize_global, recenter_per_shape=cfg.recenter_per_shape, random_subsample=random_subsample, + clip_forge_enable=cfg.clip_forge_enable, + clip_model=cfg.clip_model, **kwargs) eval_split = getattr(args, "eval_split", "val") @@ -369,6 +409,8 @@ def get_datasets(cfg, args): recenter_per_shape=cfg.recenter_per_shape, all_points_mean=tr_dataset.all_points_mean, all_points_std=tr_dataset.all_points_std, + clip_forge_enable=cfg.clip_forge_enable, + clip_model=cfg.clip_model, ) return tr_dataset, te_dataset diff --git a/script/train_prior_clip.sh b/script/train_prior_clip.sh new file mode 100644 index 0000000..43490b6 --- /dev/null +++ b/script/train_prior_clip.sh @@ -0,0 +1,44 @@ +if [ -z "$1" ] + then + echo "Require NGPU input; " + exit +fi +loss="mse_sum" +NGPU=$1 ## 1 #8 +num_node=2 +mem=32 +BS=10 +lr=2e-4 +ENT="python train_dist.py --num_process_per_node $NGPU " +train_vae=False +cmt="lion" +ckpt="./lion_ckpt/unconditional/car/checkpoints/vae_only.pt" + +$ENT \ + --config "./lion_ckpt/unconditional/car/cfg.yml" \ + latent_pts.pvd_mse_loss 1 \ + vis_latent_point 1 \ + num_val_samples 24 \ + ddpm.ema 1 \ + ddpm.use_bn False ddpm.use_gn True \ + ddpm.time_dim 64 \ + ddpm.beta_T 0.02 \ + sde.vae_checkpoint $ckpt \ + sde.learning_rate_dae $lr sde.learning_rate_min_dae $lr \ + trainer.epochs 18000 \ + sde.num_channels_dae 2048 \ + sde.dropout 0.3 \ + sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \ + sde.train_vae $train_vae \ + sde.embedding_scale 1.0 \ + viz.save_freq 1000 \ + viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \ + data.batch_size $BS \ + trainer.type 'trainers.train_2prior' \ + cmt $cmt \ + clipforge.enable 1 \ + data.clip_forge_enable 1 \ + data.clip_model 'ViT-B/32' \ + clipforge.clip_model 'ViT-B/32' \ + latent_pts.style_prior 'models.score_sde.resnet.PriorSEClip' \ + diff --git a/trainers/train_2prior.py b/trainers/train_2prior.py index 7f90018..4998850 100644 --- a/trainers/train_2prior.py +++ b/trainers/train_2prior.py @@ -254,6 +254,8 @@ class Trainer(PriorTrainer): tr_img).view(B, nimg, -1).mean(1).float() else: clip_feat = None + if self.cfg.clipforge.enable: + assert(clip_feat is not None) # optimize vae params vae_optimizer.zero_grad()