feat: it's training, but for how long ?
Former-commit-id: edc9e9bc0c32a08b263c945b9296980b5242924b [formerly 41b1c1a9704f82518e44dab52adc482e02cbbf73] Former-commit-id: e157962777b0cf057628b6c127ee68537731a528
This commit is contained in:
parent
4dab157dda
commit
5bd2e5b2c4
|
@ -1,11 +1,15 @@
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch.utils.data import DataLoader, Subset
|
from torch.utils.data import DataLoader, Subset
|
||||||
|
|
||||||
from utils import RandomPaste
|
import wandb
|
||||||
|
|
||||||
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
from .dataset import RealDataset
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
return tuple(zip(*batch))
|
||||||
|
|
||||||
|
|
||||||
class Spheres(pl.LightningDataModule):
|
class Spheres(pl.LightningDataModule):
|
||||||
|
@ -13,21 +17,22 @@ class Spheres(pl.LightningDataModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
# transform = A.Compose(
|
transforms = A.Compose(
|
||||||
# [
|
[
|
||||||
# A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
A.ToFloat(max_value=255),
|
||||||
# A.Flip(),
|
ToTensorV2(),
|
||||||
# A.ColorJitter(),
|
],
|
||||||
# RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
bbox_params=A.BboxParams(
|
||||||
# A.GaussianBlur(),
|
format="pascal_voc",
|
||||||
# A.ISONoise(),
|
min_area=0.0,
|
||||||
# ],
|
min_visibility=0.0,
|
||||||
# )
|
label_fields=["labels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms)
|
||||||
|
print(f"len(dataset)={len(dataset)}")
|
||||||
dataset = LabeledDataset2(image_dir="/media/disk1/lfainsin/TEST_tmp_mrcnn/")
|
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||||
dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -36,11 +41,12 @@ class Spheres(pl.LightningDataModule):
|
||||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# def val_dataloader(self):
|
# def val_dataloader(self):
|
||||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
# dataset = Subset(dataset, list(range(len(dataset)))) # somhow this allows to better utilize the gpu
|
# dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||||
|
|
||||||
# return DataLoader(
|
# return DataLoader(
|
||||||
# dataset,
|
# dataset,
|
||||||
|
|
|
@ -1,147 +1,42 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import albumentations as A
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataset(Dataset):
|
class RealDataset(Dataset):
|
||||||
def __init__(self, image_dir, transform):
|
def __init__(self, root, transforms=None):
|
||||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.images)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
# open and convert image
|
|
||||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
|
||||||
|
|
||||||
# create empty mask of same size
|
|
||||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
||||||
|
|
||||||
# augment image and mask
|
|
||||||
augmentations = self.transform(image=image, mask=mask)
|
|
||||||
image = augmentations["image"]
|
|
||||||
mask = augmentations["mask"]
|
|
||||||
|
|
||||||
# convert image & mask to Tensor float in [0, 1]
|
|
||||||
post_process = A.Compose(
|
|
||||||
[
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
augmentations = post_process(image=image, mask=mask)
|
|
||||||
image = augmentations["image"]
|
|
||||||
mask = augmentations["mask"]
|
|
||||||
|
|
||||||
# make sure image and mask are floats
|
|
||||||
image = image.float()
|
|
||||||
mask = mask.float()
|
|
||||||
|
|
||||||
return image, mask
|
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
|
||||||
def __init__(self, image_dir):
|
|
||||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.images)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
# open and convert image
|
|
||||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
|
||||||
|
|
||||||
# open and convert mask
|
|
||||||
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
|
||||||
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255
|
|
||||||
|
|
||||||
# convert image & mask to Tensor float in [0, 1]
|
|
||||||
post_process = A.Compose(
|
|
||||||
[
|
|
||||||
A.SmallestMaxSize(1024),
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
augmentations = post_process(image=image, mask=mask)
|
|
||||||
image = augmentations["image"]
|
|
||||||
mask = augmentations["mask"]
|
|
||||||
|
|
||||||
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
|
|
||||||
image = image.float()
|
|
||||||
mask = mask.float()
|
|
||||||
|
|
||||||
return image, mask
|
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset2(Dataset):
|
|
||||||
def __init__(self, image_dir):
|
|
||||||
self.image_dir = Path(image_dir)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(list(self.image_dir.iterdir()))
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
path = self.image_dir / str(index)
|
|
||||||
|
|
||||||
# open and convert image
|
|
||||||
image = np.array(Image.open(path / "image.jpg").convert("RGB"), dtype=np.uint8)
|
|
||||||
|
|
||||||
# open and convert mask
|
|
||||||
mask = np.array(Image.open(path / "MASK.PNG").convert("L"), dtype=np.uint8) // 255
|
|
||||||
|
|
||||||
# convert image & mask to Tensor float in [0, 1]
|
|
||||||
post_process = A.Compose(
|
|
||||||
[
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
augmentations = post_process(image=image, mask=mask)
|
|
||||||
image = augmentations["image"]
|
|
||||||
mask = augmentations["mask"]
|
|
||||||
|
|
||||||
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
|
|
||||||
image = image.float()
|
|
||||||
mask = mask.float()
|
|
||||||
|
|
||||||
return image, mask
|
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset3(object):
|
|
||||||
def __init__(self, root, transforms):
|
|
||||||
self.root = root
|
self.root = root
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
|
||||||
# load all image files, sorting them to ensure that they are aligned
|
# load all image files, sorting them to ensure that they are aligned
|
||||||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# create paths from ids
|
# create paths from ids
|
||||||
img_path = os.path.join(self.root, "images", self.imgs[idx])
|
image_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||||
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||||
|
|
||||||
# load image and mask
|
# load image and mask
|
||||||
img = Image.open(img_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
mask = Image.open(mask_path)
|
mask = Image.open(mask_path)
|
||||||
|
|
||||||
# convert mask to numpy array to apply operations
|
# convert to numpy arrays
|
||||||
|
image = np.array(image)
|
||||||
mask = np.array(mask)
|
mask = np.array(mask)
|
||||||
|
|
||||||
|
# get ids from mask
|
||||||
obj_ids = np.unique(mask)
|
obj_ids = np.unique(mask)
|
||||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||||
|
|
||||||
# split the color-encoded mask into a set of binary masks
|
# split the color-encoded mask into a set of binary masks
|
||||||
masks = mask == obj_ids[:, None, None]
|
masks = mask == obj_ids[:, None, None]
|
||||||
|
|
||||||
# get bounding box coordinates for each mask
|
# create bboxes from masks (pascal format)
|
||||||
num_objs = len(obj_ids)
|
num_objs = len(obj_ids)
|
||||||
bboxes = []
|
bboxes = []
|
||||||
for i in range(num_objs):
|
for i in range(num_objs):
|
||||||
|
@ -152,27 +47,46 @@ class LabeledDataset3(object):
|
||||||
ymax = np.max(pos[0])
|
ymax = np.max(pos[0])
|
||||||
bboxes.append([xmin, ymin, xmax, ymax])
|
bboxes.append([xmin, ymin, xmax, ymax])
|
||||||
|
|
||||||
# convert arrays to tensors
|
# convert arrays to tensors, TODO: check what albumentations wants, to reduce follow lines
|
||||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||||
labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
|
labels = torch.ones((num_objs,), dtype=torch.int64) # suppose there is only one class (id=1)
|
||||||
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
masks = [mask for mask in masks] # albumentations wants list of masks
|
||||||
|
|
||||||
# image_id = torch.tensor([idx])
|
|
||||||
# area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
|
||||||
# iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
|
||||||
|
|
||||||
target = {}
|
|
||||||
target["boxes"] = bboxes
|
|
||||||
target["labels"] = labels
|
|
||||||
target["masks"] = masks
|
|
||||||
# target["image_id"] = image_id
|
|
||||||
# target["area"] = area
|
|
||||||
# target["iscrowd"] = iscrowd
|
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
img, target = self.transforms(img, target)
|
# arrange transform data
|
||||||
|
data = {
|
||||||
|
"image": image,
|
||||||
|
"labels": labels,
|
||||||
|
"bboxes": bboxes,
|
||||||
|
"masks": masks,
|
||||||
|
}
|
||||||
|
# apply transform
|
||||||
|
augmented = self.transforms(**data)
|
||||||
|
# get augmented image and bboxes
|
||||||
|
image = augmented["image"]
|
||||||
|
bboxes = augmented["bboxes"]
|
||||||
|
labels = augmented["labels"]
|
||||||
|
# get masks
|
||||||
|
masks = augmented["masks"]
|
||||||
|
|
||||||
return img, target
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||||
|
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 requiered by torchvision maskrcnn
|
||||||
|
masks = torch.stack(masks) # stack masks, wanted by maskrcnn from torchvision
|
||||||
|
|
||||||
|
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||||
|
image_id = torch.tensor([idx])
|
||||||
|
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
|
||||||
|
|
||||||
|
target = {
|
||||||
|
"boxes": bboxes,
|
||||||
|
"labels": labels,
|
||||||
|
"masks": masks,
|
||||||
|
"area": area,
|
||||||
|
"image_id": image_id,
|
||||||
|
"iscrowd": iscrowd,
|
||||||
|
}
|
||||||
|
|
||||||
|
return image, target
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imgs)
|
return len(self.imgs)
|
||||||
|
|
|
@ -3,16 +3,17 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import wandb
|
|
||||||
from torchvision.models.detection._utils import Matcher
|
from torchvision.models.detection._utils import Matcher
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||||
from torchvision.ops.boxes import box_iou
|
from torchvision.ops.boxes import box_iou
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
def get_model_instance_segmentation(num_classes):
|
||||||
# load an instance segmentation model pre-trained on COCO
|
# load an instance segmentation model pre-trained on COCO
|
||||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
||||||
|
|
||||||
# get number of input features for the classifier
|
# get number of input features for the classifier
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
@ -42,82 +43,84 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# Network
|
# Network
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
def forward(self, imgs):
|
# def forward(self, imgs):
|
||||||
# Torchvision FasterRCNN returns the loss during training
|
# # Torchvision FasterRCNN returns the loss during training
|
||||||
# and the boxes during eval
|
# # and the boxes during eval
|
||||||
self.model.eval()
|
# self.model.eval()
|
||||||
return self.model(imgs)
|
# return self.model(imgs)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# unpack batch
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
# enable train mode
|
# enable train mode
|
||||||
self.model.train()
|
# self.model.train()
|
||||||
|
|
||||||
# fasterrcnn takes both images and targets for training
|
# fasterrcnn takes both images and targets for training
|
||||||
loss_dict = self.model(images, targets)
|
loss_dict = self.model(images, targets)
|
||||||
loss = sum(loss_dict.values())
|
loss = sum(loss_dict.values())
|
||||||
|
# self.log_dict(loss_dict)
|
||||||
|
# self.log(loss)
|
||||||
return {"loss": loss, "log": loss_dict}
|
return {"loss": loss, "log": loss_dict}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
# def validation_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# # unpack batch
|
||||||
images, targets = batch
|
# images, targets = batch
|
||||||
|
|
||||||
# enable eval mode
|
# # enable eval mode
|
||||||
self.detector.eval()
|
# # self.detector.eval()
|
||||||
|
|
||||||
# make a prediction
|
# # make a prediction
|
||||||
preds = self.detector(images)
|
# preds = self.model(images)
|
||||||
|
|
||||||
# compute validation loss
|
# # compute validation loss
|
||||||
self.val_loss = torch.mean(
|
# self.val_loss = torch.mean(
|
||||||
torch.stack(
|
# torch.stack(
|
||||||
[
|
# [
|
||||||
self.accuracy(
|
# self.accuracy(
|
||||||
target,
|
# target,
|
||||||
pred["boxes"],
|
# pred["boxes"],
|
||||||
iou_threshold=0.5,
|
# iou_threshold=0.5,
|
||||||
)
|
# )
|
||||||
for target, pred in zip(targets, preds)
|
# for target, pred in zip(targets, preds)
|
||||||
],
|
# ],
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
|
||||||
return self.val_loss
|
# return self.val_loss
|
||||||
|
|
||||||
def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
||||||
"""
|
# """
|
||||||
The accuracy method is not the one used in the evaluator but very similar
|
# The accuracy method is not the one used in the evaluator but very similar
|
||||||
"""
|
# """
|
||||||
total_gt = len(src_boxes)
|
# total_gt = len(src_boxes)
|
||||||
total_pred = len(pred_boxes)
|
# total_pred = len(pred_boxes)
|
||||||
if total_gt > 0 and total_pred > 0:
|
# if total_gt > 0 and total_pred > 0:
|
||||||
|
|
||||||
# Define the matcher and distance matrix based on iou
|
# # Define the matcher and distance matrix based on iou
|
||||||
matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
||||||
match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
||||||
|
|
||||||
results = matcher(match_quality_matrix)
|
# results = matcher(match_quality_matrix)
|
||||||
|
|
||||||
true_positive = torch.count_nonzero(results.unique() != -1)
|
# true_positive = torch.count_nonzero(results.unique() != -1)
|
||||||
matched_elements = results[results > -1]
|
# matched_elements = results[results > -1]
|
||||||
|
|
||||||
# in Matcher, a pred element can be matched only twice
|
# # in Matcher, a pred element can be matched only twice
|
||||||
false_positive = torch.count_nonzero(results == -1) + (
|
# false_positive = torch.count_nonzero(results == -1) + (
|
||||||
len(matched_elements) - len(matched_elements.unique())
|
# len(matched_elements) - len(matched_elements.unique())
|
||||||
)
|
# )
|
||||||
false_negative = total_gt - true_positive
|
# false_negative = total_gt - true_positive
|
||||||
|
|
||||||
return true_positive / (true_positive + false_positive + false_negative)
|
# return true_positive / (true_positive + false_positive + false_negative)
|
||||||
|
|
||||||
elif total_gt == 0:
|
# elif total_gt == 0:
|
||||||
if total_pred > 0:
|
# if total_pred > 0:
|
||||||
return torch.tensor(0.0).cuda()
|
# return torch.tensor(0.0).cuda()
|
||||||
else:
|
# else:
|
||||||
return torch.tensor(1.0).cuda()
|
# return torch.tensor(1.0).cuda()
|
||||||
elif total_gt > 0 and total_pred == 0:
|
# elif total_gt > 0 and total_pred == 0:
|
||||||
return torch.tensor(0.0).cuda()
|
# return torch.tensor(0.0).cuda()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.SGD(
|
optimizer = torch.optim.SGD(
|
||||||
|
@ -125,20 +128,22 @@ class MRCNNModule(pl.LightningModule):
|
||||||
lr=wandb.config.LEARNING_RATE,
|
lr=wandb.config.LEARNING_RATE,
|
||||||
momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
nesterov=wandb.config.NESTEROV,
|
|
||||||
)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
||||||
optimizer,
|
|
||||||
T_0=3,
|
|
||||||
T_mult=1,
|
|
||||||
lr=wandb.config.LEARNING_RATE_MIN,
|
|
||||||
verbose=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
"optimizer": optimizer,
|
# optimizer,
|
||||||
"lr_scheduler": {
|
# T_0=3,
|
||||||
"scheduler": scheduler,
|
# T_mult=1,
|
||||||
"monitor": "val_accuracy",
|
# lr=wandb.config.LEARNING_RATE_MIN,
|
||||||
},
|
# verbose=True,
|
||||||
}
|
# )
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# "optimizer": optimizer,
|
||||||
|
# "lr_scheduler": {
|
||||||
|
# "scheduler": scheduler,
|
||||||
|
# "monitor": "val_accuracy",
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
|
import wandb
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from mrcnn import MRCNNModule
|
from mrcnn import MRCNNModule
|
||||||
from unet import UNetModule
|
from unet import UNetModule
|
||||||
|
@ -58,10 +58,10 @@ if __name__ == "__main__":
|
||||||
precision=16,
|
precision=16,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
val_check_interval=100,
|
# val_check_interval=100,
|
||||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||||
# profiler="simple",
|
# profiler="simple",
|
||||||
# num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually train the model
|
# actually train the model
|
||||||
|
|
10
wandb.yaml
10
wandb.yaml
|
@ -17,11 +17,11 @@ AMP:
|
||||||
PIN_MEMORY:
|
PIN_MEMORY:
|
||||||
value: True
|
value: True
|
||||||
BENCHMARK:
|
BENCHMARK:
|
||||||
value: True
|
value: False
|
||||||
DEVICE:
|
DEVICE:
|
||||||
value: gpu
|
value: gpu
|
||||||
WORKERS:
|
WORKERS:
|
||||||
value: 8
|
value: 1
|
||||||
|
|
||||||
IMG_SIZE:
|
IMG_SIZE:
|
||||||
value: 512
|
value: 512
|
||||||
|
@ -31,11 +31,11 @@ SPHERES:
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 3
|
value: 3
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 128 # 100
|
value: 2 # 100
|
||||||
VAL_BATCH_SIZE:
|
VAL_BATCH_SIZE:
|
||||||
value: 8 # 10
|
value: 0 # 10
|
||||||
PREFETCH_FACTOR:
|
PREFETCH_FACTOR:
|
||||||
value: 2
|
value: 1
|
||||||
|
|
||||||
LEARNING_RATE:
|
LEARNING_RATE:
|
||||||
value: 1.0e-4
|
value: 1.0e-4
|
||||||
|
|
Loading…
Reference in a new issue