From c3c26753691252b2e59980a5c91ab8dfbf0cb4b8 Mon Sep 17 00:00:00 2001 From: milesial Date: Wed, 29 Jul 2020 18:50:33 -0700 Subject: [PATCH] Torch hub Former-commit-id: 4ce26d9d02760fead8fa34445f0e01acf9e4efef --- hubconf.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 hubconf.py diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..db0591a --- /dev/null +++ b/hubconf.py @@ -0,0 +1,15 @@ +import torch +from unet import UNet as _UNet + +def unet_carvana(pretrained=False): + """ + UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ). + Set the scale to 1 (100%) when predicting. + """ + net = _UNet(n_channels=3, n_classes=1, bilinear=True) + if pretrained: + checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth' + net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) + + return net +