diff --git a/bench/batch-matrix-multiply.cc b/bench/batch-matrix-multiply.cc index b09cdebc34b..c2affbc14e3 100644 --- a/bench/batch-matrix-multiply.cc +++ b/bench/batch-matrix-multiply.cc @@ -28,8 +28,10 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) { const size_t batch_size = state.range(0); const size_t m = state.range(1); - const size_t k = state.range(1); - const size_t n = state.range(1); + const size_t k = state.range(2); + const size_t n = state.range(3); + const bool transpose_b = state.range(4); + const size_t num_threads = state.range(5); std::random_device random_device; auto rng = std::mt19937(random_device()); @@ -53,8 +55,9 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) xnnpack::Buffer ops(num_buffers); + uint32_t flags = transpose_b ? XNN_FLAG_TRANSPOSE_B : 0; for (xnn_operator_t& op : ops) { - status = xnn_create_batch_matrix_multiply_nc_f32(/*flags=*/0, &op); + status = xnn_create_batch_matrix_multiply_nc_f32(flags, &op); if (status != xnn_status_success) { state.SkipWithError("failed to create FP32 Convolution operator"); return; @@ -85,12 +88,14 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) } size_t buffer_index = 0; + pthreadpool_t threadpool = pthreadpool_create(num_threads); + for (auto _ : state) { state.PauseTiming(); buffer_index = (buffer_index + 1) % num_buffers; state.ResumeTiming(); - status = xnn_run_operator(ops[buffer_index], /*threadpool=*/nullptr); + status = xnn_run_operator(ops[buffer_index], threadpool); if (status != xnn_status_success) { state.SkipWithError("failed to run FP32 Convolution operator"); return; @@ -114,6 +119,8 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) state.counters["FLOPS"] = benchmark::Counter( uint64_t(state.iterations()) * batch_size * m * k * n, benchmark::Counter::kIsRate); + + pthreadpool_destroy(threadpool); } #ifdef BENCHMARK_TENSORFLOW_LITE