Skip to content

Commit

Permalink
Minor bug fixes post c++ conversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
FinnWilkinson committed Jan 24, 2024
1 parent 8d05a4c commit cabe7d1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ArmPL/gemm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class gemm_cpu : public gemm<T> {
/** Make a class to the BLAS Library Kernel. */
virtual void callKernel() override {
if constexpr (std::is_same_v<T, float>) {
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, m, n, k, ALPHA,
A_.data(), MAX(1, m), B.data(), MAX(1, k), BETA, C.data(),
MAX(1, m));
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, m_, n_, k_, ALPHA,
A_.data(), MAX(1, m_), B_.data(), MAX(1, k_), BETA, C_.data(),
MAX(1, m_));
} else if constexpr (std::is_same_v<T, double>) {
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, m, n, k, ALPHA,
A_.data(), MAX(1, m), B.data(), MAX(1, k), BETA, C.data(),
MAX(1, m));
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, m_, n_, k_, ALPHA,
A_.data(), MAX(1, m_), B_.data(), MAX(1, k_), BETA, C_.data(),
MAX(1, m_));
} else {
// Un-specialised class will not do any work - print error and exit.
std::cout << "ERROR - Datatype for ArmPL CPU GEMM kernel not supported."
Expand Down
1 change: 1 addition & 0 deletions DefaultGPU/gemm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <time.h>

#include <cmath>
#include <vector>

#include "../include/GPU/gemm.hh"
Expand Down
2 changes: 2 additions & 0 deletions include/helpers.hh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once

#include <cmath>
#include <cstdint>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>

Expand Down

0 comments on commit cabe7d1

Please sign in to comment.