Skip to content

Commit

Permalink
Merge pull request #20 from gomlx/io
Browse files Browse the repository at this point in the history
v0.5.0: Added direct access to PJRT buffers when PJRT running in CPU; Benchmarks.
janpfeifer authored Dec 19, 2024
2 parents c8c81d9 + 4ef585a commit 3e4e41d
Showing 29 changed files with 1,975 additions and 521 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yaml
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.22.x"
go-version: "1.23.x"

- name: Install Gopjrt C library gomlx_xlabuilder and PJRT plugin
shell: bash
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -329,7 +329,7 @@ Also, see [this blog post](https://opensource.googleblog.com/2024/03/pjrt-plugin
Because of https://github.com/golang/go/issues/13467 : C API's cannot be exported across packages, even within the same repo.
Even a function as simple as `func Add(a, b C.int) C.int` in one package cannot be called from another.
So we need to wrap everything, and more than that, one cannot create separate sub-packages to handle separate concerns.
THis is also the reason the library `chelper.go` is copied in both `pjrt` and `xlabuilder` packages.
This is also the reason the library `chelper.go` is copied in both `pjrt` and `xlabuilder` packages.
* **Why does PJRT spits out so much logging ? Can we disable it ?**
This is a great question ... imagine if every library we use decided they also want to clutter our stderr?
I have [an open question in Abseil about it](https://github.com/abseil/abseil-cpp/discussions/1700).
@@ -340,6 +340,18 @@ Also, see [this blog post](https://opensource.googleblog.com/2024/03/pjrt-plugin
before calling `pjrt.GetPlugin`. But it may have unintended consequences, if some other library is depending
on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.

## Environment Variables

That help control or debug how **gopjrt** work:

* `PJRT_PLUGIN_LIBRARY_PATH`: Path to search for PJRT plugins. **gopjrt** also searches in `/usr/local/lib/gomlx/pjrt`,
the standard library paths for the system and `$LD_LIBRARY_PATH`.
* `XLA_DEBUG_OPTIONS`: If set, it is parsed as a `DebugOptions` proto that
is passed during the JIT-compilation (`Client.Compile()`) of a computation graph.
It is not documented how it works in PJRT (e.g. I observed a great slow down when this is set,
even if set to the default values), but [the proto has some documentation](https://github.com/gomlx/gopjrt/blob/main/protos/xla.proto#L40).
* `GOPJRT_INSTALL_DIR` and `GOPJRT_NOSUDO`: used by the install scripts, see "Installing" section above.

## Links to documentation

* [Google Drive Directory with Design Docs](https://drive.google.com/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN): Some links are outdated or redirected, but very valuable information.
5 changes: 3 additions & 2 deletions c/WORKSPACE
Original file line number Diff line number Diff line change
@@ -21,11 +21,12 @@ http_archive(
# Notice bazel.sh scrape the line below for the OpenXLA version, the format
# of the line should remain the same (the hash in between quotes), or bazel.sh
# must be changed accordingly.
OPENXLA_XLA_COMMIT_HASH = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" # From 2024-11-24
# OPENXLA_XLA_COMMIT_HASH = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" # From 2024-11-24
OPENXLA_XLA_COMMIT_HASH = "e2e8952ad0fac8833e9a78f9b3689e803ff8524f" # From 2024-12-11

http_archive(
name = "xla",
sha256 = "a910124d546bc79edb685612edaa3d56153f0e0927f967e8defaf312b833d404", # From 2024-11-24
sha256 = "5ec6919a25952fa790904983481ccb51ebbe20bbc53e15ddbb6d3e0b3aa3dfe1", # From 2024-12-11
strip_prefix = "xla-" + OPENXLA_XLA_COMMIT_HASH,
urls = [
"https://github.com/openxla/xla/archive/{hash}.zip".format(hash = OPENXLA_XLA_COMMIT_HASH),
8 changes: 3 additions & 5 deletions chelper.go
Original file line number Diff line number Diff line change
@@ -31,17 +31,15 @@ func cSizeOf[T any]() C.size_t {
// It must be manually freed with cFree() by the user.
func cMalloc[T any]() (ptr *T) {
size := cSizeOf[T]()
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
cPtr := (*T)(C.calloc(1, size))
return cPtr
}

// cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero.
// It must be manually freed with C.free() by the user.
func cMallocArray[T any](n int) (ptr *T) {
size := cSizeOf[T]() * C.size_t(n)
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
size := cSizeOf[T]()
cPtr := (*T)(C.calloc(C.size_t(n), size))
return cPtr
}

4 changes: 2 additions & 2 deletions cmd/run_coverage.sh
Original file line number Diff line number Diff line change
@@ -2,6 +2,6 @@

# Run this from the root of gopjrt repository to generate docs/coverage.out with the coverage data.

PACKAGE_COVERAGE="./pjrt ./xlabuilder"
go test -v -cover -coverprofile docs/coverage.out -coverpkg ${PACKAGE_COVERAGE}
PACKAGE_COVERAGE="github.com/gomlx/gopjrt/pjrt,github.com/gomlx/gopjrt/xlabuilder"
go test -cover -coverprofile docs/coverage.out -coverpkg="${PACKAGE_COVERAGE}" ./... -test.count=1 -test.short
go tool cover -func docs/coverage.out -o docs/coverage.out
22 changes: 21 additions & 1 deletion docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
# Next
# v0.5.0 - 2024/12/19 - Adding direct access to PJRT buffers for CPU.

* Added `install_linux_amd64_amazonlinux.sh` and pre-built libraries for amazonlinux (built using old glibc support).
* Fixed installation scripts: s/sudo/$_SUDO. Also made them more verbose.
* Removed dependency on `xargs` in installation script for Linux.
* Improved documentation on Nvidia GPU card detection, and error message if not found.
* Updated GitHub action (`go.yaml`) to only change the README.md with the result of the change, if pushing to the
`main` branch.
* Added `prjt.arena` to avoid costly allocations for CGO calls, and merged some of CGO calls for general speed-ups.
The following functions had > 50% improvements on their fixed-cost (measured on transfers with 1 value, and minimal programs)
execution time (**not the variable part**):
* `Buffer.ToHost()`
* `Client.BufferFromHost()`
* `LoadedExecutable.Execute()`
* Added `BufferToHost` and `BufferFromHost` benchmarks.
* Added support for environment variable `XLA_DEBUG_OPTIONS`: if set, it is parsed as a `DebugOptions` proto that
is passed to the JIT-compilation of a computation graph.
* `LoadedExecutable.Execute()` now waits for the end of the execution (by setting
`PJRT_LoadedExecutable_Execute_Args.device_complete_events`).
Previous behavior lead to odd behavior and was undefined (not documented).
* Package `dtypes`:
* Added tests;
* Added `SizeForDimensions()` to be used for dtypes that uses fractions of bytes (like 4 bits).
* Added `Client.NewSharedBuffer` (and the lower level `client.CreateViewOfDeviceBuffer()`) to create buffers with shared
memory with the host, for faster input.
* Added `AlignedAlloc` and `AlignedFree` required by `client.CreateViewOfDeviceBuffer`.
* Added `Buffer.Data` for direct access to a buffer's data. Undocumented in PJRT, and likely only works on CPU.
* Fixed coverage script.

# v0.4.9 - 2024-11-25

473 changes: 348 additions & 125 deletions docs/coverage.out

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions dtypes/dtypes.go
Original file line number Diff line number Diff line change
@@ -152,10 +152,30 @@ func FromAny(value any) DType {
}

// Size returns the number of bytes for the given DType.
// If the size is < 1 (like a 4-bits quantity), consider the SizeForDimensions method.
func (dtype DType) Size() int {
return int(dtype.GoType().Size())
}

// SizeForDimensions returns the size in bytes used for the given dimensions.
// This is a safer method than Size in case the dtype uses an underlying size that is not multiple of 8 bits.
//
// It works also for scalar (one element), where dimensions list is empty.
func (dtype DType) SizeForDimensions(dimensions ...int) int {
numElements := 1
for _, dim := range dimensions {
if dim <= 0 {
panicf("cannot use dim <= 0 for SizeForDimensions, got %v", dimensions)
}
numElements *= dim
}

// Switch case for dtypes with size not multiple of 8 bits (1 byte).

// Default is simply the number of elements times the size in bytes per element.
return numElements * dtype.Size()
}

// Memory returns the number of bytes for the given DType.
// It's an alias to Size, converted to uintptr.
func (dtype DType) Memory() uintptr {
19 changes: 19 additions & 0 deletions dtypes/dtypes_test.go
Original file line number Diff line number Diff line change
@@ -33,3 +33,22 @@ func TestMapOfNames(t *testing.T) {
require.Equal(t, BFloat16, MapOfNames["BF16"])
require.Equal(t, BFloat16, MapOfNames["bf16"])
}

func TestFromAny(t *testing.T) {
require.Equal(t, Int64, FromAny(int64(7)))
require.Equal(t, Float32, FromAny(float32(13)))
require.Equal(t, BFloat16, FromAny(bfloat16.FromFloat32(1.0)))
require.Equal(t, Float16, FromAny(float16.Fromfloat32(3.0)))
}

func TestSize(t *testing.T) {
require.Equal(t, 8, Int64.Size())
require.Equal(t, 4, Float32.Size())
require.Equal(t, 2, BFloat16.Size())
}

func TestSizeForDimensions(t *testing.T) {
require.Equal(t, 2*3*8, Int64.SizeForDimensions(2, 3))
require.Equal(t, 4, Float32.SizeForDimensions())
require.Equal(t, 2, BFloat16.SizeForDimensions(1, 1, 1))
}
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@ go 1.23

require (
github.com/chewxy/math32 v1.11.1
github.com/janpfeifer/go-benchmarks v0.1.1
github.com/janpfeifer/gonb v0.10.6
github.com/janpfeifer/must v0.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
google.golang.org/protobuf v1.35.2
k8s.io/klog/v2 v2.130.1
@@ -18,6 +19,7 @@ require (
github.com/go-logr/logr v1.4.2 // indirect
github.com/gofrs/uuid v4.4.0+incompatible // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d // indirect
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@ github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw=
github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM=
github.com/janpfeifer/gonb v0.10.6 h1:DyzaE8lLtJRrzu84N/XuUb0OC71EsIf9ZZhZc4mWxwA=
github.com/janpfeifer/gonb v0.10.6/go.mod h1:ZX+93yQa2s3td0JOML7x+Jo4mmuwi+XryQ0iQvrSpBQ=
github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ=
@@ -16,8 +18,10 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs=
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo=
54 changes: 54 additions & 0 deletions pjrt/alignedalloc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package pjrt

// This file defines an alignedAlloc and alignedFree, modelled after mm_malloc.

/*
#include <stdlib.h>
*/
import "C"
import (
"fmt"
"unsafe"
)

// BufferAlignment is the default alignment required for memory shared with CPU PJRT.
// See AlignedAlloc and FreeAlloc.
const BufferAlignment = 64

// AlignedAlloc assumes that malloc/calloc already aligns to 8 bytes. And that alignment is a multiple of 8.
// The pointer returned must be freed with AlignedFree.
//
// The allocation is filled with 0s.
func AlignedAlloc(size, alignment uintptr) unsafe.Pointer {
if alignment < 8 || alignment%8 != 0 {
panic(fmt.Sprintf("alignedAlloc: alignment must be a multiple of 8, got %d", alignment))
}

// It uses a strategy of allocating extra to allow the alignment, and it stores the pointer to the
// original allocation just before the alignedPtr.
totalSize := size + alignment
ptr := unsafe.Pointer(C.calloc(C.size_t(totalSize), C.size_t(1)))
if ptr == nil {
return nil
}

alignedPtr := ptr
offset := uintptr(ptr) % alignment
if offset != 0 {
alignedPtr = unsafe.Pointer(uintptr(ptr) + (alignment - offset))
} else {
alignedPtr = unsafe.Pointer(uintptr(ptr) + alignment) // This way we have the space to save the original ptr.
}

originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(alignedPtr) - unsafe.Sizeof(uintptr(0))))
*originalPtrPtr = uintptr(ptr)

return alignedPtr
}

// AlignedFree frees an allocation created with AlignedAlloc.
func AlignedFree(ptr unsafe.Pointer) {
originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(ptr) - unsafe.Sizeof(uintptr(0))))
originalPtr := unsafe.Pointer(*originalPtrPtr)
C.free(originalPtr)
}
25 changes: 25 additions & 0 deletions pjrt/alignedalloc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package pjrt

import (
"math/rand/v2"
"testing"
"unsafe"
)

func TestAlignedAlloc(t *testing.T) {
rng := rand.New(rand.NewPCG(42, 42))
numLivePointers := 1_000
maxAllocSize := 1_000
pointers := make([]unsafe.Pointer, numLivePointers)
for _ = range 1_000_000 {
idx := rng.IntN(numLivePointers)
if pointers[idx] != nil {
AlignedFree(pointers[idx])
}
size := uintptr(rng.IntN(maxAllocSize))
pointers[idx] = AlignedAlloc(size, BufferAlignment)
}
for _, ptr := range pointers {
AlignedFree(ptr)
}
}
117 changes: 117 additions & 0 deletions pjrt/arena.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package pjrt

/*
#include <string.h>
*/
import "C"
import (
"fmt"
"reflect"
"sync"
"unsafe"
)

// arenaContainer implements a trivial arena object to accelerate allocations that will be used in CGO calls.
//
// The issue it is trying to solve is that individual CGO calls are slow, including C.malloc().
//
// It pre-allocates the given size in bytes in C -- so it does not needs to be pinned when using CGO, and allow
// for fast sub-allocations.
// It can only be freed all at once.
//
// If you don't call Free at the end, it will leak the C allocated space.
//
// See newArena and arenaAlloc, and also arenaPool.
type arenaContainer struct {
buf []byte
size, current int
}

const arenaDefaultSize = 2048

var arenaPool sync.Pool = sync.Pool{
New: func() interface{} {
return newArena(arenaDefaultSize)
},
}

// getArenaFromPool gets an arena of the default size.
// Must be matched with a call returnArenaToPool when it's no longer used.
func getArenaFromPool() *arenaContainer {
return arenaPool.Get().(*arenaContainer)
}

// returnArenaToPool returns an arena acquired with getArenaFromPool.
// It also resets the arena.
func returnArenaToPool(a *arenaContainer) {
a.Reset()
arenaPool.Put(a)
}

// newArena creates a new Arena with the given fixed size.
//
// It provides fast sub-allocations, which can only be freed all at once.
//
// TODO: support memory-alignment.
//
// See arenaAlloc, arena.Free and arena.Reset.
func newArena(size int) *arenaContainer {
buf := cMallocArray[byte](size)
a := &arenaContainer{
buf: unsafe.Slice(buf, size),
size: size,
}
return a
}

const arenaAlignBytes = 8

// arenaAlloc allocates a type T from the arena. It panics if the arena run out of memory.
func arenaAlloc[T any](a *arenaContainer) (ptr *T) {
allocSize := cSizeOf[T]()
if a.current+int(allocSize) > a.size {
panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for %q", allocSize, reflect.TypeOf(ptr).Elem()))
}
ptr = (*T)(unsafe.Pointer(&a.buf[a.current]))
a.current += int(allocSize)
a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1)
return
}

// arenaAllocSlice allocates an array of n elements of type T from the arena.
//
// It panics if the arena run out of memory.
func arenaAllocSlice[T any](a *arenaContainer, n int) (slice []T) {
allocSize := C.size_t(n) * cSizeOf[T]()
if a.current+int(allocSize) > a.size {
panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for [%d]%s", allocSize, n, reflect.TypeOf(slice).Elem()))
}
ptr := (*T)(unsafe.Pointer(&a.buf[a.current]))
a.current += int(allocSize)
a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1)
slice = unsafe.Slice(ptr, n)
return
}

// Free invalidates all previous allocations of the arena and frees the C allocated area.
func (a *arenaContainer) Free() {
cFree(&a.buf[0])
a.buf = nil
a.size = 0
a.current = 0
}

// Reset invalidates all previous allocations with the arena, but does not free the C allocated area.
// This way the arena can be re-used.
func (a *arenaContainer) Reset() {
// Zero the values used.
if a.buf == nil || a.size == 0 {
a.current = 0
return
}
if a.current > 0 {
clearSize := min(a.size, a.current)
C.memset(unsafe.Pointer(&a.buf[0]), 0, C.size_t(clearSize))
}
a.current = 0
}
26 changes: 26 additions & 0 deletions pjrt/arena_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pjrt

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestArena(t *testing.T) {
arena := newArena(1024)
for _ = range 2 {
require.Equal(t, 1024, arena.size)
require.Equal(t, 0, arena.current)
_ = arenaAlloc[int](arena)
require.Equal(t, 8, arena.current)
_ = arenaAlloc[int32](arena)
require.Equal(t, 16, arena.current)

_ = arenaAllocSlice[byte](arena, 9) // Aligning, it will occupy 16 bytes total.
require.Equal(t, 32, arena.current)

require.Panics(t, func() { _ = arenaAlloc[[512]int](arena) }, "Arena out of memory")
require.Panics(t, func() { _ = arenaAllocSlice[float64](arena, 512) }, "Arena out of memory")
arena.Reset()
}
arena.Free()
}
394 changes: 394 additions & 0 deletions pjrt/benchmarks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
package pjrt

// To run benchmarks: (and fix to P-Cores if running on a Intel i9-12900K)
// go test -c . && taskset 0xFF ./pjrt.test -test.v -test.run=Bench -test.count=1
//
// See results in https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=1369069161#gid=1369069161
import (
"flag"
"fmt"
"github.com/gomlx/gopjrt/dtypes"
"github.com/gomlx/gopjrt/xlabuilder"
benchmarks "github.com/janpfeifer/go-benchmarks"
"runtime"
"testing"
"time"
"unsafe"
)

var (
flagBenchDuration = flag.Duration("bench_duration", 1*time.Second, "Benchmark duration")

// testShapes used during benchmarks executing small computation graphs.
testShapes = []xlabuilder.Shape{
xlabuilder.MakeShape(dtypes.Float32, 1, 1),
xlabuilder.MakeShape(dtypes.Float32, 10, 10),
xlabuilder.MakeShape(dtypes.Float32, 100, 100),
xlabuilder.MakeShape(dtypes.Float32, 1000, 1000),
}
)

// TestBenchCGO benchmarks a minimal CGO call.
func TestBenchCGO(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
const repeats = 1000
repeatedCGO := func() {
for _ = range repeats {
dummyCGO(unsafe.Pointer(plugin.api))
}
}
benchmarks.New(benchmarks.NamedFunction{"CGOCall", repeatedCGO}).
WithInnerRepeats(repeats).
Done()
}

// Benchmark tests different methods to create temporary pointers to be passed to CGO.
func TestBenchArena(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

numAllocationsList := []int{1, 5, 10, 100}
allocations := make([]*int, 100)
testFns := make([]benchmarks.NamedFunction, 4*len(numAllocationsList))
const repeats = 10
idxFn := 0
for _, allocType := range []string{"arena", "arenaPool", "malloc", "go+pinner"} {
for _, numAllocations := range numAllocationsList {
testFns[idxFn].Name = fmt.Sprintf("%s/%s/%d", t.Name(), allocType, numAllocations)
var fn func()
switch allocType {
case "arena":
fn = func() {
for _ = range repeats {
arena := newArena(1024)
for idx := range numAllocations {
allocations[idx] = arenaAlloc[int](arena)
}
dummyCGO(unsafe.Pointer(allocations[numAllocations-1]))
arena.Free()
}
}
case "arenaPool":
fn = func() {
for _ = range repeats {
arena := getArenaFromPool()
for idx := range numAllocations {
allocations[idx] = arenaAlloc[int](arena)
}
dummyCGO(unsafe.Pointer(allocations[numAllocations-1]))
returnArenaToPool(arena)
}
}
case "malloc":
fn = func() {
for _ = range repeats {
for idx := range numAllocations {
allocations[idx] = cMalloc[int]()
}
dummyCGO(unsafe.Pointer(allocations[numAllocations-1]))
for idx := range numAllocations {
cFree(allocations[idx])
}
}
}
case "go+pinner":
fn = func() {
for _ = range repeats {
var pinner runtime.Pinner
for idx := range numAllocations {
v := idx
allocations[idx] = &v
pinner.Pin(allocations[idx])
}
dummyCGO(unsafe.Pointer(allocations[numAllocations-1]))
pinner.Unpin()
}
}
}
testFns[idxFn].Func = fn
idxFn++
}
}
benchmarks.New(testFns...).
WithInnerRepeats(repeats).
WithWarmUps(10).
Done()
}

// TestBenchBufferFromHost benchmarks host->buffer transfer time.
func TestBenchBufferFromHost(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

const repeats = 10
numShapes := len(testShapes)
inputData := make([][]float32, numShapes)
testFns := make([]benchmarks.NamedFunction, numShapes)
for shapeIdx, s := range testShapes {
inputData[shapeIdx] = make([]float32, s.Size())
for i := 0; i < s.Size(); i++ {
inputData[shapeIdx][i] = float32(i)
}
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s)
testFns[shapeIdx].Func = func() {
for _ = range repeats {
x := inputData[shapeIdx]
s := testShapes[shapeIdx]
buf := must1(ArrayToBuffer(client, x, s.Dimensions...))
must(buf.Destroy())
}
}
}
benchmarks.New(testFns...).
WithInnerRepeats(repeats).
WithDuration(*flagBenchDuration).
Done()
}

// TestBenchBufferToHost benchmarks time to transfer data from device buffer to host.
//
// Results on CPU:
//
// Benchmarks: Median 5%-tile 99%-tile
// TestBenchBufferToHost/shape=(Float32)[1 1] 1.684µs 1.555µs 3.762µs
// TestBenchBufferToHost/shape=(Float32)[10 10] 1.651µs 1.534µs 3.699µs
// TestBenchBufferToHost/shape=(Float32)[100 100] 5.393µs 5.002µs 7.271µs
// TestBenchBufferToHost/shape=(Float32)[1000 1000] 131.826µs 131.498µs 139.316µs
func TestBenchBufferToHost(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

const repeats = 10
numShapes := len(testShapes)
testFns := make([]benchmarks.NamedFunction, numShapes)
// Prepare output data (host destination array) and upload buffers to GPU
outputData := make([][]float32, numShapes)
buffers := make([]*Buffer, numShapes)
for shapeIdx, s := range testShapes {
outputData[shapeIdx] = make([]float32, s.Size())
for i := 0; i < s.Size(); i++ {
outputData[shapeIdx][i] = float32(i)
}
buffers[shapeIdx] = must1(ArrayToBuffer(client, outputData[shapeIdx], s.Dimensions...))
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s)
testFns[shapeIdx].Func = func() {
for _ = range repeats {
buf := buffers[shapeIdx]
rawData := unsafe.Slice((*byte)(unsafe.Pointer(&outputData[shapeIdx][0])), len(outputData[shapeIdx])*int(unsafe.Sizeof(outputData[shapeIdx][0])))
must(buf.ToHost(rawData))
}
}
}
defer func() {
for _, buf := range buffers {
must(buf.Destroy())
}
}()

benchmarks.New(testFns...).
WithInnerRepeats(repeats).
Done()
}

// BenchmarkAdd1Execution benchmarks the execution time for a minimal program.
func TestBenchAdd1Execution(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// Prepare input data, the uploaded buffers and the executables.
const repeats = 10
numShapes := len(testShapes)
execs := make([]*LoadedExecutable, numShapes)
inputData := make([][]float32, numShapes)
buffers := make([]*Buffer, numShapes)
testFns := make([]benchmarks.NamedFunction, numShapes)
for shapeIdx, s := range testShapes {
inputData[shapeIdx] = make([]float32, s.Size())
for i := 0; i < s.Size(); i++ {
inputData[shapeIdx][i] = float32(i)
}
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...))

builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s))
// f(x) = x + 1
x := must1(xlabuilder.Parameter(builder, "x", 0, s))
one := must1(xlabuilder.ScalarOne(builder, s.DType))
add1 := must1(xlabuilder.Add(x, one))
comp := must1(builder.Build(add1))
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done())
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s)
testFns[shapeIdx].Func = func() {
for _ = range repeats {
buf := buffers[shapeIdx]
exec := execs[shapeIdx]
output := must1(exec.Execute(buf).Done())[0]
must(output.Destroy())
}
}
}
defer func() {
// Clean up -- and don't wait for the GC.
for shapeIdx := range numShapes {
must(buffers[shapeIdx].Destroy())
must(execs[shapeIdx].Destroy())
}
}()

benchmarks.New(testFns...).
WithInnerRepeats(repeats).
WithWarmUps(100).
WithDuration(*flagBenchDuration).
Done()
}

// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2.
//
// Runtimes for cpu:
//
// Benchmarks: Median 5%-tile 99%-tile
// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs
// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs
// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs
// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs
func TestBenchAdd1Div2Execution(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// Prepare input data, the uploaded buffers and the executables.
const repeats = 10
numShapes := len(testShapes)
execs := make([]*LoadedExecutable, numShapes)
inputData := make([][]float32, numShapes)
buffers := make([]*Buffer, numShapes)
testFns := make([]benchmarks.NamedFunction, numShapes)
for shapeIdx, s := range testShapes {
inputData[shapeIdx] = make([]float32, s.Size())
for i := 0; i < s.Size(); i++ {
inputData[shapeIdx][i] = float32(i)
}
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...))

builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s))
// f(x) = x + 1
x := must1(xlabuilder.Parameter(builder, "x", 0, s))
one := must1(xlabuilder.ScalarOne(builder, s.DType))
add1 := must1(xlabuilder.Add(x, one))
half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5))))
div2 := must1(xlabuilder.Mul(add1, half))
comp := must1(builder.Build(div2))
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done())
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s)
testFns[shapeIdx].Func = func() {
for _ = range repeats {
buf := buffers[shapeIdx]
exec := execs[shapeIdx]
output := must1(exec.Execute(buf).Done())[0]
_ = output.Destroy()
}
}
}
defer func() {
// Clean up -- and don't wait for the GC.
for shapeIdx := range numShapes {
must(buffers[shapeIdx].Destroy())
must(execs[shapeIdx].Destroy())
}
}()

benchmarks.New(testFns...).
WithInnerRepeats(repeats).
WithWarmUps(100).
WithDuration(*flagBenchDuration).
Done()
}

// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2.
//
// Runtimes for cpu:
//
// Benchmarks: Median 5%-tile 99%-tile
// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs
// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs
// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs
// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs
func TestBenchMeanNormalizedExecution(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// Prepare input data, the uploaded buffers and the executables.
const repeats = 10
numShapes := len(testShapes)
execs := make([]*LoadedExecutable, numShapes)
inputData := make([][]float32, numShapes)
buffers := make([]*Buffer, numShapes)
testFns := make([]benchmarks.NamedFunction, numShapes)
for shapeIdx, s := range testShapes {
inputData[shapeIdx] = make([]float32, s.Size())
for i := 0; i < s.Size(); i++ {
inputData[shapeIdx][i] = float32(i)
}
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...))

builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s))
// f(x) = x + 1
x := must1(xlabuilder.Parameter(builder, "x", 0, s))
one := must1(xlabuilder.ScalarOne(builder, s.DType))
add1 := must1(xlabuilder.Add(x, one))
half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5))))
div2 := must1(xlabuilder.Mul(add1, half))
mean := must1(xlabuilder.ReduceSum(div2))
normalized := must1(xlabuilder.Sub(div2, mean))

comp := must1(builder.Build(normalized))
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done())
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s)
testFns[shapeIdx].Func = func() {
for _ = range repeats {
buf := buffers[shapeIdx]
exec := execs[shapeIdx]
output := must1(exec.Execute(buf).Done())[0]
_ = output.Destroy()
}
}
}
defer func() {
// Clean up -- and don't wait for the GC.
for shapeIdx := range numShapes {
must(buffers[shapeIdx].Destroy())
must(execs[shapeIdx].Destroy())
}
}()

benchmarks.New(testFns...).
WithInnerRepeats(repeats).
WithWarmUps(100).
WithDuration(*flagBenchDuration).
Done()
}
370 changes: 64 additions & 306 deletions pjrt/buffers.go

Large diffs are not rendered by default.

236 changes: 236 additions & 0 deletions pjrt/buffers_from_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package pjrt

/*
#include "pjrt_c_api.h"
#include "gen_api_calls.h"
#include "gen_new_struct.h"
PJRT_Error* BufferFromHostAndWait(const PJRT_Api *api, PJRT_Client_BufferFromHostBuffer_Args *args) {
PJRT_Error* err = api->PJRT_Client_BufferFromHostBuffer(args);
if (err) {
return err;
}
PJRT_Event_Await_Args event_args = {0};
event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
event_args.event = args->done_with_host_buffer;
err = api->PJRT_Event_Await(&event_args);
PJRT_Event_Destroy_Args efree_args;
efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
efree_args.event = args->done_with_host_buffer;
api->PJRT_Event_Destroy(&efree_args);
return err;
}
PJRT_Error *dummy_error;
PJRT_Error *Dummy(void *api) {
if (api == NULL) {
return NULL;
}
return dummy_error;
}
*/
import "C"
import (
"github.com/gomlx/gopjrt/dtypes"
"github.com/pkg/errors"
"reflect"
"runtime"
"slices"
"unsafe"
)

// BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is
// created with Client.BufferFromHost.
//
// The data to transfer from host can be set up with one of the following methods:
//
// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions).
// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions).
//
// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum.
//
// At the end call BufferFromHostConfig.Done to actually initiate the transfer.
//
// TODO: Implement async transfers, arbitrary memory layout, etc.
type BufferFromHostConfig struct {
client *Client
data []byte
dtype dtypes.DType
dimensions []int
device *Device

hostBufferSemantics PJRT_HostBufferSemantics

// err stores the first error that happened during configuration.
// If it is not nil, it is immediately returned by the Done call.
err error
}

// FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant)
// during the call. The parameters dtype and dimensions provide the shape of the array.
func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig {
if b.err != nil {
return b
}
b.data = data
b.dtype = dtype
b.dimensions = dimensions
return b
}

// ToDevice configures which device to copy the host data to.
//
// If left un-configured, it will pick the first device returned by Client.AddressableDevices.
//
// You can also provide a device by their index in Client.AddressableDevices.
func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig {
if b.err != nil {
return b
}
if device == nil {
b.err = errors.New("BufferFromHost().ToDevice() given a nil device")
return b
}
addressable, err := device.IsAddressable()
if err != nil {
b.err = errors.WithMessagef(err, "BufferFromHost().ToDevice() failed to check whether device is addressable")
return b
}
if !addressable {
b.err = errors.New("BufferFromHost().ToDevice() given a non addressable device")
return b
}
b.device = device
return b
}

// ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the
// list returned by Client.AddressableDevices.
//
// If left un-configured, it will pick the first device returned by Client.AddressableDevices.
//
// You can also provide a device by their index in Client.AddressableDevices.
func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig {
if b.err != nil {
return b
}
if deviceNum < 0 || deviceNum >= len(b.client.addressableDevices) {
b.err = errors.Errorf("BufferFromHost().ToDeviceNum() invalid deviceNum=%d, only %d addressable devices available", deviceNum, len(b.client.addressableDevices))
return b
}
return b.ToDevice(b.client.addressableDevices[deviceNum])
}

// FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying
// dimensions.
// The flat slice size must match the product of the dimension.
// If no dimensions are given, it is assumed to be a scalar, and flat should have length 1.
func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig {
if b.err != nil {
return b
}
// Checks dimensions.
expectedSize := 1
for _, dim := range dimensions {
if dim <= 0 {
b.err = errors.Errorf("FromFlatDataWithDimensions cannot be given zero or negative dimensions, got %v", dimensions)
return b
}
expectedSize *= dim
}

// Check the flat slice has the right shape.
flatV := reflect.ValueOf(flat)
if flatV.Kind() != reflect.Slice {
b.err = errors.Errorf("FromFlatDataWithDimensions was given a %s for flat, but it requires a slice", flatV.Kind())
return b
}
if flatV.Len() != expectedSize {
b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions=%v) needs %d values to match dimensions, but got len(flat)=%d", dimensions, expectedSize, flatV.Len())
return b
}

// Check validity of the slice elements type.
element0 := flatV.Index(0)
element0Type := element0.Type()
dtype := dtypes.FromGoType(element0Type)
if dtype == dtypes.InvalidDType {
b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions%v) got flat=[]%s, expected a slice of a Go tyep that can be converted to a valid DType", dimensions, element0Type)
return b
}

// Create slice of bytes and use b.FromRawData.
sizeBytes := uintptr(flatV.Len()) * element0Type.Size()
data := unsafe.Slice((*byte)(element0.Addr().UnsafePointer()), sizeBytes)
return b.FromRawData(data, dtype, dimensions)
}

// Done will use the configuration to start the transfer from host to device.
// It's synchronous: it awaits the transfer to finish and then returns.
func (b *BufferFromHostConfig) Done() (*Buffer, error) {
if b.err != nil {
// Return first error saved during configuration.
return nil, b.err
}
if len(b.data) == 0 {
return nil, errors.New("BufferFromHost requires one to configure the host data to transfer, none was configured.")
}
defer runtime.KeepAlive(b)

// Makes sure program data is not moved around by the GC during the C/C++ call.
var pinner runtime.Pinner
defer pinner.Unpin()
dataPtr := unsafe.SliceData(b.data)
pinner.Pin(dataPtr)

// Set default device.
if b.device == nil {
devices := b.client.AddressableDevices()
if len(devices) == 0 {
return nil, errors.New("BufferFromHost can't find addressable device to transfer to")
}
b.device = devices[0]
}

// Arena for memory allocations used by CGO.
arena := getArenaFromPool()
defer returnArenaToPool(arena)

// Arguments to PJRT call.
var args *C.PJRT_Client_BufferFromHostBuffer_Args
args = arenaAlloc[C.PJRT_Client_BufferFromHostBuffer_Args](arena)
args.struct_size = C.PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE
args.client = b.client.client
args.data = unsafe.Pointer(dataPtr)
args._type = C.PJRT_Buffer_Type(b.dtype)
args.num_dims = C.size_t(len(b.dimensions))
if len(b.dimensions) > 0 {
dims := arenaAllocSlice[C.int64_t](arena, len(b.dimensions))
for ii, dim := range b.dimensions {
dims[ii] = C.int64_t(dim)
}
args.dims = unsafe.SliceData(dims)
}
args.host_buffer_semantics = C.PJRT_HostBufferSemantics(b.hostBufferSemantics)
args.device = b.device.cDevice
err := toError(b.client.plugin, C.BufferFromHostAndWait(b.client.plugin.api, args))
if err != nil {
return nil, err
}

buffer := newBuffer(b.client, args.buffer)
buffer.dims = slices.Clone(b.dimensions)
buffer.dimsSet = true
buffer.dtype = b.dtype
buffer.dtypeSet = true
return buffer, nil
}

// dummyCGO calls a minimal C function and doesn't do anything.
// Here for the purpose of benchmarking CGO calls.
func dummyCGO(pointer unsafe.Pointer) {
_ = C.Dummy(pointer)
}
201 changes: 201 additions & 0 deletions pjrt/buffers_shared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package pjrt

/*
#include "pjrt_c_api.h"
#include "gen_api_calls.h"
#include "gen_new_struct.h"
extern void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg);
// OnDeleteSharedBuffer is a no-op.
void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg) {
return;
}
extern void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg);
void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg) = &OnDeleteSharedBuffer;
*/
import "C"
import (
"github.com/gomlx/gopjrt/dtypes"
"github.com/pkg/errors"
"reflect"
"slices"
"unsafe"
)

// CreateViewOfDeviceBuffer creates a PJRT Buffer that is backed by storage on the same device given by the caller as flatData and shape.
// Consider using the simpler API NewSharedBuffer.
//
// Different PJRT may have different requirements on alignment, but for the CPU PJRT the library provide
// AlignedAlloc and AlignedFree, that can be used to allocate the aligned storage space.
//
// Example of how a typical usage where the same buffer is reused as input in loop:
//
// dtype := dtypes.Float32
// dimensions := []int{batchSize, sequenceLength, 384}
// rawData := pjrt.AlignedAlloc(dtype.SizeForDimensions(dimensions...), pjrt.BufferAlignment)
// defer pjrt.AlignedFree(rawData)
// buf := client.CreateViewOfDeviceBuffer(rawData, dtype, dimensions)
// flat := unsafe.Slice((*float32)(storage), batchSize*sequenceLength*384)
// for _, batch := range batches {
// // ... set flat values
// // ... use buf as input when executing a PJRT program
// }
//
// If device is not given (at most one can be given), the first device available for the client is used.
//
// The naming comes from PJRT and is unfortunate, since it's the name from PJRT's perspective (PJRT view of a
// users device buffer).
// Probably, it should have been better named by "ShareDeviceBuffer" or something similar.
//
// This may not be implemented for all hardware (or all PJRT plugins).
//
// This can be useful to avoid the copy of values, by mutating directly in the memory shared with PJRT, to be used
// as input to a computation.
//
// See: dtypes.SizeForDimensions() to calculate the size for an arbitrary shape; AlignedAlloc, AlignedFree and
// BufferAlignment (a constant with the required alignment size) to allocate and free aligned storage.
func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, device ...*Device) (*Buffer, error) {
var selectedDevice *Device
if len(device) > 1 {
return nil, errors.Errorf("only one device can be given to CreateViewOfDeviceBuffer, %d were given", len(device))
} else if len(device) == 1 {
selectedDevice = device[0]
} else {
devices := c.AddressableDevices()
if len(devices) == 0 {
return nil, errors.New("CreateViewOfDeviceBuffer can't find addressable device to transfer to")
}
selectedDevice = devices[0]
}

// Arena for memory allocations used by CGO.
arena := getArenaFromPool()
defer returnArenaToPool(arena)

// Arguments to PJRT call.
var args *C.PJRT_Client_CreateViewOfDeviceBuffer_Args
args = arenaAlloc[C.PJRT_Client_CreateViewOfDeviceBuffer_Args](arena)
args.struct_size = C.PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE
args.client = c.client
args.device_buffer_ptr = rawData
args.element_type = C.PJRT_Buffer_Type(dtype)
args.num_dims = C.size_t(len(dimensions))
if args.num_dims > 0 {
dims := arenaAllocSlice[C.int64_t](arena, int(args.num_dims))
for ii, dim := range dimensions {
dims[ii] = C.int64_t(dim)
}
args.dims = unsafe.SliceData(dims)
}
args.device = selectedDevice.cDevice
args.on_delete_callback = (*[0]byte)(unsafe.Pointer(C.OnDeleteSharedBufferPtr))
args.on_delete_callback_arg = nil
err := toError(c.plugin, C.call_PJRT_Client_CreateViewOfDeviceBuffer(c.plugin.api, args))
if err != nil {
return nil, err
}
buffer := newBuffer(c, args.buffer)
buffer.isShared = true
buffer.dims = slices.Clone(dimensions)
buffer.dimsSet = true
buffer.dtype = dtype
buffer.dtypeSet = true
return buffer, nil
}

// NewSharedBuffer returns a buffer that can be used for execution and share the underlying
// memory space with the host/local, which can be read and mutated directly.
//
// Shared buffers cannot be donated to executions.
//
// The buffer should not be mutated while it is used by an execution.
//
// When the buffer is finalized, the shared memory is also de-allocated.
//
// It returns a handle to the buffer and a slice of the corresponding data type pointing
// to the shared data.
func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error) {
memorySize := uintptr(dtype.SizeForDimensions(dimensions...))
rawStorage := AlignedAlloc(memorySize, BufferAlignment)
buffer, err = c.CreateViewOfDeviceBuffer(rawStorage, dtype, dimensions, device...)
if err != nil {
AlignedFree(rawStorage)
err = errors.WithMessagef(err, "NewSharedBuffer failed creating new buffer")
buffer = nil
return
}
buffer.sharedRawStorage = rawStorage
buffer.isShared = true

goDType := dtype.GoType()
flat = reflect.SliceAt(dtype.GoType(), rawStorage, int(memorySize/goDType.Size())).Interface()
return
}

// IsShared returns whether this buffer was created with Client.NewSharedBuffer.
// These buffers cannot be donated in execution.
func (b *Buffer) IsShared() bool {
return b.isShared
}

// UnsafePointer returns platform-dependent address for the given buffer that is often but
// not guaranteed to be the physical/device address.
// Consider using the more convenient DirectAccess.
//
// Probably, this should only be used by CPU plugins.
//
// To be on the safe side, only use this if Client.HasSharedBuffers is true.
// It uses the undocumented PJRT_Buffer_UnsafePointer.
func (b *Buffer) UnsafePointer() (unsafe.Pointer, error) {
plugin := b.client.plugin

// Arena for memory allocations used by CGO.
arena := getArenaFromPool()
defer returnArenaToPool(arena)

// Arguments to PJRT call.
var args *C.PJRT_Buffer_UnsafePointer_Args
args = arenaAlloc[C.PJRT_Buffer_UnsafePointer_Args](arena)
args.struct_size = C.PJRT_Buffer_UnsafePointer_Args_STRUCT_SIZE
args.buffer = b.cBuffer
err := toError(plugin, C.call_PJRT_Buffer_UnsafePointer(plugin.api, args))
if err != nil {
return nil, err
}
return unsafe.Pointer(uintptr(args.buffer_pointer)), nil
}

// Data returns the flat slice pointing to the underlying storage data for the buffer.
//
// This is an undocumented feature of PJRT and likely only works for CPU platforms.
// The flat slice returned is only valid while the buffer is alive.
func (b *Buffer) Data() (flat any, err error) {
var rawStorage unsafe.Pointer
rawStorage, err = b.UnsafePointer()
if err != nil {
return nil, err
}
dims := b.dims
dtype := b.dtype
if !b.dimsSet {
dims, err = b.Dimensions()
if err != nil {
return nil, err
}
}
if !b.dtypeSet {
dtype, err = b.DType()
if err != nil {
return nil, err
}
}

numElements := 1
for _, dim := range dims {
numElements *= dim
}
return reflect.SliceAt(dtype.GoType(), rawStorage, numElements).Interface(), nil
}
172 changes: 172 additions & 0 deletions pjrt/buffers_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package pjrt

import (
"flag"
"fmt"
"github.com/gomlx/gopjrt/dtypes"
"github.com/gomlx/gopjrt/xlabuilder"
"github.com/stretchr/testify/require"
"runtime"
"testing"
"unsafe"
)
@@ -23,6 +26,7 @@ func testTransfersImpl[T interface {
fmt.Printf("From %#v\n", input)
buffer, err := ArrayToBuffer(client, input, 3, 1)
require.NoError(t, err)
require.False(t, buffer.IsShared())

output, outputDims, err := BufferToArray[T](buffer)
require.NoError(t, err)
@@ -124,3 +128,171 @@ func TestBufferProperties(t *testing.T) {
require.Equal(t, dtypes.Uint8, dtype)
}
}

var flagForceSharedBuffer = flag.Bool(
"force_shared_buffer", false, "Force executing TestCreateViewOfDeviceBuffer and TestBufferUnsafePointer even if plugin is not \"cpu\".")

func TestCreateViewOfDeviceBuffer(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestCreateViewOfDeviceBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// f(x) = x + 1
dtype := dtypes.Float32
shape := xlabuilder.MakeShape(dtype, 2, 3)
builder := xlabuilder.New("Add1")
x := must1(xlabuilder.Parameter(builder, "x", 0, shape))
one := must1(xlabuilder.ScalarOne(builder, shape.DType))
add1 := must1(xlabuilder.Add(x, one))
comp := must1(builder.Build(add1))
exec := must1(client.Compile().WithComputation(comp).Done())

// Input is created as a "Device Buffer View"
storage := AlignedAlloc(shape.Memory(), BufferAlignment)
defer AlignedFree(storage)
inputBuffer, err := client.CreateViewOfDeviceBuffer(storage, dtype, shape.Dimensions)
require.NoError(t, err)
defer func() {
err := inputBuffer.Destroy()
if err != nil {
t.Logf("Failed Buffer.Destroy(): %+v", err)
}
}()

flatData := unsafe.Slice((*float32)(storage), shape.Size())
for ii := range flatData {
flatData[ii] = float32(ii)
}
require.True(t, inputBuffer.IsShared())

results, err := exec.Execute(inputBuffer).DonateNone().Done()
require.NoError(t, err)
require.Len(t, results, 1)

gotFlat, gotDims, err := BufferToArray[float32](results[0])
require.NoError(t, err)
require.Equal(t, shape.Dimensions, gotDims)
require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, gotFlat)

// Change the buffer directly, and see that we can reuse the buffer in PJRT, without the extra transfer.
flatData[1] = 11
results, err = exec.Execute(inputBuffer).DonateNone().Done()
require.NoError(t, err)
require.Len(t, results, 1)
gotFlat, gotDims, err = BufferToArray[float32](results[0])
require.NoError(t, err)
require.Equal(t, shape.Dimensions, gotDims)
require.Equal(t, []float32{1, 12, 3, 4, 5, 6}, gotFlat)
require.NoError(t, inputBuffer.Destroy())
}

func TestNewSharedBuffer(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// f(x) = x + 1
dtype := dtypes.Float32
shape := xlabuilder.MakeShape(dtype, 2, 3)
builder := xlabuilder.New("Add1")
x := must1(xlabuilder.Parameter(builder, "x", 0, shape))
one := must1(xlabuilder.ScalarOne(builder, shape.DType))
add1 := must1(xlabuilder.Add(x, one))
comp := must1(builder.Build(add1))
exec := must1(client.Compile().WithComputation(comp).Done())

// Input is created as a "Device Buffer View"
inputBuffer, flatAny, err := client.NewSharedBuffer(dtype, shape.Dimensions)
require.NoError(t, err)
defer func() {
err := inputBuffer.Destroy()
if err != nil {
t.Logf("Failed to destroy shared buffer: %+v", err)
}
}()

flatData := flatAny.([]float32)
for ii := range flatData {
flatData[ii] = float32(ii)
}
require.True(t, inputBuffer.IsShared())

results, err := exec.Execute(inputBuffer).DonateNone().Done()
require.NoError(t, err)
require.Len(t, results, 1)

gotFlat, gotDims, err := BufferToArray[float32](results[0])
require.NoError(t, err)
require.Equal(t, shape.Dimensions, gotDims)
require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, gotFlat)

// Change the buffer directly, and see that we can reuse the buffer in PJRT, without the extra transfer.
flatData[1] = 11
results, err = exec.Execute(inputBuffer).DonateNone().Done()
require.NoError(t, err)
require.Len(t, results, 1)
gotFlat, gotDims, err = BufferToArray[float32](results[0])
require.NoError(t, err)
require.Equal(t, shape.Dimensions, gotDims)
require.Equal(t, []float32{1, 12, 3, 4, 5, 6}, gotFlat)

require.NoError(t, inputBuffer.Destroy())
}

func TestBufferData(t *testing.T) {
if *flagPluginName != "cpu" && !*flagForceSharedBuffer {
t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " +
"Set --force_create_view to force executing the test anyway")
}

// Create plugin.
plugin := must1(GetPlugin(*flagPluginName))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)

// f(x) = x + 1
dtype := dtypes.Float32
shape := xlabuilder.MakeShape(dtype, 2, 3)
builder := xlabuilder.New("Add1")
x := must1(xlabuilder.Parameter(builder, "x", 0, shape))
one := must1(xlabuilder.ScalarOne(builder, shape.DType))
add1 := must1(xlabuilder.Add(x, one))
comp := must1(builder.Build(add1))
exec := must1(client.Compile().WithComputation(comp).Done())

// Input is created as a "Device Buffer View"
inputBuffer, flatAny, err := client.NewSharedBuffer(dtype, shape.Dimensions)
require.NoError(t, err)
defer func() {
err := inputBuffer.Destroy()
if err != nil {
t.Logf("Failed to destroy shared buffer: %+v", err)
}
}()

flatData := flatAny.([]float32)
for ii := range flatData {
flatData[ii] = float32(ii)
}
require.True(t, inputBuffer.IsShared())

results, err := exec.Execute(inputBuffer).DonateNone().Done()
require.NoError(t, err)
require.Len(t, results, 1)

flatOutput, err := results[0].Data()
require.NoError(t, err)
require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, flatOutput.([]float32))
}
83 changes: 83 additions & 0 deletions pjrt/buffers_to_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pjrt

import (
"github.com/pkg/errors"
"runtime"
"unsafe"
)

/*
#include "pjrt_c_api.h"
#include "gen_api_calls.h"
#include "gen_new_struct.h"
PJRT_Error* BufferToHost(const PJRT_Api *api, PJRT_Buffer *buffer, void *dst, int64_t dst_size, int rank) {
PJRT_Buffer_ToHostBuffer_Args args = {0};
args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE;
args.src = buffer;
args.dst = dst;
args.dst_size = dst_size;
PJRT_Buffer_MemoryLayout layout_args = {0};
layout_args.struct_size = PJRT_Buffer_MemoryLayout_STRUCT_SIZE;
args.host_layout = &layout_args;
layout_args.type = PJRT_Buffer_MemoryLayout_Type_Tiled;
layout_args.tiled.minor_to_major_size = rank;
int64_t minor_to_major[rank > 0 ? rank : 1];
if (rank > 0) {
for (int axisIdx = 0; axisIdx < rank; axisIdx++) {
minor_to_major[axisIdx] = rank - axisIdx - 1;
}
layout_args.tiled.minor_to_major = &minor_to_major[0];
}
PJRT_Error* err = api->PJRT_Buffer_ToHostBuffer(&args);
if (err) {
return err;
}
PJRT_Event_Await_Args event_args = {0};
event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
event_args.event = args.event;
err = api->PJRT_Event_Await(&event_args);
PJRT_Event_Destroy_Args efree_args;
efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
efree_args.event = args.event;
api->PJRT_Event_Destroy(&efree_args);
return err;
}
*/
import "C"

// ToHost transfers the contents of buffer stored on device to the host.
// The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned.
//
// This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to
// reorganize the layout.
func (b *Buffer) ToHost(dst []byte) error {
defer runtime.KeepAlive(b)
if b == nil || b.client.plugin == nil || b.cBuffer == nil {
// Already destroyed ?
return errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?")
}

// We'll need the buffer rank to set up the layout.
dims, err := b.Dimensions()
if err != nil {
return err
}
rank := len(dims)

dstBytes := unsafe.Pointer(unsafe.SliceData(dst))
var pinner runtime.Pinner
pinner.Pin(dstBytes)
defer pinner.Unpin()

pErr := C.BufferToHost(b.client.plugin.api, b.cBuffer, dstBytes, C.int64_t(len(dst)), C.int(rank))
err = toError(b.client.plugin, pErr)
if err != nil {
return errors.WithMessage(err, "Failed to call PJRT_Buffer_ToHostBuffer to transfer the buffer to host")
}
return nil
}
3 changes: 2 additions & 1 deletion pjrt/clients.go
Original file line number Diff line number Diff line change
@@ -115,6 +115,7 @@ type Client struct {
platform, platformVersion string
processIndex int
addressableDevices []*Device
allowBufferViews bool
}

// newClient is called by Plugin.NewClient to create a new PJRT_Client wrapper.
@@ -276,6 +277,6 @@ func (c *Client) BufferFromHost() *BufferFromHostConfig {
return &BufferFromHostConfig{
client: c,
device: nil,
hostBufferSemantics: PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
hostBufferSemantics: PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
}
}
27 changes: 27 additions & 0 deletions pjrt/compile.go
Original file line number Diff line number Diff line change
@@ -3,14 +3,20 @@ package pjrt
import (
"github.com/gomlx/gopjrt/cbuffer"
"github.com/gomlx/gopjrt/protos/compile_options"
"github.com/gomlx/gopjrt/protos/xla"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
"k8s.io/klog/v2"
"os"
"runtime"
"unsafe"
)

// EnvXlaDebugOptions is an environment variable that can be defined to set XLA DebugOptions proto when compiling
// a program.
const EnvXlaDebugOptions = "XLA_DEBUG_OPTIONS"

// CompileConfig is created with Client.Compile, and is a "builder pattern" to configure a compilation call.
//
// At a minimum one has to set the program to compile (use CompileConfig.WithHLO or CompileConfig.WithComputation).
@@ -65,6 +71,27 @@ func newCompileConfig(client *Client) (cc *CompileConfig) {
NumReplicas: 1,
NumPartitions: 1,
}

debugOptionsStr := os.Getenv(EnvXlaDebugOptions)
if debugOptionsStr != "" {
debugOptions := &xla.DebugOptions{}
err := prototext.Unmarshal([]byte(debugOptionsStr), debugOptions)
if err != nil {
cc.err = errors.Wrapf(err, "Failed to parse xla.DebugOptions protobuf from $%s=%q", EnvXlaDebugOptions, debugOptionsStr)
return cc
}
// Print configuration.
textBytes, err := prototext.MarshalOptions{Multiline: true}.Marshal(debugOptions)
if err != nil {
klog.Infof("Failed to convert xla.DebugOptions proto to text: %v", err)
} else {
klog.Infof("This adds 10ms!! of time to the execution of the compiled graph!")
klog.Infof("%s=%s -> %s", EnvXlaDebugOptions, debugOptionsStr, textBytes)
}

// Set parsed configuration.
cc.options.ExecutableBuildOptions.DebugOptions = debugOptions
}
return cc
}

8 changes: 3 additions & 5 deletions pjrt/gen_chelper.go
Original file line number Diff line number Diff line change
@@ -33,17 +33,15 @@ func cSizeOf[T any]() C.size_t {
// It must be manually freed with cFree() by the user.
func cMalloc[T any]() (ptr *T) {
size := cSizeOf[T]()
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
cPtr := (*T)(C.calloc(1, size))
return cPtr
}

// cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero.
// It must be manually freed with C.free() by the user.
func cMallocArray[T any](n int) (ptr *T) {
size := cSizeOf[T]() * C.size_t(n)
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
size := cSizeOf[T]()
cPtr := (*T)(C.calloc(C.size_t(n), size))
return cPtr
}

156 changes: 105 additions & 51 deletions pjrt/loadedexecutables.go
Original file line number Diff line number Diff line change
@@ -4,13 +4,41 @@ package pjrt
#include "pjrt_c_api.h"
#include "gen_api_calls.h"
#include "gen_new_struct.h"
PJRT_Error* ExecuteAndWait(const PJRT_Api *api, PJRT_LoadedExecutable_Execute_Args* args) {
PJRT_Error *err = api->PJRT_LoadedExecutable_Execute(args);
if (err) {
return err;
}
if (args->device_complete_events) {
// Wait for devices to complete executions.
for (int ii = 0; ii < args->num_devices; ii++) {
PJRT_Event_Await_Args event_args = {0};
event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
event_args.event = args->device_complete_events[ii];
err = api->PJRT_Event_Await(&event_args);
PJRT_Event_Destroy_Args efree_args;
efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
efree_args.event = args->device_complete_events[ii];
api->PJRT_Event_Destroy(&efree_args);
if (err) {
return err;
}
}
}
return NULL;
}
*/
import "C"
import (
"github.com/pkg/errors"
"k8s.io/klog/v2"
"runtime"
"slices"
"sync/atomic"
"unsafe"
)

@@ -32,13 +60,21 @@ type LoadedExecutable struct {
NumOutputs int
}

var numLoadedExecutables atomic.Int64

// LoadedExecutablesAlive returns a count of the numbers of LoadedExecutables currently in memory and tracked by gopjrt.
func LoadedExecutablesAlive() int64 {
return numLoadedExecutables.Load()
}

// newLoadedExecutable creates LoadedExecutable and registers it for freeing.
func newLoadedExecutable(plugin *Plugin, client *Client, cLoadedExecutable *C.PJRT_LoadedExecutable) (*LoadedExecutable, error) {
e := &LoadedExecutable{
plugin: plugin,
client: client,
cLoadedExecutable: cLoadedExecutable,
}
numLoadedExecutables.Add(1)
runtime.SetFinalizer(e, func(e *LoadedExecutable) { e.destroyOrLog() })

// Gather information about executable:
@@ -78,6 +114,8 @@ func (e *LoadedExecutable) Destroy() error {
err := toError(e.plugin, C.call_PJRT_LoadedExecutable_Destroy(e.plugin.api, args))
e.plugin = nil
e.cLoadedExecutable = nil

numLoadedExecutables.Add(-1)
return err
}

@@ -277,82 +315,98 @@ func (c *ExecutionConfig) Done() ([]*Buffer, error) {
c.devices = []*Device{devices[0]}
}

// Dimensions of inputs/outputs.
numInputs := len(c.inputs)
numOutputs := e.NumOutputs

// Allocations that will be used by CGO.
// Except if the number of inputs/outputs is very large, used the default arena size.
var arena *arenaContainer
minSize := (numInputs+numOutputs)*3*8 /*pointer size*/ + 1024
if minSize > arenaDefaultSize {
arena = newArena(arenaDefaultSize + minSize)
defer arena.Free()
} else {
arena = getArenaFromPool()
defer returnArenaToPool(arena)
}

// Create arguments structures for call to Execute.
args := C.new_PJRT_LoadedExecutable_Execute_Args()
defer cFree(args)
var args *C.PJRT_LoadedExecutable_Execute_Args
args = arenaAlloc[C.PJRT_LoadedExecutable_Execute_Args](arena)
args.struct_size = C.PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE
args.executable = e.cLoadedExecutable
options := C.new_PJRT_ExecuteOptions() // Like more args that for some reason(?) go on a separate struct.
defer cFree(options)

var options *C.PJRT_ExecuteOptions
options = arenaAlloc[C.PJRT_ExecuteOptions](arena) // Extra args that for some reason(?) go on a separate struct.
options.struct_size = C.PJRT_ExecuteOptions_STRUCT_SIZE
args.options = options

// Configure (non-)donatable inputs.
if len(c.nonDonatableInputs) > 0 {
options.num_non_donatable_input_indices = C.size_t(len(c.nonDonatableInputs))
options.non_donatable_input_indices = cMallocArrayAndSet[C.int64_t](len(c.nonDonatableInputs), func(ii int) C.int64_t {
return C.int64_t(c.nonDonatableInputs[ii])
})
defer cFree(options.non_donatable_input_indices)
nonDonatableIndices := arenaAllocSlice[C.int64_t](arena, len(c.nonDonatableInputs))
for ii := range nonDonatableIndices {
nonDonatableIndices[ii] = C.int64_t(c.nonDonatableInputs[ii])
}
options.non_donatable_input_indices = &nonDonatableIndices[0]
}

numDevices := 1
args.num_devices = C.size_t(numDevices)
args.execute_device = c.devices[0].cDevice
args.num_args = C.size_t(len(c.inputs))
args.argument_lists = allocatePerDeviceBufferList(numDevices, c.inputs)
defer freePerDeviceBufferList(args.argument_lists, numDevices)
args.output_lists = allocatePerDeviceBufferList(numDevices, make([]*Buffer, e.NumOutputs))
defer freePerDeviceBufferList(args.output_lists, numDevices)

args.num_args = C.size_t(numInputs)
if args.num_args > 0 {
args.argument_lists = allocatePerDeviceBufferListWithArena(arena, numDevices, numInputs, c.inputs)
}

// For some reason the line below doesn't work. I think something is wrong with PJRT ... but I'm not sure.
if numOutputs > 0 {
args.output_lists = allocatePerDeviceBufferListWithArena(arena, numDevices, numOutputs, nil)
}

// Create events to wait for the end of execution: leaving this as NULL is allowed, but what happens then
// (does it wait or not, and then what?) is not documented in PJRT.
perDeviceEvents := arenaAllocSlice[*C.PJRT_Event](arena, numDevices)
args.device_complete_events = (**C.PJRT_Event)(unsafe.SliceData(perDeviceEvents))
//args.device_complete_events = cMallocArray[*C.PJRT_Event](numDevices)
//defer cFree(args.device_complete_events)

err := toError(e.plugin, C.call_PJRT_LoadedExecutable_Execute(e.plugin.api, args))
err := toError(e.plugin, C.ExecuteAndWait(e.plugin.api, args))
if err != nil {
return nil, err
}

perDevice := gatherPerDeviceBufferList(e.client, args.output_lists, numDevices, e.NumOutputs)
return perDevice[0], nil
// We only support one device for now, so we return the results from the first device.
outputs := make([]*Buffer, numOutputs)
outputBuffers := unsafe.Slice(*args.output_lists, numOutputs)
for ii := range outputs {
outputs[ii] = newBuffer(e.client, outputBuffers[ii])
}
return outputs, nil
}

// Allocate [numDevices][numBuffers]*Buffer C 2D-array to be used by PJRT C API, with the given Buffer pointers.
func allocatePerDeviceBufferList(numDevices int, buffers []*Buffer) ***C.PJRT_Buffer {
func allocatePerDeviceBufferListWithArena(arena *arenaContainer, numDevices int, numBuffers int, buffers []*Buffer) ***C.PJRT_Buffer {
// Top level:
perDevice := make([]**C.PJRT_Buffer, numDevices)
perDevice := arenaAllocSlice[**C.PJRT_Buffer](arena, numDevices)
for deviceIdx := range perDevice {
perDevice[deviceIdx] = cMallocArrayAndSet[*C.PJRT_Buffer](len(buffers), func(idxBuffer int) *C.PJRT_Buffer {
if buffers[idxBuffer] == nil {
// No buffer given for structure.
return nil
}
if buffers[idxBuffer].cBuffer == nil {
// Buffer given, but it's cBuffer is nil -> probably it has already been destroyed.
panicf("buffers[%d].cBuffer is nil, has it already been destroyed!?", idxBuffer)
deviceBuffers := arenaAllocSlice[*C.PJRT_Buffer](arena, numBuffers)
perDevice[deviceIdx] = &deviceBuffers[0]
if buffers != nil {
for bufferIdx := range deviceBuffers {
if buffers[bufferIdx] == nil {
deviceBuffers[bufferIdx] = nil
continue
}
if buffers[bufferIdx].cBuffer == nil {
// Buffer given, but it's cBuffer is nil -> probably it has already been destroyed.
panicf("buffers[%d].cBuffer is nil, has it already been destroyed!?", bufferIdx)
}
deviceBuffers[bufferIdx] = buffers[bufferIdx].cBuffer
}
return buffers[idxBuffer].cBuffer
})
}
return cMallocArrayFromSlice(perDevice)
}

// freePerDeviceBufferList frees the intermediary array pointers, but it doesn't touch the buffers themselves.
func freePerDeviceBufferList(data ***C.PJRT_Buffer, numDevices int) {
perDevice := cDataToSlice[**C.PJRT_Buffer](unsafe.Pointer(data), numDevices)
for _, list := range perDevice {
cFree(list)
}
cFree(data)
}

// gatherPerDeviceBufferList returns a [numDevices][numBuffers]*Buffer given the C 2D array.
func gatherPerDeviceBufferList(client *Client, data ***C.PJRT_Buffer, numDevices, numBuffers int) [][]*Buffer {
perDevice := make([][]*Buffer, numDevices)
cPerDevice := cDataToSlice[**C.PJRT_Buffer](unsafe.Pointer(data), numDevices)
for ii, cBufferListPtr := range cPerDevice {
perDevice[ii] = make([]*Buffer, numBuffers)
cBuffers := cDataToSlice[*C.PJRT_Buffer](unsafe.Pointer(cBufferListPtr), numBuffers)
for jj, cBuffer := range cBuffers {
perDevice[ii][jj] = newBuffer(client, cBuffer)
}
}
return perDevice
return &perDevice[0]
}
17 changes: 8 additions & 9 deletions pjrt/minimal_test.go
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@ import (
"flag"
"fmt"
"github.com/gomlx/gopjrt/protos/hlo"
"github.com/janpfeifer/must"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
@@ -52,32 +51,32 @@ func TestMinimal(t *testing.T) {
var hloModule hlo.HloModuleProto
if *flagLoadHLO != "" {
fmt.Printf("Loading HLO program from %s...\n", *flagLoadHLO)
hloSerialized = must.M1(os.ReadFile(*flagLoadHLO))
must.M(proto.Unmarshal(hloSerialized, &hloModule))
hloSerialized = must1(os.ReadFile(*flagLoadHLO))
must(proto.Unmarshal(hloSerialized, &hloModule))
} else {
// Serialize HLO program from hloText:
must.M(prototext.Unmarshal([]byte(hloText), &hloModule))
hloSerialized = must.M1(proto.Marshal(&hloModule))
must(prototext.Unmarshal([]byte(hloText), &hloModule))
hloSerialized = must1(proto.Marshal(&hloModule))
}
fmt.Printf("HLO Program:\n%s\n\n", hloModule.String())

// `dlopen` PJRT plugin.
plugin := must.M1(GetPlugin(*flagPluginName))
plugin := must1(GetPlugin(*flagPluginName))
defer runtime.KeepAlive(plugin)
fmt.Printf("PJRT: %s\n", plugin.String())

// Create client.
client := must.M1(plugin.NewClient(nil))
client := must1(plugin.NewClient(nil))
defer runtime.KeepAlive(client)
devices := client.AddressableDevices()
for ii, dev := range devices {
desc := must.M1(dev.GetDescription())
desc := must1(dev.GetDescription())
fmt.Printf("\tDevice #%d: %s\n", ii, desc.DebugString())
}

// Compile.
defer runtime.KeepAlive(hloSerialized)
loadedExec := must.M1(client.Compile().WithHLO(hloSerialized).Done())
loadedExec := must1(client.Compile().WithHLO(hloSerialized).Done())
defer runtime.KeepAlive(loadedExec)
fmt.Printf("\t- program compiled successfully.\n")

8 changes: 4 additions & 4 deletions pjrt/pjrt_c_api.h
Original file line number Diff line number Diff line change
@@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 57
#define PJRT_API_MINOR 58

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
@@ -214,9 +214,9 @@ struct PJRT_Plugin_Attributes_Args {
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes);

// Returns an array of plugin attributes which are key-value pairs. One example
// attribute is the minimum supported StableHLO version.
// TODO(b/280349977): standardize the list of attributes.
// Returns an array of plugin attributes which are key-value pairs. Common keys
// include `xla_version`, `stablehlo_current_version`, and
// `stablehlo_minimum_version`.
typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args);

// ---------------------------------- Events -----------------------------------
12 changes: 12 additions & 0 deletions pjrt/pjrt_test.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import (
"flag"
"fmt"
"github.com/gomlx/gopjrt/dtypes"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"k8s.io/klog/v2"
"testing"
@@ -32,6 +33,17 @@ func (e errTester[T]) Test(t *testing.T) T {
return e.value
}

func must(err error) {
if err != nil {
panicf("Failed: %+v", errors.WithStack(err))
}
}

func must1[T any](t T, err error) T {
must(err)
return t
}

// getPJRTClient loads a PJRT plugin and create a client to run tests on.
// It exits the test if anything goes wrong.
func getPJRTClient(t *testing.T) *Client {
8 changes: 3 additions & 5 deletions xlabuilder/gen_chelper.go
Original file line number Diff line number Diff line change
@@ -33,17 +33,15 @@ func cSizeOf[T any]() C.size_t {
// It must be manually freed with cFree() by the user.
func cMalloc[T any]() (ptr *T) {
size := cSizeOf[T]()
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
cPtr := (*T)(C.calloc(1, size))
return cPtr
}

// cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero.
// It must be manually freed with C.free() by the user.
func cMallocArray[T any](n int) (ptr *T) {
size := cSizeOf[T]() * C.size_t(n)
cPtr := (*T)(C.malloc(size))
C.memset(unsafe.Pointer(cPtr), 0, size)
size := cSizeOf[T]()
cPtr := (*T)(C.calloc(C.size_t(n), size))
return cPtr
}

0 comments on commit 3e4e41d

Please sign in to comment.