Conversion to FloatTensor
Former-commit-id: cc44eebf19f5d98fbd1ca800608ae0b38d998b54
This commit is contained in:
parent
aa9defe6a5
commit
416e076dbc
|
@ -59,4 +59,7 @@ class BasicDataset(Dataset):
|
|||
img = self.preprocess(img, self.scale)
|
||||
mask = self.preprocess(mask, self.scale)
|
||||
|
||||
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
|
||||
return {
|
||||
'image': torch.from_numpy(img).type(torch.FloatTensor),
|
||||
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue