feat: got precision 16 back

Former-commit-id: 6b19dc9bd17078bb2c151d5cd96e7ba4da9e1b89 [formerly 5d1eac2ed10be960c89407ad265ff350e11c1adf]
Former-commit-id: 1db4ca0ce11ac818408b94625b872c1202b5d4ed
This commit is contained in:
Laurent Fainsin 2022-07-11 17:02:13 +02:00
parent ed07e130e6
commit 82682ceeb2
5 changed files with 11 additions and 7 deletions

View file

@ -58,7 +58,7 @@ class LabeledDataset(Dataset):
# open and convert mask
mask_path = self.images[index].parent.joinpath("MASK.PNG")
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255
# convert image & mask to Tensor float in [0, 1]
post_process = A.Compose(
@ -72,4 +72,8 @@ class LabeledDataset(Dataset):
image = augmentations["image"]
mask = augmentations["mask"]
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
image = image.float()
mask = mask.float()
return image, mask

View file

@ -38,7 +38,7 @@ if __name__ == "__main__":
# model.load_state_dict(state_dict)
# log gradients and weights regularly
logger.watch(model, log="all")
logger.watch(model.model, log="all")
# Create the dataloaders
datamodule = Spheres()
@ -49,7 +49,7 @@ if __name__ == "__main__":
accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK,
# profiler="simple",
# precision=16,
precision=16,
logger=logger,
log_every_n_steps=1,
val_check_interval=100,

View file

@ -38,7 +38,7 @@ class UNetModule(pl.LightningModule):
# forward pass, compute masks
prediction = self.model(data)
binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary
binary = (torch.sigmoid(prediction) > 0.5).half()
# compute metrics (in dictionnary)
metrics = {

View file

@ -31,7 +31,7 @@ class TableLog(Callback):
zip(
images.cpu(),
ground_truth.cpu(),
predictions["linear"].cpu(),
predictions["linear"].cpu().float(),
predictions["binary"].cpu().squeeze(1).int().numpy(),
)
):

View file

@ -87,7 +87,7 @@ class RandomPaste(A.DualTransform):
img.paste(paste_img, (x, y), paste_mask)
return np.asarray(img.convert("RGB"))
return np.array(img.convert("RGB"))
def apply_to_mask(self, mask, augmentations, paste_mask, **params):
# convert mask to Image, needed for `paste` function
@ -116,7 +116,7 @@ class RandomPaste(A.DualTransform):
mask.paste(paste_mask, (x, y), paste_mask_bin)
return np.asarray(mask.convert("L"))
return np.array(mask.convert("L"))
def get_params_dependent_on_targets(self, params):
# choose a random image and its corresponding mask