mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
Fix torchhub
Former-commit-id: 9a6ee84c50390306e24bb369d1c24d2fdd9d7e60
This commit is contained in:
parent
f55693b6a4
commit
2ca43802cc
|
@ -155,16 +155,16 @@ You can specify which model file to use with `--model MODEL.pth`.
|
||||||
The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.
|
The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.
|
||||||
|
|
||||||
When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it
|
When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it
|
||||||
by setting the `WANDB_API_KEY` environment variable.
|
by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days.
|
||||||
|
|
||||||
|
|
||||||
## Pretrained model
|
## Pretrained model
|
||||||
A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v2.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
|
A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True)
|
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)
|
||||||
```
|
```
|
||||||
The training was done with a 50% scale and bilinear upsampling.
|
Available scales are 0.5 and 1.0.
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
||||||
|
|
12
hubconf.py
12
hubconf.py
|
@ -1,14 +1,20 @@
|
||||||
import torch
|
import torch
|
||||||
from unet import UNet as _UNet
|
from unet import UNet as _UNet
|
||||||
|
|
||||||
def unet_carvana(pretrained=False):
|
def unet_carvana(pretrained=False, scale=0.5):
|
||||||
"""
|
"""
|
||||||
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
|
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
|
||||||
Set the scale to 0.5 (50%) when predicting.
|
Set the scale to 0.5 (50%) when predicting.
|
||||||
"""
|
"""
|
||||||
net = _UNet(n_channels=3, n_classes=2, bilinear=True)
|
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v2.0/unet_carvana_scale0.5_epoch1.pth'
|
if scale == 0.5:
|
||||||
|
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
|
||||||
|
elif scale == 1.0:
|
||||||
|
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Only 0.5 and 1.0 scales are available')
|
||||||
|
|
||||||
net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
|
net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
|
@ -57,6 +57,7 @@ def get_args():
|
||||||
help='Minimum probability value to consider a mask pixel white')
|
help='Minimum probability value to consider a mask pixel white')
|
||||||
parser.add_argument('--scale', '-s', type=float, default=0.5,
|
parser.add_argument('--scale', '-s', type=float, default=0.5,
|
||||||
help='Scale factor for the input images')
|
help='Scale factor for the input images')
|
||||||
|
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ if __name__ == '__main__':
|
||||||
in_files = args.input
|
in_files = args.input
|
||||||
out_files = get_output_filenames(args)
|
out_files = get_output_filenames(args)
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=2)
|
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
logging.info(f'Loading model {args.model}')
|
logging.info(f'Loading model {args.model}')
|
||||||
|
|
7
train.py
7
train.py
|
@ -25,7 +25,7 @@ def train_net(net,
|
||||||
device,
|
device,
|
||||||
epochs: int = 5,
|
epochs: int = 5,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 1e-5,
|
||||||
val_percent: float = 0.1,
|
val_percent: float = 0.1,
|
||||||
save_checkpoint: bool = True,
|
save_checkpoint: bool = True,
|
||||||
img_scale: float = 0.5,
|
img_scale: float = 0.5,
|
||||||
|
@ -147,13 +147,14 @@ def get_args():
|
||||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
|
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
|
||||||
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
|
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
|
||||||
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
|
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
|
||||||
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
|
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
|
||||||
help='Learning rate', dest='lr')
|
help='Learning rate', dest='lr')
|
||||||
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
|
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
|
||||||
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
|
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
|
||||||
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
|
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
|
||||||
help='Percent of the data that is used as validation (0-100)')
|
help='Percent of the data that is used as validation (0-100)')
|
||||||
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
|
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
|
||||||
|
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -168,7 +169,7 @@ if __name__ == '__main__':
|
||||||
# Change here to adapt to your data
|
# Change here to adapt to your data
|
||||||
# n_channels=3 for RGB images
|
# n_channels=3 for RGB images
|
||||||
# n_classes is the number of probabilities you want to get per pixel
|
# n_classes is the number of probabilities you want to get per pixel
|
||||||
net = UNet(n_channels=3, n_classes=2, bilinear=True)
|
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
|
||||||
|
|
||||||
logging.info(f'Network:\n'
|
logging.info(f'Network:\n'
|
||||||
f'\t{net.n_channels} input channels\n'
|
f'\t{net.n_channels} input channels\n'
|
||||||
|
|
|
@ -4,7 +4,7 @@ from .unet_parts import *
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, n_channels, n_classes, bilinear=True):
|
def __init__(self, n_channels, n_classes, bilinear=False):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
|
|
Loading…
Reference in a new issue