Skip to content

Commit

Permalink
PR: ML Potential Allegro fix
Browse files Browse the repository at this point in the history
1. now your order of pseudo atoms will automatically match to the Allegro model's element list without sorting the order beforehand
2. some small update to the compilation file
  • Loading branch information
Zhaoli2042 authored Aug 21, 2024
2 parents 0c5b63a + 2bc2e1b commit 11548d0
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 13 deletions.
2 changes: 1 addition & 1 deletion NVC_COMPILE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CXX="nvc++"

LINKFLAG="-L/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/lib64 -L/usr/lib64/ -L/opt/local/lib/gcc11/"

NVCFLAG="-O3 -std=c++20 -gpu=cc86 -Minline -fopenmp -cuda -stdpar=multicore"
NVCFLAG="-O3 -std=c++20 -target=gpu -Minline -fopenmp -cuda -stdpar=multicore"

$CXX $NVCFLAG $LINKFLAG -c axpy.cu

Expand Down
2 changes: 1 addition & 1 deletion cppflow_NVC_COMPILE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CXX="nvc++"

LINKFLAG="-ltensorflow -L/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/lib64 -L/usr/lib64/ -L/opt/local/lib/gcc11/"

NVCFLAG="-O3 -std=c++20 -gpu=cc86 -Minline -fopenmp -cuda -stdpar=multicore"
NVCFLAG="-O3 -std=c++20 -target=gpu -Minline -fopenmp -cuda -stdpar=multicore"

$CXX $NVCFLAG $LINKFLAG -c axpy.cu

Expand Down
2 changes: 2 additions & 0 deletions libtorch-patch/Allegro/PATCH_ALLEGRO_main.cpp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ PATCH_ALLEGRO_MAIN_PREP
printf("Setting up Allegro model\n");
SystemComponents[a].DNN.ReadModel(SystemComponents[a].ModelName[0]);
printf("DONE Reading the model, model name %s\n", SystemComponents[a].ModelName[0].c_str());
SystemComponents[a].DNN.Match_Element_PseudoAtom_with_model(SystemComponents[a].PseudoAtoms);

SystemComponents[a].DNN.UCAtoms.resize(SystemComponents[a].NComponents.x);
SystemComponents[a].DNN.ReplicaAtoms.resize(SystemComponents[a].NComponents.x);

Expand Down
2 changes: 1 addition & 1 deletion libtorch_NVC_COMPILE
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ LINKFLAG="-D_GLIBCXX_USE_CXX11_ABI=1 -L${torchDir}/lib -I${torchDir}/include/ -I
#LINKFLAG="-ltensorflow -L/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/lib64 -L/usr/lib64/ -L/opt/local/lib/gcc11/"

#NVCFLAG="-O3 -std=c++20 -gpu=cc86 -Minline -fopenmp -Minfo -DUSE_DOUBLE -cuda"
NVCFLAG="-O3 -std=c++20 -gpu=cc86 -Minline -fopenmp -cuda -stdpar=multicore"
NVCFLAG="-O3 -std=c++20 -target=gpu -Minline -fopenmp -cuda -stdpar=multicore"

$CXX $NVCFLAG $LINKFLAG -c axpy.cu

Expand Down
11 changes: 6 additions & 5 deletions src_clean/data_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,18 +774,19 @@ struct FRAMEWORK_COMPONENT_LISTS
struct PseudoAtomDefinitions //Always a host struct, never on the device
{
std::vector<std::string> Name;
std::vector<std::string> Symbol;
std::vector<std::string> Symbol; //Symbol name for each pseudo-atom
std::vector<std::string> UniqueSymbol; //all the unique Symbol list //
std::vector<size_t> SymbolIndex; //It has the size of the number of pseudo atoms, it tells the ID of the symbol for the pseudo-atoms, e.g., CO2->C->2
std::vector<double> oxidation;
std::vector<double> mass;
std::vector<double> charge;
std::vector<double> polar; //polarizability
size_t MatchSymbolTypeFromSymbolName(std::string& SymbolName)
size_t MatchUniqueSymbolTypeFromSymbolName(std::string& SymbolName)
{
size_t SymbolIdx = 0;
for(size_t i = 0; i < Symbol.size(); i++)
size_t SymbolIdx = UniqueSymbol.size();
for(size_t i = 0; i < UniqueSymbol.size(); i++)
{
if(SymbolName == Symbol[i])
if(SymbolName == UniqueSymbol[i])
{
SymbolIdx = i; break;
}
Expand Down
7 changes: 5 additions & 2 deletions src_clean/read_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,9 +1081,12 @@ void PseudoAtomParser(ForceField& FF, PseudoAtomDefinitions& PseudoAtom)
{
Split_Tab_Space(termsScannedLined, str);
if(termsScannedLined[0] != PseudoAtom.Name[counter-3]) throw std::runtime_error("Order of pseudo-atom and force field definition don't match!");
PseudoAtom.Symbol.push_back(termsScannedLined[2]);

//Match 1-to-1 list of pseudo_atom type and symbol type//
size_t SymbolIdx = PseudoAtom.MatchSymbolTypeFromSymbolName(termsScannedLined[2]);
PseudoAtom.Symbol.push_back(termsScannedLined[2]);
size_t SymbolIdx = PseudoAtom.MatchUniqueSymbolTypeFromSymbolName(termsScannedLined[2]);
if(SymbolIdx >= PseudoAtom.UniqueSymbol.size()) PseudoAtom.UniqueSymbol.push_back(termsScannedLined[2]);

PseudoAtom.SymbolIndex.push_back(SymbolIdx);
PseudoAtom.oxidation.push_back(std::stod(termsScannedLined[4]));
PseudoAtom.mass.push_back(std::stod(termsScannedLined[5]));
Expand Down
34 changes: 31 additions & 3 deletions src_clean/torch_allegro.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct Allegro
Boxsize ReplicaBox;
std::vector<Atoms> UCAtoms;
std::vector<Atoms> ReplicaAtoms;
std::vector<std::string> ElementSymbolUsed;
std::vector<int>Match_AllegroElement_PseudoAtom_order; //length = # of PseudoAtoms, value stored = order in the Allegro Model//
double Cutoff = 6.0;
double Cutoffsq = 0.0;
NeighList NL;
Expand All @@ -15,6 +17,22 @@ struct Allegro

size_t nstep = 0;

void Match_Element_PseudoAtom_with_model(PseudoAtomDefinitions& PseudoAtoms)
{
printf("------- MATCHING ALLEGRO ELEMENTS WITH PSEUDO ATOM ELEMENT SYMBOLS -------\n");
Match_AllegroElement_PseudoAtom_order.resize(PseudoAtoms.Symbol.size(), -1);
for(size_t i = 0; i < PseudoAtoms.Symbol.size(); i++)
{
for(size_t j = 0; j < ElementSymbolUsed.size(); j++)
if(PseudoAtoms.Symbol[i] == ElementSymbolUsed[j])
{
Match_AllegroElement_PseudoAtom_order[i] = j;
printf("PseudoAtom Symbol[%zu]: %s, Allegro Symbol[%zu]: %s, MATCHED\n", i, PseudoAtoms.Symbol[i].c_str(), j, ElementSymbolUsed[j].c_str());
break;
}
}
}

void GetSQ_From_Cutoff()
{
Cutoffsq = Cutoff * Cutoff;
Expand Down Expand Up @@ -190,9 +208,9 @@ struct Allegro
if(comp != 0)
if(!ConsiderThisAdsorbateAtom[i]) continue;
UCAtoms[comp].pos[update_i] = HostAtoms.pos[i];
size_t SymbolIdx = PseudoAtoms.GetSymbolIdxFromPseudoAtomTypeIdx(HostAtoms.Type[i]);
size_t SymbolIdx = Match_AllegroElement_PseudoAtom_order[HostAtoms.Type[i]];
UCAtoms[comp].Type[update_i] = SymbolIdx;
//printf("Component %zu, Atom %zu, xyz %f %f %f, Type %zu, SymbolIndex %zu\n", comp, i, UCAtoms[comp].pos[i].x, UCAtoms[comp].pos[i].y, UCAtoms[comp].pos[i].z, HostAtoms.Type[i], UCAtoms[comp].Type[i]);
if(i < 5 || i > (NAtoms - 5)) printf("Component %zu, Atom %zu, xyz %f %f %f, Type %zu, SymbolIndex %zu\n", comp, i, UCAtoms[comp].pos[i].x, UCAtoms[comp].pos[i].y, UCAtoms[comp].pos[i].z, HostAtoms.Type[i], UCAtoms[comp].Type[i]);
update_i ++;
}
}
Expand Down Expand Up @@ -239,6 +257,15 @@ struct Allegro
Model.eval();
//Freeze the model
Model = torch::jit::freeze(Model);
std::cout << "MODEL TYPE NAMES " << metadata["type_names"] << "\n";
std::cout << "MODEL SPECIES " << metadata["n_species"] << "\n";
//std::string name = std::to_string(metadata["type_names"]);
std::string name = metadata["type_names"];
std::vector<std::string> termsScannedLined{};
Split_Tab_Space(termsScannedLined, name);
ElementSymbolUsed = termsScannedLined;
printf("First element of type: %s, first: %s\n", name.c_str(), ElementSymbolUsed[0].c_str());
//printf("Model");
//ReadCutOffFromModel(ModelName);
}

Expand All @@ -265,7 +292,7 @@ struct Allegro
pos[counter][2] = ReplicaAtoms[comp].pos[i].z;
ij2type[counter]= ReplicaAtoms[comp].Type[i];
//if(comp != 0)
// printf("comp %zu, counter %zu, ij2type %zu\n", comp, counter, ij2type[counter]);
//printf("comp %zu, counter %zu, ij2type %zu\n", comp, counter, ij2type[counter]);
counter ++;
}
size_t N_Replica_FrameworkAtoms = 0; size_t NFrameworkAtoms = 0;
Expand Down Expand Up @@ -551,6 +578,7 @@ struct Allegro
WrapSuperCellAtomIntoUCBox(comp);
GenerateReplicaCells(Initialize);
Get_Neighbor_List_Replica(Initialize);
//printf("Doing Neighbor list\n");
double DNN_E = Predict();
//This generates the unit of eV, convert to 10J/mol.
//https://www.weizmann.ac.il/oc/martin/tools/hartree.html
Expand Down

0 comments on commit 11548d0

Please sign in to comment.