diff --git a/mace/data/__init__.py b/mace/data/__init__.py index b231cb12..b78e340d 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -1,6 +1,5 @@ from .atomic_data import AtomicData from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 -from .lmdb_dataset import LMDBDataset from .neighborhood import get_neighborhood from .utils import ( Configuration, @@ -29,8 +28,13 @@ "compute_average_E0s", "save_dataset_as_HDF5", "HDF5Dataset", - "LMDBDataset", "dataset_from_sharded_hdf5", "save_AtomicData_to_HDF5", "save_configurations_as_HDF5", ] + + +def _import_lmdb(): + global LMDBDataset + from .lmdb_dataset import LMDBDataset + __all__.append("LMDBDataset") diff --git a/mace/data/lmdb_dataset.py b/mace/data/lmdb_dataset.py index 5eb66656..2fbe1ee5 100644 --- a/mace/data/lmdb_dataset.py +++ b/mace/data/lmdb_dataset.py @@ -1,12 +1,16 @@ from mace.data.atomic_data import AtomicData from mace.data.utils import Configuration, config_from_atoms -from fairchem.core.datasets import AseDBDataset from torch.utils.data import Dataset from mace.tools.utils import AtomicNumberTable from ase.io.extxyz import save_calc_results import numpy as np import os +try: + from fairchem.core.datasets import AseDBDataset +except ImportError: + raise ImportError("Please install MACE with LMDB support: pip install mace[lmdb]") + class LMDBDataset(Dataset): def __init__(self, file_path, r_max, z_table, **kwargs): dataset_paths = file_path.split(":") # using : split multiple paths