Skip to content

Commit

Permalink
method_meta add uses_backend function
Browse files Browse the repository at this point in the history
Differential Revision: D69143825

Pull Request resolved: pytorch#8198
  • Loading branch information
cmt0 authored Feb 6, 2025
1 parent 4a1bf29 commit edf3952
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions runtime/executor/method_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ Result<int64_t> MethodMeta::memory_planned_buffer_size(size_t index) const {
return s_plan_->non_const_buffer_sizes()->Get(index + 1);
}

bool MethodMeta::uses_backend(const char* backend_name) const {
const auto delegates = s_plan_->delegates();
for (size_t i = 0; i < delegates->size(); i++) {
auto delegate = delegates->Get(i);
if (strcmp(delegate->id()->c_str(), backend_name) == 0) {
return true;
}
}
return false;
}

size_t MethodMeta::num_instructions() const {
const auto chains = s_plan_->chains();
if (chains == nullptr) {
Expand Down
8 changes: 8 additions & 0 deletions runtime/executor/method_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ class MethodMeta final {
*/
Result<int64_t> memory_planned_buffer_size(size_t index) const;

/**
* Check to see if a backend is used in this method.
*
* @param[in] backend_name The name of the backend to search for.
* @returns true if a backend is used in this method, otherwise false.
*/
bool uses_backend(const char* backend_name) const;

/**
* Get the number of instructions in this method.
*
Expand Down
4 changes: 4 additions & 0 deletions runtime/executor/test/backend_integration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ TEST_P(BackendIntegrationTest, BasicInitSucceeds) {
Result<Program> program = Program::load(&loader.get());
ASSERT_EQ(program.error(), Error::Ok);

auto method_meta = program->method_meta("forward");
EXPECT_EQ(method_meta->uses_backend(StubBackend::kName), true);
EXPECT_EQ(method_meta->uses_backend("INVALID_BACKEND_NAME"), false);

ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method_res = program->load_method("forward", &mmm.get());
EXPECT_EQ(method_res.error(), Error::Ok);
Expand Down

0 comments on commit edf3952

Please sign in to comment.