2021-11-08 10:09:50 +00:00
|
|
|
import os
|
|
|
|
|
|
|
|
abspath = os.path.abspath(__file__)
|
|
|
|
dname = os.path.dirname(abspath)
|
|
|
|
os.chdir(dname)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
import argparse
|
|
|
|
import shutil
|
|
|
|
import time
|
|
|
|
|
|
|
|
import numpy as np
|
2021-11-08 10:09:50 +00:00
|
|
|
import torch
|
|
|
|
import torch.optim as optim
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from src import config
|
2023-05-26 12:59:53 +00:00
|
|
|
from src.data import collate_remove_none, worker_init_fn
|
2021-11-08 10:09:50 +00:00
|
|
|
from src.model import Encode2Points
|
2023-05-26 12:59:53 +00:00
|
|
|
from src.training import Trainer
|
|
|
|
from src.utils import AverageMeter, initialize_logger, load_config, load_model_manual
|
|
|
|
|
|
|
|
np.set_printoptions(precision=4)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-05-26 12:59:53 +00:00
|
|
|
parser = argparse.ArgumentParser(description="MNIST toy experiment")
|
|
|
|
parser.add_argument("config", type=str, help="Path to config file.")
|
|
|
|
parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
|
|
|
|
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
args = parser.parse_args()
|
2023-05-26 12:59:53 +00:00
|
|
|
cfg = load_config(args.config, "configs/default.yaml")
|
2021-11-08 10:09:50 +00:00
|
|
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
|
|
|
device = torch.device("cuda" if use_cuda else "cpu")
|
2023-05-26 12:59:53 +00:00
|
|
|
cfg["data"]["input_type"]
|
|
|
|
batch_size = cfg["train"]["batch_size"]
|
|
|
|
model_selection_metric = cfg["train"]["model_selection_metric"]
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# PYTORCH VERSION > 1.0.0
|
2023-05-26 12:59:53 +00:00
|
|
|
assert float(torch.__version__.split(".")[-3]) > 0
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# boiler-plate
|
2023-05-26 12:59:53 +00:00
|
|
|
if cfg["train"]["timestamp"]:
|
|
|
|
cfg["train"]["out_dir"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
|
2021-11-08 10:09:50 +00:00
|
|
|
logger = initialize_logger(cfg)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
np.random.seed(args.seed)
|
2023-05-26 12:59:53 +00:00
|
|
|
shutil.copyfile(args.config, os.path.join(cfg["train"]["out_dir"], "config.yaml"))
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
logger.info("using GPU: " + torch.cuda.get_device_name(0))
|
|
|
|
|
|
|
|
# TensorboardX writer
|
2023-05-26 12:59:53 +00:00
|
|
|
tblogdir = os.path.join(cfg["train"]["out_dir"], "tensorboard_log")
|
2021-11-08 10:09:50 +00:00
|
|
|
if not os.path.exists(tblogdir):
|
|
|
|
os.makedirs(tblogdir, exist_ok=True)
|
|
|
|
writer = SummaryWriter(log_dir=tblogdir)
|
|
|
|
|
|
|
|
inputs = None
|
2023-05-26 12:59:53 +00:00
|
|
|
train_dataset = config.get_dataset("train", cfg)
|
|
|
|
val_dataset = config.get_dataset("val", cfg)
|
|
|
|
vis_dataset = config.get_dataset("vis", cfg)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
collate_fn = collate_remove_none
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
2023-05-26 12:59:53 +00:00
|
|
|
train_dataset,
|
|
|
|
batch_size=batch_size,
|
|
|
|
num_workers=cfg["train"]["n_workers"],
|
|
|
|
shuffle=True,
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
worker_init_fn=worker_init_fn,
|
|
|
|
)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
val_loader = torch.utils.data.DataLoader(
|
2023-05-26 12:59:53 +00:00
|
|
|
val_dataset,
|
|
|
|
batch_size=1,
|
|
|
|
num_workers=cfg["train"]["n_workers_val"],
|
|
|
|
shuffle=False,
|
|
|
|
collate_fn=collate_remove_none,
|
|
|
|
worker_init_fn=worker_init_fn,
|
|
|
|
)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
vis_loader = torch.utils.data.DataLoader(
|
2023-05-26 12:59:53 +00:00
|
|
|
vis_dataset,
|
|
|
|
batch_size=1,
|
|
|
|
num_workers=cfg["train"]["n_workers_val"],
|
|
|
|
shuffle=False,
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
worker_init_fn=worker_init_fn,
|
|
|
|
)
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
if torch.cuda.device_count() > 1:
|
|
|
|
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
|
|
|
|
else:
|
|
|
|
model = Encode2Points(cfg).to(device)
|
|
|
|
|
|
|
|
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
2023-05-26 12:59:53 +00:00
|
|
|
logger.info("Number of parameters: %d" % n_parameter)
|
2021-11-08 10:09:50 +00:00
|
|
|
# load model
|
|
|
|
try:
|
|
|
|
# load model
|
2023-05-26 12:59:53 +00:00
|
|
|
state_dict = torch.load(os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
|
|
|
load_model_manual(state_dict["state_dict"], model)
|
|
|
|
|
|
|
|
out = "Load model from iteration %d" % state_dict.get("it", 0)
|
2021-11-08 10:09:50 +00:00
|
|
|
logger.info(out)
|
|
|
|
# load point cloud
|
|
|
|
except:
|
|
|
|
state_dict = dict()
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
metric_val_best = state_dict.get("loss_val_best", np.inf)
|
|
|
|
|
|
|
|
logger.info(f"Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}")
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
LR = float(cfg["train"]["lr"])
|
2021-11-08 10:09:50 +00:00
|
|
|
optimizer = optim.Adam(model.parameters(), lr=LR)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
start_epoch = state_dict.get("epoch", -1)
|
|
|
|
it = state_dict.get("it", -1)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
trainer = Trainer(cfg, optimizer, device=device)
|
|
|
|
runtime = {}
|
2023-05-26 12:59:53 +00:00
|
|
|
runtime["all"] = AverageMeter()
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
# training loop
|
|
|
|
for epoch in range(start_epoch + 1, cfg["train"]["total_epochs"] + 1):
|
2021-11-08 10:09:50 +00:00
|
|
|
for batch in train_loader:
|
|
|
|
it += 1
|
2023-05-26 12:59:53 +00:00
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
start = time.time()
|
|
|
|
loss, loss_each = trainer.train_step(inputs, batch, model)
|
|
|
|
|
|
|
|
# measure elapsed time
|
|
|
|
end = time.time()
|
2023-05-26 12:59:53 +00:00
|
|
|
runtime["all"].update(end - start)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if it % cfg["train"]["print_every"] == 0:
|
|
|
|
log_text = ("[Epoch %02d] it=%d, loss=%.4f") % (epoch, it, loss)
|
|
|
|
writer.add_scalar("train/loss", loss, it)
|
2021-11-08 10:09:50 +00:00
|
|
|
if loss_each is not None:
|
|
|
|
for k, l in loss_each.items():
|
2023-05-26 12:59:53 +00:00
|
|
|
if l.item() != 0.0:
|
|
|
|
log_text += f" loss_{k}={l.item():.4f}"
|
|
|
|
writer.add_scalar("train/%s" % k, l, it)
|
|
|
|
|
|
|
|
log_text += (" time={:.3f} / {:.2f}").format(runtime["all"].val, runtime["all"].sum)
|
2021-11-08 10:09:50 +00:00
|
|
|
logger.info(log_text)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if (it > 0) & (it % cfg["train"]["visualize_every"] == 0):
|
2021-11-08 10:09:50 +00:00
|
|
|
for i, batch_vis in enumerate(vis_loader):
|
|
|
|
trainer.save(model, batch_vis, it, i)
|
|
|
|
if i >= 4:
|
|
|
|
break
|
2023-05-26 12:59:53 +00:00
|
|
|
logger.info("Saved mesh and pointcloud")
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# run validation
|
2023-05-26 12:59:53 +00:00
|
|
|
if it > 0 and (it % cfg["train"]["validate_every"]) == 0:
|
2021-11-08 10:09:50 +00:00
|
|
|
eval_dict = trainer.evaluate(val_loader, model)
|
|
|
|
metric_val = eval_dict[model_selection_metric]
|
2023-05-26 12:59:53 +00:00
|
|
|
logger.info(f"Validation metric ({model_selection_metric}): {metric_val:.4f}")
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
for k, v in eval_dict.items():
|
2023-05-26 12:59:53 +00:00
|
|
|
writer.add_scalar("val/%s" % k, v, it)
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if -(metric_val - metric_val_best) >= 0:
|
2021-11-08 10:09:50 +00:00
|
|
|
metric_val_best = metric_val
|
2023-05-26 12:59:53 +00:00
|
|
|
logger.info("New best model (loss %.4f)" % metric_val_best)
|
|
|
|
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
|
|
|
state["state_dict"] = model.state_dict()
|
|
|
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model_best.pt"))
|
2021-11-08 10:09:50 +00:00
|
|
|
|
|
|
|
# save checkpoint
|
2023-05-26 12:59:53 +00:00
|
|
|
if (epoch > 0) & (it % cfg["train"]["checkpoint_every"] == 0):
|
|
|
|
state = {"epoch": epoch, "it": it, "loss_val_best": metric_val_best}
|
|
|
|
state["state_dict"] = model.state_dict()
|
|
|
|
|
|
|
|
torch.save(state, os.path.join(cfg["train"]["out_dir"], "model.pt"))
|
|
|
|
|
|
|
|
if it % cfg["train"]["backup_every"] == 0:
|
|
|
|
torch.save(state, os.path.join(cfg["train"]["dir_model"], "%04d" % it + ".pt"))
|
2021-11-08 10:09:50 +00:00
|
|
|
logger.info("Backup model at iteration %d" % it)
|
|
|
|
logger.info("Save new model at iteration %d" % it)
|
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
time.time()
|
|
|
|
|
2021-11-08 10:09:50 +00:00
|
|
|
|
2023-05-26 12:59:53 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|