Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upstream static #807

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion static/include/cuda_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ inline DeviceError QueryEvent(EventType event) {
return cudaEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return cudaGetErrorString(err);
}

Expand Down
1 change: 0 additions & 1 deletion static/include/debug_utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "device_functions-generated.h"

Expand Down
2 changes: 1 addition & 1 deletion static/include/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#define DEVICE_CHECK(call) \
if ((call) != GetDeviceSuccess()) { \
throw std::runtime_error( \
#call " API call failed: " + GetLastErrorString() + " at " + \
#call " API call failed: " + GetErrorString(call) + " at " + \
__FILE__ + ", line" + std::to_string(__LINE__)); \
}

Expand Down
29 changes: 25 additions & 4 deletions static/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ namespace ait {
inline void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") +
cudaGetErrorString(device_error) +
std::string msg = std::string("Got error: ") + GetErrorString(device_error) +
" enum: " + std::to_string(device_error) + " at " + file + ": " +
std::to_string(line);
LOG(ERROR) << msg;
Expand Down Expand Up @@ -217,6 +216,29 @@ class ModelBase {
}

void RunAsGraph(StreamType stream) {
#ifdef __HIP_PLATFORM_HCC__
if (graph_exec_ == nullptr) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING)
<< "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#else
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
Expand All @@ -230,13 +252,11 @@ class ModelBase {
}
throw;
}

// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });

if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (
Expand All @@ -247,6 +267,7 @@ class ModelBase {
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#endif

DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
Expand Down
15 changes: 8 additions & 7 deletions static/include/rocm_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace ait {

inline thread_local bool target_has_graph_mode = false;
inline thread_local bool target_has_graph_mode = true;

using DeviceError = hipError_t;
using DevicePropertyType = hipDeviceProp_t;
Expand Down Expand Up @@ -57,7 +57,7 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has 32-bit integer atomics for shared memory: "
<< (arch.hasSharedInt32Atomics ? "yes" : "no")
<< "\n Has 32-bit float atomic exch for shared memory: "
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no"
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no")
<< "\n Has 32-bit float atomic add in global and shared memory: "
<< (arch.hasFloatAtomicAdd ? "yes" : "no")
<< "\n Has 64-bit integer atomics for global memory: "
Expand All @@ -67,9 +67,9 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has double-precision floating point: "
<< (arch.hasDoubles ? "yes" : "no")
<< "\n Has warp vote instructions (__any, __all): "
<< (arch.hasWarpVote: ? "yes" : "no")
<< (arch.hasWarpVote ? "yes" : "no")
<< "\n Has warp ballot instructions (__ballot): "
<< (arch.hasWarpBallot: ? "yes" : "no")
<< (arch.hasWarpBallot ? "yes" : "no")
<< "\n Has warp shuffle operations. (__shfl_*): "
<< (arch.hasWarpShuffle ? "yes" : "no")
<< "\n Has funnel two words into one with shift&mask caps: "
Expand Down Expand Up @@ -187,7 +187,7 @@ inline DeviceError StreamDestroy(StreamType stream) {
}

inline DeviceError StreamWaitEvent(StreamType stream, EventType event) {
return hipStreamWaitEvent(stream, event);
return hipStreamWaitEvent(stream, event, 0);
}

inline DeviceError GraphInstantiate(
Expand All @@ -202,7 +202,8 @@ inline DeviceError GraphDestroy(GraphType graph) {

inline DeviceError GraphExecUpdate(GraphExecType graph_exec, GraphType graph) {
// We don't have hipGraphExecUpdate in some versions of rocm
return hipErrorUnknown;
hipGraphExecUpdateResult update;
return hipGraphExecUpdate(graph_exec, graph, nullptr, &update);
}

inline DeviceError GraphExecDestroy(GraphExecType graph_exec) {
Expand Down Expand Up @@ -314,7 +315,7 @@ inline DeviceError QueryEvent(EventType event) {
return hipEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return hipGetErrorString(err);
}

Expand Down