add scripts to train with clip feat
This commit is contained in:
parent
0ee193fa1e
commit
b2cf1e19ac
|
@ -14,6 +14,9 @@ def get_path(dataname=None):
|
|||
'./data/ShapeNetCore.v2.PC15k/'
|
||||
|
||||
]
|
||||
dataset_path['clip_forge_image'] = [
|
||||
'./data/shapenet_render/'
|
||||
]
|
||||
|
||||
if dataname is None:
|
||||
return dataset_path
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.utils import data
|
|||
import random
|
||||
import tqdm
|
||||
from datasets.data_path import get_path
|
||||
from PIL import Image
|
||||
OVERFIT = 0
|
||||
|
||||
# taken from https://github.com/optas/latent_3d_points/blob/
|
||||
|
@ -101,7 +102,16 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
all_points_mean=None,
|
||||
all_points_std=None,
|
||||
input_dim=3,
|
||||
clip_forge_enable=0, clip_model=None
|
||||
):
|
||||
self.clip_forge_enable = clip_forge_enable
|
||||
if clip_forge_enable:
|
||||
import clip
|
||||
_, self.clip_preprocess = clip.load(clip_model)
|
||||
if self.clip_forge_enable:
|
||||
self.img_path = []
|
||||
img_path = get_path('clip_forge_image')
|
||||
|
||||
self.normalize_shape_box = normalize_shape_box
|
||||
root_dir = get_path('pointflow')
|
||||
self.root_dir = root_dir
|
||||
|
@ -146,6 +156,7 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
print("Directory missing : %s " % (sub_path))
|
||||
raise ValueError('check the data path')
|
||||
continue
|
||||
|
||||
if True:
|
||||
all_mids = []
|
||||
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
|
||||
|
@ -161,6 +172,15 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
all_mids = sorted(all_mids)
|
||||
for mid in all_mids:
|
||||
# obj_fname = os.path.join(sub_path, x)
|
||||
if self.clip_forge_enable:
|
||||
synset_id = subd
|
||||
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
|
||||
|
||||
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
||||
#if not (os.path.exists(render_img_path)): continue
|
||||
self.img_path.append(render_img_path)
|
||||
assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}'
|
||||
|
||||
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
||||
point_cloud = np.load(obj_fname) # (15k, 3)
|
||||
self.all_points.append(point_cloud[np.newaxis, ...])
|
||||
|
@ -177,6 +197,8 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
||||
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
||||
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
||||
if self.clip_forge_enable:
|
||||
self.img_path = [self.img_path[i] for i in self.shuffle_idx]
|
||||
|
||||
# Normalization
|
||||
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
||||
|
@ -248,7 +270,7 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
|
||||
# TODO: why do we need this??
|
||||
self.train_points = self.all_points[:, :min(
|
||||
10000, self.all_points.shape[1])]
|
||||
10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape
|
||||
self.tr_sample_size = min(10000, tr_sample_size)
|
||||
self.te_sample_size = min(5000, te_sample_size)
|
||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||
|
@ -300,7 +322,7 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
cate_idx = self.cate_idx_lst[idx]
|
||||
sid, mid = self.all_cate_mids[idx]
|
||||
input_pts = tr_out
|
||||
|
||||
|
||||
output.update(
|
||||
{
|
||||
'idx': idx,
|
||||
|
@ -314,6 +336,22 @@ class ShapeNet15kPointClouds(Dataset):
|
|||
'mid': mid,
|
||||
'display_axis_order': self.display_axis_order
|
||||
})
|
||||
|
||||
# read image
|
||||
if self.clip_forge_enable:
|
||||
img_path = self.img_path[idx]
|
||||
img_list = os.listdir(img_path)
|
||||
img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p]
|
||||
assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}'
|
||||
# subset 5 image
|
||||
img_idx = np.random.choice(len(img_list), 5)
|
||||
img_list = [img_list[o] for o in img_idx]
|
||||
img_list = [Image.open(img).convert('RGB') for img in img_list]
|
||||
img_list = [self.clip_preprocess(img) for img in img_list]
|
||||
img_list = torch.stack(img_list, dim=0) # B,3,H,W
|
||||
all_img = img_list
|
||||
output['tr_img'] = all_img
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
@ -352,6 +390,8 @@ def get_datasets(cfg, args):
|
|||
normalize_global=cfg.normalize_global,
|
||||
recenter_per_shape=cfg.recenter_per_shape,
|
||||
random_subsample=random_subsample,
|
||||
clip_forge_enable=cfg.clip_forge_enable,
|
||||
clip_model=cfg.clip_model,
|
||||
**kwargs)
|
||||
|
||||
eval_split = getattr(args, "eval_split", "val")
|
||||
|
@ -369,6 +409,8 @@ def get_datasets(cfg, args):
|
|||
recenter_per_shape=cfg.recenter_per_shape,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
clip_forge_enable=cfg.clip_forge_enable,
|
||||
clip_model=cfg.clip_model,
|
||||
)
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
|
44
script/train_prior_clip.sh
Normal file
44
script/train_prior_clip.sh
Normal file
|
@ -0,0 +1,44 @@
|
|||
if [ -z "$1" ]
|
||||
then
|
||||
echo "Require NGPU input; "
|
||||
exit
|
||||
fi
|
||||
loss="mse_sum"
|
||||
NGPU=$1 ## 1 #8
|
||||
num_node=2
|
||||
mem=32
|
||||
BS=10
|
||||
lr=2e-4
|
||||
ENT="python train_dist.py --num_process_per_node $NGPU "
|
||||
train_vae=False
|
||||
cmt="lion"
|
||||
ckpt="./lion_ckpt/unconditional/car/checkpoints/vae_only.pt"
|
||||
|
||||
$ENT \
|
||||
--config "./lion_ckpt/unconditional/car/cfg.yml" \
|
||||
latent_pts.pvd_mse_loss 1 \
|
||||
vis_latent_point 1 \
|
||||
num_val_samples 24 \
|
||||
ddpm.ema 1 \
|
||||
ddpm.use_bn False ddpm.use_gn True \
|
||||
ddpm.time_dim 64 \
|
||||
ddpm.beta_T 0.02 \
|
||||
sde.vae_checkpoint $ckpt \
|
||||
sde.learning_rate_dae $lr sde.learning_rate_min_dae $lr \
|
||||
trainer.epochs 18000 \
|
||||
sde.num_channels_dae 2048 \
|
||||
sde.dropout 0.3 \
|
||||
sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \
|
||||
sde.train_vae $train_vae \
|
||||
sde.embedding_scale 1.0 \
|
||||
viz.save_freq 1000 \
|
||||
viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \
|
||||
data.batch_size $BS \
|
||||
trainer.type 'trainers.train_2prior' \
|
||||
cmt $cmt \
|
||||
clipforge.enable 1 \
|
||||
data.clip_forge_enable 1 \
|
||||
data.clip_model 'ViT-B/32' \
|
||||
clipforge.clip_model 'ViT-B/32' \
|
||||
latent_pts.style_prior 'models.score_sde.resnet.PriorSEClip' \
|
||||
|
|
@ -254,6 +254,8 @@ class Trainer(PriorTrainer):
|
|||
tr_img).view(B, nimg, -1).mean(1).float()
|
||||
else:
|
||||
clip_feat = None
|
||||
if self.cfg.clipforge.enable:
|
||||
assert(clip_feat is not None)
|
||||
|
||||
# optimize vae params
|
||||
vae_optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in a new issue