From 758f0f80701c3bb11879a1cc26c74cb6d3087af3 Mon Sep 17 00:00:00 2001 From: Gouvernathor Date: Wed, 6 Apr 2022 13:35:02 +0200 Subject: [PATCH] Update train.py Former-commit-id: dee78b12ca6810f5e02febfb244dce3885ed49ac --- train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index bffe179..c098bf9 100644 --- a/train.py +++ b/train.py @@ -72,10 +72,10 @@ def train_net(net, global_step = 0 # 5. Begin training - for epoch in range(epochs): + for epoch in range(1, epochs+1): net.train() epoch_loss = 0 - with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: + with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: for batch in train_loader: images = batch['image'] true_masks = batch['mask'] @@ -139,8 +139,8 @@ def train_net(net, if save_checkpoint: Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) - torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) - logging.info(f'Checkpoint {epoch + 1} saved!') + torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) + logging.info(f'Checkpoint {epoch} saved!') def get_args(): @@ -155,6 +155,7 @@ def get_args(): help='Percent of the data that is used as validation (0-100)') parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') return parser.parse_args() @@ -169,7 +170,7 @@ if __name__ == '__main__': # Change here to adapt to your data # n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel - net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear) + net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) logging.info(f'Network:\n' f'\t{net.n_channels} input channels\n' @@ -193,4 +194,4 @@ if __name__ == '__main__': except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') logging.info('Saved interrupt') - sys.exit(0) + raise