-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from gomlx/io
v0.5.0: Added direct access to PJRT buffers when PJRT running in CPU; Benchmarks.
Showing
29 changed files
with
1,975 additions
and
521 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package pjrt | ||
|
||
// This file defines an alignedAlloc and alignedFree, modelled after mm_malloc. | ||
|
||
/* | ||
#include <stdlib.h> | ||
*/ | ||
import "C" | ||
import ( | ||
"fmt" | ||
"unsafe" | ||
) | ||
|
||
// BufferAlignment is the default alignment required for memory shared with CPU PJRT. | ||
// See AlignedAlloc and FreeAlloc. | ||
const BufferAlignment = 64 | ||
|
||
// AlignedAlloc assumes that malloc/calloc already aligns to 8 bytes. And that alignment is a multiple of 8. | ||
// The pointer returned must be freed with AlignedFree. | ||
// | ||
// The allocation is filled with 0s. | ||
func AlignedAlloc(size, alignment uintptr) unsafe.Pointer { | ||
if alignment < 8 || alignment%8 != 0 { | ||
panic(fmt.Sprintf("alignedAlloc: alignment must be a multiple of 8, got %d", alignment)) | ||
} | ||
|
||
// It uses a strategy of allocating extra to allow the alignment, and it stores the pointer to the | ||
// original allocation just before the alignedPtr. | ||
totalSize := size + alignment | ||
ptr := unsafe.Pointer(C.calloc(C.size_t(totalSize), C.size_t(1))) | ||
if ptr == nil { | ||
return nil | ||
} | ||
|
||
alignedPtr := ptr | ||
offset := uintptr(ptr) % alignment | ||
if offset != 0 { | ||
alignedPtr = unsafe.Pointer(uintptr(ptr) + (alignment - offset)) | ||
} else { | ||
alignedPtr = unsafe.Pointer(uintptr(ptr) + alignment) // This way we have the space to save the original ptr. | ||
} | ||
|
||
originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(alignedPtr) - unsafe.Sizeof(uintptr(0)))) | ||
*originalPtrPtr = uintptr(ptr) | ||
|
||
return alignedPtr | ||
} | ||
|
||
// AlignedFree frees an allocation created with AlignedAlloc. | ||
func AlignedFree(ptr unsafe.Pointer) { | ||
originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(ptr) - unsafe.Sizeof(uintptr(0)))) | ||
originalPtr := unsafe.Pointer(*originalPtrPtr) | ||
C.free(originalPtr) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package pjrt | ||
|
||
import ( | ||
"math/rand/v2" | ||
"testing" | ||
"unsafe" | ||
) | ||
|
||
func TestAlignedAlloc(t *testing.T) { | ||
rng := rand.New(rand.NewPCG(42, 42)) | ||
numLivePointers := 1_000 | ||
maxAllocSize := 1_000 | ||
pointers := make([]unsafe.Pointer, numLivePointers) | ||
for _ = range 1_000_000 { | ||
idx := rng.IntN(numLivePointers) | ||
if pointers[idx] != nil { | ||
AlignedFree(pointers[idx]) | ||
} | ||
size := uintptr(rng.IntN(maxAllocSize)) | ||
pointers[idx] = AlignedAlloc(size, BufferAlignment) | ||
} | ||
for _, ptr := range pointers { | ||
AlignedFree(ptr) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package pjrt | ||
|
||
/* | ||
#include <string.h> | ||
*/ | ||
import "C" | ||
import ( | ||
"fmt" | ||
"reflect" | ||
"sync" | ||
"unsafe" | ||
) | ||
|
||
// arenaContainer implements a trivial arena object to accelerate allocations that will be used in CGO calls. | ||
// | ||
// The issue it is trying to solve is that individual CGO calls are slow, including C.malloc(). | ||
// | ||
// It pre-allocates the given size in bytes in C -- so it does not needs to be pinned when using CGO, and allow | ||
// for fast sub-allocations. | ||
// It can only be freed all at once. | ||
// | ||
// If you don't call Free at the end, it will leak the C allocated space. | ||
// | ||
// See newArena and arenaAlloc, and also arenaPool. | ||
type arenaContainer struct { | ||
buf []byte | ||
size, current int | ||
} | ||
|
||
const arenaDefaultSize = 2048 | ||
|
||
var arenaPool sync.Pool = sync.Pool{ | ||
New: func() interface{} { | ||
return newArena(arenaDefaultSize) | ||
}, | ||
} | ||
|
||
// getArenaFromPool gets an arena of the default size. | ||
// Must be matched with a call returnArenaToPool when it's no longer used. | ||
func getArenaFromPool() *arenaContainer { | ||
return arenaPool.Get().(*arenaContainer) | ||
} | ||
|
||
// returnArenaToPool returns an arena acquired with getArenaFromPool. | ||
// It also resets the arena. | ||
func returnArenaToPool(a *arenaContainer) { | ||
a.Reset() | ||
arenaPool.Put(a) | ||
} | ||
|
||
// newArena creates a new Arena with the given fixed size. | ||
// | ||
// It provides fast sub-allocations, which can only be freed all at once. | ||
// | ||
// TODO: support memory-alignment. | ||
// | ||
// See arenaAlloc, arena.Free and arena.Reset. | ||
func newArena(size int) *arenaContainer { | ||
buf := cMallocArray[byte](size) | ||
a := &arenaContainer{ | ||
buf: unsafe.Slice(buf, size), | ||
size: size, | ||
} | ||
return a | ||
} | ||
|
||
const arenaAlignBytes = 8 | ||
|
||
// arenaAlloc allocates a type T from the arena. It panics if the arena run out of memory. | ||
func arenaAlloc[T any](a *arenaContainer) (ptr *T) { | ||
allocSize := cSizeOf[T]() | ||
if a.current+int(allocSize) > a.size { | ||
panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for %q", allocSize, reflect.TypeOf(ptr).Elem())) | ||
} | ||
ptr = (*T)(unsafe.Pointer(&a.buf[a.current])) | ||
a.current += int(allocSize) | ||
a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1) | ||
return | ||
} | ||
|
||
// arenaAllocSlice allocates an array of n elements of type T from the arena. | ||
// | ||
// It panics if the arena run out of memory. | ||
func arenaAllocSlice[T any](a *arenaContainer, n int) (slice []T) { | ||
allocSize := C.size_t(n) * cSizeOf[T]() | ||
if a.current+int(allocSize) > a.size { | ||
panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for [%d]%s", allocSize, n, reflect.TypeOf(slice).Elem())) | ||
} | ||
ptr := (*T)(unsafe.Pointer(&a.buf[a.current])) | ||
a.current += int(allocSize) | ||
a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1) | ||
slice = unsafe.Slice(ptr, n) | ||
return | ||
} | ||
|
||
// Free invalidates all previous allocations of the arena and frees the C allocated area. | ||
func (a *arenaContainer) Free() { | ||
cFree(&a.buf[0]) | ||
a.buf = nil | ||
a.size = 0 | ||
a.current = 0 | ||
} | ||
|
||
// Reset invalidates all previous allocations with the arena, but does not free the C allocated area. | ||
// This way the arena can be re-used. | ||
func (a *arenaContainer) Reset() { | ||
// Zero the values used. | ||
if a.buf == nil || a.size == 0 { | ||
a.current = 0 | ||
return | ||
} | ||
if a.current > 0 { | ||
clearSize := min(a.size, a.current) | ||
C.memset(unsafe.Pointer(&a.buf[0]), 0, C.size_t(clearSize)) | ||
} | ||
a.current = 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package pjrt | ||
|
||
import ( | ||
"github.com/stretchr/testify/require" | ||
"testing" | ||
) | ||
|
||
func TestArena(t *testing.T) { | ||
arena := newArena(1024) | ||
for _ = range 2 { | ||
require.Equal(t, 1024, arena.size) | ||
require.Equal(t, 0, arena.current) | ||
_ = arenaAlloc[int](arena) | ||
require.Equal(t, 8, arena.current) | ||
_ = arenaAlloc[int32](arena) | ||
require.Equal(t, 16, arena.current) | ||
|
||
_ = arenaAllocSlice[byte](arena, 9) // Aligning, it will occupy 16 bytes total. | ||
require.Equal(t, 32, arena.current) | ||
|
||
require.Panics(t, func() { _ = arenaAlloc[[512]int](arena) }, "Arena out of memory") | ||
require.Panics(t, func() { _ = arenaAllocSlice[float64](arena, 512) }, "Arena out of memory") | ||
arena.Reset() | ||
} | ||
arena.Free() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,394 @@ | ||
package pjrt | ||
|
||
// To run benchmarks: (and fix to P-Cores if running on a Intel i9-12900K) | ||
// go test -c . && taskset 0xFF ./pjrt.test -test.v -test.run=Bench -test.count=1 | ||
// | ||
// See results in https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=1369069161#gid=1369069161 | ||
import ( | ||
"flag" | ||
"fmt" | ||
"github.com/gomlx/gopjrt/dtypes" | ||
"github.com/gomlx/gopjrt/xlabuilder" | ||
benchmarks "github.com/janpfeifer/go-benchmarks" | ||
"runtime" | ||
"testing" | ||
"time" | ||
"unsafe" | ||
) | ||
|
||
var ( | ||
flagBenchDuration = flag.Duration("bench_duration", 1*time.Second, "Benchmark duration") | ||
|
||
// testShapes used during benchmarks executing small computation graphs. | ||
testShapes = []xlabuilder.Shape{ | ||
xlabuilder.MakeShape(dtypes.Float32, 1, 1), | ||
xlabuilder.MakeShape(dtypes.Float32, 10, 10), | ||
xlabuilder.MakeShape(dtypes.Float32, 100, 100), | ||
xlabuilder.MakeShape(dtypes.Float32, 1000, 1000), | ||
} | ||
) | ||
|
||
// TestBenchCGO benchmarks a minimal CGO call. | ||
func TestBenchCGO(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
const repeats = 1000 | ||
repeatedCGO := func() { | ||
for _ = range repeats { | ||
dummyCGO(unsafe.Pointer(plugin.api)) | ||
} | ||
} | ||
benchmarks.New(benchmarks.NamedFunction{"CGOCall", repeatedCGO}). | ||
WithInnerRepeats(repeats). | ||
Done() | ||
} | ||
|
||
// Benchmark tests different methods to create temporary pointers to be passed to CGO. | ||
func TestBenchArena(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
numAllocationsList := []int{1, 5, 10, 100} | ||
allocations := make([]*int, 100) | ||
testFns := make([]benchmarks.NamedFunction, 4*len(numAllocationsList)) | ||
const repeats = 10 | ||
idxFn := 0 | ||
for _, allocType := range []string{"arena", "arenaPool", "malloc", "go+pinner"} { | ||
for _, numAllocations := range numAllocationsList { | ||
testFns[idxFn].Name = fmt.Sprintf("%s/%s/%d", t.Name(), allocType, numAllocations) | ||
var fn func() | ||
switch allocType { | ||
case "arena": | ||
fn = func() { | ||
for _ = range repeats { | ||
arena := newArena(1024) | ||
for idx := range numAllocations { | ||
allocations[idx] = arenaAlloc[int](arena) | ||
} | ||
dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) | ||
arena.Free() | ||
} | ||
} | ||
case "arenaPool": | ||
fn = func() { | ||
for _ = range repeats { | ||
arena := getArenaFromPool() | ||
for idx := range numAllocations { | ||
allocations[idx] = arenaAlloc[int](arena) | ||
} | ||
dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) | ||
returnArenaToPool(arena) | ||
} | ||
} | ||
case "malloc": | ||
fn = func() { | ||
for _ = range repeats { | ||
for idx := range numAllocations { | ||
allocations[idx] = cMalloc[int]() | ||
} | ||
dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) | ||
for idx := range numAllocations { | ||
cFree(allocations[idx]) | ||
} | ||
} | ||
} | ||
case "go+pinner": | ||
fn = func() { | ||
for _ = range repeats { | ||
var pinner runtime.Pinner | ||
for idx := range numAllocations { | ||
v := idx | ||
allocations[idx] = &v | ||
pinner.Pin(allocations[idx]) | ||
} | ||
dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) | ||
pinner.Unpin() | ||
} | ||
} | ||
} | ||
testFns[idxFn].Func = fn | ||
idxFn++ | ||
} | ||
} | ||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
WithWarmUps(10). | ||
Done() | ||
} | ||
|
||
// TestBenchBufferFromHost benchmarks host->buffer transfer time. | ||
func TestBenchBufferFromHost(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
const repeats = 10 | ||
numShapes := len(testShapes) | ||
inputData := make([][]float32, numShapes) | ||
testFns := make([]benchmarks.NamedFunction, numShapes) | ||
for shapeIdx, s := range testShapes { | ||
inputData[shapeIdx] = make([]float32, s.Size()) | ||
for i := 0; i < s.Size(); i++ { | ||
inputData[shapeIdx][i] = float32(i) | ||
} | ||
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) | ||
testFns[shapeIdx].Func = func() { | ||
for _ = range repeats { | ||
x := inputData[shapeIdx] | ||
s := testShapes[shapeIdx] | ||
buf := must1(ArrayToBuffer(client, x, s.Dimensions...)) | ||
must(buf.Destroy()) | ||
} | ||
} | ||
} | ||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
WithDuration(*flagBenchDuration). | ||
Done() | ||
} | ||
|
||
// TestBenchBufferToHost benchmarks time to transfer data from device buffer to host. | ||
// | ||
// Results on CPU: | ||
// | ||
// Benchmarks: Median 5%-tile 99%-tile | ||
// TestBenchBufferToHost/shape=(Float32)[1 1] 1.684µs 1.555µs 3.762µs | ||
// TestBenchBufferToHost/shape=(Float32)[10 10] 1.651µs 1.534µs 3.699µs | ||
// TestBenchBufferToHost/shape=(Float32)[100 100] 5.393µs 5.002µs 7.271µs | ||
// TestBenchBufferToHost/shape=(Float32)[1000 1000] 131.826µs 131.498µs 139.316µs | ||
func TestBenchBufferToHost(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
const repeats = 10 | ||
numShapes := len(testShapes) | ||
testFns := make([]benchmarks.NamedFunction, numShapes) | ||
// Prepare output data (host destination array) and upload buffers to GPU | ||
outputData := make([][]float32, numShapes) | ||
buffers := make([]*Buffer, numShapes) | ||
for shapeIdx, s := range testShapes { | ||
outputData[shapeIdx] = make([]float32, s.Size()) | ||
for i := 0; i < s.Size(); i++ { | ||
outputData[shapeIdx][i] = float32(i) | ||
} | ||
buffers[shapeIdx] = must1(ArrayToBuffer(client, outputData[shapeIdx], s.Dimensions...)) | ||
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) | ||
testFns[shapeIdx].Func = func() { | ||
for _ = range repeats { | ||
buf := buffers[shapeIdx] | ||
rawData := unsafe.Slice((*byte)(unsafe.Pointer(&outputData[shapeIdx][0])), len(outputData[shapeIdx])*int(unsafe.Sizeof(outputData[shapeIdx][0]))) | ||
must(buf.ToHost(rawData)) | ||
} | ||
} | ||
} | ||
defer func() { | ||
for _, buf := range buffers { | ||
must(buf.Destroy()) | ||
} | ||
}() | ||
|
||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
Done() | ||
} | ||
|
||
// BenchmarkAdd1Execution benchmarks the execution time for a minimal program. | ||
func TestBenchAdd1Execution(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
// Prepare input data, the uploaded buffers and the executables. | ||
const repeats = 10 | ||
numShapes := len(testShapes) | ||
execs := make([]*LoadedExecutable, numShapes) | ||
inputData := make([][]float32, numShapes) | ||
buffers := make([]*Buffer, numShapes) | ||
testFns := make([]benchmarks.NamedFunction, numShapes) | ||
for shapeIdx, s := range testShapes { | ||
inputData[shapeIdx] = make([]float32, s.Size()) | ||
for i := 0; i < s.Size(); i++ { | ||
inputData[shapeIdx][i] = float32(i) | ||
} | ||
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) | ||
|
||
builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) | ||
// f(x) = x + 1 | ||
x := must1(xlabuilder.Parameter(builder, "x", 0, s)) | ||
one := must1(xlabuilder.ScalarOne(builder, s.DType)) | ||
add1 := must1(xlabuilder.Add(x, one)) | ||
comp := must1(builder.Build(add1)) | ||
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) | ||
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) | ||
testFns[shapeIdx].Func = func() { | ||
for _ = range repeats { | ||
buf := buffers[shapeIdx] | ||
exec := execs[shapeIdx] | ||
output := must1(exec.Execute(buf).Done())[0] | ||
must(output.Destroy()) | ||
} | ||
} | ||
} | ||
defer func() { | ||
// Clean up -- and don't wait for the GC. | ||
for shapeIdx := range numShapes { | ||
must(buffers[shapeIdx].Destroy()) | ||
must(execs[shapeIdx].Destroy()) | ||
} | ||
}() | ||
|
||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
WithWarmUps(100). | ||
WithDuration(*flagBenchDuration). | ||
Done() | ||
} | ||
|
||
// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2. | ||
// | ||
// Runtimes for cpu: | ||
// | ||
// Benchmarks: Median 5%-tile 99%-tile | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs | ||
func TestBenchAdd1Div2Execution(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
// Prepare input data, the uploaded buffers and the executables. | ||
const repeats = 10 | ||
numShapes := len(testShapes) | ||
execs := make([]*LoadedExecutable, numShapes) | ||
inputData := make([][]float32, numShapes) | ||
buffers := make([]*Buffer, numShapes) | ||
testFns := make([]benchmarks.NamedFunction, numShapes) | ||
for shapeIdx, s := range testShapes { | ||
inputData[shapeIdx] = make([]float32, s.Size()) | ||
for i := 0; i < s.Size(); i++ { | ||
inputData[shapeIdx][i] = float32(i) | ||
} | ||
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) | ||
|
||
builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) | ||
// f(x) = x + 1 | ||
x := must1(xlabuilder.Parameter(builder, "x", 0, s)) | ||
one := must1(xlabuilder.ScalarOne(builder, s.DType)) | ||
add1 := must1(xlabuilder.Add(x, one)) | ||
half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5)))) | ||
div2 := must1(xlabuilder.Mul(add1, half)) | ||
comp := must1(builder.Build(div2)) | ||
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) | ||
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) | ||
testFns[shapeIdx].Func = func() { | ||
for _ = range repeats { | ||
buf := buffers[shapeIdx] | ||
exec := execs[shapeIdx] | ||
output := must1(exec.Execute(buf).Done())[0] | ||
_ = output.Destroy() | ||
} | ||
} | ||
} | ||
defer func() { | ||
// Clean up -- and don't wait for the GC. | ||
for shapeIdx := range numShapes { | ||
must(buffers[shapeIdx].Destroy()) | ||
must(execs[shapeIdx].Destroy()) | ||
} | ||
}() | ||
|
||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
WithWarmUps(100). | ||
WithDuration(*flagBenchDuration). | ||
Done() | ||
} | ||
|
||
// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2. | ||
// | ||
// Runtimes for cpu: | ||
// | ||
// Benchmarks: Median 5%-tile 99%-tile | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs | ||
// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs | ||
func TestBenchMeanNormalizedExecution(t *testing.T) { | ||
if testing.Short() { | ||
t.SkipNow() | ||
} | ||
plugin := must1(GetPlugin(*flagPluginName)) | ||
client := must1(plugin.NewClient(nil)) | ||
defer runtime.KeepAlive(client) | ||
|
||
// Prepare input data, the uploaded buffers and the executables. | ||
const repeats = 10 | ||
numShapes := len(testShapes) | ||
execs := make([]*LoadedExecutable, numShapes) | ||
inputData := make([][]float32, numShapes) | ||
buffers := make([]*Buffer, numShapes) | ||
testFns := make([]benchmarks.NamedFunction, numShapes) | ||
for shapeIdx, s := range testShapes { | ||
inputData[shapeIdx] = make([]float32, s.Size()) | ||
for i := 0; i < s.Size(); i++ { | ||
inputData[shapeIdx][i] = float32(i) | ||
} | ||
buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) | ||
|
||
builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) | ||
// f(x) = x + 1 | ||
x := must1(xlabuilder.Parameter(builder, "x", 0, s)) | ||
one := must1(xlabuilder.ScalarOne(builder, s.DType)) | ||
add1 := must1(xlabuilder.Add(x, one)) | ||
half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5)))) | ||
div2 := must1(xlabuilder.Mul(add1, half)) | ||
mean := must1(xlabuilder.ReduceSum(div2)) | ||
normalized := must1(xlabuilder.Sub(div2, mean)) | ||
|
||
comp := must1(builder.Build(normalized)) | ||
execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) | ||
testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) | ||
testFns[shapeIdx].Func = func() { | ||
for _ = range repeats { | ||
buf := buffers[shapeIdx] | ||
exec := execs[shapeIdx] | ||
output := must1(exec.Execute(buf).Done())[0] | ||
_ = output.Destroy() | ||
} | ||
} | ||
} | ||
defer func() { | ||
// Clean up -- and don't wait for the GC. | ||
for shapeIdx := range numShapes { | ||
must(buffers[shapeIdx].Destroy()) | ||
must(execs[shapeIdx].Destroy()) | ||
} | ||
}() | ||
|
||
benchmarks.New(testFns...). | ||
WithInnerRepeats(repeats). | ||
WithWarmUps(100). | ||
WithDuration(*flagBenchDuration). | ||
Done() | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
package pjrt | ||
|
||
/* | ||
#include "pjrt_c_api.h" | ||
#include "gen_api_calls.h" | ||
#include "gen_new_struct.h" | ||
PJRT_Error* BufferFromHostAndWait(const PJRT_Api *api, PJRT_Client_BufferFromHostBuffer_Args *args) { | ||
PJRT_Error* err = api->PJRT_Client_BufferFromHostBuffer(args); | ||
if (err) { | ||
return err; | ||
} | ||
PJRT_Event_Await_Args event_args = {0}; | ||
event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; | ||
event_args.event = args->done_with_host_buffer; | ||
err = api->PJRT_Event_Await(&event_args); | ||
PJRT_Event_Destroy_Args efree_args; | ||
efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; | ||
efree_args.event = args->done_with_host_buffer; | ||
api->PJRT_Event_Destroy(&efree_args); | ||
return err; | ||
} | ||
PJRT_Error *dummy_error; | ||
PJRT_Error *Dummy(void *api) { | ||
if (api == NULL) { | ||
return NULL; | ||
} | ||
return dummy_error; | ||
} | ||
*/ | ||
import "C" | ||
import ( | ||
"github.com/gomlx/gopjrt/dtypes" | ||
"github.com/pkg/errors" | ||
"reflect" | ||
"runtime" | ||
"slices" | ||
"unsafe" | ||
) | ||
|
||
// BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is | ||
// created with Client.BufferFromHost. | ||
// | ||
// The data to transfer from host can be set up with one of the following methods: | ||
// | ||
// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). | ||
// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions). | ||
// | ||
// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum. | ||
// | ||
// At the end call BufferFromHostConfig.Done to actually initiate the transfer. | ||
// | ||
// TODO: Implement async transfers, arbitrary memory layout, etc. | ||
type BufferFromHostConfig struct { | ||
client *Client | ||
data []byte | ||
dtype dtypes.DType | ||
dimensions []int | ||
device *Device | ||
|
||
hostBufferSemantics PJRT_HostBufferSemantics | ||
|
||
// err stores the first error that happened during configuration. | ||
// If it is not nil, it is immediately returned by the Done call. | ||
err error | ||
} | ||
|
||
// FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant) | ||
// during the call. The parameters dtype and dimensions provide the shape of the array. | ||
func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig { | ||
if b.err != nil { | ||
return b | ||
} | ||
b.data = data | ||
b.dtype = dtype | ||
b.dimensions = dimensions | ||
return b | ||
} | ||
|
||
// ToDevice configures which device to copy the host data to. | ||
// | ||
// If left un-configured, it will pick the first device returned by Client.AddressableDevices. | ||
// | ||
// You can also provide a device by their index in Client.AddressableDevices. | ||
func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig { | ||
if b.err != nil { | ||
return b | ||
} | ||
if device == nil { | ||
b.err = errors.New("BufferFromHost().ToDevice() given a nil device") | ||
return b | ||
} | ||
addressable, err := device.IsAddressable() | ||
if err != nil { | ||
b.err = errors.WithMessagef(err, "BufferFromHost().ToDevice() failed to check whether device is addressable") | ||
return b | ||
} | ||
if !addressable { | ||
b.err = errors.New("BufferFromHost().ToDevice() given a non addressable device") | ||
return b | ||
} | ||
b.device = device | ||
return b | ||
} | ||
|
||
// ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the | ||
// list returned by Client.AddressableDevices. | ||
// | ||
// If left un-configured, it will pick the first device returned by Client.AddressableDevices. | ||
// | ||
// You can also provide a device by their index in Client.AddressableDevices. | ||
func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig { | ||
if b.err != nil { | ||
return b | ||
} | ||
if deviceNum < 0 || deviceNum >= len(b.client.addressableDevices) { | ||
b.err = errors.Errorf("BufferFromHost().ToDeviceNum() invalid deviceNum=%d, only %d addressable devices available", deviceNum, len(b.client.addressableDevices)) | ||
return b | ||
} | ||
return b.ToDevice(b.client.addressableDevices[deviceNum]) | ||
} | ||
|
||
// FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying | ||
// dimensions. | ||
// The flat slice size must match the product of the dimension. | ||
// If no dimensions are given, it is assumed to be a scalar, and flat should have length 1. | ||
func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig { | ||
if b.err != nil { | ||
return b | ||
} | ||
// Checks dimensions. | ||
expectedSize := 1 | ||
for _, dim := range dimensions { | ||
if dim <= 0 { | ||
b.err = errors.Errorf("FromFlatDataWithDimensions cannot be given zero or negative dimensions, got %v", dimensions) | ||
return b | ||
} | ||
expectedSize *= dim | ||
} | ||
|
||
// Check the flat slice has the right shape. | ||
flatV := reflect.ValueOf(flat) | ||
if flatV.Kind() != reflect.Slice { | ||
b.err = errors.Errorf("FromFlatDataWithDimensions was given a %s for flat, but it requires a slice", flatV.Kind()) | ||
return b | ||
} | ||
if flatV.Len() != expectedSize { | ||
b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions=%v) needs %d values to match dimensions, but got len(flat)=%d", dimensions, expectedSize, flatV.Len()) | ||
return b | ||
} | ||
|
||
// Check validity of the slice elements type. | ||
element0 := flatV.Index(0) | ||
element0Type := element0.Type() | ||
dtype := dtypes.FromGoType(element0Type) | ||
if dtype == dtypes.InvalidDType { | ||
b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions%v) got flat=[]%s, expected a slice of a Go tyep that can be converted to a valid DType", dimensions, element0Type) | ||
return b | ||
} | ||
|
||
// Create slice of bytes and use b.FromRawData. | ||
sizeBytes := uintptr(flatV.Len()) * element0Type.Size() | ||
data := unsafe.Slice((*byte)(element0.Addr().UnsafePointer()), sizeBytes) | ||
return b.FromRawData(data, dtype, dimensions) | ||
} | ||
|
||
// Done will use the configuration to start the transfer from host to device. | ||
// It's synchronous: it awaits the transfer to finish and then returns. | ||
func (b *BufferFromHostConfig) Done() (*Buffer, error) { | ||
if b.err != nil { | ||
// Return first error saved during configuration. | ||
return nil, b.err | ||
} | ||
if len(b.data) == 0 { | ||
return nil, errors.New("BufferFromHost requires one to configure the host data to transfer, none was configured.") | ||
} | ||
defer runtime.KeepAlive(b) | ||
|
||
// Makes sure program data is not moved around by the GC during the C/C++ call. | ||
var pinner runtime.Pinner | ||
defer pinner.Unpin() | ||
dataPtr := unsafe.SliceData(b.data) | ||
pinner.Pin(dataPtr) | ||
|
||
// Set default device. | ||
if b.device == nil { | ||
devices := b.client.AddressableDevices() | ||
if len(devices) == 0 { | ||
return nil, errors.New("BufferFromHost can't find addressable device to transfer to") | ||
} | ||
b.device = devices[0] | ||
} | ||
|
||
// Arena for memory allocations used by CGO. | ||
arena := getArenaFromPool() | ||
defer returnArenaToPool(arena) | ||
|
||
// Arguments to PJRT call. | ||
var args *C.PJRT_Client_BufferFromHostBuffer_Args | ||
args = arenaAlloc[C.PJRT_Client_BufferFromHostBuffer_Args](arena) | ||
args.struct_size = C.PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE | ||
args.client = b.client.client | ||
args.data = unsafe.Pointer(dataPtr) | ||
args._type = C.PJRT_Buffer_Type(b.dtype) | ||
args.num_dims = C.size_t(len(b.dimensions)) | ||
if len(b.dimensions) > 0 { | ||
dims := arenaAllocSlice[C.int64_t](arena, len(b.dimensions)) | ||
for ii, dim := range b.dimensions { | ||
dims[ii] = C.int64_t(dim) | ||
} | ||
args.dims = unsafe.SliceData(dims) | ||
} | ||
args.host_buffer_semantics = C.PJRT_HostBufferSemantics(b.hostBufferSemantics) | ||
args.device = b.device.cDevice | ||
err := toError(b.client.plugin, C.BufferFromHostAndWait(b.client.plugin.api, args)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
buffer := newBuffer(b.client, args.buffer) | ||
buffer.dims = slices.Clone(b.dimensions) | ||
buffer.dimsSet = true | ||
buffer.dtype = b.dtype | ||
buffer.dtypeSet = true | ||
return buffer, nil | ||
} | ||
|
||
// dummyCGO calls a minimal C function and doesn't do anything. | ||
// Here for the purpose of benchmarking CGO calls. | ||
func dummyCGO(pointer unsafe.Pointer) { | ||
_ = C.Dummy(pointer) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
package pjrt | ||
|
||
/* | ||
#include "pjrt_c_api.h" | ||
#include "gen_api_calls.h" | ||
#include "gen_new_struct.h" | ||
extern void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg); | ||
// OnDeleteSharedBuffer is a no-op. | ||
void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg) { | ||
return; | ||
} | ||
extern void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg); | ||
void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg) = &OnDeleteSharedBuffer; | ||
*/ | ||
import "C" | ||
import ( | ||
"github.com/gomlx/gopjrt/dtypes" | ||
"github.com/pkg/errors" | ||
"reflect" | ||
"slices" | ||
"unsafe" | ||
) | ||
|
||
// CreateViewOfDeviceBuffer creates a PJRT Buffer that is backed by storage on the same device given by the caller as flatData and shape. | ||
// Consider using the simpler API NewSharedBuffer. | ||
// | ||
// Different PJRT may have different requirements on alignment, but for the CPU PJRT the library provide | ||
// AlignedAlloc and AlignedFree, that can be used to allocate the aligned storage space. | ||
// | ||
// Example of how a typical usage where the same buffer is reused as input in loop: | ||
// | ||
// dtype := dtypes.Float32 | ||
// dimensions := []int{batchSize, sequenceLength, 384} | ||
// rawData := pjrt.AlignedAlloc(dtype.SizeForDimensions(dimensions...), pjrt.BufferAlignment) | ||
// defer pjrt.AlignedFree(rawData) | ||
// buf := client.CreateViewOfDeviceBuffer(rawData, dtype, dimensions) | ||
// flat := unsafe.Slice((*float32)(storage), batchSize*sequenceLength*384) | ||
// for _, batch := range batches { | ||
// // ... set flat values | ||
// // ... use buf as input when executing a PJRT program | ||
// } | ||
// | ||
// If device is not given (at most one can be given), the first device available for the client is used. | ||
// | ||
// The naming comes from PJRT and is unfortunate, since it's the name from PJRT's perspective (PJRT view of a | ||
// users device buffer). | ||
// Probably, it should have been better named by "ShareDeviceBuffer" or something similar. | ||
// | ||
// This may not be implemented for all hardware (or all PJRT plugins). | ||
// | ||
// This can be useful to avoid the copy of values, by mutating directly in the memory shared with PJRT, to be used | ||
// as input to a computation. | ||
// | ||
// See: dtypes.SizeForDimensions() to calculate the size for an arbitrary shape; AlignedAlloc, AlignedFree and | ||
// BufferAlignment (a constant with the required alignment size) to allocate and free aligned storage. | ||
func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, device ...*Device) (*Buffer, error) { | ||
var selectedDevice *Device | ||
if len(device) > 1 { | ||
return nil, errors.Errorf("only one device can be given to CreateViewOfDeviceBuffer, %d were given", len(device)) | ||
} else if len(device) == 1 { | ||
selectedDevice = device[0] | ||
} else { | ||
devices := c.AddressableDevices() | ||
if len(devices) == 0 { | ||
return nil, errors.New("CreateViewOfDeviceBuffer can't find addressable device to transfer to") | ||
} | ||
selectedDevice = devices[0] | ||
} | ||
|
||
// Arena for memory allocations used by CGO. | ||
arena := getArenaFromPool() | ||
defer returnArenaToPool(arena) | ||
|
||
// Arguments to PJRT call. | ||
var args *C.PJRT_Client_CreateViewOfDeviceBuffer_Args | ||
args = arenaAlloc[C.PJRT_Client_CreateViewOfDeviceBuffer_Args](arena) | ||
args.struct_size = C.PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE | ||
args.client = c.client | ||
args.device_buffer_ptr = rawData | ||
args.element_type = C.PJRT_Buffer_Type(dtype) | ||
args.num_dims = C.size_t(len(dimensions)) | ||
if args.num_dims > 0 { | ||
dims := arenaAllocSlice[C.int64_t](arena, int(args.num_dims)) | ||
for ii, dim := range dimensions { | ||
dims[ii] = C.int64_t(dim) | ||
} | ||
args.dims = unsafe.SliceData(dims) | ||
} | ||
args.device = selectedDevice.cDevice | ||
args.on_delete_callback = (*[0]byte)(unsafe.Pointer(C.OnDeleteSharedBufferPtr)) | ||
args.on_delete_callback_arg = nil | ||
err := toError(c.plugin, C.call_PJRT_Client_CreateViewOfDeviceBuffer(c.plugin.api, args)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
buffer := newBuffer(c, args.buffer) | ||
buffer.isShared = true | ||
buffer.dims = slices.Clone(dimensions) | ||
buffer.dimsSet = true | ||
buffer.dtype = dtype | ||
buffer.dtypeSet = true | ||
return buffer, nil | ||
} | ||
|
||
// NewSharedBuffer returns a buffer that can be used for execution and share the underlying | ||
// memory space with the host/local, which can be read and mutated directly. | ||
// | ||
// Shared buffers cannot be donated to executions. | ||
// | ||
// The buffer should not be mutated while it is used by an execution. | ||
// | ||
// When the buffer is finalized, the shared memory is also de-allocated. | ||
// | ||
// It returns a handle to the buffer and a slice of the corresponding data type pointing | ||
// to the shared data. | ||
func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error) { | ||
memorySize := uintptr(dtype.SizeForDimensions(dimensions...)) | ||
rawStorage := AlignedAlloc(memorySize, BufferAlignment) | ||
buffer, err = c.CreateViewOfDeviceBuffer(rawStorage, dtype, dimensions, device...) | ||
if err != nil { | ||
AlignedFree(rawStorage) | ||
err = errors.WithMessagef(err, "NewSharedBuffer failed creating new buffer") | ||
buffer = nil | ||
return | ||
} | ||
buffer.sharedRawStorage = rawStorage | ||
buffer.isShared = true | ||
|
||
goDType := dtype.GoType() | ||
flat = reflect.SliceAt(dtype.GoType(), rawStorage, int(memorySize/goDType.Size())).Interface() | ||
return | ||
} | ||
|
||
// IsShared returns whether this buffer was created with Client.NewSharedBuffer. | ||
// These buffers cannot be donated in execution. | ||
func (b *Buffer) IsShared() bool { | ||
return b.isShared | ||
} | ||
|
||
// UnsafePointer returns platform-dependent address for the given buffer that is often but | ||
// not guaranteed to be the physical/device address. | ||
// Consider using the more convenient DirectAccess. | ||
// | ||
// Probably, this should only be used by CPU plugins. | ||
// | ||
// To be on the safe side, only use this if Client.HasSharedBuffers is true. | ||
// It uses the undocumented PJRT_Buffer_UnsafePointer. | ||
func (b *Buffer) UnsafePointer() (unsafe.Pointer, error) { | ||
plugin := b.client.plugin | ||
|
||
// Arena for memory allocations used by CGO. | ||
arena := getArenaFromPool() | ||
defer returnArenaToPool(arena) | ||
|
||
// Arguments to PJRT call. | ||
var args *C.PJRT_Buffer_UnsafePointer_Args | ||
args = arenaAlloc[C.PJRT_Buffer_UnsafePointer_Args](arena) | ||
args.struct_size = C.PJRT_Buffer_UnsafePointer_Args_STRUCT_SIZE | ||
args.buffer = b.cBuffer | ||
err := toError(plugin, C.call_PJRT_Buffer_UnsafePointer(plugin.api, args)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return unsafe.Pointer(uintptr(args.buffer_pointer)), nil | ||
} | ||
|
||
// Data returns the flat slice pointing to the underlying storage data for the buffer. | ||
// | ||
// This is an undocumented feature of PJRT and likely only works for CPU platforms. | ||
// The flat slice returned is only valid while the buffer is alive. | ||
func (b *Buffer) Data() (flat any, err error) { | ||
var rawStorage unsafe.Pointer | ||
rawStorage, err = b.UnsafePointer() | ||
if err != nil { | ||
return nil, err | ||
} | ||
dims := b.dims | ||
dtype := b.dtype | ||
if !b.dimsSet { | ||
dims, err = b.Dimensions() | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
if !b.dtypeSet { | ||
dtype, err = b.DType() | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
numElements := 1 | ||
for _, dim := range dims { | ||
numElements *= dim | ||
} | ||
return reflect.SliceAt(dtype.GoType(), rawStorage, numElements).Interface(), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package pjrt | ||
|
||
import ( | ||
"github.com/pkg/errors" | ||
"runtime" | ||
"unsafe" | ||
) | ||
|
||
/* | ||
#include "pjrt_c_api.h" | ||
#include "gen_api_calls.h" | ||
#include "gen_new_struct.h" | ||
PJRT_Error* BufferToHost(const PJRT_Api *api, PJRT_Buffer *buffer, void *dst, int64_t dst_size, int rank) { | ||
PJRT_Buffer_ToHostBuffer_Args args = {0}; | ||
args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; | ||
args.src = buffer; | ||
args.dst = dst; | ||
args.dst_size = dst_size; | ||
PJRT_Buffer_MemoryLayout layout_args = {0}; | ||
layout_args.struct_size = PJRT_Buffer_MemoryLayout_STRUCT_SIZE; | ||
args.host_layout = &layout_args; | ||
layout_args.type = PJRT_Buffer_MemoryLayout_Type_Tiled; | ||
layout_args.tiled.minor_to_major_size = rank; | ||
int64_t minor_to_major[rank > 0 ? rank : 1]; | ||
if (rank > 0) { | ||
for (int axisIdx = 0; axisIdx < rank; axisIdx++) { | ||
minor_to_major[axisIdx] = rank - axisIdx - 1; | ||
} | ||
layout_args.tiled.minor_to_major = &minor_to_major[0]; | ||
} | ||
PJRT_Error* err = api->PJRT_Buffer_ToHostBuffer(&args); | ||
if (err) { | ||
return err; | ||
} | ||
PJRT_Event_Await_Args event_args = {0}; | ||
event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; | ||
event_args.event = args.event; | ||
err = api->PJRT_Event_Await(&event_args); | ||
PJRT_Event_Destroy_Args efree_args; | ||
efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; | ||
efree_args.event = args.event; | ||
api->PJRT_Event_Destroy(&efree_args); | ||
return err; | ||
} | ||
*/ | ||
import "C" | ||
|
||
// ToHost transfers the contents of buffer stored on device to the host. | ||
// The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned. | ||
// | ||
// This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to | ||
// reorganize the layout. | ||
func (b *Buffer) ToHost(dst []byte) error { | ||
defer runtime.KeepAlive(b) | ||
if b == nil || b.client.plugin == nil || b.cBuffer == nil { | ||
// Already destroyed ? | ||
return errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?") | ||
} | ||
|
||
// We'll need the buffer rank to set up the layout. | ||
dims, err := b.Dimensions() | ||
if err != nil { | ||
return err | ||
} | ||
rank := len(dims) | ||
|
||
dstBytes := unsafe.Pointer(unsafe.SliceData(dst)) | ||
var pinner runtime.Pinner | ||
pinner.Pin(dstBytes) | ||
defer pinner.Unpin() | ||
|
||
pErr := C.BufferToHost(b.client.plugin.api, b.cBuffer, dstBytes, C.int64_t(len(dst)), C.int(rank)) | ||
err = toError(b.client.plugin, pErr) | ||
if err != nil { | ||
return errors.WithMessage(err, "Failed to call PJRT_Buffer_ToHostBuffer to transfer the buffer to host") | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters