import torch from unet import UNet as _UNet def unet_carvana(pretrained=False, scale=0.5): """ UNet model trained on the Carvana dataset ( ). Set the scale to 0.5 (50%) when predicting. """ net = _UNet(n_channels=3, n_classes=2, bilinear=False) if pretrained: if scale == 0.5: checkpoint = '' elif scale == 1.0: checkpoint = '' else: raise RuntimeError('Only 0.5 and 1.0 scales are available') net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) return net