fix: paste bad algo :(

Former-commit-id: 23df9b5127d0abd3cc9ac77f56b0a1755b28eb00 [formerly 30326f6dd6eace91f4e0e463eab3fdb2083fb38e]
Former-commit-id: 23cbd859a96b5cf0cc361b71b32b54f54edb321b
This commit is contained in:
Laurent Fainsin 2022-06-29 11:22:23 +02:00
parent 83f8825340
commit c700835065
3 changed files with 19 additions and 7 deletions

View file

@ -1 +1 @@
eb93e8ed23714201f3e83001ee4821b791b01a53
946598e599580eb39a7de79d3607416538bc0019

View file

@ -18,7 +18,7 @@ from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
@ -88,6 +88,9 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# enable cudnn benchmarking
torch.backends.cudnn.benchmark = True
# 0. Create network
net = UNet(n_channels=3, n_classes=args.classes)
logging.info(
@ -131,7 +134,7 @@ def main():
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True)
loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
@ -141,6 +144,7 @@ def main():
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
criterion = nn.BCEWithLogitsLoss()
# setup wandb
wandb.init(
project="U-Net-tmp",
config=dict(
@ -150,6 +154,7 @@ def main():
amp=args.amp,
),
)
wandb.save(f"{CHECKPOINT_DIR}/*")
logging.info(
f"""Starting training:

View file

@ -56,6 +56,13 @@ class RandomPaste(A.DualTransform):
return np.asarray(mask.convert("L"))
@staticmethod
def overlap(positions, x1, y1, w, h):
for x2, y2 in positions:
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
return True
return False
def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder
filename = rd.choice(os.listdir(self.path_paste_img_dir))
@ -107,13 +114,13 @@ class RandomPaste(A.DualTransform):
# generate some positions
positions = []
while len(positions) <= rd.randint(1, self.nb):
NB = rd.randint(1, self.nb)
while len(positions) <= NB:
x = rd.randint(0, target_shape[0] - paste_shape[0])
y = rd.randint(0, target_shape[1] - paste_shape[1])
# check for overlapping
for xo, yo in positions:
if (x <= xo + paste_shape[0]) and (y <= yo + paste_shape[1]):
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]):
continue
positions.append((x, y))