Merge pull request #353 from Gouvernathor/patch-1
Various minor changes Former-commit-id: e36c782fbfc976b7326182a47dd7213bd3360a7e
This commit is contained in:
commit
408f2c9ec2
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,3 +6,4 @@ checkpoints/
|
||||||
*.jpg
|
*.jpg
|
||||||
venv/
|
venv/
|
||||||
.idea/
|
.idea/
|
||||||
|
wandb/
|
||||||
|
|
13
train.py
13
train.py
|
@ -72,10 +72,10 @@ def train_net(net,
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
# 5. Begin training
|
# 5. Begin training
|
||||||
for epoch in range(epochs):
|
for epoch in range(1, epochs+1):
|
||||||
net.train()
|
net.train()
|
||||||
epoch_loss = 0
|
epoch_loss = 0
|
||||||
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
|
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
|
||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
images = batch['image']
|
images = batch['image']
|
||||||
true_masks = batch['mask']
|
true_masks = batch['mask']
|
||||||
|
@ -139,8 +139,8 @@ def train_net(net,
|
||||||
|
|
||||||
if save_checkpoint:
|
if save_checkpoint:
|
||||||
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
|
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
|
||||||
logging.info(f'Checkpoint {epoch + 1} saved!')
|
logging.info(f'Checkpoint {epoch} saved!')
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -155,6 +155,7 @@ def get_args():
|
||||||
help='Percent of the data that is used as validation (0-100)')
|
help='Percent of the data that is used as validation (0-100)')
|
||||||
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
|
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
|
||||||
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
|
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
|
||||||
|
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -169,7 +170,7 @@ if __name__ == '__main__':
|
||||||
# Change here to adapt to your data
|
# Change here to adapt to your data
|
||||||
# n_channels=3 for RGB images
|
# n_channels=3 for RGB images
|
||||||
# n_classes is the number of probabilities you want to get per pixel
|
# n_classes is the number of probabilities you want to get per pixel
|
||||||
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
|
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
|
||||||
|
|
||||||
logging.info(f'Network:\n'
|
logging.info(f'Network:\n'
|
||||||
f'\t{net.n_channels} input channels\n'
|
f'\t{net.n_channels} input channels\n'
|
||||||
|
@ -193,4 +194,4 @@ if __name__ == '__main__':
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||||
logging.info('Saved interrupt')
|
logging.info('Saved interrupt')
|
||||||
sys.exit(0)
|
raise
|
||||||
|
|
|
@ -58,8 +58,8 @@ class BasicDataset(Dataset):
|
||||||
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
|
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
|
||||||
img_file = list(self.images_dir.glob(name + '.*'))
|
img_file = list(self.images_dir.glob(name + '.*'))
|
||||||
|
|
||||||
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
|
|
||||||
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
|
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
|
||||||
|
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
|
||||||
mask = self.load(mask_file[0])
|
mask = self.load(mask_file[0])
|
||||||
img = self.load(img_file[0])
|
img = self.load(img_file[0])
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue