Skip to content

Commit

Permalink
Fix test conditions to handle the case an edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
william-dawson committed Apr 18, 2024
1 parent 3f0a31a commit 9826c0a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
3 changes: 2 additions & 1 deletion Source/C/PSMatrix_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void ScaleMatrix_ps_wrp(int *ih_this, const double *constant);
double MatrixNorm_ps_wrp(const int *ih_this);
void MatrixTrace_ps_wrp(const int *ih_this, double *trace_val);
int IsIdentity_ps_wrp(const int *ih_this);
void MatrixDiagonalScale_ps_wrp(int *ih_mat, const int *ih_tlist);
void MatrixDiagonalScale_psr_wrp(int *ih_mat, const int *ih_tlist);
void MatrixDiagonalScale_psc_wrp(int *ih_mat, const int *ih_tlist);

#endif
7 changes: 6 additions & 1 deletion Source/CPlusPlus/PSMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ void Matrix_ps::Scale(double constant) {

//////////////////////////////////////////////////////////////////////////////
void Matrix_ps::DiagonalScale(const TripletList_r &tlist) {
MatrixDiagonalScale_ps_wrp(ih_this, tlist.ih_this);
MatrixDiagonalScale_psr_wrp(ih_this, tlist.ih_this);
}

//////////////////////////////////////////////////////////////////////////////
void Matrix_ps::DiagonalScale(const TripletList_c &tlist) {
MatrixDiagonalScale_psc_wrp(ih_this, tlist.ih_this);
}

//////////////////////////////////////////////////////////////////////////////
Expand Down
3 changes: 3 additions & 0 deletions Source/CPlusPlus/PSMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class Matrix_ps {
//!\param tlist the triplet list.
//!\param threshold for flushing small values.
void DiagonalScale(const NTPoly::TripletList_r &tlist);
//!\param tlist the triplet list.
//!\param threshold for flushing small values.
void DiagonalScale(const NTPoly::TripletList_c &tlist);

public:
//! Destructor.
Expand Down
25 changes: 20 additions & 5 deletions Source/Wrapper/PSMatrixAlgebraModule_wrp.F90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ MODULE PSMatrixAlgebraModule_wrp
USE PSMatrixAlgebraModule
USE PSMatrixModule_wrp, ONLY : Matrix_ps_wrp
USE PermutationModule_wrp, ONLY : Permutation_wrp
USE TripletListModule_wrp, ONLY : TripletList_r_wrp
USE TripletListModule_wrp, ONLY : TripletList_r_wrp, TripletList_c_wrp
USE WrapperModule, ONLY : SIZE_wrp
USE ISO_C_BINDING, ONLY : c_int, c_char, c_bool
IMPLICIT NONE
Expand All @@ -20,7 +20,8 @@ MODULE PSMatrixAlgebraModule_wrp
PUBLIC :: ScaleMatrix_ps_wrp
PUBLIC :: MatrixNorm_ps_wrp
PUBLIC :: MatrixTrace_ps_wrp
PUBLIC :: MatrixDiagonalScale_ps_wrp
PUBLIC :: MatrixDiagonalScale_psr_wrp
PUBLIC :: MatrixDiagonalScale_psc_wrp
CONTAINS!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!> Matrix B = alpha*Matrix A + Matrix B (AXPY)
SUBROUTINE IncrementMatrix_ps_wrp(ih_matA, ih_matB, alpha_in,threshold_in) &
Expand Down Expand Up @@ -146,8 +147,8 @@ SUBROUTINE MatrixTrace_ps_wrp(ih_this, trace_value) &
END SUBROUTINE MatrixTrace_ps_wrp
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!> Scale a matrix using a diagonal matrix (triplet list form).
SUBROUTINE MatrixDiagonalScale_ps_wrp(ih_mat, ih_tlist) &
& BIND(c,name="MatrixDiagonalScale_ps_wrp")
SUBROUTINE MatrixDiagonalScale_psr_wrp(ih_mat, ih_tlist) &
& BIND(c,name="MatrixDiagonalScale_psr_wrp")
INTEGER(kind=c_int), INTENT(INOUT) :: ih_mat(SIZE_wrp)
INTEGER(kind=c_int), INTENT(IN) :: ih_tlist(SIZE_wrp)
TYPE(Matrix_ps_wrp) :: h_mat
Expand All @@ -157,6 +158,20 @@ SUBROUTINE MatrixDiagonalScale_ps_wrp(ih_mat, ih_tlist) &
h_tlist = TRANSFER(ih_tlist, h_tlist)

CALL MatrixDiagonalScale(h_mat%DATA, h_tlist%DATA)
END SUBROUTINE MatrixDiagonalScale_ps_wrp
END SUBROUTINE MatrixDiagonalScale_psr_wrp
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!> Scale a matrix using a diagonal matrix (triplet list form).
SUBROUTINE MatrixDiagonalScale_psc_wrp(ih_mat, ih_tlist) &
& BIND(c,name="MatrixDiagonalScale_psc_wrp")
INTEGER(kind=c_int), INTENT(INOUT) :: ih_mat(SIZE_wrp)
INTEGER(kind=c_int), INTENT(IN) :: ih_tlist(SIZE_wrp)
TYPE(Matrix_ps_wrp) :: h_mat
TYPE(TripletList_c_wrp) :: h_tlist

h_mat = TRANSFER(ih_mat, h_mat)
h_tlist = TRANSFER(ih_tlist, h_tlist)

CALL MatrixDiagonalScale(h_mat%DATA, h_tlist%DATA)
END SUBROUTINE MatrixDiagonalScale_psc_wrp
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
END MODULE PSMatrixAlgebraModule_wrp
12 changes: 8 additions & 4 deletions UnitTests/test_psmatrixalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def test_scalediag(self):
self.write_matrix(matrix, self.input_file1)
self.CheckMat = deepcopy(matrix)

# Need a guard because otherwise we write a real matrix by mistake.
if self.complex1 and matrix.nnz == 0:
continue

tlist = self.TripletList()
for i in range(matrix.shape[1]):
t = self.Triplet()
Expand Down Expand Up @@ -393,8 +397,8 @@ class TestPSMatrixAlgebra_rc(TestPSMatrixAlgebra, unittest.TestCase):
complex1 = True
# Whether the second matrix is complex or not
complex2 = False
TripletList = nt.TripletList_r
Triplet = nt.Triplet_r
TripletList = nt.TripletList_c
Triplet = nt.Triplet_c

def setUp(self):
'''Set up a specific test.'''
Expand All @@ -410,8 +414,8 @@ class TestPSMatrixAlgebra_cr(TestPSMatrixAlgebra, unittest.TestCase):
complex1 = False
# Whether the second matrix is complex or not
complex2 = True
TripletList = nt.TripletList_c
Triplet = nt.Triplet_c
TripletList = nt.TripletList_r
Triplet = nt.Triplet_r

def setUp(self):
'''Set up a specific test.'''
Expand Down

0 comments on commit 9826c0a

Please sign in to comment.