feat: saving models with wandb

Former-commit-id: 994d061d49c8be9f680878d093b1e2b5710fa37c [formerly 3cadb0699c150421244c78511ee61b14262728d0]
Former-commit-id: 64b01c6ed73ba64a25d85acd65adc9b6c9cf126e
This commit is contained in:
Laurent Fainsin 2022-06-29 14:15:04 +02:00
parent c700835065
commit e4b155991b
3 changed files with 33 additions and 34 deletions

View file

@ -20,18 +20,18 @@ def evaluate(net, dataloader, device):
# forward, predict the mask # forward, predict the mask
with torch.inference_mode(): with torch.inference_mode():
masks_pred = net(images) masks_pred = net(images)
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float() masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute the Dice score # compute the Dice score
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False) dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
# update progress bar # update progress bar
pbar.update(images.shape[0]) pbar.update(images.shape[0])
# save some images to wandb # save some images to wandb
table = wandb.Table(columns=["image", "mask", "prediction"]) table = wandb.Table(columns=["id", "image", "mask", "prediction"])
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")): for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
wandb.log({"predictions_table": table}, commit=False) wandb.log({"predictions_table": table}, commit=False)
net.train() net.train()

View file

@ -13,32 +13,6 @@ from unet import UNet
from utils.utils import plot_img_and_mask from utils.utils import plot_img_and_mask
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)[0]
else:
probs = torch.sigmoid(output)[0]
tf = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
)
full_mask = tf(probs.cpu()).squeeze()
if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Predict masks from input images", description="Predict masks from input images",
@ -95,6 +69,29 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.inference_mode():
output = net(img)
probs = torch.sigmoid(output)[0]
tf = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
)
full_mask = tf(probs.cpu()).squeeze()
if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_output_filenames(args): def get_output_filenames(args):
def _generate_name(fn): def _generate_name(fn):
split = os.path.splitext(fn) split = os.path.splitext(fn)
@ -127,6 +124,7 @@ if __name__ == "__main__":
logging.info("Model loaded!") logging.info("Model loaded!")
for i, filename in enumerate(in_files): for i, filename in enumerate(in_files):
logging.info(f"\nPredicting image {filename} ...") logging.info(f"\nPredicting image {filename} ...")
img = Image.open(filename) img = Image.open(filename)

View file

@ -111,10 +111,11 @@ def main():
# 1. Create transforms # 1. Create transforms
tf_train = A.Compose( tf_train = A.Compose(
[ [
A.Resize(500, 500), A.Resize(512, 512),
A.Flip(), A.Flip(),
A.ColorJitter(), A.ColorJitter(),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
@ -122,7 +123,7 @@ def main():
) )
tf_valid = A.Compose( tf_valid = A.Compose(
[ [
A.Resize(500, 500), A.Resize(512, 512),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
@ -154,7 +155,7 @@ def main():
amp=args.amp, amp=args.amp,
), ),
) )
wandb.save(f"{CHECKPOINT_DIR}/*") wandb.watch(net, log_freq=100)
logging.info( logging.info(
f"""Starting training: f"""Starting training: