diff --git a/pytorch/data.py b/pytorch/data.py index d1e186e..24ffac4 100644 --- a/pytorch/data.py +++ b/pytorch/data.py @@ -36,7 +36,7 @@ def load_data(partition): all_data = [] all_label = [] for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): - f = h5py.File(h5_name) + f = h5py.File(h5_name, 'r') data = f['data'][:].astype('float32') label = f['label'][:].astype('int64') f.close()