add scripts to train with clip feat

This commit is contained in:
xzeng 2023-04-03 17:03:27 -04:00
parent 0ee193fa1e
commit b2cf1e19ac
4 changed files with 93 additions and 2 deletions

View file

@ -14,6 +14,9 @@ def get_path(dataname=None):
'./data/ShapeNetCore.v2.PC15k/'
]
dataset_path['clip_forge_image'] = [
'./data/shapenet_render/'
]
if dataname is None:
return dataset_path

View file

@ -18,6 +18,7 @@ from torch.utils import data
import random
import tqdm
from datasets.data_path import get_path
from PIL import Image
OVERFIT = 0
# taken from https://github.com/optas/latent_3d_points/blob/
@ -101,7 +102,16 @@ class ShapeNet15kPointClouds(Dataset):
all_points_mean=None,
all_points_std=None,
input_dim=3,
clip_forge_enable=0, clip_model=None
):
self.clip_forge_enable = clip_forge_enable
if clip_forge_enable:
import clip
_, self.clip_preprocess = clip.load(clip_model)
if self.clip_forge_enable:
self.img_path = []
img_path = get_path('clip_forge_image')
self.normalize_shape_box = normalize_shape_box
root_dir = get_path('pointflow')
self.root_dir = root_dir
@ -146,6 +156,7 @@ class ShapeNet15kPointClouds(Dataset):
print("Directory missing : %s " % (sub_path))
raise ValueError('check the data path')
continue
if True:
all_mids = []
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
@ -161,6 +172,15 @@ class ShapeNet15kPointClouds(Dataset):
all_mids = sorted(all_mids)
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
if self.clip_forge_enable:
synset_id = subd
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
#if not (os.path.exists(render_img_path)): continue
self.img_path.append(render_img_path)
assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}'
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
point_cloud = np.load(obj_fname) # (15k, 3)
self.all_points.append(point_cloud[np.newaxis, ...])
@ -177,6 +197,8 @@ class ShapeNet15kPointClouds(Dataset):
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
if self.clip_forge_enable:
self.img_path = [self.img_path[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
@ -248,7 +270,7 @@ class ShapeNet15kPointClouds(Dataset):
# TODO: why do we need this??
self.train_points = self.all_points[:, :min(
10000, self.all_points.shape[1])]
10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
assert self.scale == 1, "Scale (!= 1) is deprecated"
@ -314,6 +336,22 @@ class ShapeNet15kPointClouds(Dataset):
'mid': mid,
'display_axis_order': self.display_axis_order
})
# read image
if self.clip_forge_enable:
img_path = self.img_path[idx]
img_list = os.listdir(img_path)
img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p]
assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}'
# subset 5 image
img_idx = np.random.choice(len(img_list), 5)
img_list = [img_list[o] for o in img_idx]
img_list = [Image.open(img).convert('RGB') for img in img_list]
img_list = [self.clip_preprocess(img) for img in img_list]
img_list = torch.stack(img_list, dim=0) # B,3,H,W
all_img = img_list
output['tr_img'] = all_img
return output
@ -352,6 +390,8 @@ def get_datasets(cfg, args):
normalize_global=cfg.normalize_global,
recenter_per_shape=cfg.recenter_per_shape,
random_subsample=random_subsample,
clip_forge_enable=cfg.clip_forge_enable,
clip_model=cfg.clip_model,
**kwargs)
eval_split = getattr(args, "eval_split", "val")
@ -369,6 +409,8 @@ def get_datasets(cfg, args):
recenter_per_shape=cfg.recenter_per_shape,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
clip_forge_enable=cfg.clip_forge_enable,
clip_model=cfg.clip_model,
)
return tr_dataset, te_dataset

View file

@ -0,0 +1,44 @@
if [ -z "$1" ]
then
echo "Require NGPU input; "
exit
fi
loss="mse_sum"
NGPU=$1 ## 1 #8
num_node=2
mem=32
BS=10
lr=2e-4
ENT="python train_dist.py --num_process_per_node $NGPU "
train_vae=False
cmt="lion"
ckpt="./lion_ckpt/unconditional/car/checkpoints/vae_only.pt"
$ENT \
--config "./lion_ckpt/unconditional/car/cfg.yml" \
latent_pts.pvd_mse_loss 1 \
vis_latent_point 1 \
num_val_samples 24 \
ddpm.ema 1 \
ddpm.use_bn False ddpm.use_gn True \
ddpm.time_dim 64 \
ddpm.beta_T 0.02 \
sde.vae_checkpoint $ckpt \
sde.learning_rate_dae $lr sde.learning_rate_min_dae $lr \
trainer.epochs 18000 \
sde.num_channels_dae 2048 \
sde.dropout 0.3 \
sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \
sde.train_vae $train_vae \
sde.embedding_scale 1.0 \
viz.save_freq 1000 \
viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \
data.batch_size $BS \
trainer.type 'trainers.train_2prior' \
cmt $cmt \
clipforge.enable 1 \
data.clip_forge_enable 1 \
data.clip_model 'ViT-B/32' \
clipforge.clip_model 'ViT-B/32' \
latent_pts.style_prior 'models.score_sde.resnet.PriorSEClip' \

View file

@ -254,6 +254,8 @@ class Trainer(PriorTrainer):
tr_img).view(B, nimg, -1).mean(1).float()
else:
clip_feat = None
if self.cfg.clipforge.enable:
assert(clip_feat is not None)
# optimize vae params
vae_optimizer.zero_grad()