automatically download dataset
This commit is contained in:
parent
c9dbca9c7c
commit
16baefaa74
|
@ -12,7 +12,24 @@ from torch.utils.data import Dataset
|
||||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||||
|
|
||||||
|
|
||||||
|
def download():
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||||
|
if not os.path.exists(DATA_DIR):
|
||||||
|
os.mkdir(DATA_DIR)
|
||||||
|
if not os.path.exists(os.path.join(DATA_DIR, 'h5_files')):
|
||||||
|
# note that this link only contains the hardest perturbed variant (PB_T50_RS).
|
||||||
|
# for full versions, consider the following link.
|
||||||
|
www = 'https://web.northeastern.edu/smilelab/xuma/datasets/h5_files.zip'
|
||||||
|
# www = 'http://103.24.77.34/scanobjectnn/h5_files.zip'
|
||||||
|
zipfile = os.path.basename(www)
|
||||||
|
os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
|
||||||
|
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
|
||||||
|
os.system('rm %s' % (zipfile))
|
||||||
|
|
||||||
|
|
||||||
def load_scanobjectnn_data(partition):
|
def load_scanobjectnn_data(partition):
|
||||||
|
download()
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
all_data = []
|
all_data = []
|
||||||
all_label = []
|
all_label = []
|
||||||
|
|
Loading…
Reference in a new issue