diff --git a/Source/C/PSMatrix_c.h b/Source/C/PSMatrix_c.h index b351b217..c07bd7a0 100644 --- a/Source/C/PSMatrix_c.h +++ b/Source/C/PSMatrix_c.h @@ -62,5 +62,6 @@ void ScaleMatrix_ps_wrp(int *ih_this, const double *constant); double MatrixNorm_ps_wrp(const int *ih_this); void MatrixTrace_ps_wrp(const int *ih_this, double *trace_val); int IsIdentity_ps_wrp(const int *ih_this); +void MatrixDiagonalScale_ps_wrp(int *ih_mat, const int *ih_tlist); #endif diff --git a/Source/CPlusPlus/PSMatrix.cc b/Source/CPlusPlus/PSMatrix.cc index c2c296bd..d421c4df 100644 --- a/Source/CPlusPlus/PSMatrix.cc +++ b/Source/CPlusPlus/PSMatrix.cc @@ -217,6 +217,11 @@ void Matrix_ps::Scale(double constant) { ScaleMatrix_ps_wrp(ih_this, &constant); } +////////////////////////////////////////////////////////////////////////////// +void Matrix_ps::DiagonalScale(const TripletList_r &tlist) { + MatrixDiagonalScale_ps_wrp(ih_this, tlist.ih_this); +} + ////////////////////////////////////////////////////////////////////////////// double Matrix_ps::Norm() const { return MatrixNorm_ps_wrp(ih_this); } diff --git a/Source/CPlusPlus/PSMatrix.h b/Source/CPlusPlus/PSMatrix.h index eb2bd2ab..d6e3bc90 100644 --- a/Source/CPlusPlus/PSMatrix.h +++ b/Source/CPlusPlus/PSMatrix.h @@ -160,13 +160,16 @@ class Matrix_ps { void Gemm(const Matrix_ps &matA, const Matrix_ps &matB, PMatrixMemoryPool &memory_pool, double alpha = 1.0, double beta = 0.0, double threshold = 0.0); - //! scale the matrix by a constatn. + //! scale the matrix by a constant. //! constant the value to scale by. void Scale(double constant); //! compute the norm of a matrix. double Norm() const; //! compute the trace of a matrix. double Trace() const; + //!\param tlist the triplet list. + //!\param threshold for flushing small values. + void DiagonalScale(const NTPoly::TripletList_r &tlist); public: //! Destructor. diff --git a/Source/Fortran/PSMatrixAlgebraModule.F90 b/Source/Fortran/PSMatrixAlgebraModule.F90 index c2b13f86..6ecea2da 100644 --- a/Source/Fortran/PSMatrixAlgebraModule.F90 +++ b/Source/Fortran/PSMatrixAlgebraModule.F90 @@ -13,13 +13,16 @@ MODULE PSMatrixAlgebraModule & ConstructMatrixMemoryPool USE PSMatrixModule, ONLY : Matrix_ps, ConstructEmptyMatrix, CopyMatrix, & & DestructMatrix, ConvertMatrixToComplex, ConjugateMatrix, & - & MergeMatrixLocalBlocks, IsIdentity + & MergeMatrixLocalBlocks, IsIdentity, SplitMatrixToLocalBlocks USE SMatrixAlgebraModule, ONLY : MatrixMultiply, MatrixGrandSum, & & PairwiseMultiplyMatrix, IncrementMatrix, ScaleMatrix, & - & MatrixColumnNorm + & MatrixColumnNorm, MatrixDiagonalScale USE SMatrixModule, ONLY : Matrix_lsr, Matrix_lsc, DestructMatrix, CopyMatrix,& & TransposeMatrix, ComposeMatrixColumns, MatrixToTripletList - USE TripletListModule, ONLY : TripletList_r, TripletList_c + USE TripletListModule, ONLY : TripletList_r, TripletList_c, & + & ConstructTripletList, AppendToTripletList, DestructTripletList, & + & GetTripletAt + USE TripletModule, ONLY : Triplet_r, Triplet_c USE NTMPIModule IMPLICIT NONE PRIVATE @@ -34,6 +37,7 @@ MODULE PSMatrixAlgebraModule PUBLIC :: ScaleMatrix PUBLIC :: MatrixTrace PUBLIC :: SimilarityTransform + PUBLIC :: MatrixDiagonalScale !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! INTERFACE MatrixSigma MODULE PROCEDURE MatrixSigma_ps @@ -62,6 +66,10 @@ MODULE PSMatrixAlgebraModule MODULE PROCEDURE ScaleMatrix_psr MODULE PROCEDURE ScaleMatrix_psc END INTERFACE ScaleMatrix + INTERFACE MatrixDiagonalScale + MODULE PROCEDURE MatrixDiagonalScale_psr + MODULE PROCEDURE MatrixDiagonalScale_psc + END INTERFACE MatrixDiagonalScale INTERFACE MatrixTrace MODULE PROCEDURE MatrixTrace_psr END INTERFACE MatrixTrace @@ -491,6 +499,34 @@ RECURSIVE SUBROUTINE ScaleMatrix_psc(this, constant) END IF END SUBROUTINE ScaleMatrix_psc +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + !> Will scale a distributed sparse matrix by a constant. + SUBROUTINE MatrixDiagonalScale_psr(this, tlist) + !> Matrix to scale. + TYPE(Matrix_ps), INTENT(INOUT) :: this + !> A constant scale factor. + TYPE(TripletList_r), INTENT(IN) :: tlist + !! Local Data + TYPE(Matrix_lsr) :: lmat + TYPE(TripletList_r) :: filtered + TYPE(Triplet_r) :: trip + +#include "distributed_algebra_includes/ScaleDiagonal.f90" + END SUBROUTINE MatrixDiagonalScale_psr +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + !> Will scale a distributed sparse matrix by a constant. + RECURSIVE SUBROUTINE MatrixDiagonalScale_psc(this, tlist) + !> Matrix to scale. + TYPE(Matrix_ps), INTENT(INOUT) :: this + !> A constant scale factor. + TYPE(TripletList_c), INTENT(IN) :: tlist + !! Local Data + TYPE(Matrix_lsc) :: lmat + TYPE(TripletList_c) :: filtered + TYPE(Triplet_c) :: trip + +#include "distributed_algebra_includes/ScaleDiagonal.f90" + END SUBROUTINE MatrixDiagonalScale_psc !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !> Compute the trace of the matrix. SUBROUTINE MatrixTrace_psr(this, trace_value) diff --git a/Source/Fortran/PSMatrixModule.F90 b/Source/Fortran/PSMatrixModule.F90 index 19812a28..423254e8 100644 --- a/Source/Fortran/PSMatrixModule.F90 +++ b/Source/Fortran/PSMatrixModule.F90 @@ -1443,7 +1443,7 @@ SUBROUTINE TransposeMatrix_psr(AMat, TransMat) !! Local Variables TYPE(TripletList_r) :: tlist TYPE(TripletList_r) :: new_list - TYPE(Triplet_r) :: trip, trip_t + TYPE(Triplet_r) :: trip #include "distributed_includes/TransposeMatrix.f90" @@ -1458,7 +1458,7 @@ SUBROUTINE TransposeMatrix_psc(AMat, TransMat) !! Local Variables TYPE(TripletList_c) :: tlist TYPE(TripletList_c) :: new_list - TYPE(Triplet_c) :: trip, trip_t + TYPE(Triplet_c) :: trip #include "distributed_includes/TransposeMatrix.f90" diff --git a/Source/Fortran/distributed_algebra_includes/ScaleDiagonal.f90 b/Source/Fortran/distributed_algebra_includes/ScaleDiagonal.f90 new file mode 100644 index 00000000..ee3b4db5 --- /dev/null +++ b/Source/Fortran/distributed_algebra_includes/ScaleDiagonal.f90 @@ -0,0 +1,24 @@ +INTEGER :: II, row + +!! Merge to the local block +CALL MergeMatrixLocalBlocks(this, lmat) + +!! Filter out the triplets that aren't stored locally +CALL ConstructTripletList(filtered) +DO II = 1, tlist%CurrentSize + CALL GetTripletAt(tlist, II, trip) + row = trip%index_row + IF (row .GE. this%start_row .AND. row .LT. this%end_row) THEN + trip%index_row = trip%index_row - this%start_row + 1 + trip%index_column = trip%index_row + CALL AppendToTripletList(filtered, trip) + END IF +END DO + +!! Scale +CALL MatrixDiagonalScale(lmat, filtered) + +!! Split +CALL SplitMatrixToLocalBlocks(this, lmat) +CALL DestructMatrix(lmat) +CALL DestructTripletList(filtered) \ No newline at end of file diff --git a/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 b/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 index 3fe86b2d..e333d6ac 100644 --- a/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 +++ b/Source/Wrapper/PSMatrixAlgebraModule_wrp.F90 @@ -20,6 +20,7 @@ MODULE PSMatrixAlgebraModule_wrp PUBLIC :: ScaleMatrix_ps_wrp PUBLIC :: MatrixNorm_ps_wrp PUBLIC :: MatrixTrace_ps_wrp + PUBLIC :: MatrixDiagonalScale_ps_wrp CONTAINS!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !> Matrix B = alpha*Matrix A + Matrix B (AXPY) SUBROUTINE IncrementMatrix_ps_wrp(ih_matA, ih_matB, alpha_in,threshold_in) & @@ -143,5 +144,19 @@ SUBROUTINE MatrixTrace_ps_wrp(ih_this, trace_value) & h_this = TRANSFER(ih_this,h_this) CALL MatrixTrace(h_this%DATA, trace_value) END SUBROUTINE MatrixTrace_ps_wrp +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + !> Scale a matrix using a diagonal matrix (triplet list form). + SUBROUTINE MatrixDiagonalScale_ps_wrp(ih_mat, ih_tlist) & + & BIND(c,name="MatrixDiagonalScale_ps_wrp") + INTEGER(kind=c_int), INTENT(INOUT) :: ih_mat(SIZE_wrp) + INTEGER(kind=c_int), INTENT(IN) :: ih_tlist(SIZE_wrp) + TYPE(Matrix_ps_wrp) :: h_mat + TYPE(TripletList_r_wrp) :: h_tlist + + h_mat = TRANSFER(ih_mat, h_mat) + h_tlist = TRANSFER(ih_tlist, h_tlist) + + CALL MatrixDiagonalScale(h_mat%DATA, h_tlist%DATA) + END SUBROUTINE MatrixDiagonalScale_ps_wrp !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! END MODULE PSMatrixAlgebraModule_wrp diff --git a/UnitTests/test_psmatrixalgebra.py b/UnitTests/test_psmatrixalgebra.py index 1d09c88b..f653aa6f 100644 --- a/UnitTests/test_psmatrixalgebra.py +++ b/UnitTests/test_psmatrixalgebra.py @@ -284,6 +284,29 @@ def test_reverse(self): self.check_result() + def test_scalediag(self): + '''Test routines to scale by a diagonal matrix.''' + from copy import deepcopy + for param in self.parameters: + matrix = param.create_matrix(complex=self.complex) + mmwrite(self.file1, matrix) + CheckMat = deepcopy(matrix) + + tlist = self.TripletList() + for i in range(matrix.shape[1]): + t = self.Triplet() + t.index_column = i + 1 + t.index_row = i + 1 + t.point_value = i + tlist.Append(t) + CheckMat[:, i] *= i + ntmatrix = self.PSMatrix(self.file1) + ntmatrix.DiagonalScale(tlist) + ntmatrix.WriteToMatrixMarket(self.file2) + + ResultMat = mmread(self.file2) + self._compare_mat(CheckMat, ResultMat) + class TestPSMatrixAlgebra_r(TestPSMatrixAlgebra, unittest.TestCase): '''Special routines for real algebra''' @@ -291,6 +314,8 @@ class TestPSMatrixAlgebra_r(TestPSMatrixAlgebra, unittest.TestCase): complex1 = False # Whether the second matrix is complex or not complex2 = False + TripletList = nt.TripletList_r + Triplet = nt.Triplet_r def test_dot(self): '''Test routines to add together matrices.''' @@ -327,6 +352,8 @@ class TestPSMatrixAlgebra_c(TestPSMatrixAlgebra, unittest.TestCase): complex1 = True # Whether the second matrix is complex or not complex2 = True + TripletList = nt.TripletList_c + Triplet = nt.Triplet_c def test_dot(self): '''Test routines to add together matrices.''' @@ -366,6 +393,8 @@ class TestPSMatrixAlgebra_rc(TestPSMatrixAlgebra, unittest.TestCase): complex1 = True # Whether the second matrix is complex or not complex2 = False + TripletList = nt.TripletList_r + Triplet = nt.Triplet_r def setUp(self): '''Set up a specific test.''' @@ -381,6 +410,8 @@ class TestPSMatrixAlgebra_cr(TestPSMatrixAlgebra, unittest.TestCase): complex1 = False # Whether the second matrix is complex or not complex2 = True + TripletList = nt.TripletList_c + Triplet = nt.Triplet_c def setUp(self): '''Set up a specific test.'''