From 9826c0ac9ed542dbc93255dac1e66bbdd9d50411 Mon Sep 17 00:00:00 2001 From: William Dawson Date: Thu, 18 Apr 2024 16:27:43 +0900 Subject: [PATCH] Fix test conditions to handle the case an edge case --- Source/C/PSMatrix_c.h | 3 ++- Source/CPlusPlus/PSMatrix.cc | 7 +++++- Source/CPlusPlus/PSMatrix.h | 3 +++ Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 | 25 ++++++++++++++++---- UnitTests/test_psmatrixalgebra.py | 12 ++++++---- 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/Source/C/PSMatrix_c.h b/Source/C/PSMatrix_c.h index c07bd7a0..a211fe73 100644 --- a/Source/C/PSMatrix_c.h +++ b/Source/C/PSMatrix_c.h @@ -62,6 +62,7 @@ void ScaleMatrix_ps_wrp(int *ih_this, const double *constant); double MatrixNorm_ps_wrp(const int *ih_this); void MatrixTrace_ps_wrp(const int *ih_this, double *trace_val); int IsIdentity_ps_wrp(const int *ih_this); -void MatrixDiagonalScale_ps_wrp(int *ih_mat, const int *ih_tlist); +void MatrixDiagonalScale_psr_wrp(int *ih_mat, const int *ih_tlist); +void MatrixDiagonalScale_psc_wrp(int *ih_mat, const int *ih_tlist); #endif diff --git a/Source/CPlusPlus/PSMatrix.cc b/Source/CPlusPlus/PSMatrix.cc index d421c4df..05f0d593 100644 --- a/Source/CPlusPlus/PSMatrix.cc +++ b/Source/CPlusPlus/PSMatrix.cc @@ -219,7 +219,12 @@ void Matrix_ps::Scale(double constant) { ////////////////////////////////////////////////////////////////////////////// void Matrix_ps::DiagonalScale(const TripletList_r &tlist) { - MatrixDiagonalScale_ps_wrp(ih_this, tlist.ih_this); + MatrixDiagonalScale_psr_wrp(ih_this, tlist.ih_this); +} + +////////////////////////////////////////////////////////////////////////////// +void Matrix_ps::DiagonalScale(const TripletList_c &tlist) { + MatrixDiagonalScale_psc_wrp(ih_this, tlist.ih_this); } ////////////////////////////////////////////////////////////////////////////// diff --git a/Source/CPlusPlus/PSMatrix.h b/Source/CPlusPlus/PSMatrix.h index d6e3bc90..8aba171a 100644 --- a/Source/CPlusPlus/PSMatrix.h +++ b/Source/CPlusPlus/PSMatrix.h @@ -170,6 +170,9 @@ class Matrix_ps { //!\param tlist the triplet list. //!\param threshold for flushing small values. void DiagonalScale(const NTPoly::TripletList_r &tlist); + //!\param tlist the triplet list. + //!\param threshold for flushing small values. + void DiagonalScale(const NTPoly::TripletList_c &tlist); public: //! Destructor. diff --git a/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 b/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 index e333d6ac..f57aa77a 100644 --- a/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 +++ b/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 @@ -6,7 +6,7 @@ MODULE PSMatrixAlgebraModule_wrp USE PSMatrixAlgebraModule USE PSMatrixModule_wrp, ONLY : Matrix_ps_wrp USE PermutationModule_wrp, ONLY : Permutation_wrp - USE TripletListModule_wrp, ONLY : TripletList_r_wrp + USE TripletListModule_wrp, ONLY : TripletList_r_wrp, TripletList_c_wrp USE WrapperModule, ONLY : SIZE_wrp USE ISO_C_BINDING, ONLY : c_int, c_char, c_bool IMPLICIT NONE @@ -20,7 +20,8 @@ MODULE PSMatrixAlgebraModule_wrp PUBLIC :: ScaleMatrix_ps_wrp PUBLIC :: MatrixNorm_ps_wrp PUBLIC :: MatrixTrace_ps_wrp - PUBLIC :: MatrixDiagonalScale_ps_wrp + PUBLIC :: MatrixDiagonalScale_psr_wrp + PUBLIC :: MatrixDiagonalScale_psc_wrp CONTAINS!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !> Matrix B = alpha*Matrix A + Matrix B (AXPY) SUBROUTINE IncrementMatrix_ps_wrp(ih_matA, ih_matB, alpha_in,threshold_in) & @@ -146,8 +147,8 @@ SUBROUTINE MatrixTrace_ps_wrp(ih_this, trace_value) & END SUBROUTINE MatrixTrace_ps_wrp !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !> Scale a matrix using a diagonal matrix (triplet list form). - SUBROUTINE MatrixDiagonalScale_ps_wrp(ih_mat, ih_tlist) & - & BIND(c,name="MatrixDiagonalScale_ps_wrp") + SUBROUTINE MatrixDiagonalScale_psr_wrp(ih_mat, ih_tlist) & + & BIND(c,name="MatrixDiagonalScale_psr_wrp") INTEGER(kind=c_int), INTENT(INOUT) :: ih_mat(SIZE_wrp) INTEGER(kind=c_int), INTENT(IN) :: ih_tlist(SIZE_wrp) TYPE(Matrix_ps_wrp) :: h_mat @@ -157,6 +158,20 @@ SUBROUTINE MatrixDiagonalScale_ps_wrp(ih_mat, ih_tlist) & h_tlist = TRANSFER(ih_tlist, h_tlist) CALL MatrixDiagonalScale(h_mat%DATA, h_tlist%DATA) - END SUBROUTINE MatrixDiagonalScale_ps_wrp + END SUBROUTINE MatrixDiagonalScale_psr_wrp +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + !> Scale a matrix using a diagonal matrix (triplet list form). + SUBROUTINE MatrixDiagonalScale_psc_wrp(ih_mat, ih_tlist) & + & BIND(c,name="MatrixDiagonalScale_psc_wrp") + INTEGER(kind=c_int), INTENT(INOUT) :: ih_mat(SIZE_wrp) + INTEGER(kind=c_int), INTENT(IN) :: ih_tlist(SIZE_wrp) + TYPE(Matrix_ps_wrp) :: h_mat + TYPE(TripletList_c_wrp) :: h_tlist + + h_mat = TRANSFER(ih_mat, h_mat) + h_tlist = TRANSFER(ih_tlist, h_tlist) + + CALL MatrixDiagonalScale(h_mat%DATA, h_tlist%DATA) + END SUBROUTINE MatrixDiagonalScale_psc_wrp !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! END MODULE PSMatrixAlgebraModule_wrp diff --git a/UnitTests/test_psmatrixalgebra.py b/UnitTests/test_psmatrixalgebra.py index fea5551c..052dff8c 100644 --- a/UnitTests/test_psmatrixalgebra.py +++ b/UnitTests/test_psmatrixalgebra.py @@ -292,6 +292,10 @@ def test_scalediag(self): self.write_matrix(matrix, self.input_file1) self.CheckMat = deepcopy(matrix) + # Need a guard because otherwise we write a real matrix by mistake. + if self.complex1 and matrix.nnz == 0: + continue + tlist = self.TripletList() for i in range(matrix.shape[1]): t = self.Triplet() @@ -393,8 +397,8 @@ class TestPSMatrixAlgebra_rc(TestPSMatrixAlgebra, unittest.TestCase): complex1 = True # Whether the second matrix is complex or not complex2 = False - TripletList = nt.TripletList_r - Triplet = nt.Triplet_r + TripletList = nt.TripletList_c + Triplet = nt.Triplet_c def setUp(self): '''Set up a specific test.''' @@ -410,8 +414,8 @@ class TestPSMatrixAlgebra_cr(TestPSMatrixAlgebra, unittest.TestCase): complex1 = False # Whether the second matrix is complex or not complex2 = True - TripletList = nt.TripletList_c - Triplet = nt.Triplet_c + TripletList = nt.TripletList_r + Triplet = nt.Triplet_r def setUp(self): '''Set up a specific test.'''