automatically download dataset
This commit is contained in:
parent
c9dbca9c7c
commit
16baefaa74
|
@ -12,7 +12,24 @@ from torch.utils.data import Dataset
|
|||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||
|
||||
|
||||
def download():
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||
if not os.path.exists(DATA_DIR):
|
||||
os.mkdir(DATA_DIR)
|
||||
if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')):
|
||||
# note that this link only contains the hardest perturbed variant (PB_T50_RS).
|
||||
# for full versions, consider the following link.
|
||||
www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip'
|
||||
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
|
||||
zipfile = os.path.basename(www)
|
||||
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
|
||||
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
||||
os.system('rm %s' % (zipfile))
|
||||
|
||||
|
||||
def load_scanobjectnn_data(partition):
|
||||
download()
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
all_data = []
|
||||
all_label = []
|
||||
|
|
Loading…
Reference in a new issue