From 1c450fa978f16424ff75c52df4da1e144e854283 Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 19 Aug 2021 11:14:08 +0200 Subject: [PATCH] Update torchhub Former-commit-id: 7a89332de99777c24419d8ed4154bb1963ddc8f5 --- hubconf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hubconf.py b/hubconf.py index db0591a..fa6f7a0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -4,11 +4,11 @@ from unet import UNet as _UNet def unet_carvana(pretrained=False): """ UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ). - Set the scale to 1 (100%) when predicting. + Set the scale to 0.5 (50%) when predicting. """ - net = _UNet(n_channels=3, n_classes=1, bilinear=True) + net = _UNet(n_channels=3, n_classes=2, bilinear=True) if pretrained: - checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth' + checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v2.0/unet_carvana_scale0.5_epoch1.pth' net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) return net