Skip to content

Commit

Permalink
Fix BMM benchmark and add TRANSPOSE_B and threads
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689029942
  • Loading branch information
alankelly authored and xnnpack-bot committed Oct 23, 2024
1 parent e0698d1 commit 542d281
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions bench/batch-matrix-multiply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) {
const size_t batch_size = state.range(0);
const size_t m = state.range(1);
const size_t k = state.range(1);
const size_t n = state.range(1);
const size_t k = state.range(2);
const size_t n = state.range(3);
const bool transpose_b = state.range(4);
const size_t num_threads = state.range(5);

std::random_device random_device;
auto rng = std::mt19937(random_device());
Expand All @@ -53,8 +55,9 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net)

xnnpack::Buffer<xnn_operator_t> ops(num_buffers);

uint32_t flags = transpose_b ? XNN_FLAG_TRANSPOSE_B : 0;
for (xnn_operator_t& op : ops) {
status = xnn_create_batch_matrix_multiply_nc_f32(/*flags=*/0, &op);
status = xnn_create_batch_matrix_multiply_nc_f32(flags, &op);
if (status != xnn_status_success) {
state.SkipWithError("failed to create FP32 Convolution operator");
return;
Expand Down Expand Up @@ -85,12 +88,14 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net)
}

size_t buffer_index = 0;
pthreadpool_t threadpool = pthreadpool_create(num_threads);

for (auto _ : state) {
state.PauseTiming();
buffer_index = (buffer_index + 1) % num_buffers;
state.ResumeTiming();

status = xnn_run_operator(ops[buffer_index], /*threadpool=*/nullptr);
status = xnn_run_operator(ops[buffer_index], threadpool);
if (status != xnn_status_success) {
state.SkipWithError("failed to run FP32 Convolution operator");
return;
Expand All @@ -114,6 +119,8 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net)
state.counters["FLOPS"] = benchmark::Counter(
uint64_t(state.iterations()) * batch_size * m * k * n,
benchmark::Counter::kIsRate);

pthreadpool_destroy(threadpool);
}

#ifdef BENCHMARK_TENSORFLOW_LITE
Expand Down

0 comments on commit 542d281

Please sign in to comment.