From bac16ab98978b26c6f1b5be31cdaabb177797181 Mon Sep 17 00:00:00 2001 From: Tim Davis Date: Wed, 3 Jul 2024 09:49:25 -0500 Subject: [PATCH] arminmax --- experimental/algorithm/LAGraph_argminmax.c | 26 ++++-- .../hpec24_notes/LAGraph_argminmax.c | 89 +++++++++++++++++++ experimental/benchmark/argmax_tests.m | 70 +++++++++++++++ experimental/test/test_argminmax.c | 5 +- 4 files changed, 183 insertions(+), 7 deletions(-) create mode 100644 experimental/algorithm/hpec24_notes/LAGraph_argminmax.c create mode 100644 experimental/benchmark/argmax_tests.m diff --git a/experimental/algorithm/LAGraph_argminmax.c b/experimental/algorithm/LAGraph_argminmax.c index 79b543b012..07511d88f5 100644 --- a/experimental/algorithm/LAGraph_argminmax.c +++ b/experimental/algorithm/LAGraph_argminmax.c @@ -102,10 +102,26 @@ int argminmax // for dim=2: find the position of the min/max entry in each row: // p = G*y, so that p(i) = j if x(i) = A(i,j) = min/max (A (i,:)). - // Use the SECONDI operator since built-in indexing is 0-based. The ANY - // monoid would be faster, but this uses MIN monoid so that the result for - // the user is repeatable. - GRB_TRY (GrB_mxm (*p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, desc)) ; + #if 0 + printf ("argmin/max with 2ndi\n") ; + // Use the SECONDI operator since built-in indexing is 0-based. The + // ANY monoid would be faster, but this uses MIN monoid so that the + // result for the user is repeatable. + // p = G*y or G'*y using the MIN_SECONDI semiring + GRB_TRY (GrB_mxm (*p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, desc)) ; + #else + printf ("argmin/max without 2ndi\n") ; + // H = rowindex (G) if dim is 1, or colindex (G) if dim is 2. + GrB_Matrix H = NULL ; + GRB_TRY (GrB_Matrix_new (&H, GrB_INT64, nrows, ncols)) ; + GRB_TRY (GrB_apply (H, NULL, NULL, + (dim == 1) ? GrB_ROWINDEX_INT64 : GrB_COLINDEX_INT64, G, + (int64_t) 0, NULL)) ; + // p = H*y or H'*y using the MIN_FIRST semiring + GRB_TRY (GrB_mxm (*p, NULL, NULL, GrB_MIN_FIRST_SEMIRING_INT64, H, y, + desc)) ; + GRB_TRY (GrB_Matrix_free (&H)) ; + #endif //-------------------------------------------------------------------------- // free workspace @@ -312,7 +328,7 @@ int LAGraph_argminmax GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [0]), *p, 0, 0)) ; // I [1] = p [I [0]-1] (use -1 since I[0] is 1-based), // which is the column index of the global argmin/max of A - GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [1]), p1, I [0] - 1, 0)) ; + GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [1]), p1, I [0], 0)) ; } // free workspace and create p = [row, col] diff --git a/experimental/algorithm/hpec24_notes/LAGraph_argminmax.c b/experimental/algorithm/hpec24_notes/LAGraph_argminmax.c new file mode 100644 index 0000000000..0a5b822d94 --- /dev/null +++ b/experimental/algorithm/hpec24_notes/LAGraph_argminmax.c @@ -0,0 +1,89 @@ +// A simplified algorithm for HPEC'24 +// assume the matrix type is FP64 +// assume argmax +// use mxv where appropriate +// don't use the ANY monoid. + +//------------------------------------------------------------------------------ +// argmax: compute argmax of each row of A +//------------------------------------------------------------------------------ + +int argmax +( + // output + GrB_Vector *x_handle, // max value in each row of A + GrB_Vector *p_handle, // index of max value in each row of A + // input + GrB_Matrix A // assumed to be GrB_FP64 +) +{ + + //-------------------------------------------------------------------------- + // create outputs x and p, and the iso full vector y + //-------------------------------------------------------------------------- + + GrB_Index nrows, ncols ; + GrB_Matrix_nrows (&nrows, A) ; + GrB_Matrix_ncols (&ncols, A) ; + GrB_Vector y = NULL, x = NULL, p = NULL ; + GrB_Matrix G = NULL, D = NULL ; + GrB_Vector_new (&x, GrB_FP64, nrows) ; + GrB_Vector_new (&y, GrB_FP64, ncols) ; + GrB_Vector_new (&p, GrB_INT64, nrows) ; + + // y (:) = 1, an full vector with all entries equal to 1 + GrB_Matrix_assign_INT64 (y, NULL, NULL, 1, GrB_ALL, ncols, NULL) ; + + //-------------------------------------------------------------------------- + // compute x = max(A) + //-------------------------------------------------------------------------- + + // x = max (A) where x(i) = max (A (i,:)) + GrB_mxv (x, NULL, NULL, GrB_MAX_FIRST_SEMIRING_FP64, A, y, NULL) ; + + //-------------------------------------------------------------------------- + // compute G, where G(i,j)=1 if A(i,j) is the max in its row + //-------------------------------------------------------------------------- + + // D = diag (x) + GrB_Matrix_diag (&D, x, 0) ; + GrB_Matrix_new (&G, GrB_BOOL, nrows, ncols) ; + // G = D*A using the EQ_EQ_FP64 semiring + GrB_mxm (G, NULL, NULL, GxB_EQ_EQ_FP64, D, A, NULL) ; + // drop explicit zeros from G + GrB_Matrix_select_BOOL (G, NULL, NULL, GrB_VALUENE_BOOL, G, 0, NULL) ; + + //-------------------------------------------------------------------------- + // extract the positions of the entries in G + //-------------------------------------------------------------------------- + + // find the position of the max entry in each row: + // p = G*y, so that p(i) = j if x(i) = A(i,j) = max (A (i,:)). + + if (no 2ndI op) + { + // H = rowindex (G) + GrB_Matrix H = NULL ; + GrB_Matrix_new (&H, nrows, ncols) ; + GrB_apply (H, NULL, NULL, GrB_ROWINDEX_INT64, G, NULL) ; + // p = H*y + GrB_mxv (p, NULL, NULL, GrB_MIN_FIRST_SEMIRING_INT64, H, y, NULL) ; + GrB_free (&H) ; + } + else + { + // using the SECONDI operator + GrB_mxm (p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, NULL) ; + } + + //-------------------------------------------------------------------------- + // free workspace and return result + //-------------------------------------------------------------------------- + + GrB_Matrix_free (&D) ; + GrB_Matrix_free (&G) ; + GrB_Matrix_free (&y) ; + (*x_handle) = x ; + (*p_handle) = p ; +} + diff --git a/experimental/benchmark/argmax_tests.m b/experimental/benchmark/argmax_tests.m new file mode 100644 index 0000000000..ccc8b9f008 --- /dev/null +++ b/experimental/benchmark/argmax_tests.m @@ -0,0 +1,70 @@ + +% for HPEC'24 paper + +if (0) + clear all + Prob = ssget ('GAP/GAP-twitter') +end +A = Prob.A ; +nz = nnz (A) +n = size (A,1) ; +A = A + speye (n) ; +nz = nnz (A) +G = GrB (A) ; + +% time the GraphBLAS max and argmax methods +for thr = [1 40] + GrB.threads (thr) ; + for trial = 1:3 + fprintf ('\ntrial %d, threads %g\n', trial, thr) ; + + t = tic ; + x = max (G, [ ], 1) ; + t1 = toc (t) ; + fprintf ('GrB colwise max: %g sec\n', t1) ; + t = tic ; + [x,p] = GrB.argmax (G, 1) ; + t2 = toc (t) ; + fprintf ('GrB colwise argmax: %g sec\n', t2) ; + fprintf ('GrB colwise argmax time / max time: %g\n', t2/t1) ; + + t = tic ; + x = max (G, [ ], 2) ; + t1 = toc (t) ; + fprintf ('GrB rowwise max: %g sec\n', t1) ; + t = tic ; + [x,p] = GrB.argmax (G, 2) ; + t2 = toc (t) ; + fprintf ('GrB rowwise argmax: %g sec\n', t2) ; + fprintf ('GrB rowwise argmax time / max time: %g\n', t2/t1) ; + + end +end + + +% time the MATLAB max and argmax methods +for trial = 1:3 + fprintf ('\ntrial %d\n', trial) ; + + t = tic ; + x = max (A, [ ], 1) ; + t1 = toc (t) ; + fprintf ('MATLAB colwise max: %g sec\n', t1) ; + t = tic ; + [x,p] = max (A, [ ], 1) ; + t2 = toc (t) ; + fprintf ('MATLAB colwise argmax: %g sec\n', t2) ; + fprintf ('MATLAB colwise argmax time / max time: %g\n', t2/t1) ; + + t = tic ; + x = max (A, [ ], 2) ; + t1 = toc (t) ; + fprintf ('MATLAB rowwise max: %g sec\n', t1) ; + t = tic ; + [x,p] = max (A, [ ], 2) ; + t2 = toc (t) ; + fprintf ('MATLAB rowwise argmax: %g sec\n', t2) ; + fprintf ('MATLAB rowwise argmax time / max time: %g\n', t2/t1) ; + +end + diff --git a/experimental/test/test_argminmax.c b/experimental/test/test_argminmax.c index 566d91c1f6..f161e995bc 100644 --- a/experimental/test/test_argminmax.c +++ b/experimental/test/test_argminmax.c @@ -40,8 +40,9 @@ void test_argminmax (void) printf ("\nInput of Matrix:\n") ; GxB_print(A, 2); // test the algorithm - OK (LAGraph_argminmax (&x,&p, A,dim,is_min, msg)); - printf("\n") ; + int info = LAGraph_argminmax (&x,&p, A,dim,is_min, msg); + printf("%s\n", msg) ; + OK (info) ; GxB_print(x,3); GxB_print(p,3); // print the result