diff --git a/src/train.py b/src/train.py index 9ddac8e..fe5e35b 100644 --- a/src/train.py +++ b/src/train.py @@ -30,14 +30,14 @@ if __name__ == "__main__": DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", - FEATURES=[64, 128, 256, 512], + FEATURES=[16, 32, 64, 128], N_CHANNELS=3, N_CLASSES=1, AMP=True, PIN_MEMORY=True, BENCHMARK=True, DEVICE="cuda", - WORKERS=8, + WORKERS=7, EPOCHS=5, BATCH_SIZE=16, LEARNING_RATE=1e-4, @@ -88,13 +88,10 @@ if __name__ == "__main__": # 2. Create datasets ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) - ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) # 2.5. Create subset, if uncommented ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # 3. Create data loaders train_loader = DataLoader( @@ -108,14 +105,6 @@ if __name__ == "__main__": ds_valid, shuffle=False, drop_last=True, - batch_size=wandb.config.BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) - test_loader = DataLoader( - ds_test, - shuffle=False, - drop_last=False, batch_size=1, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, @@ -136,13 +125,13 @@ if __name__ == "__main__": dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") + torch.onnx.export(net, dummy_input, "checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("checkpoints/model-0.onnx") + artifact.add_file("checkpoints/model.onnx") wandb.run.log_artifact(artifact) # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) + wandb.watch(net, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") @@ -156,6 +145,7 @@ if __name__ == "__main__": ) try: + global_step = 0 for epoch in range(1, wandb.config.EPOCHS + 1): with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: @@ -195,7 +185,6 @@ if __name__ == "__main__": # log metrics wandb.log( { - "epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, "train/bce": train_loss, "train/dice": dice, @@ -203,14 +192,16 @@ if __name__ == "__main__": } ) - if step and (step % 250 == 0 or step == len(train_loader)): + global_step += 1 + + if global_step % 100 == 0: # Evaluation round net.eval() accuracy = 0 val_loss = 0 dice = 0 mae = 0 - with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2: + with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2: for images, masks_true in val_loader: # transfer images to device @@ -228,8 +219,9 @@ if __name__ == "__main__": accuracy += (masks_true == masks_pred_bin).float().mean() dice += dice_coeff(masks_pred_bin, masks_true) - # update progress bar + # update progress bars pbar2.update(images.shape[0]) + pbar.refresh() accuracy /= len(val_loader) val_loss /= len(val_loader) @@ -285,75 +277,6 @@ if __name__ == "__main__": artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx") wandb.run.log_artifact(artifact) - # testing round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3: - for images, masks_true in test_loader: - - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) - - # forward - with torch.inference_mode(): - masks_pred = net(images) - - # compute metrics - val_loss += criterion(masks_pred, masks_true) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) - - # update progress bar - pbar3.update(images.shape[0]) - - accuracy /= len(test_loader) - val_loss /= len(test_loader) - dice /= len(test_loader) - mae /= len(test_loader) - - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - # log validation metrics - wandb.log( - { - "test/predictions": table, - "test/accuracy": accuracy, - "test/bce": val_loss, - "test/dice": dice, - "test/mae": mae, - }, - commit=False, - ) - # stop wandb wandb.run.finish() diff --git a/src/unet/model.py b/src/unet/model.py index 08d2807..296f5c0 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -26,7 +26,6 @@ class UNet(nn.Module): self.outc = OutConv(features[0], n_classes) def forward(self, x): - skips = [] x = self.inc(x)