feat: saving models with wandb
Former-commit-id: 994d061d49c8be9f680878d093b1e2b5710fa37c [formerly 3cadb0699c150421244c78511ee61b14262728d0] Former-commit-id: 64b01c6ed73ba64a25d85acd65adc9b6c9cf126e
This commit is contained in:
parent
c700835065
commit
e4b155991b
|
@ -20,18 +20,18 @@ def evaluate(net, dataloader, device):
|
||||||
# forward, predict the mask
|
# forward, predict the mask
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
masks_pred = net(images)
|
masks_pred = net(images)
|
||||||
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
# compute the Dice score
|
# compute the Dice score
|
||||||
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
# save some images to wandb
|
# save some images to wandb
|
||||||
table = wandb.Table(columns=["image", "mask", "prediction"])
|
table = wandb.Table(columns=["id", "image", "mask", "prediction"])
|
||||||
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
||||||
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||||
wandb.log({"predictions_table": table}, commit=False)
|
wandb.log({"predictions_table": table}, commit=False)
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
|
@ -13,32 +13,6 @@ from unet import UNet
|
||||||
from utils.utils import plot_img_and_mask
|
from utils.utils import plot_img_and_mask
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
|
|
||||||
net.eval()
|
|
||||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
|
||||||
img = img.unsqueeze(0)
|
|
||||||
img = img.to(device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = net(img)
|
|
||||||
|
|
||||||
if net.n_classes > 1:
|
|
||||||
probs = F.softmax(output, dim=1)[0]
|
|
||||||
else:
|
|
||||||
probs = torch.sigmoid(output)[0]
|
|
||||||
|
|
||||||
tf = transforms.Compose(
|
|
||||||
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
|
|
||||||
)
|
|
||||||
|
|
||||||
full_mask = tf(probs.cpu()).squeeze()
|
|
||||||
|
|
||||||
if net.n_classes == 1:
|
|
||||||
return (full_mask > out_threshold).numpy()
|
|
||||||
else:
|
|
||||||
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Predict masks from input images",
|
description="Predict masks from input images",
|
||||||
|
@ -95,6 +69,29 @@ def get_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
|
||||||
|
net.eval()
|
||||||
|
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
img = img.to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = net(img)
|
||||||
|
|
||||||
|
probs = torch.sigmoid(output)[0]
|
||||||
|
|
||||||
|
tf = transforms.Compose(
|
||||||
|
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
|
||||||
|
)
|
||||||
|
|
||||||
|
full_mask = tf(probs.cpu()).squeeze()
|
||||||
|
|
||||||
|
if net.n_classes == 1:
|
||||||
|
return (full_mask > out_threshold).numpy()
|
||||||
|
else:
|
||||||
|
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
||||||
|
|
||||||
|
|
||||||
def get_output_filenames(args):
|
def get_output_filenames(args):
|
||||||
def _generate_name(fn):
|
def _generate_name(fn):
|
||||||
split = os.path.splitext(fn)
|
split = os.path.splitext(fn)
|
||||||
|
@ -127,6 +124,7 @@ if __name__ == "__main__":
|
||||||
logging.info("Model loaded!")
|
logging.info("Model loaded!")
|
||||||
|
|
||||||
for i, filename in enumerate(in_files):
|
for i, filename in enumerate(in_files):
|
||||||
|
|
||||||
logging.info(f"\nPredicting image {filename} ...")
|
logging.info(f"\nPredicting image {filename} ...")
|
||||||
img = Image.open(filename)
|
img = Image.open(filename)
|
||||||
|
|
||||||
|
|
|
@ -111,10 +111,11 @@ def main():
|
||||||
# 1. Create transforms
|
# 1. Create transforms
|
||||||
tf_train = A.Compose(
|
tf_train = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(500, 500),
|
A.Resize(512, 512),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
A.ColorJitter(),
|
||||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||||
|
A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
|
@ -122,7 +123,7 @@ def main():
|
||||||
)
|
)
|
||||||
tf_valid = A.Compose(
|
tf_valid = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(500, 500),
|
A.Resize(512, 512),
|
||||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
|
@ -154,7 +155,7 @@ def main():
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
wandb.save(f"{CHECKPOINT_DIR}/*")
|
wandb.watch(net, log_freq=100)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"""Starting training:
|
f"""Starting training:
|
||||||
|
|
Loading…
Reference in a new issue