Update torchhub
Former-commit-id: 7a89332de99777c24419d8ed4154bb1963ddc8f5
This commit is contained in:
parent
fdb686fc43
commit
1c450fa978
|
@ -4,11 +4,11 @@ from unet import UNet as _UNet
|
|||
def unet_carvana(pretrained=False):
|
||||
"""
|
||||
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
|
||||
Set the scale to 1 (100%) when predicting.
|
||||
Set the scale to 0.5 (50%) when predicting.
|
||||
"""
|
||||
net = _UNet(n_channels=3, n_classes=1, bilinear=True)
|
||||
net = _UNet(n_channels=3, n_classes=2, bilinear=True)
|
||||
if pretrained:
|
||||
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth'
|
||||
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v2.0/unet_carvana_scale0.5_epoch1.pth'
|
||||
net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
|
||||
|
||||
return net
|
||||
|
|
Loading…
Reference in a new issue