diff --git a/predict.py b/predict.py index 16b7b1c..cb3c1c2 100755 --- a/predict.py +++ b/predict.py @@ -23,7 +23,7 @@ def predict_img(net, net.eval() ds = BasicDataset('', '', scale=scale_factor) - img = ds.preprocess(full_img) + img = torch.from_numpy(ds.preprocess(full_img)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32)