Skip to content

Commit

Permalink
Adds a minimal but viable implementation of string arrays (with `nump…
Browse files Browse the repository at this point in the history
…y.dtypes.StringDType`) in JAX. Currently this only supports making of a string array by means of either `jax.numpy.asarray` or `jax.device_put` and reading it back with `jax.device_get`.

PiperOrigin-RevId: 716042460
  • Loading branch information
Google-ML-Automation committed Feb 3, 2025
1 parent a42a623 commit cdddc20
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 5 deletions.
2 changes: 1 addition & 1 deletion xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ cc_library(
features = ["-use_header_modules"],
visibility = [":friends"],
deps = [
":nb_helpers",
":nb_numpy",
"//xla:literal",
"//xla:shape_util",
Expand Down Expand Up @@ -351,6 +350,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
83 changes: 83 additions & 0 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -97,6 +98,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -1700,15 +1702,96 @@ absl::StatusOr<nb::object> PyHostValue::AsNumPyArray(
} else {
TF_RETURN_IF_ERROR(ready_.Await());
}
if (string_array_contents_ != nullptr) {
TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array));
}
return value_;
}

absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray(
ifrt::Array* ifrt_array) {
#ifdef NPY_2_0_API_VERSION
if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) {
return absl::FailedPreconditionError(
absl::StrCat("String arrays are not supported in NumPy version: ",
PyArray_RUNTIME_VERSION));
}
auto numpy_dtype = nb::steal<nb_dtype>(
reinterpret_cast<PyObject*>(PyArray_DescrFromType(NPY_VSTRING)));
value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(),
/*strides=*/std::nullopt);

auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr());
auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(dst_py_array_obj)));
for (auto& cord : *string_array_contents_) {
absl::string_view input_str_view = cord.Flatten();
auto py_unicode = nb::steal(PyUnicode_FromStringAndSize(
input_str_view.data(), input_str_view.size()));
if (py_unicode.ptr() == nullptr) {
return absl::InternalError("PyUnicode_FromStringAndSize failed");
}
if (PyArray_SETITEM(dst_py_array_obj,
static_cast<char*>(PyArray_ITER_DATA(iter.ptr())),
py_unicode.ptr()) != 0) {
return absl::InternalError("PyArray_SETITEM failed");
}
PyArray_ITER_NEXT(iter.ptr());
}

value_.attr("flags").attr("writeable") = nb::bool_(false);

string_array_contents_.reset();

return absl::OkStatus();
#else
return absl::FailedPreconditionError(
"String arrays are not supported in this NumPy version.");
#endif
}

absl::Status PyHostValue::CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
auto transfer_guard_formatter = [ifrt_array] {
return absl::StrCat(
"shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","),
"), dtype=", ifrt_array->dtype().DebugString(), ", device=",
ifrt_array->sharding().devices()->devices().front()->DebugString());
};
TF_RETURN_IF_ERROR(
jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter));

TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype()));
auto shape = ifrt_array->shape();

// Allocate a vector of cords to hold the contents of the array until
// they are until they are ultimately converted to a numpy array as part
// of the `AsNumPyArray` call.
string_array_contents_ =
std::make_shared<std::vector<absl::Cord>>(shape.num_elements());
ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(),
/*byte_strides=*/std::nullopt,
ifrt::ArrayCopySemantics::kAlwaysCopy);

ready_.OnReady(
[string_array_contents = string_array_contents_](absl::Status) {
}); // Keeps the cords alive until the copy is done.

return absl::OkStatus();
}

absl::Status PyHostValue::CopyToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array) {
if (ready_.IsValid()) {
// The array value has been populated, so CopyToHostAsync has been called.
return absl::OkStatus();
}

// Copying in Arrays of type kString requires some special handling
if (ifrt_array->dtype().kind() == ifrt::DType::kString) {
return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array);
}

auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() &&
IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) {
Expand Down
12 changes: 12 additions & 0 deletions xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -69,8 +70,19 @@ class PyHostValue {
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

private:
absl::Status CopyStringArrayToHostAsync(
std::optional<Shape>& dynamic_shape_holder, ifrt::Array* ifrt_array);

absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array);

ifrt::Future<> ready_;
nb_numpy_ndarray value_;

// Optional field, only used for arrays of type kString. This vector of cords
// serves as input buffer for the CopyToHostBuffer call. It holds these
// contents until it is lazily converted it to a numpy array when the user
// calls `AsNumPyArray`.
std::shared_ptr<std::vector<absl::Cord>> string_array_contents_;
};

// Private to PyArray, but you cannot forward declare member classes.
Expand Down
71 changes: 71 additions & 0 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@ limitations under the License.
#include <exception>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -66,6 +71,32 @@ namespace xla {

namespace {

absl::StatusOr<std::vector<absl::Cord>> StringDTypeArrayToCords(
PyArrayObject* py_array_obj) {
if (PyArray_SIZE(py_array_obj) == 0) {
return absl::InvalidArgumentError("empty numpy array");
}

std::vector<absl::Cord> cords;
cords.reserve(PyArray_SIZE(py_array_obj));

auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(py_array_obj)));
while (PyArray_ITER_NOTDONE(iter.ptr())) {
auto* iter_data = PyArray_ITER_DATA(iter.ptr());
auto* item = PyArray_GETITEM(py_array_obj, static_cast<char*>(iter_data));
if (!item) {
return absl::InternalError(
"Failed to get elements out of the ndarray iter.");
}
Py_ssize_t len;
auto str = PyUnicode_AsUTF8AndSize(item, &len);
cords.push_back(absl::Cord(absl::string_view(str, len)));
PyArray_ITER_NEXT(iter.ptr());
}
return cords;
}

using DevicePutFunc = std::function<absl::StatusOr<DevicePutResultFn>(
nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind)>;
Expand Down Expand Up @@ -252,10 +283,50 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
};
}

absl::StatusOr<DevicePutResultFn> HandleStringNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
auto py_array_obj = reinterpret_cast<PyArrayObject*>(array.ptr());
TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj));

// Assemble all the parameters of MakeArrayFromHostBuffer
void* data = cords.data();
ifrt::Shape shape(
absl::MakeSpan(static_cast<const int64_t*>(array.shape()), array.ndim()));
std::shared_ptr<xla::ifrt::Sharding> sharding =
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind);

auto on_done_with_host_buffer = [cords = std::move(cords)] {};

return [client, data = data, shape = std::move(shape),
sharding = std::move(sharding),
on_done_with_host_buffer =
std::move(on_done_with_host_buffer)]() mutable
-> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt::DType(ifrt::DType::kString), std::move(shape),
/*byte_strides=*/std::nullopt, std::move(sharding),
ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
std::move(on_done_with_host_buffer)));

return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}

absl::StatusOr<DevicePutResultFn> HandleNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);

// String numpy arrays require substantially different processing.
if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') {
return HandleStringNumpyArray(h, client, to_device, options,
to_memory_kind);
}

TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));

PrimitiveType squashed_type;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 309
_version = 310

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
8 changes: 5 additions & 3 deletions xla/tsl/python/lib/core/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ limitations under the License.
#include <Python.h>
// clang-format on

#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export
#include "numpy/arrayobject.h" // IWYU pragma: export
#include "numpy/ndarraytypes.h" // IWYU pragma: export
#include "numpy/npy_common.h" // IWYU pragma: export
#include "numpy/numpyconfig.h" // IWYU pragma: export
#include "numpy/ufuncobject.h" // IWYU pragma: export

namespace tsl {

Expand Down

0 comments on commit cdddc20

Please sign in to comment.