Skip to content

Commit

Permalink
replace non-test/trace fprintf with new hwy::Warn/HWY_WARN
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698367210
  • Loading branch information
jan-wassenberg authored and copybara-github committed Nov 20, 2024
1 parent 3cb3a91 commit 3926b93
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 110 deletions.
50 changes: 46 additions & 4 deletions hwy/abort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stdio.h>
#include <stdlib.h>

#include <atomic>
#include <string>

#include "hwy/base.h"
Expand All @@ -20,21 +21,62 @@
namespace hwy {

namespace {

std::atomic<WarnFunc>& AtomicWarnFunc() {
static std::atomic<WarnFunc> func;
return func;
}

std::atomic<AbortFunc>& AtomicAbortFunc() {
static std::atomic<AbortFunc> func;
return func;
}

std::string GetBaseName(std::string const& file_name) {
auto last_slash = file_name.find_last_of("/\\");
return file_name.substr(last_slash + 1);
}

} // namespace

// Returning a reference is unfortunately incompatible with `std::atomic`, which
// is required to safely implement `SetWarnFunc`. As a workaround, we store a
// copy here, update it when called, and return a reference to the copy. This
// has the added benefit of protecting the actual pointer from modification.
HWY_DLLEXPORT WarnFunc& GetWarnFunc() {
static WarnFunc func;
func = AtomicWarnFunc().load();
return func;
}

HWY_DLLEXPORT AbortFunc& GetAbortFunc() {
static AbortFunc func;
func = AtomicAbortFunc().load();
return func;
}

HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func) {
return AtomicWarnFunc().exchange(func);
}

HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func) {
const AbortFunc prev = GetAbortFunc();
GetAbortFunc() = func;
return prev;
return AtomicAbortFunc().exchange(func);
}

HWY_DLLEXPORT void HWY_FORMAT(3, 4)
Warn(const char* file, int line, const char* format, ...) {
char buf[800];
va_list args;
va_start(args, format);
vsnprintf(buf, sizeof(buf), format, args);
va_end(args);

WarnFunc handler = AtomicWarnFunc().load();
if (handler != nullptr) {
handler(file, line, buf);
} else {
fprintf(stderr, "Warn at %s:%d: %s\n", GetBaseName(file).data(), line, buf);
}
}

HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4)
Expand All @@ -45,7 +87,7 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4)
vsnprintf(buf, sizeof(buf), format, args);
va_end(args);

AbortFunc handler = GetAbortFunc();
AbortFunc handler = AtomicAbortFunc().load();
if (handler != nullptr) {
handler(file, line, buf);
} else {
Expand Down
34 changes: 25 additions & 9 deletions hwy/abort.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,36 @@

namespace hwy {

// Interface for custom abort handler
typedef void (*AbortFunc)(const char* file, int line,
const char* formatted_err);
// Interfaces for custom Warn/Abort handlers.
typedef void (*WarnFunc)(const char* file, int line, const char* message);

// Retrieve current abort handler
// Returns null if no abort handler registered, indicating Highway should print and abort
typedef void (*AbortFunc)(const char* file, int line, const char* message);

// Returns current Warn() handler, or nullptr if no handler was yet registered,
// indicating Highway should print to stderr.
// DEPRECATED because this is thread-hostile and prone to misuse (modifying the
// underlying pointer through the reference).
HWY_DLLEXPORT WarnFunc& GetWarnFunc();

// Returns current Abort() handler, or nullptr if no handler was yet registered,
// indicating Highway should print to stderr and abort.
// DEPRECATED because this is thread-hostile and prone to misuse (modifying the
// underlying pointer through the reference).
HWY_DLLEXPORT AbortFunc& GetAbortFunc();

// Sets a new abort handler and returns the previous abort handler
// If this handler does not do the aborting itself Highway will use its own abort mechanism
// which allows this to be used to customize the handling of the error itself.
// Returns null if no previous abort handler registered
// Sets a new Warn() handler and returns the previous handler, which is nullptr
// if no previous handler was registered, and should otherwise be called from
// the new handler. Thread-safe.
HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func);

// Sets a new Abort() handler and returns the previous handler, which is nullptr
// if no previous handler was registered, and should otherwise be called from
// the new handler. If all handlers return, then Highway will terminate the app.
// Thread-safe.
HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func);

// Abort()/Warn() and HWY_ABORT/HWY_WARN are declared in base.h.

} // namespace hwy

#endif // HIGHWAY_HWY_ABORT_H_
21 changes: 21 additions & 0 deletions hwy/abort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,28 @@
namespace hwy {
namespace {

TEST(AbortTest, WarnOverrideChain) {
WarnFunc FirstHandler = [](const char* file, int line,
const char* formatted_err) -> void {
fprintf(stderr, "%s from %d of %s", formatted_err, line, file);
};
WarnFunc SecondHandler = [](const char* file, int line,
const char* formatted_err) -> void {
fprintf(stderr, "%s from %d of %s", formatted_err, line, file);
};

// Do not check that the first SetWarnFunc returns nullptr, because it is
// not guaranteed to be the first call - other TEST may come first.
(void)SetWarnFunc(FirstHandler);
HWY_ASSERT(GetWarnFunc() == FirstHandler);
HWY_ASSERT(SetWarnFunc(SecondHandler) == FirstHandler);
HWY_ASSERT(GetWarnFunc() == SecondHandler);
HWY_ASSERT(SetWarnFunc(nullptr) == SecondHandler);
HWY_ASSERT(GetWarnFunc() == nullptr);
}

#ifdef GTEST_HAS_DEATH_TEST

std::string GetBaseName(std::string const& file_name) {
auto last_slash = file_name.find_last_of("/\\");
return file_name.substr(last_slash + 1);
Expand Down
6 changes: 6 additions & 0 deletions hwy/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ namespace hwy {
// 4 instances of a given literal value, useful as input to LoadDup128.
#define HWY_REP4(literal) literal, literal, literal, literal

HWY_DLLEXPORT void HWY_FORMAT(3, 4)
Warn(const char* file, int line, const char* format, ...);

#define HWY_WARN(format, ...) \
::hwy::Warn(__FILE__, __LINE__, format, ##__VA_ARGS__)

HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4)
Abort(const char* file, int line, const char* format, ...);

Expand Down
2 changes: 1 addition & 1 deletion hwy/contrib/bit_pack/bit_pack_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ struct TestPack {
},
inputs, kNumInputs, results, p);
if (num_results != kNumInputs) {
fprintf(stderr, "MeasureClosure failed.\n");
HWY_WARN("MeasureClosure failed.\n");
return;
}
// Print throughput for pack+unpack round trip
Expand Down
3 changes: 1 addition & 2 deletions hwy/contrib/math/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
static bool once = true;
if (once) {
once = false;
fprintf(stderr,
"Skipping math_test due to GCC issue with excess precision.\n");
HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
}
return;
}
Expand Down
5 changes: 2 additions & 3 deletions hwy/contrib/sort/algo-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,7 @@ void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared,
case Algo::kVXSort: {
#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
(!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
fprintf(stderr, "Do not call for target %s\n",
hwy::TargetName(HWY_TARGET));
HWY_WARN("Do not call for target %s\n", hwy::TargetName(HWY_TARGET));
return;
#else
#if VXSORT_AVX3
Expand All @@ -566,7 +565,7 @@ void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared,
if (kAscending) {
return vx.sort(inout, inout + num_keys - 1);
} else {
fprintf(stderr, "Skipping VX - does not support descending order\n");
HWY_WARN("Skipping VX - does not support descending order\n");
return;
}
#endif // enabled for this target
Expand Down
3 changes: 1 addition & 2 deletions hwy/contrib/sort/bench_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ HWY_NOINLINE void BenchAllColdSort() {

char cpu100[100];
if (!platform::HaveTimerStop(cpu100)) {
fprintf(stderr, "CPU '%s' does not support RDTSCP, skipping benchmark.\n",
cpu100);
HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100);
return;
}

Expand Down
1 change: 0 additions & 1 deletion hwy/contrib/sort/sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ void TestAllSortIota() {
if (hwy::HaveFloat64()) {
TestSortIota<double>(pool);
}
fprintf(stderr, "Iota OK\n");
#endif
}

Expand Down
10 changes: 5 additions & 5 deletions hwy/contrib/sort/vqsort-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,8 @@ HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys,
const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
if (partial_128 || huge_vec) {
if (VQSORT_PRINT >= 1) {
fprintf(stderr, "WARNING: using slow HeapSort: partial %d huge %d\n",
partial_128, huge_vec);
HWY_WARN("using slow HeapSort: partial %d huge %d\n", partial_128,
huge_vec);
}
HeapSort(st, keys, num);
return true;
Expand Down Expand Up @@ -1998,7 +1998,7 @@ void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
(void)d;
(void)buf;
if (VQSORT_PRINT >= 1) {
fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
HWY_WARN("using slow HeapSort because vqsort disabled\n");
}
detail::HeapSort(st, keys, num);
#endif // VQSORT_ENABLED
Expand Down Expand Up @@ -2043,7 +2043,7 @@ void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k,
(void)d;
(void)buf;
if (VQSORT_PRINT >= 1) {
fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
HWY_WARN("using slow HeapSort because vqsort disabled\n");
}
detail::HeapPartialSort(st, keys, num, k);
#endif // VQSORT_ENABLED
Expand Down Expand Up @@ -2084,7 +2084,7 @@ void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
(void)d;
(void)buf;
if (VQSORT_PRINT >= 1) {
fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n");
HWY_WARN("using slow HeapSort because vqsort disabled\n");
}
detail::HeapSelect(st, keys, num, k);
#endif // VQSORT_ENABLED
Expand Down
Loading

0 comments on commit 3926b93

Please sign in to comment.