Skip to content

Commit

Permalink
[JAX] Keep CPU host callbacks alive via IFRT, rather than by attachin…
Browse files Browse the repository at this point in the history
…g them to the Python object.

We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
  • Loading branch information
hawkinsp authored and copybara-github committed Oct 5, 2023
1 parent 2650ed7 commit ee02f9e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 11 deletions.
11 changes: 1 addition & 10 deletions xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,6 @@ PjRtLoadedExecutable::CreateInternal(
host_send_and_recv_callbacks.push_back(host_send_and_recv_callback);
}
}
if (!loaded_host_callbacks.empty() &&
!client->pjrt_client()->SupportsSendRecvCallbacks()) {
return InternalError("Host callback not supported for runtime type: %s",
client->runtime_type());
}

return std::unique_ptr<LoadedExecutable>(new PjRtLoadedExecutable(
client, std::move(pjrt_loaded_executable), std::move(devices),
Expand Down Expand Up @@ -473,11 +468,7 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
output_shapes_(std::move(output_shapes)),
output_shardings_(std::move(output_shardings)) {}

PjRtLoadedExecutable::~PjRtLoadedExecutable() {
// Reset the PjRt executable before host callbacks.
pjrt_loaded_executable_ = nullptr;
all_loaded_host_callbacks_->clear();
}
PjRtLoadedExecutable::~PjRtLoadedExecutable() = default;

StatusOr<PjRtLoadedExecutable::ExecuteResult> PjRtLoadedExecutable::Execute(
absl::Span<tsl::RCReference<Array>> args, const ExecuteOptions& options,
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 201
_version = 202

# Version number for MLIR:Python components.
mlir_api_version = 54
Expand Down

0 comments on commit ee02f9e

Please sign in to comment.