fix: paste bad algo :(
Former-commit-id: 23df9b5127d0abd3cc9ac77f56b0a1755b28eb00 [formerly 30326f6dd6eace91f4e0e463eab3fdb2083fb38e] Former-commit-id: 23cbd859a96b5cf0cc361b71b32b54f54edb321b
This commit is contained in:
parent
83f8825340
commit
c700835065
|
@ -1 +1 @@
|
||||||
eb93e8ed23714201f3e83001ee4821b791b01a53
|
946598e599580eb39a7de79d3607416538bc0019
|
|
@ -18,7 +18,7 @@ from utils.paste import RandomPaste
|
||||||
|
|
||||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
CHECKPOINT_DIR = Path("./checkpoints/")
|
||||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
||||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/")
|
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
|
||||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
||||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
||||||
|
|
||||||
|
@ -88,6 +88,9 @@ def main():
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logging.info(f"Using device {device}")
|
logging.info(f"Using device {device}")
|
||||||
|
|
||||||
|
# enable cudnn benchmarking
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=3, n_classes=args.classes)
|
net = UNet(n_channels=3, n_classes=args.classes)
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -131,7 +134,7 @@ def main():
|
||||||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True)
|
loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True)
|
||||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||||
|
|
||||||
|
@ -141,6 +144,7 @@ def main():
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
# setup wandb
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project="U-Net-tmp",
|
project="U-Net-tmp",
|
||||||
config=dict(
|
config=dict(
|
||||||
|
@ -150,6 +154,7 @@ def main():
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
wandb.save(f"{CHECKPOINT_DIR}/*")
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"""Starting training:
|
f"""Starting training:
|
||||||
|
|
|
@ -56,6 +56,13 @@ class RandomPaste(A.DualTransform):
|
||||||
|
|
||||||
return np.asarray(mask.convert("L"))
|
return np.asarray(mask.convert("L"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def overlap(positions, x1, y1, w, h):
|
||||||
|
for x2, y2 in positions:
|
||||||
|
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_params_dependent_on_targets(self, params):
|
def get_params_dependent_on_targets(self, params):
|
||||||
# choose a random image inside the image folder
|
# choose a random image inside the image folder
|
||||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
||||||
|
@ -107,14 +114,14 @@ class RandomPaste(A.DualTransform):
|
||||||
|
|
||||||
# generate some positions
|
# generate some positions
|
||||||
positions = []
|
positions = []
|
||||||
while len(positions) <= rd.randint(1, self.nb):
|
NB = rd.randint(1, self.nb)
|
||||||
|
while len(positions) <= NB:
|
||||||
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
||||||
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
||||||
|
|
||||||
# check for overlapping
|
# check for overlapping
|
||||||
for xo, yo in positions:
|
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]):
|
||||||
if (x <= xo + paste_shape[0]) and (y <= yo + paste_shape[1]):
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
positions.append((x, y))
|
positions.append((x, y))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue