diff --git a/src/train.py b/src/train.py index c717652..5feb01c 100644 --- a/src/train.py +++ b/src/train.py @@ -26,7 +26,7 @@ def main(): DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", - FEATURES=[16, 32, 64, 128], + FEATURES=[64, 128, 256, 512], N_CHANNELS=3, N_CLASSES=1, AMP=True, @@ -35,7 +35,7 @@ def main(): DEVICE="cuda", WORKERS=8, EPOCHS=5, - BATCH_SIZE=64, + BATCH_SIZE=16, LEARNING_RATE=1e-4, IMG_SIZE=512, SPHERES=5, @@ -105,11 +105,15 @@ def main(): ) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp - optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE) + optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() + # accuracy stuff + mse = torch.nn.MSELoss() + mae = torch.nn.L1Loss() + # save model.pth torch.save(net.state_dict(), "checkpoints/model-0.pth") artifact = wandb.Artifact("pth", type="model") @@ -151,7 +155,7 @@ def main(): # forward with torch.cuda.amp.autocast(enabled=wandb.config.AMP): pred_masks = net(images) - train_loss = criterion(true_masks, pred_masks) + train_loss = criterion(pred_masks, true_masks) # backward optimizer.zero_grad(set_to_none=True) @@ -159,21 +163,29 @@ def main(): grad_scaler.step(optimizer) grad_scaler.update() + # compute metrics + accuracy = (true_masks == pred_masks).float().mean() + mse = torch.nn.functional.mse_loss(pred_masks, true_masks) + mae = torch.nn.functional.l1_loss(pred_masks, true_masks) + # update tqdm progress bar pbar.update(images.shape[0]) pbar.set_postfix(**{"loss": train_loss.item()}) - # log training metrics + # log metrics wandb.log( { "train/epoch": epoch - 1 + step / len(train_loader), - "train/train_loss": train_loss, + "train/accuracy": accuracy, + "train/loss": train_loss, + "train/mse": mse, + "train/mae": mae, } ) # Evaluation round val_score = evaluate(net, val_loader, device) - # scheduler.step(val_score) + scheduler.step(val_score) # log validation metrics wandb.log( diff --git a/src/unet/blocks.py b/src/unet/blocks.py index b5b9267..b1f7d7f 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -72,7 +72,7 @@ class OutConv(nn.Module): self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1), - nn.Sigmoid(), + # nn.Sigmoid(), ) def forward(self, x):