Conversion to FloatTensor

Former-commit-id: cc44eebf19f5d98fbd1ca800608ae0b38d998b54
This commit is contained in:
Louis Lac 2020-06-05 19:00:04 +02:00
parent aa9defe6a5
commit 416e076dbc

View file

@ -59,4 +59,7 @@ class BasicDataset(Dataset):
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
return {
'image': torch.from_numpy(img).type(torch.FloatTensor),
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
}