From 3056ebc49804175c246c8fbbe767478b59ffb7e9 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 30 Jun 2022 14:43:41 +0200 Subject: [PATCH] fix: typo Former-commit-id: 0336f7e3a3809f6501dec4ca11a9ae1f3be44f29 [formerly 1bd5e48616d48d52e13067ff505a94162acf9e6f] Former-commit-id: 471e9d11a26b14f6b878f9735d3e9c40eb10161b --- src/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index 6e331b8..57e1179 100644 --- a/src/train.py +++ b/src/train.py @@ -43,7 +43,7 @@ def main(): ) # create device - device = torch.device(wandb.config.device) + device = torch.device(wandb.config.DEVICE) # enable cudnn benchmarking torch.backends.cudnn.benchmark = wandb.config.BENCHMARK @@ -112,7 +112,7 @@ def main(): # save model.onxx dummy_input = torch.randn( - 1, wandb.config.n_channels, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True + 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) torch.onnx.export(net, dummy_input, "model.onnx") artifact = wandb.Artifact("onnx", type="model") @@ -132,8 +132,8 @@ def main(): # Training round for step, (images, true_masks) in enumerate(train_loader): - assert images.shape[1] == net.N_CHANNELS, ( - f"Network has been defined with {net.N_CHANNELS} input channels, " + assert images.shape[1] == net.n_channels, ( + f"Network has been defined with {net.n_channels} input channels, " f"but loaded images have {images.shape[1]} channels. Please check that " "the images are loaded correctly." )