Update train.py

Former-commit-id: dee78b12ca6810f5e02febfb244dce3885ed49ac
This commit is contained in:
Gouvernathor 2022-04-06 13:35:02 +02:00 committed by GitHub
parent 182cbd48f9
commit 758f0f8070

View file

@ -72,10 +72,10 @@ def train_net(net,
global_step = 0
# 5. Begin training
for epoch in range(epochs):
for epoch in range(1, epochs+1):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images = batch['image']
true_masks = batch['mask']
@ -139,8 +139,8 @@ def train_net(net,
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
logging.info(f'Checkpoint {epoch + 1} saved!')
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
def get_args():
@ -155,6 +155,7 @@ def get_args():
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
@ -169,7 +170,7 @@ if __name__ == '__main__':
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
@ -193,4 +194,4 @@ if __name__ == '__main__':
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
sys.exit(0)
raise