diff --git a/src/train.py b/src/train.py index 6e331b8..57e1179 100644 --- a/src/train.py +++ b/src/train.py @@ -43,7 +43,7 @@ def main(): ) # create device - device = torch.device(wandb.config.device) + device = torch.device(wandb.config.DEVICE) # enable cudnn benchmarking torch.backends.cudnn.benchmark = wandb.config.BENCHMARK @@ -112,7 +112,7 @@ def main(): # save model.onxx dummy_input = torch.randn( - 1, wandb.config.n_channels, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True + 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) torch.onnx.export(net, dummy_input, "model.onnx") artifact = wandb.Artifact("onnx", type="model") @@ -132,8 +132,8 @@ def main(): # Training round for step, (images, true_masks) in enumerate(train_loader): - assert images.shape[1] == net.N_CHANNELS, ( - f"Network has been defined with {net.N_CHANNELS} input channels, " + assert images.shape[1] == net.n_channels, ( + f"Network has been defined with {net.n_channels} input channels, " f"but loaded images have {images.shape[1]} channels. Please check that " "the images are loaded correctly." )