mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
fix: paste bad algo :(
Former-commit-id: 23df9b5127d0abd3cc9ac77f56b0a1755b28eb00 [formerly 30326f6dd6eace91f4e0e463eab3fdb2083fb38e] Former-commit-id: 23cbd859a96b5cf0cc361b71b32b54f54edb321b
This commit is contained in:
parent
83f8825340
commit
c700835065
|
@ -1 +1 @@
|
|||
eb93e8ed23714201f3e83001ee4821b791b01a53
|
||||
946598e599580eb39a7de79d3607416538bc0019
|
|
@ -18,7 +18,7 @@ from utils.paste import RandomPaste
|
|||
|
||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
|
||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
||||
|
||||
|
@ -88,6 +88,9 @@ def main():
|
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device {device}")
|
||||
|
||||
# enable cudnn benchmarking
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=3, n_classes=args.classes)
|
||||
logging.info(
|
||||
|
@ -131,7 +134,7 @@ def main():
|
|||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||
|
||||
# 3. Create data loaders
|
||||
loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True)
|
||||
loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True)
|
||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||
|
||||
|
@ -141,6 +144,7 @@ def main():
|
|||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
# setup wandb
|
||||
wandb.init(
|
||||
project="U-Net-tmp",
|
||||
config=dict(
|
||||
|
@ -150,6 +154,7 @@ def main():
|
|||
amp=args.amp,
|
||||
),
|
||||
)
|
||||
wandb.save(f"{CHECKPOINT_DIR}/*")
|
||||
|
||||
logging.info(
|
||||
f"""Starting training:
|
||||
|
|
|
@ -56,6 +56,13 @@ class RandomPaste(A.DualTransform):
|
|||
|
||||
return np.asarray(mask.convert("L"))
|
||||
|
||||
@staticmethod
|
||||
def overlap(positions, x1, y1, w, h):
|
||||
for x2, y2 in positions:
|
||||
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_params_dependent_on_targets(self, params):
|
||||
# choose a random image inside the image folder
|
||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
||||
|
@ -107,14 +114,14 @@ class RandomPaste(A.DualTransform):
|
|||
|
||||
# generate some positions
|
||||
positions = []
|
||||
while len(positions) <= rd.randint(1, self.nb):
|
||||
NB = rd.randint(1, self.nb)
|
||||
while len(positions) <= NB:
|
||||
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
||||
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
||||
|
||||
# check for overlapping
|
||||
for xo, yo in positions:
|
||||
if (x <= xo + paste_shape[0]) and (y <= yo + paste_shape[1]):
|
||||
continue
|
||||
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]):
|
||||
continue
|
||||
|
||||
positions.append((x, y))
|
||||
|
||||
|
|
Loading…
Reference in a new issue