2020-07-30 01:50:33 +00:00
|
|
|
import torch
|
|
|
|
from unet import UNet as _UNet
|
|
|
|
|
2022-02-19 04:01:54 +00:00
|
|
|
def unet_carvana(pretrained=False, scale=0.5):
|
2020-07-30 01:50:33 +00:00
|
|
|
"""
|
|
|
|
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
|
2021-08-19 09:14:08 +00:00
|
|
|
Set the scale to 0.5 (50%) when predicting.
|
2020-07-30 01:50:33 +00:00
|
|
|
"""
|
2022-02-19 04:01:54 +00:00
|
|
|
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
|
2020-07-30 01:50:33 +00:00
|
|
|
if pretrained:
|
2022-02-19 04:01:54 +00:00
|
|
|
if scale == 0.5:
|
|
|
|
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
|
|
|
|
elif scale == 1.0:
|
|
|
|
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
|
|
|
|
else:
|
|
|
|
raise RuntimeError('Only 0.5 and 1.0 scales are available')
|
|
|
|
|
2020-07-30 01:50:33 +00:00
|
|
|
net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
|
|
|
|
|
|
|
|
return net
|
|
|
|
|