Skip to content

Commit

Permalink
Merge pull request #14 from gomlx/protos
Browse files Browse the repository at this point in the history
Updated proto definitions
  • Loading branch information
janpfeifer authored Nov 17, 2024
2 parents 0a2298a + 80156f6 commit 56ec5a2
Show file tree
Hide file tree
Showing 48 changed files with 8,811 additions and 4,852 deletions.
15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ computations (with large data) from Go using various [backends supported by Open
It can be used to power Machine Learning frameworks (e.g. [GoMLX](github.com/gomlx/gomlx)), image processing, scientific
computation, game AIs, etc.

**NEW**: Experimental, and somewhat limited **Apple/Metal support**.

And because Jax, TensorFlow and [optionally PyTorch](https://pytorch.org/xla/release/2.3/index.html) run on XLA,
it is possible to run Jax functions in Go with `gopjrt`, and probably TensorFlow and PyTorch as well.
See example 2 below.
Expand Down Expand Up @@ -204,16 +202,7 @@ For Linux (or Windows+WSL)+CUDA (NVidia GPU) support, in addition also run ([see
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_cuda.sh | bash
```

For Darwin/arm64 (M1, M2) GPU support, run the following script ([see source](https://github.com/gomlx/gopjrt/blob/main/cmd/install_darwin_arm64.sh)) to install under `/usr/local/{lib,include}`:

* **VERY EXPERIMENTAL**: only a subset of the operations and types supported (`float64` doesn't work). See https://developer.apple.com/metal/jax/.
And the CPU version of XLA is not working either. More of a `gopjrt` developer version.

```bash
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_darwin_arm64.sh | bash
```

**TODO(Darwin)**: Create a Homebrew version.
* ** 🚧🛠️ Mac (Darwin) support currently broken 🛠🚧️**: follow discussion in [XLA's issue #19152](https://github.com/openxla/xla/issues/19152) (and on XLA's discord channels)

**That's it**. The next sections explains in more details for those interested in special cases.

Expand Down Expand Up @@ -244,7 +233,7 @@ The installation scripts download the Linux/CUDA PJRT or the Darwin/arm64 and Da
If you have any questions, or want a custom installation of hte XLA Builder library, check and modify
[`cmd/install_linux_amd64.sh`](https://github.com/gomlx/gopjrt/blob/main/cmd/install_linux_amd64.sh),
[`cmd/install_cuda.sh`](https://github.com/gomlx/gopjrt/blob/main/cmd/install_cuda.sh) or
[`cmd/install_darwin_arm64.sh`](https://github.com/gomlx/gopjrt/blob/main/cmd/install_darwin_arm64.sh) (**VERY EXPERIMENTAL, GPU ONLY**)
[`cmd/install_darwin_arm64.sh`](https://github.com/gomlx/gopjrt/blob/main/cmd/install_darwin_arm64.sh) (🚧🛠️ **broken see note above** 🛠🚧️)
they are self-explaining.

### Installing PJRT plugins
Expand Down
4 changes: 2 additions & 2 deletions c/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ http_archive(
# Notice bazel.sh scrape the line below for the OpenXLA version, the format
# of the line should remain the same (the hash in between quotes), or bazel.sh
# must be changed accordingly.
OPENXLA_XLA_COMMIT_HASH = "5389d1b4b97942d71f8b598e88d4a6d4648229d8" # From 2024-10-18
OPENXLA_XLA_COMMIT_HASH = "9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea" # From 2024-11-16
http_archive(
name = "xla",
sha256 = "01556591a05a802ead29e67e5366d2056c3eafb27ed2f0abf7ce7978f1f4a32e", # From 2024-10-18
sha256 = "29e3e69bbbcce846d4f18564870629e4680c46789b8ddcaceea1b0f25c233468", # From 2024-11-16
strip_prefix = "xla-" + OPENXLA_XLA_COMMIT_HASH,
urls = [
"https://github.com/openxla/xla/archive/{hash}.zip".format(hash = OPENXLA_XLA_COMMIT_HASH),
Expand Down
15 changes: 12 additions & 3 deletions c/bazel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export USE_BAZEL_VERSION=7.4.0 # First version allowing cc_static_library rule.

DEBUG=0
OUTPUT_DIR=""
USE_STABLEHLO="false"
while [[ $# -gt 0 ]]; do
case $1 in
--debug)
Expand All @@ -28,6 +29,11 @@ while [[ $# -gt 0 ]]; do
OUTPUT_DIR="--output_base=$1"
shift
;;
--stablehlo)
echo "Linking StableHLO support."
USE_STABLEHLO="true"
shift
;;
-*|--*)
echo "Unknown flag $1"
exit 1
Expand Down Expand Up @@ -99,15 +105,15 @@ case "${TARGET_PLATFORM}" in
echo "Building for macOS amd64"
STARTUP_FLAGS="${STARTUP_FLAGS} --bazelrc=custom_darwin_amd64.bazelrc"
BUILD_FLAGS="${BUILD_FLAGS} --config=macos_amd64"
# Apple/Metal PJRT only works with StableHLO, so we link it along.
BUILD_FLAGS="${BUILD_FLAGS} --define use_stablehlo=false"
if [[ "$USE_TABLE_HLO" == "false" ]] ; then
echo "*** Apple/Metal PJRT (maintained by Apple) only works with StableHLO, consider adding --USE_STABLEHLO"
fi
;;

"darwin_arm64")
echo "Building for macOS arm64"
BUILD_FLAGS="${BUILD_FLAGS} --config=macos_arm64"
# Apple/Metal PJRT only works with StableHLO, so we link it along.
BUILD_FLAGS="${BUILD_FLAGS} --define use_stablehlo=true"
;;

*)
Expand All @@ -130,6 +136,9 @@ BUILD_FLAGS="${BUILD_FLAGS} --define tsl_protobuf_header_only=false"
# We need the dependencies to be linked statically -- they won't come from some external .so:
BUILD_FLAGS="${BUILD_FLAGS} --define framework_shared_object=false"

# Link-in StableHLO support: this multiplies by 8 the size of the gomlx_xlabuilder library, because
# it links in LLVM. But it's required for the Apple/Metal support.
BUILD_FLAGS="${BUILD_FLAGS} --define use_stablehlo=${USE_STABLEHLO}"

# XLA rules weren't meant to be exported, so we overrule their visibility
# constraints.
Expand Down
2 changes: 1 addition & 1 deletion c/gomlx/xlabuilder/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ gomlx_xlabuilder_deps = [
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/status",
"@xla//xla:comparison_util",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla:types",
"@xla//xla:util",
Expand Down
8 changes: 4 additions & 4 deletions c/gomlx/xlabuilder/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ int initCall = initSetTfLogs();
int xla_wrapper_version() { return XlaWrapperVersion; }

// Cast C pointer type to C++ object pointer.
xla::Status *XlaStatusCast(XlaStatus *s) {
return static_cast<xla::Status *>(s);
absl::Status *XlaStatusCast(XlaStatus *s) {
return static_cast<absl::Status *>(s);
}

char *c_str(const std::string &s) { return strdup(s.c_str()); }
Expand Down Expand Up @@ -82,8 +82,8 @@ int XlaStatusCode(XlaStatus *status) {
return int(XlaStatusCast(status)->code());
}

XlaStatus *FromStatus(const xla::Status &status) {
return static_cast<XlaStatus *>(new xla::Status(status));
XlaStatus *FromStatus(const absl::Status &status) {
return static_cast<XlaStatus *>(new absl::Status(status));
}

void XlaStatusFree(XlaStatus *xla_status) {
Expand Down
14 changes: 7 additions & 7 deletions c/gomlx/xlabuilder/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

// utils.h holds several small C/Go connector tools:
//
// - handling of xla::Status and xla::StatusOr.
// - handling of absl::Status and xla::StatusOr.
// - C definitions of VectorPointers and VectorData
// - Memory stats, usage and heap checker for leaks.

#ifndef _GOMLX_XLABUILDER_STATUS_H
#define _GOMLX_XLABUILDER_STATUS_H
// utils.h holds the simplified C interface to xla::Status and xla::StatusOr
// utils.h holds the simplified C interface to absl::Status and xla::StatusOr
// objects.
#include <stdlib.h>

Expand All @@ -35,7 +35,7 @@ typedef _Bool bool;
extern "C" {
#endif

// XlaStatus behind the scenes is a xla::Status type.
// XlaStatus is an *absl::Status cast to void* (is used to be `xla::Status`, now `absl::Status`).
typedef void XlaStatus;

// StatusOr contains status or the value from the C++ StatusOr.
Expand Down Expand Up @@ -75,7 +75,7 @@ typedef struct {
#include <string>
#include <vector>

#include "xla/status.h"
#include "absl/status/status.h"
#include "xla/statusor.h"
// #include "xla/xla/shape_util.h"

Expand Down Expand Up @@ -113,12 +113,12 @@ extern VectorPointers *c_vector_str(const std::vector<std::string> &v);

// FromStatus creates a dynamically allocated status (aliased to *XlaStatus)
// from the given one -- contents are transferred.
XlaStatus *FromStatus(const xla::Status &status);
XlaStatus *FromStatus(const absl::Status &status);

template <typename T> StatusOr FromStatusOr(xla::StatusOr<std::unique_ptr<T>> &status_or) {
StatusOr r;
r.status =
static_cast<XlaStatus *>(new xla::Status(std::move(status_or.status())));
static_cast<XlaStatus *>(new absl::Status(std::move(status_or.status())));
if (status_or.ok()) {
r.value = static_cast<void *>(status_or->get());
status_or->release(); // Ownership should go to StatusOr.
Expand All @@ -129,7 +129,7 @@ template <typename T> StatusOr FromStatusOr(xla::StatusOr<std::unique_ptr<T>> &s
template <typename T> StatusOr FromStatusOr(xla::StatusOr<T *> &status_or) {
StatusOr r;
r.status =
static_cast<XlaStatus *>(new xla::Status(std::move(status_or.status())));
static_cast<XlaStatus *>(new absl::Status(std::move(status_or.status())));
if (status_or.ok()) {
r.value = static_cast<void *>(status_or.Value());
}
Expand Down
8 changes: 4 additions & 4 deletions c/gomlx/xlabuilder/xlabuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) {
// Create the update computation: only Add supported for now.
auto shape_or = builder->GetShape(*inputs[0]);
if (!shape_or.ok()) {
return new xla::Status(std::move(shape_or.status()));
return new absl::Status(std::move(shape_or.status()));
}
xla::PrimitiveType primitive_type = shape_or.value().element_type();
auto update_computation =
Expand Down Expand Up @@ -619,17 +619,17 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) {
break;

default:
return new xla::Status(
return new absl::Status(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("unknown op_type=%d for XlaBuilderAddOp",
serialized_op->op_type));
}
if (!op.valid()) {
auto status = builder->first_error();
if (!status.ok()) {
return new xla::Status(status);
return new absl::Status(status);
}
return new xla::Status(
return new absl::Status(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("failed to convert serialized_op to XLA: op_type=%d",
serialized_op->op_type));
Expand Down
18 changes: 10 additions & 8 deletions c/xla_configure.darwin_amd64.bazelrc
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
build --action_env CLANG_COMPILER_PATH=/usr/local/Cellar/llvm/19.1.3/bin/clang-19
build --repo_env CC=/usr/local/Cellar/llvm/19.1.3/bin/clang-19
build --repo_env BAZEL_COMPILER=/usr/local/Cellar/llvm/19.1.3/bin/clang-19
build --linkopt --ld-path=/usr/local/bin/ld.lld
build --action_env CLANG_COMPILER_PATH=/usr/bin/clang
build --repo_env CC=/usr/bin/clang
build --repo_env BAZEL_COMPILER=/usr/bin/clang
build --linkopt --ld-path=/usr/bin/ld
build --action_env LD_LIBRARY_PATH=/usr/local/lib
build --action_env PYTHON_BIN_PATH=/usr/local/opt/[email protected]/bin/python3.13
build --python_path /usr/local/opt/[email protected]/bin/python3.13
test --test_env LD_LIBRARY_PATH
test --test_size_filters small,medium
build --copt -Wno-sign-compare
build --copt -Wno-error=unused-command-line-argument
build --build_tag_filters -no_oss,-gpu
build --test_tag_filters -no_oss,-gpu
test --build_tag_filters -no_oss,-gpu
test --test_tag_filters -no_oss,-gpu
build --copt -Wno-gnu-offsetof-extensions
build --build_tag_filters -no_oss,-no_mac,-gpu
build --test_tag_filters -no_oss,-no_mac,-gpu
test --build_tag_filters -no_oss,-no_mac,-gpu
test --test_tag_filters -no_oss,-no_mac,-gpu
15 changes: 7 additions & 8 deletions c/xla_configure.darwin_arm64.bazelrc
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
build --action_env CLANG_COMPILER_PATH=/opt/homebrew/Cellar/llvm/19.1.2/bin/clang-19
build --repo_env CC=/opt/homebrew/Cellar/llvm/19.1.2/bin/clang-19
build --repo_env BAZEL_COMPILER=/opt/homebrew/Cellar/llvm/19.1.2/bin/clang-19
build --linkopt --ld-path=/opt/homebrew/bin/ld.lld
build --action_env CLANG_COMPILER_PATH=/usr/bin/clang
build --repo_env CC=/usr/bin/clang
build --repo_env BAZEL_COMPILER=/usr/bin/clang
build --action_env PYTHON_BIN_PATH=/opt/homebrew/opt/[email protected]/bin/python3.13
build --python_path /opt/homebrew/opt/[email protected]/bin/python3.13
test --test_env LD_LIBRARY_PATH
test --test_size_filters small,medium
build --copt -Wno-sign-compare
build --copt -Wno-error=unused-command-line-argument
build --build_tag_filters -no_oss,-gpu
build --test_tag_filters -no_oss,-gpu
test --build_tag_filters -no_oss,-gpu
test --test_tag_filters -no_oss,-gpu
build --build_tag_filters -no_oss,-no_mac,-gpu
build --test_tag_filters -no_oss,-no_mac,-gpu
test --build_tag_filters -no_oss,-no_mac,-gpu
test --test_tag_filters -no_oss,-no_mac,-gpu
29 changes: 19 additions & 10 deletions cmd/pjrt_codegen/main.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
// codegen parses the pjrt_c_api.h and generates boilerplate code for creating the various C structures.
// pjrt_codegen copies prjt_c_api.h from github.com/openxla/xla source (pointed by XLA_SRC env variable),
// parses it and generates boilerplate code for creating the various C structures.
package main

import (
"bytes"
"github.com/janpfeifer/gonb/common"
"github.com/janpfeifer/must"
"io"
"log"
"os"
"path"
)

const pjrtCAPIHFilePath = "pjrt_c_api.h"
const (
xlaSrcEnvVar = "XLA_SRC"
pjrtAPIFileName = "pjrt_c_api.h"
)

func main() {
// Read pjrt_c_api.h
f := must.M1(os.OpenFile(pjrtCAPIHFilePath, os.O_RDONLY, os.ModePerm))
var b bytes.Buffer
_ = must.M1(io.Copy(&b, f))
must.M(f.Close())
contents := b.String()
xlaSrc := os.Getenv(xlaSrcEnvVar)
if xlaSrc == "" {
log.Fatalf("Please set %s to the directory containing the cloned github.com/openxla/xla repository.\n", xlaSrcEnvVar)
}
xlaSrc = common.ReplaceTildeInDir(xlaSrc)

// Copy pjrt_c_api.h.
contentsBytes := must.M1(os.ReadFile(path.Join(xlaSrc, "xla", "pjrt", "c", pjrtAPIFileName)))
must.M(os.WriteFile(pjrtAPIFileName, contentsBytes, 0644))

// Create various Go generate files.
contents := string(contentsBytes)
generateNewStruct(contents)
generateAPICalls(contents)
generateEnums(contents)
Expand Down
41 changes: 32 additions & 9 deletions cmd/protoc_xla_protos/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package main

import (
"fmt"
"github.com/pkg/errors"
"log"
"os"
"os/exec"
"path"
Expand All @@ -30,6 +32,7 @@ var protos = []string{
"xla/autotuning.proto",
"xla/pjrt/compile_options.proto",
"xla/service/hlo.proto",
"xla/service/metrics.proto",
"xla/stream_executor/device_description.proto",
"xla/xla.proto",
"xla/xla_data.proto",
Expand All @@ -38,8 +41,7 @@ var protos = []string{
func main() {
xlaSrc := os.Getenv(xlaSrcEnvVar)
if xlaSrc == "" {
fmt.Fprintf(os.Stderr, "Please set %s to the directory containing the cloned github.com/openxla/xla repository.\n", xlaSrcEnvVar)
os.Exit(1)
log.Fatalf("Please set %s to the directory containing the cloned github.com/openxla/xla repository.\n", xlaSrcEnvVar)
}

// Generate the --go_opt=M... flags
Expand All @@ -54,14 +56,12 @@ func main() {
packageName := protoPackage(proto)
err := os.Mkdir(packageName, 0755)
if err != nil && !os.IsExist(err) {
fmt.Fprintf(os.Stderr, "Failed to create sub-directory %q: %+v", packageName, err)
os.Exit(1)
log.Fatalf("Failed to create sub-directory %q: %+v", packageName, err)
}
// Remove go_package options from the proto file
protoPath := filepath.Join(xlaSrc, proto)
if err := removeGoPackageOption(protoPath); err != nil {
fmt.Fprintf(os.Stderr, "Error removing go_package option from %s: %v\n", proto, err)
os.Exit(1)
log.Fatalf("Error removing go_package option from %s: %v\n", proto, err)
}

// Construct the protoc command
Expand All @@ -79,9 +79,17 @@ func main() {
cmd.Stderr = os.Stderr

if err := cmd.Run(); err != nil {
fmt.Fprintf(os.Stderr, "Error executing protoc for %s: %v\n", proto, err)
fmt.Fprintf(os.Stderr, "Command:\n%s\n", cmd)
os.Exit(1)
log.Printf("Command:\n%s\n", cmd)
log.Fatalf("Error executing protoc for %s: %v\n", proto, err)
}

currentDir, err := os.Getwd()
if err != nil {
log.Fatalf("Failed to get current directory: %v", err)
}
localCopyPath := path.Join(currentDir, path.Base(protoPath))
if err := copyFile(localCopyPath, protoPath); err != nil {
log.Fatalf("Failed to copy file: %v", err)
}
}
}
Expand All @@ -103,3 +111,18 @@ func removeGoPackageOption(protoPath string) error {

return os.WriteFile(protoPath, newContent, 0644)
}

func copyFile(dst, src string) error {
// Read all content of src to data, may cause OOM for a large file.
data, err := os.ReadFile(src)
if err != nil {
return errors.Wrapf(err, "failed to read %q", src)
}

// Write data to dst
err = os.WriteFile(dst, data, 0644)
if err != nil {
return errors.Wrapf(err, "failed to read %q", src)
}
return nil
}
Loading

0 comments on commit 56ec5a2

Please sign in to comment.