mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: saving models with wandb
Former-commit-id: 994d061d49c8be9f680878d093b1e2b5710fa37c [formerly 3cadb0699c150421244c78511ee61b14262728d0] Former-commit-id: 64b01c6ed73ba64a25d85acd65adc9b6c9cf126e
This commit is contained in:
parent
c700835065
commit
e4b155991b
|
@ -20,18 +20,18 @@ def evaluate(net, dataloader, device):
|
|||
# forward, predict the mask
|
||||
with torch.inference_mode():
|
||||
masks_pred = net(images)
|
||||
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# compute the Dice score
|
||||
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
||||
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
# save some images to wandb
|
||||
table = wandb.Table(columns=["image", "mask", "prediction"])
|
||||
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
||||
table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||
table = wandb.Table(columns=["id", "image", "mask", "prediction"])
|
||||
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||
wandb.log({"predictions_table": table}, commit=False)
|
||||
|
||||
net.train()
|
||||
|
|
|
@ -13,32 +13,6 @@ from unet import UNet
|
|||
from utils.utils import plot_img_and_mask
|
||||
|
||||
|
||||
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
|
||||
net.eval()
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
output = net(img)
|
||||
|
||||
if net.n_classes > 1:
|
||||
probs = F.softmax(output, dim=1)[0]
|
||||
else:
|
||||
probs = torch.sigmoid(output)[0]
|
||||
|
||||
tf = transforms.Compose(
|
||||
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
|
||||
)
|
||||
|
||||
full_mask = tf(probs.cpu()).squeeze()
|
||||
|
||||
if net.n_classes == 1:
|
||||
return (full_mask > out_threshold).numpy()
|
||||
else:
|
||||
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Predict masks from input images",
|
||||
|
@ -95,6 +69,29 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
|
||||
net.eval()
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = net(img)
|
||||
|
||||
probs = torch.sigmoid(output)[0]
|
||||
|
||||
tf = transforms.Compose(
|
||||
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
|
||||
)
|
||||
|
||||
full_mask = tf(probs.cpu()).squeeze()
|
||||
|
||||
if net.n_classes == 1:
|
||||
return (full_mask > out_threshold).numpy()
|
||||
else:
|
||||
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
||||
|
||||
|
||||
def get_output_filenames(args):
|
||||
def _generate_name(fn):
|
||||
split = os.path.splitext(fn)
|
||||
|
@ -127,6 +124,7 @@ if __name__ == "__main__":
|
|||
logging.info("Model loaded!")
|
||||
|
||||
for i, filename in enumerate(in_files):
|
||||
|
||||
logging.info(f"\nPredicting image {filename} ...")
|
||||
img = Image.open(filename)
|
||||
|
||||
|
|
|
@ -111,10 +111,11 @@ def main():
|
|||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(500, 500),
|
||||
A.Resize(512, 512),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
|
@ -122,7 +123,7 @@ def main():
|
|||
)
|
||||
tf_valid = A.Compose(
|
||||
[
|
||||
A.Resize(500, 500),
|
||||
A.Resize(512, 512),
|
||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
|
@ -154,7 +155,7 @@ def main():
|
|||
amp=args.amp,
|
||||
),
|
||||
)
|
||||
wandb.save(f"{CHECKPOINT_DIR}/*")
|
||||
wandb.watch(net, log_freq=100)
|
||||
|
||||
logging.info(
|
||||
f"""Starting training:
|
||||
|
|
Loading…
Reference in a new issue