Skip to content

Commit

Permalink
Under migration from _phono3py.c to _phono3py.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Jul 8, 2024
1 parent 4afa5a7 commit f502046
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
146 changes: 145 additions & 1 deletion c/_phono3py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ static Darray *convert_to_darray(nb::ndarray<> npyary) {
return ary;
}

// static void show_colmat_info(const PyArrayObject *py_collision_matrix,
// const long i_sigma, const long i_temp,
// const long adrs_shift) {
// long i;

// printf(" Array_shape:(");
// for (i = 0; i < PyArray_NDIM(py_collision_matrix); i++) {
// printf("%d", (int)PyArray_DIM(py_collision_matrix, i));
// if (i < PyArray_NDIM(py_collision_matrix) - 1) {
// printf(",");
// } else {
// printf("), ");
// }
// }
// printf("Data shift:%lu [%lu, %lu]\n", adrs_shift, i_sigma, i_temp);
// }

void py_get_interaction(nb::ndarray<> py_fc3_normal_squared,
nb::ndarray<> py_g_zero, nb::ndarray<> py_frequencies,
nb::ndarray<> py_eigenvectors,
Expand Down Expand Up @@ -1008,7 +1025,128 @@ long py_get_bz_grid_addresses(nb::ndarray<> py_bz_grid_addresses,
return num_total_gp;
}

NB_MODULE(_phonopy, m) {
long py_rotate_bz_grid_addresses(long bz_grid_index, nb::ndarray<> py_rotation,
nb::ndarray<> py_bz_grid_addresses,
nb::ndarray<> py_bz_map,
nb::ndarray<> py_D_diag, nb::ndarray<> py_PS,
long type) {
long(*bz_grid_addresses)[3];
long(*rotation)[3];
long *bz_map;
long *D_diag;
long *PS;
long ret_bz_gp;

bz_grid_addresses = (long(*)[3])py_bz_grid_addresses.data();
rotation = (long(*)[3])py_rotation.data();
bz_map = (long *)py_bz_map.data();
D_diag = (long *)py_D_diag.data();
PS = (long *)py_PS.data();

ret_bz_gp = ph3py_rotate_bz_grid_index(
bz_grid_index, rotation, bz_grid_addresses, bz_map, D_diag, PS, type);

return ret_bz_gp;
}

long py_diagonalize_collision_matrix(nb::ndarray<> py_collision_matrix,
nb::ndarray<> py_eigenvalues, long i_sigma,
long i_temp, double cutoff, long solver,
long is_pinv) {
double *collision_matrix;
double *eigvals;
long num_temp, num_grid_point, num_band;
long num_column, adrs_shift;
long info;

collision_matrix = (double *)py_collision_matrix.data();
eigvals = (double *)py_eigenvalues.data();

if (py_collision_matrix.ndim() == 2) {
num_temp = 1;
num_column = py_collision_matrix.shape(1);
} else {
num_temp = py_collision_matrix.shape(1);
num_grid_point = py_collision_matrix.shape(2);
num_band = py_collision_matrix.shape(3);
if (py_collision_matrix.ndim() == 8) {
num_column = num_grid_point * num_band * 3;
} else {
num_column = num_grid_point * num_band;
}
}
adrs_shift = (i_sigma * num_column * num_column * num_temp +
i_temp * num_column * num_column);

/* show_colmat_info(py_collision_matrix, i_sigma, i_temp, adrs_shift); */

info = ph3py_phonopy_dsyev(collision_matrix + adrs_shift, eigvals,
num_column, solver);
if (is_pinv) {
ph3py_pinv_from_eigensolution(collision_matrix + adrs_shift, eigvals,
num_column, cutoff, 0);
}

return info;
}

void py_pinv_from_eigensolution(nb::ndarray<> py_collision_matrix,
nb::ndarray<> py_eigenvalues, long i_sigma,
long i_temp, double cutoff, long pinv_method) {
double *collision_matrix;
double *eigvals;
long num_temp, num_grid_point, num_band;
long num_column, adrs_shift;

collision_matrix = (double *)py_collision_matrix.data();
eigvals = (double *)py_eigenvalues.data();
num_temp = py_collision_matrix.shape(1);
num_grid_point = py_collision_matrix.shape(2);
num_band = py_collision_matrix.shape(3);

if (py_collision_matrix.ndim() == 8) {
num_column = num_grid_point * num_band * 3;
} else {
num_column = num_grid_point * num_band;
}
adrs_shift = (i_sigma * num_column * num_column * num_temp +
i_temp * num_column * num_column);

/* show_colmat_info(py_collision_matrix, i_sigma, i_temp, adrs_shift); */

ph3py_pinv_from_eigensolution(collision_matrix + adrs_shift, eigvals,
num_column, cutoff, pinv_method);
}

long py_get_default_colmat_solver() {
#if defined(MKL_LAPACKE) || defined(SCIPY_MKL_H)
return (long)1;
#else
return (long)4;
#endif
}

long py_lapacke_pinv(nb::ndarray<> data_out_py, nb::ndarray<> data_in_py,
double cutoff) {
long m;
long n;
double *data_in;
double *data_out;
long info;

m = data_in_py.shape(0);
n = data_in_py.shape(1);
data_in = (double *)data_in_py.data();
data_out = (double *)data_out_py.data();

info = ph3py_phonopy_pinv(data_out, data_in, m, n, cutoff);

return info;
}

long py_get_omp_max_threads() { return ph3py_get_max_threads(); }

NB_MODULE(_phono3py, m) {
m.def("interaction", &py_get_interaction);
m.def("pp_collision", &py_get_pp_collision);
m.def("pp_collision_with_sigma", &py_get_pp_collision_with_sigma);
Expand Down Expand Up @@ -1048,4 +1186,10 @@ NB_MODULE(_phonopy, m) {
m.def("transform_rotations", &py_transform_rotations);
m.def("snf3x3", &py_get_snf3x3);
m.def("bz_grid_addresses", &py_get_bz_grid_addresses);
m.def("rotate_bz_grid_index", &py_rotate_bz_grid_addresses);
m.def("diagonalize_collision_matrix", &py_diagonalize_collision_matrix);
m.def("pinv_from_eigensolution", &py_pinv_from_eigensolution);
m.def("default_colmat_solver", &py_get_default_colmat_solver);
m.def("lapacke_pinv", &py_lapacke_pinv);
m.def("omp_max_threads", &py_get_omp_max_threads);
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"matplotlib>=2.2.2",
"h5py>=3.0",
"spglib>=2.3",
"phonopy>=2.25,<2.26",
"phonopy>=2.26,<2.27",
]
license = { file = "LICENSE" }

Expand Down

0 comments on commit f502046

Please sign in to comment.