add scripts to train with clip feat
This commit is contained in:
parent
0ee193fa1e
commit
b2cf1e19ac
|
@ -14,6 +14,9 @@ def get_path(dataname=None):
|
||||||
'./data/ShapeNetCore.v2.PC15k/'
|
'./data/ShapeNetCore.v2.PC15k/'
|
||||||
|
|
||||||
]
|
]
|
||||||
|
dataset_path['clip_forge_image'] = [
|
||||||
|
'./data/shapenet_render/'
|
||||||
|
]
|
||||||
|
|
||||||
if dataname is None:
|
if dataname is None:
|
||||||
return dataset_path
|
return dataset_path
|
||||||
|
|
|
@ -18,6 +18,7 @@ from torch.utils import data
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets.data_path import get_path
|
from datasets.data_path import get_path
|
||||||
|
from PIL import Image
|
||||||
OVERFIT = 0
|
OVERFIT = 0
|
||||||
|
|
||||||
# taken from https://github.com/optas/latent_3d_points/blob/
|
# taken from https://github.com/optas/latent_3d_points/blob/
|
||||||
|
@ -101,7 +102,16 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
all_points_mean=None,
|
all_points_mean=None,
|
||||||
all_points_std=None,
|
all_points_std=None,
|
||||||
input_dim=3,
|
input_dim=3,
|
||||||
|
clip_forge_enable=0, clip_model=None
|
||||||
):
|
):
|
||||||
|
self.clip_forge_enable = clip_forge_enable
|
||||||
|
if clip_forge_enable:
|
||||||
|
import clip
|
||||||
|
_, self.clip_preprocess = clip.load(clip_model)
|
||||||
|
if self.clip_forge_enable:
|
||||||
|
self.img_path = []
|
||||||
|
img_path = get_path('clip_forge_image')
|
||||||
|
|
||||||
self.normalize_shape_box = normalize_shape_box
|
self.normalize_shape_box = normalize_shape_box
|
||||||
root_dir = get_path('pointflow')
|
root_dir = get_path('pointflow')
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
|
@ -146,6 +156,7 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
print("Directory missing : %s " % (sub_path))
|
print("Directory missing : %s " % (sub_path))
|
||||||
raise ValueError('check the data path')
|
raise ValueError('check the data path')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
all_mids = []
|
all_mids = []
|
||||||
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
|
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
|
||||||
|
@ -161,6 +172,15 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
all_mids = sorted(all_mids)
|
all_mids = sorted(all_mids)
|
||||||
for mid in all_mids:
|
for mid in all_mids:
|
||||||
# obj_fname = os.path.join(sub_path, x)
|
# obj_fname = os.path.join(sub_path, x)
|
||||||
|
if self.clip_forge_enable:
|
||||||
|
synset_id = subd
|
||||||
|
render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016'
|
||||||
|
|
||||||
|
#render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
|
||||||
|
#if not (os.path.exists(render_img_path)): continue
|
||||||
|
self.img_path.append(render_img_path)
|
||||||
|
assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}'
|
||||||
|
|
||||||
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
||||||
point_cloud = np.load(obj_fname) # (15k, 3)
|
point_cloud = np.load(obj_fname) # (15k, 3)
|
||||||
self.all_points.append(point_cloud[np.newaxis, ...])
|
self.all_points.append(point_cloud[np.newaxis, ...])
|
||||||
|
@ -177,6 +197,8 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
||||||
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
||||||
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
||||||
|
if self.clip_forge_enable:
|
||||||
|
self.img_path = [self.img_path[i] for i in self.shuffle_idx]
|
||||||
|
|
||||||
# Normalization
|
# Normalization
|
||||||
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
||||||
|
@ -248,7 +270,7 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
|
|
||||||
# TODO: why do we need this??
|
# TODO: why do we need this??
|
||||||
self.train_points = self.all_points[:, :min(
|
self.train_points = self.all_points[:, :min(
|
||||||
10000, self.all_points.shape[1])]
|
10000, self.all_points.shape[1])] # subsample 15k points to 10k points per shape
|
||||||
self.tr_sample_size = min(10000, tr_sample_size)
|
self.tr_sample_size = min(10000, tr_sample_size)
|
||||||
self.te_sample_size = min(5000, te_sample_size)
|
self.te_sample_size = min(5000, te_sample_size)
|
||||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||||
|
@ -314,6 +336,22 @@ class ShapeNet15kPointClouds(Dataset):
|
||||||
'mid': mid,
|
'mid': mid,
|
||||||
'display_axis_order': self.display_axis_order
|
'display_axis_order': self.display_axis_order
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# read image
|
||||||
|
if self.clip_forge_enable:
|
||||||
|
img_path = self.img_path[idx]
|
||||||
|
img_list = os.listdir(img_path)
|
||||||
|
img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p]
|
||||||
|
assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}'
|
||||||
|
# subset 5 image
|
||||||
|
img_idx = np.random.choice(len(img_list), 5)
|
||||||
|
img_list = [img_list[o] for o in img_idx]
|
||||||
|
img_list = [Image.open(img).convert('RGB') for img in img_list]
|
||||||
|
img_list = [self.clip_preprocess(img) for img in img_list]
|
||||||
|
img_list = torch.stack(img_list, dim=0) # B,3,H,W
|
||||||
|
all_img = img_list
|
||||||
|
output['tr_img'] = all_img
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -352,6 +390,8 @@ def get_datasets(cfg, args):
|
||||||
normalize_global=cfg.normalize_global,
|
normalize_global=cfg.normalize_global,
|
||||||
recenter_per_shape=cfg.recenter_per_shape,
|
recenter_per_shape=cfg.recenter_per_shape,
|
||||||
random_subsample=random_subsample,
|
random_subsample=random_subsample,
|
||||||
|
clip_forge_enable=cfg.clip_forge_enable,
|
||||||
|
clip_model=cfg.clip_model,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
eval_split = getattr(args, "eval_split", "val")
|
eval_split = getattr(args, "eval_split", "val")
|
||||||
|
@ -369,6 +409,8 @@ def get_datasets(cfg, args):
|
||||||
recenter_per_shape=cfg.recenter_per_shape,
|
recenter_per_shape=cfg.recenter_per_shape,
|
||||||
all_points_mean=tr_dataset.all_points_mean,
|
all_points_mean=tr_dataset.all_points_mean,
|
||||||
all_points_std=tr_dataset.all_points_std,
|
all_points_std=tr_dataset.all_points_std,
|
||||||
|
clip_forge_enable=cfg.clip_forge_enable,
|
||||||
|
clip_model=cfg.clip_model,
|
||||||
)
|
)
|
||||||
return tr_dataset, te_dataset
|
return tr_dataset, te_dataset
|
||||||
|
|
||||||
|
|
44
script/train_prior_clip.sh
Normal file
44
script/train_prior_clip.sh
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
if [ -z "$1" ]
|
||||||
|
then
|
||||||
|
echo "Require NGPU input; "
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
loss="mse_sum"
|
||||||
|
NGPU=$1 ## 1 #8
|
||||||
|
num_node=2
|
||||||
|
mem=32
|
||||||
|
BS=10
|
||||||
|
lr=2e-4
|
||||||
|
ENT="python train_dist.py --num_process_per_node $NGPU "
|
||||||
|
train_vae=False
|
||||||
|
cmt="lion"
|
||||||
|
ckpt="./lion_ckpt/unconditional/car/checkpoints/vae_only.pt"
|
||||||
|
|
||||||
|
$ENT \
|
||||||
|
--config "./lion_ckpt/unconditional/car/cfg.yml" \
|
||||||
|
latent_pts.pvd_mse_loss 1 \
|
||||||
|
vis_latent_point 1 \
|
||||||
|
num_val_samples 24 \
|
||||||
|
ddpm.ema 1 \
|
||||||
|
ddpm.use_bn False ddpm.use_gn True \
|
||||||
|
ddpm.time_dim 64 \
|
||||||
|
ddpm.beta_T 0.02 \
|
||||||
|
sde.vae_checkpoint $ckpt \
|
||||||
|
sde.learning_rate_dae $lr sde.learning_rate_min_dae $lr \
|
||||||
|
trainer.epochs 18000 \
|
||||||
|
sde.num_channels_dae 2048 \
|
||||||
|
sde.dropout 0.3 \
|
||||||
|
sde.prior_model 'models.latent_points_ada_localprior.PVCNN2Prior' \
|
||||||
|
sde.train_vae $train_vae \
|
||||||
|
sde.embedding_scale 1.0 \
|
||||||
|
viz.save_freq 1000 \
|
||||||
|
viz.viz_freq -200 viz.log_freq -1 viz.val_freq -10000 \
|
||||||
|
data.batch_size $BS \
|
||||||
|
trainer.type 'trainers.train_2prior' \
|
||||||
|
cmt $cmt \
|
||||||
|
clipforge.enable 1 \
|
||||||
|
data.clip_forge_enable 1 \
|
||||||
|
data.clip_model 'ViT-B/32' \
|
||||||
|
clipforge.clip_model 'ViT-B/32' \
|
||||||
|
latent_pts.style_prior 'models.score_sde.resnet.PriorSEClip' \
|
||||||
|
|
|
@ -254,6 +254,8 @@ class Trainer(PriorTrainer):
|
||||||
tr_img).view(B, nimg, -1).mean(1).float()
|
tr_img).view(B, nimg, -1).mean(1).float()
|
||||||
else:
|
else:
|
||||||
clip_feat = None
|
clip_feat = None
|
||||||
|
if self.cfg.clipforge.enable:
|
||||||
|
assert(clip_feat is not None)
|
||||||
|
|
||||||
# optimize vae params
|
# optimize vae params
|
||||||
vae_optimizer.zero_grad()
|
vae_optimizer.zero_grad()
|
||||||
|
|
Loading…
Reference in a new issue