Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_mtp.py from_config() feature update #618

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
24 changes: 22 additions & 2 deletions maml/apps/pes/_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import re
import shutil
import subprocess
import warnings
from collections import OrderedDict
from shutil import which

import numpy as np
from monty.io import zopen
from monty.serialization import loadfn
from monty.tempfile import ScratchDir
from pymatgen.core import Lattice, Structure
from pymatgen.core import Element, Lattice, Structure

from maml.utils import check_structures_forces_stresses, convert_docs, pool_from

Expand Down Expand Up @@ -779,14 +780,16 @@ def evaluate(self, test_structures, test_energies, test_forces, test_stresses=No
return df_orig, df_predict

@staticmethod
def from_config(filename, elements):
def from_config(filename, elements, default_element_ordering=True):
"""
Initialize potentials with parameters file.

Args:
filename (str): The file storing parameters of potentials, filename should
ends with ".mtp".
elements (list): The list of elements.
default_element_ordering (bool): If True, elements argument is ordered following the
convention of Pauling electronegativity. If False, given order is kept.

Returns:
MTPotential
Expand All @@ -799,8 +802,25 @@ def from_config(filename, elements):
key = line.rstrip().split(" = ")[0]
value = json.loads(line.rstrip().split(" = ")[1].replace("{", "[").replace("}", "]"))
param[key] = value
num_species = -1
for line in lines:
if "species_count" in line:
num_species = int(line.split()[2])
break
if len(set(elements)) != num_species:
raise ValueError("Inconsistent number of species between the provided .mtp file and the elements argument")

mtp = MTPotential(param=param)
if default_element_ordering:
ordered_elements = [str(x) for x in sorted([Element(x) for x in elements])]
if elements != ordered_elements:
warnings.warn(
f"Order for the elements has been altered from {elements} to {ordered_elements} to ensure "
"consistency with default element ordering in maml during MTP fitting. Change the "
"'default_element_ordering' argument to keep original order.",
ImportWarning,
)
elements = ordered_elements
mtp.elements = elements

return mtp
4 changes: 2 additions & 2 deletions maml/utils/_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import ceil, floor
from typing import Callable
from typing import Callable, Any

import numpy as np
from monty.dev import requires
Expand Down Expand Up @@ -81,7 +81,7 @@ def wvd(z: np.ndarray, return_all: bool = False) -> tuple | np.ndarray:
AVAILABLE_SP_METHODS = {"fft_magnitude": fft_magnitude, "spectrogram": spectrogram, "cwt": cwt, "wvd": wvd}


def get_sp_method(sp_method: str | Callable) -> Callable: # type: ignore
def get_sp_method(sp_method: str | Callable[..., Any]) -> Callable[..., Any]: # type: ignore
"""
Providing a signal processing method name return the callable
Args:
Expand Down
Loading
Loading