feat: almost there, remove scheduler next time
Former-commit-id: 0a08b5a9559e46ca72f7d07ae84202c1412a63e9 [formerly 522877adbc8f7d132875405a86e594b4fb753850] Former-commit-id: 2ea20809c265b8366ec2e0aa3867b13886cbd500
This commit is contained in:
parent
d785a5c6be
commit
c0772a390e
103
src/train.py
103
src/train.py
|
@ -30,14 +30,14 @@ if __name__ == "__main__":
|
|||
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[64, 128, 256, 512],
|
||||
FEATURES=[16, 32, 64, 128],
|
||||
N_CHANNELS=3,
|
||||
N_CLASSES=1,
|
||||
AMP=True,
|
||||
PIN_MEMORY=True,
|
||||
BENCHMARK=True,
|
||||
DEVICE="cuda",
|
||||
WORKERS=8,
|
||||
WORKERS=7,
|
||||
EPOCHS=5,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-4,
|
||||
|
@ -88,13 +88,10 @@ if __name__ == "__main__":
|
|||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
|
||||
# 2.5. Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
|
||||
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
|
@ -108,14 +105,6 @@ if __name__ == "__main__":
|
|||
ds_valid,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
ds_test,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
|
@ -136,13 +125,13 @@ if __name__ == "__main__":
|
|||
dummy_input = torch.randn(
|
||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||
).to(device)
|
||||
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
|
||||
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("checkpoints/model-0.onnx")
|
||||
artifact.add_file("checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# log gradients and weights four time per epoch
|
||||
wandb.watch(net, criterion, log_freq=100)
|
||||
wandb.watch(net, log_freq=100)
|
||||
|
||||
# print the config
|
||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||
|
@ -156,6 +145,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
try:
|
||||
global_step = 0
|
||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||
|
||||
|
@ -195,7 +185,6 @@ if __name__ == "__main__":
|
|||
# log metrics
|
||||
wandb.log(
|
||||
{
|
||||
"epoch": epoch - 1 + step / len(train_loader),
|
||||
"train/accuracy": accuracy,
|
||||
"train/bce": train_loss,
|
||||
"train/dice": dice,
|
||||
|
@ -203,14 +192,16 @@ if __name__ == "__main__":
|
|||
}
|
||||
)
|
||||
|
||||
if step and (step % 250 == 0 or step == len(train_loader)):
|
||||
global_step += 1
|
||||
|
||||
if global_step % 100 == 0:
|
||||
# Evaluation round
|
||||
net.eval()
|
||||
accuracy = 0
|
||||
val_loss = 0
|
||||
dice = 0
|
||||
mae = 0
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2:
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
|
||||
for images, masks_true in val_loader:
|
||||
|
||||
# transfer images to device
|
||||
|
@ -228,8 +219,9 @@ if __name__ == "__main__":
|
|||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
# update progress bar
|
||||
# update progress bars
|
||||
pbar2.update(images.shape[0])
|
||||
pbar.refresh()
|
||||
|
||||
accuracy /= len(val_loader)
|
||||
val_loss /= len(val_loader)
|
||||
|
@ -285,75 +277,6 @@ if __name__ == "__main__":
|
|||
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# testing round
|
||||
net.eval()
|
||||
accuracy = 0
|
||||
val_loss = 0
|
||||
dice = 0
|
||||
mae = 0
|
||||
with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3:
|
||||
for images, masks_true in test_loader:
|
||||
|
||||
# transfer images to device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||
|
||||
# forward
|
||||
with torch.inference_mode():
|
||||
masks_pred = net(images)
|
||||
|
||||
# compute metrics
|
||||
val_loss += criterion(masks_pred, masks_true)
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
# update progress bar
|
||||
pbar3.update(images.shape[0])
|
||||
|
||||
accuracy /= len(test_loader)
|
||||
val_loss /= len(test_loader)
|
||||
dice /= len(test_loader)
|
||||
mae /= len(test_loader)
|
||||
|
||||
# save the last validation batch to table
|
||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
table.add_data(
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# log validation metrics
|
||||
wandb.log(
|
||||
{
|
||||
"test/predictions": table,
|
||||
"test/accuracy": accuracy,
|
||||
"test/bce": val_loss,
|
||||
"test/dice": dice,
|
||||
"test/mae": mae,
|
||||
},
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ class UNet(nn.Module):
|
|||
self.outc = OutConv(features[0], n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
skips = []
|
||||
|
||||
x = self.inc(x)
|
||||
|
|
Loading…
Reference in a new issue