From c700835065e0d53904f15581f019b2d9f0616546 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 29 Jun 2022 11:22:23 +0200 Subject: [PATCH] fix: paste bad algo :( Former-commit-id: 23df9b5127d0abd3cc9ac77f56b0a1755b28eb00 [formerly 30326f6dd6eace91f4e0e463eab3fdb2083fb38e] Former-commit-id: 23cbd859a96b5cf0cc361b71b32b54f54edb321b --- comp.ipynb.REMOVED.git-id | 2 +- src/train.py | 9 +++++++-- src/utils/paste.py | 15 +++++++++++---- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index 57a48a5..00cc96d 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -eb93e8ed23714201f3e83001ee4821b791b01a53 \ No newline at end of file +946598e599580eb39a7de79d3607416538bc0019 \ No newline at end of file diff --git a/src/train.py b/src/train.py index e9c63dc..809244f 100644 --- a/src/train.py +++ b/src/train.py @@ -18,7 +18,7 @@ from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") -DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/") +DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") @@ -88,6 +88,9 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device {device}") + # enable cudnn benchmarking + torch.backends.cudnn.benchmark = True + # 0. Create network net = UNet(n_channels=3, n_classes=args.classes) logging.info( @@ -131,7 +134,7 @@ def main(): ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True) + loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) @@ -141,6 +144,7 @@ def main(): grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) criterion = nn.BCEWithLogitsLoss() + # setup wandb wandb.init( project="U-Net-tmp", config=dict( @@ -150,6 +154,7 @@ def main(): amp=args.amp, ), ) + wandb.save(f"{CHECKPOINT_DIR}/*") logging.info( f"""Starting training: diff --git a/src/utils/paste.py b/src/utils/paste.py index 6dbaa6b..60ded70 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -56,6 +56,13 @@ class RandomPaste(A.DualTransform): return np.asarray(mask.convert("L")) + @staticmethod + def overlap(positions, x1, y1, w, h): + for x2, y2 in positions: + if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h: + return True + return False + def get_params_dependent_on_targets(self, params): # choose a random image inside the image folder filename = rd.choice(os.listdir(self.path_paste_img_dir)) @@ -107,14 +114,14 @@ class RandomPaste(A.DualTransform): # generate some positions positions = [] - while len(positions) <= rd.randint(1, self.nb): + NB = rd.randint(1, self.nb) + while len(positions) <= NB: x = rd.randint(0, target_shape[0] - paste_shape[0]) y = rd.randint(0, target_shape[1] - paste_shape[1]) # check for overlapping - for xo, yo in positions: - if (x <= xo + paste_shape[0]) and (y <= yo + paste_shape[1]): - continue + if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]): + continue positions.append((x, y))