feat: it's training, but for how long ?
Former-commit-id: edc9e9bc0c32a08b263c945b9296980b5242924b [formerly 41b1c1a9704f82518e44dab52adc482e02cbbf73] Former-commit-id: e157962777b0cf057628b6c127ee68537731a528
This commit is contained in:
parent
4dab157dda
commit
5bd2e5b2c4
|
@ -1,11 +1,15 @@
|
|||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from utils import RandomPaste
|
||||
import wandb
|
||||
|
||||
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
||||
from .dataset import RealDataset
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
return tuple(zip(*batch))
|
||||
|
||||
|
||||
class Spheres(pl.LightningDataModule):
|
||||
|
@ -13,21 +17,22 @@ class Spheres(pl.LightningDataModule):
|
|||
super().__init__()
|
||||
|
||||
def train_dataloader(self):
|
||||
# transform = A.Compose(
|
||||
# [
|
||||
# A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
# A.Flip(),
|
||||
# A.ColorJitter(),
|
||||
# RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||
# A.GaussianBlur(),
|
||||
# A.ISONoise(),
|
||||
# ],
|
||||
# )
|
||||
transforms = A.Compose(
|
||||
[
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
bbox_params=A.BboxParams(
|
||||
format="pascal_voc",
|
||||
min_area=0.0,
|
||||
min_visibility=0.0,
|
||||
label_fields=["labels"],
|
||||
),
|
||||
)
|
||||
|
||||
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||
|
||||
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
|
||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms)
|
||||
print(f"len(dataset)={len(dataset)}")
|
||||
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
|
@ -36,11 +41,12 @@ class Spheres(pl.LightningDataModule):
|
|||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# def val_dataloader(self):
|
||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
# dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
||||
# dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||
|
||||
# return DataLoader(
|
||||
# dataset,
|
||||
|
|
|
@ -1,147 +1,42 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
def __init__(self, image_dir, transform):
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# open and convert image
|
||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||
|
||||
# create empty mask of same size
|
||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
# augment image and mask
|
||||
augmentations = self.transform(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
# convert image & mask to Tensor float in [0, 1]
|
||||
post_process = A.Compose(
|
||||
[
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
augmentations = post_process(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
# make sure image and mask are floats
|
||||
image = image.float()
|
||||
mask = mask.float()
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(self, image_dir):
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# open and convert image
|
||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||
|
||||
# open and convert mask
|
||||
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
||||
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255
|
||||
|
||||
# convert image & mask to Tensor float in [0, 1]
|
||||
post_process = A.Compose(
|
||||
[
|
||||
A.SmallestMaxSize(1024),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
augmentations = post_process(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
|
||||
image = image.float()
|
||||
mask = mask.float()
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class LabeledDataset2(Dataset):
|
||||
def __init__(self, image_dir):
|
||||
self.image_dir = Path(image_dir)
|
||||
|
||||
def __len__(self):
|
||||
return len(list(self.image_dir.iterdir()))
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.image_dir / str(index)
|
||||
|
||||
# open and convert image
|
||||
image = np.array(Image.open(path / "image.jpg").convert("RGB"), dtype=np.uint8)
|
||||
|
||||
# open and convert mask
|
||||
mask = np.array(Image.open(path / "MASK.PNG").convert("L"), dtype=np.uint8) // 255
|
||||
|
||||
# convert image & mask to Tensor float in [0, 1]
|
||||
post_process = A.Compose(
|
||||
[
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
augmentations = post_process(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
|
||||
image = image.float()
|
||||
mask = mask.float()
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class LabeledDataset3(object):
|
||||
def __init__(self, root, transforms):
|
||||
class RealDataset(Dataset):
|
||||
def __init__(self, root, transforms=None):
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
||||
# load all image files, sorting them to ensure that they are aligned
|
||||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# create paths from ids
|
||||
img_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||
image_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||
|
||||
# load image and mask
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
mask = Image.open(mask_path)
|
||||
|
||||
# convert mask to numpy array to apply operations
|
||||
# convert to numpy arrays
|
||||
image = np.array(image)
|
||||
mask = np.array(mask)
|
||||
|
||||
# get ids from mask
|
||||
obj_ids = np.unique(mask)
|
||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||
|
||||
# split the color-encoded mask into a set of binary masks
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
|
||||
# get bounding box coordinates for each mask
|
||||
# create bboxes from masks (pascal format)
|
||||
num_objs = len(obj_ids)
|
||||
bboxes = []
|
||||
for i in range(num_objs):
|
||||
|
@ -152,27 +47,46 @@ class LabeledDataset3(object):
|
|||
ymax = np.max(pos[0])
|
||||
bboxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
# convert arrays to tensors
|
||||
# convert arrays to tensors, TODO: check what albumentations wants, to reduce follow lines
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
|
||||
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
||||
|
||||
# image_id = torch.tensor([idx])
|
||||
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
||||
|
||||
target = {}
|
||||
target["boxes"] = bboxes
|
||||
target["labels"] = labels
|
||||
target["masks"] = masks
|
||||
# target["image_id"] = image_id
|
||||
# target["area"] = area
|
||||
# target["iscrowd"] = iscrowd
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64) # suppose there is only one class (id=1)
|
||||
masks = [mask for mask in masks] # albumentations wants list of masks
|
||||
|
||||
if self.transforms is not None:
|
||||
img, target = self.transforms(img, target)
|
||||
# arrange transform data
|
||||
data = {
|
||||
"image": image,
|
||||
"labels": labels,
|
||||
"bboxes": bboxes,
|
||||
"masks": masks,
|
||||
}
|
||||
# apply transform
|
||||
augmented = self.transforms(**data)
|
||||
# get augmented image and bboxes
|
||||
image = augmented["image"]
|
||||
bboxes = augmented["bboxes"]
|
||||
labels = augmented["labels"]
|
||||
# get masks
|
||||
masks = augmented["masks"]
|
||||
|
||||
return img, target
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 requiered by torchvision maskrcnn
|
||||
masks = torch.stack(masks) # stack masks, wanted by maskrcnn from torchvision
|
||||
|
||||
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
image_id = torch.tensor([idx])
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
||||
|
||||
target = {
|
||||
"boxes": bboxes,
|
||||
"labels": labels,
|
||||
"masks": masks,
|
||||
"area": area,
|
||||
"image_id": image_id,
|
||||
"iscrowd": iscrowd,
|
||||
}
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
|
|
@ -3,16 +3,17 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from torchvision.models.detection._utils import Matcher
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||
from torchvision.ops.boxes import box_iou
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
def get_model_instance_segmentation(num_classes):
|
||||
# load an instance segmentation model pre-trained on COCO
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
||||
|
||||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
|
@ -42,82 +43,84 @@ class MRCNNModule(pl.LightningModule):
|
|||
# Network
|
||||
self.model = get_model_instance_segmentation(n_classes)
|
||||
|
||||
def forward(self, imgs):
|
||||
# Torchvision FasterRCNN returns the loss during training
|
||||
# and the boxes during eval
|
||||
self.model.eval()
|
||||
return self.model(imgs)
|
||||
# def forward(self, imgs):
|
||||
# # Torchvision FasterRCNN returns the loss during training
|
||||
# # and the boxes during eval
|
||||
# self.model.eval()
|
||||
# return self.model(imgs)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# enable train mode
|
||||
self.model.train()
|
||||
# self.model.train()
|
||||
|
||||
# fasterrcnn takes both images and targets for training
|
||||
loss_dict = self.model(images, targets)
|
||||
loss = sum(loss_dict.values())
|
||||
# self.log_dict(loss_dict)
|
||||
# self.log(loss)
|
||||
return {"loss": loss, "log": loss_dict}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
# def validation_step(self, batch, batch_idx):
|
||||
# # unpack batch
|
||||
# images, targets = batch
|
||||
|
||||
# enable eval mode
|
||||
self.detector.eval()
|
||||
# # enable eval mode
|
||||
# # self.detector.eval()
|
||||
|
||||
# make a prediction
|
||||
preds = self.detector(images)
|
||||
# # make a prediction
|
||||
# preds = self.model(images)
|
||||
|
||||
# compute validation loss
|
||||
self.val_loss = torch.mean(
|
||||
torch.stack(
|
||||
[
|
||||
self.accuracy(
|
||||
target,
|
||||
pred["boxes"],
|
||||
iou_threshold=0.5,
|
||||
)
|
||||
for target, pred in zip(targets, preds)
|
||||
],
|
||||
)
|
||||
)
|
||||
# # compute validation loss
|
||||
# self.val_loss = torch.mean(
|
||||
# torch.stack(
|
||||
# [
|
||||
# self.accuracy(
|
||||
# target,
|
||||
# pred["boxes"],
|
||||
# iou_threshold=0.5,
|
||||
# )
|
||||
# for target, pred in zip(targets, preds)
|
||||
# ],
|
||||
# )
|
||||
# )
|
||||
|
||||
return self.val_loss
|
||||
# return self.val_loss
|
||||
|
||||
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||
"""
|
||||
The accuracy method is not the one used in the evaluator but very similar
|
||||
"""
|
||||
total_gt = len(src_boxes)
|
||||
total_pred = len(pred_boxes)
|
||||
if total_gt > 0 and total_pred > 0:
|
||||
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||
# """
|
||||
# The accuracy method is not the one used in the evaluator but very similar
|
||||
# """
|
||||
# total_gt = len(src_boxes)
|
||||
# total_pred = len(pred_boxes)
|
||||
# if total_gt > 0 and total_pred > 0:
|
||||
|
||||
# Define the matcher and distance matrix based on iou
|
||||
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||
match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||
# # Define the matcher and distance matrix based on iou
|
||||
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||
|
||||
results = matcher(match_quality_matrix)
|
||||
# results = matcher(match_quality_matrix)
|
||||
|
||||
true_positive = torch.count_nonzero(results.unique() != -1)
|
||||
matched_elements = results[results > -1]
|
||||
# true_positive = torch.count_nonzero(results.unique() != -1)
|
||||
# matched_elements = results[results > -1]
|
||||
|
||||
# in Matcher, a pred element can be matched only twice
|
||||
false_positive = torch.count_nonzero(results == -1) + (
|
||||
len(matched_elements) - len(matched_elements.unique())
|
||||
)
|
||||
false_negative = total_gt - true_positive
|
||||
# # in Matcher, a pred element can be matched only twice
|
||||
# false_positive = torch.count_nonzero(results == -1) + (
|
||||
# len(matched_elements) - len(matched_elements.unique())
|
||||
# )
|
||||
# false_negative = total_gt - true_positive
|
||||
|
||||
return true_positive / (true_positive + false_positive + false_negative)
|
||||
# return true_positive / (true_positive + false_positive + false_negative)
|
||||
|
||||
elif total_gt == 0:
|
||||
if total_pred > 0:
|
||||
return torch.tensor(0.0).cuda()
|
||||
else:
|
||||
return torch.tensor(1.0).cuda()
|
||||
elif total_gt > 0 and total_pred == 0:
|
||||
return torch.tensor(0.0).cuda()
|
||||
# elif total_gt == 0:
|
||||
# if total_pred > 0:
|
||||
# return torch.tensor(0.0).cuda()
|
||||
# else:
|
||||
# return torch.tensor(1.0).cuda()
|
||||
# elif total_gt > 0 and total_pred == 0:
|
||||
# return torch.tensor(0.0).cuda()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(
|
||||
|
@ -125,20 +128,22 @@ class MRCNNModule(pl.LightningModule):
|
|||
lr=wandb.config.LEARNING_RATE,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
nesterov=wandb.config.NESTEROV,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=3,
|
||||
T_mult=1,
|
||||
lr=wandb.config.LEARNING_RATE_MIN,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_accuracy",
|
||||
},
|
||||
}
|
||||
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
# optimizer,
|
||||
# T_0=3,
|
||||
# T_mult=1,
|
||||
# lr=wandb.config.LEARNING_RATE_MIN,
|
||||
# verbose=True,
|
||||
# )
|
||||
|
||||
# return {
|
||||
# "optimizer": optimizer,
|
||||
# "lr_scheduler": {
|
||||
# "scheduler": scheduler,
|
||||
# "monitor": "val_accuracy",
|
||||
# },
|
||||
# }
|
||||
|
||||
return optimizer
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from data import Spheres
|
||||
from mrcnn import MRCNNModule
|
||||
from unet import UNetModule
|
||||
|
@ -58,10 +58,10 @@ if __name__ == "__main__":
|
|||
precision=16,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
val_check_interval=100,
|
||||
# val_check_interval=100,
|
||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||
# profiler="simple",
|
||||
# num_sanity_val_steps=0,
|
||||
num_sanity_val_steps=0,
|
||||
)
|
||||
|
||||
# actually train the model
|
||||
|
|
10
wandb.yaml
10
wandb.yaml
|
@ -17,11 +17,11 @@ AMP:
|
|||
PIN_MEMORY:
|
||||
value: True
|
||||
BENCHMARK:
|
||||
value: True
|
||||
value: False
|
||||
DEVICE:
|
||||
value: gpu
|
||||
WORKERS:
|
||||
value: 8
|
||||
value: 1
|
||||
|
||||
IMG_SIZE:
|
||||
value: 512
|
||||
|
@ -31,11 +31,11 @@ SPHERES:
|
|||
EPOCHS:
|
||||
value: 3
|
||||
TRAIN_BATCH_SIZE:
|
||||
value: 128 # 100
|
||||
value: 2 # 100
|
||||
VAL_BATCH_SIZE:
|
||||
value: 8 # 10
|
||||
value: 0 # 10
|
||||
PREFETCH_FACTOR:
|
||||
value: 2
|
||||
value: 1
|
||||
|
||||
LEARNING_RATE:
|
||||
value: 1.0e-4
|
||||
|
|
Loading…
Reference in a new issue