From 16baefaa745405a511b317cf06a6714039b2d08c Mon Sep 17 00:00:00 2001 From: Xu Ma Date: Tue, 15 Feb 2022 21:32:53 -0500 Subject: [PATCH] automatically download dataset --- classification_ScanObjectNN/ScanObjectNN.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/classification_ScanObjectNN/ScanObjectNN.py b/classification_ScanObjectNN/ScanObjectNN.py index 66e11e7..b479141 100644 --- a/classification_ScanObjectNN/ScanObjectNN.py +++ b/classification_ScanObjectNN/ScanObjectNN.py @@ -12,7 +12,24 @@ from torch.utils.data import Dataset os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" +def download(): + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + DATA_DIR = os.path.join(BASE_DIR, 'data') + if not os.path.exists(DATA_DIR): + os.mkdir(DATA_DIR) + if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')): + # note that this link only contains the hardest perturbed variant (PB_T50_RS). + # for full versions, consider the following link. + www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip' + # www = 'http://103.24.77.34/scanobjectnn/h5_files.zip' + zipfile = os.path.basename(www) + os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) + os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) + os.system('rm %s' % (zipfile)) + + def load_scanobjectnn_data(partition): + download() BASE_DIR = os.path.dirname(os.path.abspath(__file__)) all_data = [] all_label = []