diff --git a/utils/dataset.py b/utils/dataset.py index bfba1a6..3afeca8 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -59,4 +59,7 @@ class BasicDataset(Dataset): img = self.preprocess(img, self.scale) mask = self.preprocess(mask, self.scale) - return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} + return { + 'image': torch.from_numpy(img).type(torch.FloatTensor), + 'mask': torch.from_numpy(mask).type(torch.FloatTensor) + }