mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Torch hub
Former-commit-id: 4ce26d9d02760fead8fa34445f0e01acf9e4efef
This commit is contained in:
parent
8780e424b4
commit
c3c2675369
15
hubconf.py
Normal file
15
hubconf.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
from unet import UNet as _UNet
|
||||||
|
|
||||||
|
def unet_carvana(pretrained=False):
|
||||||
|
"""
|
||||||
|
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
|
||||||
|
Set the scale to 1 (100%) when predicting.
|
||||||
|
"""
|
||||||
|
net = _UNet(n_channels=3, n_classes=1, bilinear=True)
|
||||||
|
if pretrained:
|
||||||
|
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth'
|
||||||
|
net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
Loading…
Reference in a new issue