diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index f191ec8..8ccd98b 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -42,7 +42,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.22.x" + go-version: "1.23.x" - name: Install Gopjrt C library gomlx_xlabuilder and PJRT plugin shell: bash diff --git a/README.md b/README.md index 8e74e50..c02903d 100644 --- a/README.md +++ b/README.md @@ -329,7 +329,7 @@ Also, see [this blog post](https://opensource.googleblog.com/2024/03/pjrt-plugin Because of https://github.com/golang/go/issues/13467 : C API's cannot be exported across packages, even within the same repo. Even a function as simple as `func Add(a, b C.int) C.int` in one package cannot be called from another. So we need to wrap everything, and more than that, one cannot create separate sub-packages to handle separate concerns. - THis is also the reason the library `chelper.go` is copied in both `pjrt` and `xlabuilder` packages. + This is also the reason the library `chelper.go` is copied in both `pjrt` and `xlabuilder` packages. * **Why does PJRT spits out so much logging ? Can we disable it ?** This is a great question ... imagine if every library we use decided they also want to clutter our stderr? I have [an open question in Abseil about it](https://github.com/abseil/abseil-cpp/discussions/1700). @@ -340,6 +340,18 @@ Also, see [this blog post](https://opensource.googleblog.com/2024/03/pjrt-plugin before calling `pjrt.GetPlugin`. But it may have unintended consequences, if some other library is depending on the fd 2 to work, or if a real exceptional situation needs to be reported and is not. +## Environment Variables + +That help control or debug how **gopjrt** work: + +* `PJRT_PLUGIN_LIBRARY_PATH`: Path to search for PJRT plugins. **gopjrt** also searches in `/usr/local/lib/gomlx/pjrt`, + the standard library paths for the system and `$LD_LIBRARY_PATH`. +* `XLA_DEBUG_OPTIONS`: If set, it is parsed as a `DebugOptions` proto that + is passed during the JIT-compilation (`Client.Compile()`) of a computation graph. + It is not documented how it works in PJRT (e.g. I observed a great slow down when this is set, + even if set to the default values), but [the proto has some documentation](https://github.com/gomlx/gopjrt/blob/main/protos/xla.proto#L40). +* `GOPJRT_INSTALL_DIR` and `GOPJRT_NOSUDO`: used by the install scripts, see "Installing" section above. + ## Links to documentation * [Google Drive Directory with Design Docs](https://drive.google.com/drive/folders/18M944-QQPk1E34qRyIjkqDRDnpMa3miN): Some links are outdated or redirected, but very valuable information. diff --git a/c/WORKSPACE b/c/WORKSPACE index 65d6335..d7d85bc 100644 --- a/c/WORKSPACE +++ b/c/WORKSPACE @@ -21,11 +21,12 @@ http_archive( # Notice bazel.sh scrape the line below for the OpenXLA version, the format # of the line should remain the same (the hash in between quotes), or bazel.sh # must be changed accordingly. -OPENXLA_XLA_COMMIT_HASH = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" # From 2024-11-24 +# OPENXLA_XLA_COMMIT_HASH = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" # From 2024-11-24 +OPENXLA_XLA_COMMIT_HASH = "e2e8952ad0fac8833e9a78f9b3689e803ff8524f" # From 2024-12-11 http_archive( name = "xla", - sha256 = "a910124d546bc79edb685612edaa3d56153f0e0927f967e8defaf312b833d404", # From 2024-11-24 + sha256 = "5ec6919a25952fa790904983481ccb51ebbe20bbc53e15ddbb6d3e0b3aa3dfe1", # From 2024-12-11 strip_prefix = "xla-" + OPENXLA_XLA_COMMIT_HASH, urls = [ "https://github.com/openxla/xla/archive/{hash}.zip".format(hash = OPENXLA_XLA_COMMIT_HASH), diff --git a/chelper.go b/chelper.go index ab1ab26..3a63b0e 100644 --- a/chelper.go +++ b/chelper.go @@ -31,17 +31,15 @@ func cSizeOf[T any]() C.size_t { // It must be manually freed with cFree() by the user. func cMalloc[T any]() (ptr *T) { size := cSizeOf[T]() - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + cPtr := (*T)(C.calloc(1, size)) return cPtr } // cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero. // It must be manually freed with C.free() by the user. func cMallocArray[T any](n int) (ptr *T) { - size := cSizeOf[T]() * C.size_t(n) - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + size := cSizeOf[T]() + cPtr := (*T)(C.calloc(C.size_t(n), size)) return cPtr } diff --git a/cmd/run_coverage.sh b/cmd/run_coverage.sh index edf8c23..6d9b5b4 100755 --- a/cmd/run_coverage.sh +++ b/cmd/run_coverage.sh @@ -2,6 +2,6 @@ # Run this from the root of gopjrt repository to generate docs/coverage.out with the coverage data. -PACKAGE_COVERAGE="./pjrt ./xlabuilder" -go test -v -cover -coverprofile docs/coverage.out -coverpkg ${PACKAGE_COVERAGE} +PACKAGE_COVERAGE="github.com/gomlx/gopjrt/pjrt,github.com/gomlx/gopjrt/xlabuilder" +go test -cover -coverprofile docs/coverage.out -coverpkg="${PACKAGE_COVERAGE}" ./... -test.count=1 -test.short go tool cover -func docs/coverage.out -o docs/coverage.out \ No newline at end of file diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index a46b221..79598df 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,4 +1,4 @@ -# Next +# v0.5.0 - 2024/12/19 - Adding direct access to PJRT buffers for CPU. * Added `install_linux_amd64_amazonlinux.sh` and pre-built libraries for amazonlinux (built using old glibc support). * Fixed installation scripts: s/sudo/$_SUDO. Also made them more verbose. @@ -6,6 +6,26 @@ * Improved documentation on Nvidia GPU card detection, and error message if not found. * Updated GitHub action (`go.yaml`) to only change the README.md with the result of the change, if pushing to the `main` branch. +* Added `prjt.arena` to avoid costly allocations for CGO calls, and merged some of CGO calls for general speed-ups. + The following functions had > 50% improvements on their fixed-cost (measured on transfers with 1 value, and minimal programs) + execution time (**not the variable part**): + * `Buffer.ToHost()` + * `Client.BufferFromHost()` + * `LoadedExecutable.Execute()` +* Added `BufferToHost` and `BufferFromHost` benchmarks. +* Added support for environment variable `XLA_DEBUG_OPTIONS`: if set, it is parsed as a `DebugOptions` proto that + is passed to the JIT-compilation of a computation graph. +* `LoadedExecutable.Execute()` now waits for the end of the execution (by setting + `PJRT_LoadedExecutable_Execute_Args.device_complete_events`). + Previous behavior lead to odd behavior and was undefined (not documented). +* Package `dtypes`: + * Added tests; + * Added `SizeForDimensions()` to be used for dtypes that uses fractions of bytes (like 4 bits). +* Added `Client.NewSharedBuffer` (and the lower level `client.CreateViewOfDeviceBuffer()`) to create buffers with shared + memory with the host, for faster input. + * Added `AlignedAlloc` and `AlignedFree` required by `client.CreateViewOfDeviceBuffer`. +* Added `Buffer.Data` for direct access to a buffer's data. Undocumented in PJRT, and likely only works on CPU. +* Fixed coverage script. # v0.4.9 - 2024-11-25 diff --git a/docs/coverage.out b/docs/coverage.out index 2ede3de..cd7b984 100644 --- a/docs/coverage.out +++ b/docs/coverage.out @@ -1,125 +1,348 @@ -github.com/gomlx/gopjrt/pjrt/buffers.go:31: newBuffer 50.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:56: Destroy 90.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:72: Dimensions 78.6% -github.com/gomlx/gopjrt/pjrt/buffers.go:94: DType 76.9% -github.com/gomlx/gopjrt/pjrt/buffers.go:114: Device 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:133: Client 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:166: FromRawData 83.3% -github.com/gomlx/gopjrt/pjrt/buffers.go:181: ToDevice 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:208: ToDeviceNum 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:221: Done 82.2% -github.com/gomlx/gopjrt/pjrt/buffers.go:294: FromFlatDataWithDimensions 62.5% -github.com/gomlx/gopjrt/pjrt/buffers.go:335: ScalarToRaw 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:342: Size 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:365: ToHost 89.5% -github.com/gomlx/gopjrt/pjrt/buffers.go:428: BufferToScalar 100.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:443: ScalarToBuffer 100.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:458: ScalarToBufferOnDeviceNum 0.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:474: ArrayToBuffer 100.0% -github.com/gomlx/gopjrt/pjrt/buffers.go:479: BufferToArray 76.9% -github.com/gomlx/gopjrt/pjrt/buffers.go:522: ToFlatDataAndDimensions 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:17: pjrtClientPlatformName 85.7% -github.com/gomlx/gopjrt/pjrt/clients.go:28: pjrtClientPlatformVersion 85.7% -github.com/gomlx/gopjrt/pjrt/clients.go:39: pjrtClientProcessIndex 85.7% -github.com/gomlx/gopjrt/pjrt/clients.go:50: pjrtClientDevices 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:66: pjrtClientAddressableDevices 90.9% -github.com/gomlx/gopjrt/pjrt/clients.go:84: pjrtClientCompile 95.0% -github.com/gomlx/gopjrt/pjrt/clients.go:121: newClient 69.2% -github.com/gomlx/gopjrt/pjrt/clients.go:166: finalizeClient 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:174: Plugin 100.0% -github.com/gomlx/gopjrt/pjrt/clients.go:180: Destroy 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:196: String 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:212: Platform 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:217: PlatformVersion 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:223: ProcessIndex 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:229: Devices 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:238: AddressableDevices 100.0% -github.com/gomlx/gopjrt/pjrt/clients.go:246: NumForDevice 0.0% -github.com/gomlx/gopjrt/pjrt/clients.go:266: Compile 100.0% -github.com/gomlx/gopjrt/pjrt/clients.go:275: BufferFromHost 100.0% -github.com/gomlx/gopjrt/pjrt/common.go:6: keys 0.0% -github.com/gomlx/gopjrt/pjrt/compile.go:47: newCompileConfig 100.0% -github.com/gomlx/gopjrt/pjrt/compile.go:73: Done 75.9% -github.com/gomlx/gopjrt/pjrt/compile.go:133: WithHLO 62.5% -github.com/gomlx/gopjrt/pjrt/compile.go:154: WithStableHLO 0.0% -github.com/gomlx/gopjrt/pjrt/compile.go:193: WithComputation 40.0% -github.com/gomlx/gopjrt/pjrt/cuda.go:14: isCuda 100.0% -github.com/gomlx/gopjrt/pjrt/cuda.go:22: hasNvidiaGPU 0.0% -github.com/gomlx/gopjrt/pjrt/cuda.go:43: cudaPluginCheckDrivers 30.8% -github.com/gomlx/gopjrt/pjrt/devices.go:14: pjrtDeviceLocalHardwareId 85.7% -github.com/gomlx/gopjrt/pjrt/devices.go:25: pjrtDeviceDescriptionProcessIndex 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:62: newDevice 83.3% -github.com/gomlx/gopjrt/pjrt/devices.go:73: IsAddressable 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:86: LocalHardwareId 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:91: GetDescription 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:116: newDeviceDescription 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:131: ProcessIndex 0.0% -github.com/gomlx/gopjrt/pjrt/devices.go:140: DebugString 0.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:74: init 60.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:89: loadNamedPlugin 74.2% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:149: pathToPluginName 80.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:168: AvailablePlugins 0.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:172: searchPlugin 100.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:177: searchPlugins 84.2% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:213: checkPlugin 60.9% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:262: SuppressAbseilLoggingHack 0.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib.go:279: suppressLogging 0.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:63: osDefaultLibraryPaths 83.3% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:77: loadLibraryPaths 78.6% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:118: loadPlugin 51.7% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:168: GetPJRTApiFn 100.0% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:173: GetSymbolPointer 87.5% -github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:188: Close 83.3% -github.com/gomlx/gopjrt/pjrt/error.go:17: pjrtErrorDestroy 0.0% -github.com/gomlx/gopjrt/pjrt/error.go:25: pjrtErrorMessage 0.0% -github.com/gomlx/gopjrt/pjrt/error.go:37: pjrtErrorGetCode 0.0% -github.com/gomlx/gopjrt/pjrt/error.go:49: toError 33.3% -github.com/gomlx/gopjrt/pjrt/events.go:25: newEvent 50.0% -github.com/gomlx/gopjrt/pjrt/events.go:41: Destroy 90.0% -github.com/gomlx/gopjrt/pjrt/events.go:57: Await 85.7% -github.com/gomlx/gopjrt/pjrt/events.go:71: AwaitAndFree 0.0% -github.com/gomlx/gopjrt/pjrt/executables.go:25: newExecutable 75.0% -github.com/gomlx/gopjrt/pjrt/executables.go:36: Destroy 0.0% -github.com/gomlx/gopjrt/pjrt/executables.go:52: destroyOrLog 0.0% -github.com/gomlx/gopjrt/pjrt/executables.go:60: NumOutputs 80.0% -github.com/gomlx/gopjrt/pjrt/executables.go:76: Name 80.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:21: cFree 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:27: cSizeOf 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:34: cMalloc 0.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:43: cMallocArray 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:52: cMallocArrayFromSlice 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:62: cMallocArrayAndSet 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:73: cDataToSlice 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:79: cCharArray 100.0% -github.com/gomlx/gopjrt/pjrt/gen_chelper.go:86: cStrFree 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:36: newLoadedExecutable 58.8% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:66: Destroy 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:85: destroyOrLog 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:93: getExecutable 80.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:121: Execute 100.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:151: OnDevices 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:186: OnDevicesByNum 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:210: DonateAll 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:219: DonateNone 100.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:233: Donate 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:244: SetDonate 0.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:261: Done 88.6% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:318: allocatePerDeviceBufferList 88.9% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:338: freePerDeviceBufferList 100.0% -github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:347: gatherPerDeviceBufferList 100.0% -github.com/gomlx/gopjrt/pjrt/namedvalues.go:20: pjrtNamedValuesToMap 66.7% -github.com/gomlx/gopjrt/pjrt/namedvalues.go:50: mallocArrayPJRT_NamedValue 6.7% -github.com/gomlx/gopjrt/pjrt/namedvalues.go:101: destroyPJRT_NamedValue 0.0% -github.com/gomlx/gopjrt/pjrt/pjrt.go:17: panicf 0.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:45: pjrtPluginInitialize 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:53: pjrtPluginAttributes 87.5% -github.com/gomlx/gopjrt/pjrt/plugins.go:67: newPlugin 75.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:93: RegisterPreloadedPlugin 0.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:111: GetPlugin 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:116: Name 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:121: Path 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:126: Version 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:131: Attributes 100.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:136: String 75.0% -github.com/gomlx/gopjrt/pjrt/plugins.go:146: NewClient 100.0% -total: (statements) 53.0% +github.com/gomlx/gopjrt/pjrt/alignedalloc.go:22: AlignedAlloc 85.7% +github.com/gomlx/gopjrt/pjrt/alignedalloc.go:50: AlignedFree 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:40: getArenaFromPool 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:46: returnArenaToPool 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:58: newArena 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:70: arenaAlloc 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:84: arenaAllocSlice 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:97: Free 100.0% +github.com/gomlx/gopjrt/pjrt/arena.go:106: Reset 71.4% +github.com/gomlx/gopjrt/pjrt/buffers.go:41: BuffersAlive 0.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:46: newBuffer 85.7% +github.com/gomlx/gopjrt/pjrt/buffers.go:73: Destroy 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:100: Dimensions 84.2% +github.com/gomlx/gopjrt/pjrt/buffers.go:128: DType 84.2% +github.com/gomlx/gopjrt/pjrt/buffers.go:155: Device 78.6% +github.com/gomlx/gopjrt/pjrt/buffers.go:176: Client 0.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:181: ScalarToRaw 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:188: Size 0.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:212: BufferToScalar 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:222: ScalarToBuffer 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:232: ScalarToBufferOnDeviceNum 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:243: ArrayToBuffer 100.0% +github.com/gomlx/gopjrt/pjrt/buffers.go:248: BufferToArray 86.4% +github.com/gomlx/gopjrt/pjrt/buffers.go:285: ToFlatDataAndDimensions 85.7% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:74: FromRawData 83.3% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:89: ToDevice 50.0% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:116: ToDeviceNum 50.0% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:131: FromFlatDataWithDimensions 62.5% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:173: Done 89.7% +github.com/gomlx/gopjrt/pjrt/buffers_from_host.go:234: dummyCGO 0.0% +github.com/gomlx/gopjrt/pjrt/buffers_shared.go:60: CreateViewOfDeviceBuffer 88.9% +github.com/gomlx/gopjrt/pjrt/buffers_shared.go:120: NewSharedBuffer 69.2% +github.com/gomlx/gopjrt/pjrt/buffers_shared.go:140: IsShared 100.0% +github.com/gomlx/gopjrt/pjrt/buffers_shared.go:152: UnsafePointer 90.9% +github.com/gomlx/gopjrt/pjrt/buffers_shared.go:175: Data 83.3% +github.com/gomlx/gopjrt/pjrt/buffers_to_host.go:58: ToHost 81.2% +github.com/gomlx/gopjrt/pjrt/clients.go:17: pjrtClientPlatformName 85.7% +github.com/gomlx/gopjrt/pjrt/clients.go:28: pjrtClientPlatformVersion 85.7% +github.com/gomlx/gopjrt/pjrt/clients.go:39: pjrtClientProcessIndex 85.7% +github.com/gomlx/gopjrt/pjrt/clients.go:50: pjrtClientDevices 90.9% +github.com/gomlx/gopjrt/pjrt/clients.go:66: pjrtClientAddressableDevices 90.9% +github.com/gomlx/gopjrt/pjrt/clients.go:84: pjrtClientCompile 95.0% +github.com/gomlx/gopjrt/pjrt/clients.go:122: newClient 69.2% +github.com/gomlx/gopjrt/pjrt/clients.go:167: finalizeClient 0.0% +github.com/gomlx/gopjrt/pjrt/clients.go:175: Plugin 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:181: Destroy 90.0% +github.com/gomlx/gopjrt/pjrt/clients.go:197: String 75.0% +github.com/gomlx/gopjrt/pjrt/clients.go:213: Platform 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:218: PlatformVersion 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:224: ProcessIndex 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:230: Devices 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:239: AddressableDevices 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:247: NumForDevice 75.0% +github.com/gomlx/gopjrt/pjrt/clients.go:267: Compile 100.0% +github.com/gomlx/gopjrt/pjrt/clients.go:276: BufferFromHost 100.0% +github.com/gomlx/gopjrt/pjrt/common.go:6: keys 0.0% +github.com/gomlx/gopjrt/pjrt/compile.go:53: newCompileConfig 31.2% +github.com/gomlx/gopjrt/pjrt/compile.go:100: Done 75.9% +github.com/gomlx/gopjrt/pjrt/compile.go:160: WithHLO 62.5% +github.com/gomlx/gopjrt/pjrt/compile.go:181: WithStableHLO 0.0% +github.com/gomlx/gopjrt/pjrt/compile.go:220: WithComputation 40.0% +github.com/gomlx/gopjrt/pjrt/cuda.go:16: isCuda 100.0% +github.com/gomlx/gopjrt/pjrt/cuda.go:26: hasNvidiaGPU 33.3% +github.com/gomlx/gopjrt/pjrt/cuda.go:68: cudaPluginCheckDrivers 30.8% +github.com/gomlx/gopjrt/pjrt/devices.go:14: pjrtDeviceLocalHardwareId 85.7% +github.com/gomlx/gopjrt/pjrt/devices.go:25: pjrtDeviceDescriptionProcessIndex 85.7% +github.com/gomlx/gopjrt/pjrt/devices.go:62: newDevice 83.3% +github.com/gomlx/gopjrt/pjrt/devices.go:73: IsAddressable 85.7% +github.com/gomlx/gopjrt/pjrt/devices.go:86: LocalHardwareId 100.0% +github.com/gomlx/gopjrt/pjrt/devices.go:91: GetDescription 85.7% +github.com/gomlx/gopjrt/pjrt/devices.go:116: newDeviceDescription 83.3% +github.com/gomlx/gopjrt/pjrt/devices.go:131: ProcessIndex 0.0% +github.com/gomlx/gopjrt/pjrt/devices.go:140: DebugString 85.7% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:74: init 60.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:89: loadNamedPlugin 87.1% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:149: pathToPluginName 80.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:168: AvailablePlugins 100.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:172: searchPlugin 100.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:177: searchPlugins 84.2% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:213: checkPlugin 60.9% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:262: SuppressAbseilLoggingHack 0.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib.go:279: suppressLogging 0.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:63: osDefaultLibraryPaths 83.3% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:77: loadLibraryPaths 78.6% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:118: loadPlugin 51.7% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:168: GetPJRTApiFn 100.0% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:173: GetSymbolPointer 87.5% +github.com/gomlx/gopjrt/pjrt/dynamiclib_dlopen.go:188: Close 83.3% +github.com/gomlx/gopjrt/pjrt/error.go:17: pjrtErrorDestroy 100.0% +github.com/gomlx/gopjrt/pjrt/error.go:25: pjrtErrorMessage 100.0% +github.com/gomlx/gopjrt/pjrt/error.go:37: pjrtErrorGetCode 100.0% +github.com/gomlx/gopjrt/pjrt/error.go:49: toError 100.0% +github.com/gomlx/gopjrt/pjrt/events.go:25: newEvent 0.0% +github.com/gomlx/gopjrt/pjrt/events.go:41: Destroy 0.0% +github.com/gomlx/gopjrt/pjrt/events.go:57: Await 0.0% +github.com/gomlx/gopjrt/pjrt/events.go:71: AwaitAndFree 0.0% +github.com/gomlx/gopjrt/pjrt/executables.go:25: newExecutable 100.0% +github.com/gomlx/gopjrt/pjrt/executables.go:36: Destroy 100.0% +github.com/gomlx/gopjrt/pjrt/executables.go:52: destroyOrLog 66.7% +github.com/gomlx/gopjrt/pjrt/executables.go:60: NumOutputs 80.0% +github.com/gomlx/gopjrt/pjrt/executables.go:76: Name 80.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:21: cFree 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:27: cSizeOf 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:34: cMalloc 0.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:42: cMallocArray 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:50: cMallocArrayFromSlice 0.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:60: cMallocArrayAndSet 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:71: cDataToSlice 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:77: cCharArray 100.0% +github.com/gomlx/gopjrt/pjrt/gen_chelper.go:84: cStrFree 0.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:66: LoadedExecutablesAlive 0.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:71: newLoadedExecutable 66.7% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:102: Destroy 100.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:123: destroyOrLog 66.7% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:131: getExecutable 80.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:159: Execute 100.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:189: OnDevices 52.6% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:224: OnDevicesByNum 61.5% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:248: DonateAll 0.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:257: DonateNone 100.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:271: Donate 100.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:282: SetDonate 0.0% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:299: Done 90.4% +github.com/gomlx/gopjrt/pjrt/loadedexecutables.go:391: allocatePerDeviceBufferListWithArena 76.9% +github.com/gomlx/gopjrt/pjrt/namedvalues.go:20: pjrtNamedValuesToMap 91.7% +github.com/gomlx/gopjrt/pjrt/namedvalues.go:50: mallocArrayPJRT_NamedValue 100.0% +github.com/gomlx/gopjrt/pjrt/namedvalues.go:101: destroyPJRT_NamedValue 100.0% +github.com/gomlx/gopjrt/pjrt/pjrt.go:17: panicf 0.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:45: pjrtPluginInitialize 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:53: pjrtPluginAttributes 87.5% +github.com/gomlx/gopjrt/pjrt/plugins.go:67: newPlugin 75.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:93: RegisterPreloadedPlugin 0.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:111: GetPlugin 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:116: Name 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:121: Path 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:126: Version 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:131: Attributes 100.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:136: String 75.0% +github.com/gomlx/gopjrt/pjrt/plugins.go:146: NewClient 100.0% +github.com/gomlx/gopjrt/xlabuilder/convolve.go:30: ConvGeneralDilated 88.5% +github.com/gomlx/gopjrt/xlabuilder/convolve.go:78: DecodeConvGeneralDilated 89.5% +github.com/gomlx/gopjrt/xlabuilder/errorcode_string.go:7: _ 0.0% +github.com/gomlx/gopjrt/xlabuilder/errorcode_string.go:35: String 0.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:21: cFree 100.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:27: cSizeOf 100.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:34: cMalloc 100.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:42: cMallocArray 100.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:50: cMallocArrayFromSlice 0.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:60: cMallocArrayAndSet 100.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:71: cDataToSlice 0.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:77: cCharArray 0.0% +github.com/gomlx/gopjrt/xlabuilder/gen_chelper.go:84: cStrFree 80.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:11: Abs 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:23: Neg 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:35: Exp 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:47: Expm1 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:59: Floor 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:71: Ceil 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:83: Round 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:95: Log 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:107: Log1p 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:119: LogicalNot 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:131: Logistic 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:143: Sign 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:155: Clz 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:167: Cos 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:179: Sin 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:191: Tanh 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:203: Sqrt 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:215: Rsqrt 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:227: Imag 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:239: Real 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:251: Conj 0.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:264: Add 80.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:283: Mul 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:302: Sub 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:321: Div 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:340: Rem 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:358: And 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:376: Or 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:394: Xor 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:425: Dot 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:443: Min 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:461: Max 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:479: Pow 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:502: Complex 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:520: Equal 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:538: NotEqual 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:556: GreaterOrEqual 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:574: GreaterThan 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:592: LessOrEqual 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:610: LessThan 80.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:632: EqualTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:654: NotEqualTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:676: GreaterOrEqualTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:698: GreaterThanTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:720: LessOrEqualTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:742: LessThanTotalOrder 70.0% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:760: Erf 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:774: IsFinite 83.3% +github.com/gomlx/gopjrt/xlabuilder/gen_simple_ops.go:786: PopulationCount 83.3% +github.com/gomlx/gopjrt/xlabuilder/literal.go:27: NewLiteralFromShape 66.7% +github.com/gomlx/gopjrt/xlabuilder/literal.go:47: NewArrayLiteral 72.7% +github.com/gomlx/gopjrt/xlabuilder/literal.go:70: Data 0.0% +github.com/gomlx/gopjrt/xlabuilder/literal.go:76: NewScalarLiteral 83.3% +github.com/gomlx/gopjrt/xlabuilder/literal.go:91: NewScalarLiteralFromFloat64 68.8% +github.com/gomlx/gopjrt/xlabuilder/literal.go:122: NewScalarLiteralFromAny 75.0% +github.com/gomlx/gopjrt/xlabuilder/literal.go:142: NewArrayLiteralFromAny 81.8% +github.com/gomlx/gopjrt/xlabuilder/literal.go:177: newLiteral 83.3% +github.com/gomlx/gopjrt/xlabuilder/literal.go:189: Destroy 100.0% +github.com/gomlx/gopjrt/xlabuilder/literal.go:199: IsNil 100.0% +github.com/gomlx/gopjrt/xlabuilder/literal.go:204: Shape 0.0% +github.com/gomlx/gopjrt/xlabuilder/op.go:60: newOp 100.0% +github.com/gomlx/gopjrt/xlabuilder/op.go:70: Builder 0.0% +github.com/gomlx/gopjrt/xlabuilder/op.go:75: opFinalizer 80.0% +github.com/gomlx/gopjrt/xlabuilder/op.go:87: serializeToC 100.0% +github.com/gomlx/gopjrt/xlabuilder/op.go:125: destroyCSerializedOp 100.0% +github.com/gomlx/gopjrt/xlabuilder/optype_string.go:7: _ 0.0% +github.com/gomlx/gopjrt/xlabuilder/optype_string.go:100: String 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:43: GetReduceComputationAndInitialValue 65.6% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:148: Reduce 80.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:164: simpleReduceImpl 72.7% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:185: ReduceMax 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:192: ReduceMin 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:199: ReduceSum 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:206: ReduceProduct 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:214: ReduceAnd 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:222: ReduceOr 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:253: ReduceWindow 45.5% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:276: standardReduction 62.5% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:294: Max 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:302: Min 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:310: Sum 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:318: Product 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:330: UseComputation 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:349: WithStrides 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:365: WithBaseDilations 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:381: WithWindowDilations 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:397: WithPadding 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:410: sliceWithValue 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:419: Done 69.2% +github.com/gomlx/gopjrt/xlabuilder/reduce.go:482: DecodeReduceWindow 100.0% +github.com/gomlx/gopjrt/xlabuilder/reduceoptype_string.go:7: _ 0.0% +github.com/gomlx/gopjrt/xlabuilder/reduceoptype_string.go:24: String 66.7% +github.com/gomlx/gopjrt/xlabuilder/shape.go:37: MakeShape 80.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:48: MakeShapeOrError 80.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:59: IsScalar 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:63: Rank 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:68: Size 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:78: Memory 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:83: Clone 87.5% +github.com/gomlx/gopjrt/xlabuilder/shape.go:98: TupleSize 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:103: String 50.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:121: cShapeFromShape 100.0% +github.com/gomlx/gopjrt/xlabuilder/shape.go:144: shapeFromCShape 93.3% +github.com/gomlx/gopjrt/xlabuilder/shape.go:168: Equal 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:22: Parameter 87.5% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:38: DecodeParameter 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:45: Tuple 83.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:56: GetTupleElement 85.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:68: DecodeGetTupleElement 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:73: SplitTuple 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:92: Iota 72.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:110: DecodeIota 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:118: Identity 85.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:130: Constant 70.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:148: ScalarZero 75.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:168: ScalarOne 75.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:187: ConvertDType 75.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:201: DecodeConvertDType 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:204: Where 70.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:224: Reshape 84.6% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:246: DecodeReshape 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:259: Broadcast 83.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:270: DecodeBroadcast 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:289: BroadcastInDim 66.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:329: DecodeBroadcastInDim 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:336: Transpose 73.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:363: DecodeTranspose 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:368: Call 75.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:387: Concatenate 73.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:413: DecodeConcatenate 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:426: Slice 88.2% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:454: DecodeSlice 85.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:466: boolToInt 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:484: ArgMinMax 81.8% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:501: DecodeArgMinMax 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:518: Pad 82.4% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:544: DecodePad 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:587: Gather 87.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:630: DecodeGather 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:662: ScatterCustom 81.8% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:708: ScatterAdd 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:717: ScatterMax 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:726: ScatterMin 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:734: scatterImpl 77.8% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:753: DecodeScatter 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:783: SelectAndScatterCustom 73.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:841: SelectAndScatterMax 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:850: SelectAndScatterMin 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:859: SelectAndScatterSum 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:867: selectAndScatterImpl 76.9% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:891: GetSelectAndScatterComputation 59.7% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:983: DecodeSelectAndScatter 78.9% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1026: DotGeneral 76.9% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1072: DecodeDotGeneral 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1092: Reverse 77.8% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1107: DecodeReverse 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1118: BatchNormForInference 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1131: DecodeBatchNormForInference 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1149: BatchNormForTraining 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1173: DecodeBatchNormForTraining 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1192: BatchNormGradient 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1216: DecodeBatchNormGrad 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1231: FFT 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1244: DecodeFFT 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1269: RngBitGenerator 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1297: DecodeRngBitGenerator 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1314: While 87.5% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1327: DecodeWhile 100.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1335: scalarStartIndices 75.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1376: DynamicSlice 83.3% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1394: DecodeDynamicSlice 0.0% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1407: DynamicUpdateSlice 76.9% +github.com/gomlx/gopjrt/xlabuilder/special_ops.go:1427: DecodeDynamicUpdateSlice 0.0% +github.com/gomlx/gopjrt/xlabuilder/status.go:31: unsafePointerOrError 66.7% +github.com/gomlx/gopjrt/xlabuilder/status.go:40: pointerOrError 83.3% +github.com/gomlx/gopjrt/xlabuilder/status.go:53: errorFromStatus 22.2% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:24: panicf 0.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:61: New 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:69: newXlaBuilder 75.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:81: IsNil 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:87: Destroy 87.5% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:102: Name 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:108: addOp 84.0% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:149: Build 77.8% +github.com/gomlx/gopjrt/xlabuilder/xlabuilder.go:171: CreateSubBuilder 88.9% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:34: newXlaComputation 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:45: Destroy 75.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:55: IsNil 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:60: Name 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:70: SerializedHLO 75.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:84: HasStableHLO 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:89: HasStableHLO 100.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:99: SerializedStableHLO 0.0% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:121: TextHLO 66.7% +github.com/gomlx/gopjrt/xlabuilder/xlacomputation.go:133: TextStableHLO 0.0% +total: (statements) 70.0% diff --git a/dtypes/dtypes.go b/dtypes/dtypes.go index 22ee8ba..0ca3710 100644 --- a/dtypes/dtypes.go +++ b/dtypes/dtypes.go @@ -152,10 +152,30 @@ func FromAny(value any) DType { } // Size returns the number of bytes for the given DType. +// If the size is < 1 (like a 4-bits quantity), consider the SizeForDimensions method. func (dtype DType) Size() int { return int(dtype.GoType().Size()) } +// SizeForDimensions returns the size in bytes used for the given dimensions. +// This is a safer method than Size in case the dtype uses an underlying size that is not multiple of 8 bits. +// +// It works also for scalar (one element), where dimensions list is empty. +func (dtype DType) SizeForDimensions(dimensions ...int) int { + numElements := 1 + for _, dim := range dimensions { + if dim <= 0 { + panicf("cannot use dim <= 0 for SizeForDimensions, got %v", dimensions) + } + numElements *= dim + } + + // Switch case for dtypes with size not multiple of 8 bits (1 byte). + + // Default is simply the number of elements times the size in bytes per element. + return numElements * dtype.Size() +} + // Memory returns the number of bytes for the given DType. // It's an alias to Size, converted to uintptr. func (dtype DType) Memory() uintptr { diff --git a/dtypes/dtypes_test.go b/dtypes/dtypes_test.go index d7f033e..670f555 100644 --- a/dtypes/dtypes_test.go +++ b/dtypes/dtypes_test.go @@ -33,3 +33,22 @@ func TestMapOfNames(t *testing.T) { require.Equal(t, BFloat16, MapOfNames["BF16"]) require.Equal(t, BFloat16, MapOfNames["bf16"]) } + +func TestFromAny(t *testing.T) { + require.Equal(t, Int64, FromAny(int64(7))) + require.Equal(t, Float32, FromAny(float32(13))) + require.Equal(t, BFloat16, FromAny(bfloat16.FromFloat32(1.0))) + require.Equal(t, Float16, FromAny(float16.Fromfloat32(3.0))) +} + +func TestSize(t *testing.T) { + require.Equal(t, 8, Int64.Size()) + require.Equal(t, 4, Float32.Size()) + require.Equal(t, 2, BFloat16.Size()) +} + +func TestSizeForDimensions(t *testing.T) { + require.Equal(t, 2*3*8, Int64.SizeForDimensions(2, 3)) + require.Equal(t, 4, Float32.SizeForDimensions()) + require.Equal(t, 2, BFloat16.SizeForDimensions(1, 1, 1)) +} diff --git a/go.mod b/go.mod index 2a246b6..2cfbcab 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,11 @@ go 1.23 require ( github.com/chewxy/math32 v1.11.1 + github.com/janpfeifer/go-benchmarks v0.1.1 github.com/janpfeifer/gonb v0.10.6 github.com/janpfeifer/must v0.2.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/x448/float16 v0.8.4 google.golang.org/protobuf v1.35.2 k8s.io/klog/v2 v2.130.1 @@ -18,6 +19,7 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d // indirect golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f10e171..f09ea1b 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1 github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= +github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM= github.com/janpfeifer/gonb v0.10.6 h1:DyzaE8lLtJRrzu84N/XuUb0OC71EsIf9ZZhZc4mWxwA= github.com/janpfeifer/gonb v0.10.6/go.mod h1:ZX+93yQa2s3td0JOML7x+Jo4mmuwi+XryQ0iQvrSpBQ= github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ= @@ -16,8 +18,10 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= +github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= diff --git a/pjrt/alignedalloc.go b/pjrt/alignedalloc.go new file mode 100644 index 0000000..7d6c2d1 --- /dev/null +++ b/pjrt/alignedalloc.go @@ -0,0 +1,54 @@ +package pjrt + +// This file defines an alignedAlloc and alignedFree, modelled after mm_malloc. + +/* +#include +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +// BufferAlignment is the default alignment required for memory shared with CPU PJRT. +// See AlignedAlloc and FreeAlloc. +const BufferAlignment = 64 + +// AlignedAlloc assumes that malloc/calloc already aligns to 8 bytes. And that alignment is a multiple of 8. +// The pointer returned must be freed with AlignedFree. +// +// The allocation is filled with 0s. +func AlignedAlloc(size, alignment uintptr) unsafe.Pointer { + if alignment < 8 || alignment%8 != 0 { + panic(fmt.Sprintf("alignedAlloc: alignment must be a multiple of 8, got %d", alignment)) + } + + // It uses a strategy of allocating extra to allow the alignment, and it stores the pointer to the + // original allocation just before the alignedPtr. + totalSize := size + alignment + ptr := unsafe.Pointer(C.calloc(C.size_t(totalSize), C.size_t(1))) + if ptr == nil { + return nil + } + + alignedPtr := ptr + offset := uintptr(ptr) % alignment + if offset != 0 { + alignedPtr = unsafe.Pointer(uintptr(ptr) + (alignment - offset)) + } else { + alignedPtr = unsafe.Pointer(uintptr(ptr) + alignment) // This way we have the space to save the original ptr. + } + + originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(alignedPtr) - unsafe.Sizeof(uintptr(0)))) + *originalPtrPtr = uintptr(ptr) + + return alignedPtr +} + +// AlignedFree frees an allocation created with AlignedAlloc. +func AlignedFree(ptr unsafe.Pointer) { + originalPtrPtr := (*uintptr)(unsafe.Pointer(uintptr(ptr) - unsafe.Sizeof(uintptr(0)))) + originalPtr := unsafe.Pointer(*originalPtrPtr) + C.free(originalPtr) +} diff --git a/pjrt/alignedalloc_test.go b/pjrt/alignedalloc_test.go new file mode 100644 index 0000000..0d3a047 --- /dev/null +++ b/pjrt/alignedalloc_test.go @@ -0,0 +1,25 @@ +package pjrt + +import ( + "math/rand/v2" + "testing" + "unsafe" +) + +func TestAlignedAlloc(t *testing.T) { + rng := rand.New(rand.NewPCG(42, 42)) + numLivePointers := 1_000 + maxAllocSize := 1_000 + pointers := make([]unsafe.Pointer, numLivePointers) + for _ = range 1_000_000 { + idx := rng.IntN(numLivePointers) + if pointers[idx] != nil { + AlignedFree(pointers[idx]) + } + size := uintptr(rng.IntN(maxAllocSize)) + pointers[idx] = AlignedAlloc(size, BufferAlignment) + } + for _, ptr := range pointers { + AlignedFree(ptr) + } +} diff --git a/pjrt/arena.go b/pjrt/arena.go new file mode 100644 index 0000000..8959c75 --- /dev/null +++ b/pjrt/arena.go @@ -0,0 +1,117 @@ +package pjrt + +/* +#include +*/ +import "C" +import ( + "fmt" + "reflect" + "sync" + "unsafe" +) + +// arenaContainer implements a trivial arena object to accelerate allocations that will be used in CGO calls. +// +// The issue it is trying to solve is that individual CGO calls are slow, including C.malloc(). +// +// It pre-allocates the given size in bytes in C -- so it does not needs to be pinned when using CGO, and allow +// for fast sub-allocations. +// It can only be freed all at once. +// +// If you don't call Free at the end, it will leak the C allocated space. +// +// See newArena and arenaAlloc, and also arenaPool. +type arenaContainer struct { + buf []byte + size, current int +} + +const arenaDefaultSize = 2048 + +var arenaPool sync.Pool = sync.Pool{ + New: func() interface{} { + return newArena(arenaDefaultSize) + }, +} + +// getArenaFromPool gets an arena of the default size. +// Must be matched with a call returnArenaToPool when it's no longer used. +func getArenaFromPool() *arenaContainer { + return arenaPool.Get().(*arenaContainer) +} + +// returnArenaToPool returns an arena acquired with getArenaFromPool. +// It also resets the arena. +func returnArenaToPool(a *arenaContainer) { + a.Reset() + arenaPool.Put(a) +} + +// newArena creates a new Arena with the given fixed size. +// +// It provides fast sub-allocations, which can only be freed all at once. +// +// TODO: support memory-alignment. +// +// See arenaAlloc, arena.Free and arena.Reset. +func newArena(size int) *arenaContainer { + buf := cMallocArray[byte](size) + a := &arenaContainer{ + buf: unsafe.Slice(buf, size), + size: size, + } + return a +} + +const arenaAlignBytes = 8 + +// arenaAlloc allocates a type T from the arena. It panics if the arena run out of memory. +func arenaAlloc[T any](a *arenaContainer) (ptr *T) { + allocSize := cSizeOf[T]() + if a.current+int(allocSize) > a.size { + panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for %q", allocSize, reflect.TypeOf(ptr).Elem())) + } + ptr = (*T)(unsafe.Pointer(&a.buf[a.current])) + a.current += int(allocSize) + a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1) + return +} + +// arenaAllocSlice allocates an array of n elements of type T from the arena. +// +// It panics if the arena run out of memory. +func arenaAllocSlice[T any](a *arenaContainer, n int) (slice []T) { + allocSize := C.size_t(n) * cSizeOf[T]() + if a.current+int(allocSize) > a.size { + panic(fmt.Sprintf("Arena out of memory while allocating %d bytes for [%d]%s", allocSize, n, reflect.TypeOf(slice).Elem())) + } + ptr := (*T)(unsafe.Pointer(&a.buf[a.current])) + a.current += int(allocSize) + a.current = (a.current + arenaAlignBytes - 1) &^ (arenaAlignBytes - 1) + slice = unsafe.Slice(ptr, n) + return +} + +// Free invalidates all previous allocations of the arena and frees the C allocated area. +func (a *arenaContainer) Free() { + cFree(&a.buf[0]) + a.buf = nil + a.size = 0 + a.current = 0 +} + +// Reset invalidates all previous allocations with the arena, but does not free the C allocated area. +// This way the arena can be re-used. +func (a *arenaContainer) Reset() { + // Zero the values used. + if a.buf == nil || a.size == 0 { + a.current = 0 + return + } + if a.current > 0 { + clearSize := min(a.size, a.current) + C.memset(unsafe.Pointer(&a.buf[0]), 0, C.size_t(clearSize)) + } + a.current = 0 +} diff --git a/pjrt/arena_test.go b/pjrt/arena_test.go new file mode 100644 index 0000000..0c89c10 --- /dev/null +++ b/pjrt/arena_test.go @@ -0,0 +1,26 @@ +package pjrt + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestArena(t *testing.T) { + arena := newArena(1024) + for _ = range 2 { + require.Equal(t, 1024, arena.size) + require.Equal(t, 0, arena.current) + _ = arenaAlloc[int](arena) + require.Equal(t, 8, arena.current) + _ = arenaAlloc[int32](arena) + require.Equal(t, 16, arena.current) + + _ = arenaAllocSlice[byte](arena, 9) // Aligning, it will occupy 16 bytes total. + require.Equal(t, 32, arena.current) + + require.Panics(t, func() { _ = arenaAlloc[[512]int](arena) }, "Arena out of memory") + require.Panics(t, func() { _ = arenaAllocSlice[float64](arena, 512) }, "Arena out of memory") + arena.Reset() + } + arena.Free() +} diff --git a/pjrt/benchmarks_test.go b/pjrt/benchmarks_test.go new file mode 100644 index 0000000..1c29ad2 --- /dev/null +++ b/pjrt/benchmarks_test.go @@ -0,0 +1,394 @@ +package pjrt + +// To run benchmarks: (and fix to P-Cores if running on a Intel i9-12900K) +// go test -c . && taskset 0xFF ./pjrt.test -test.v -test.run=Bench -test.count=1 +// +// See results in https://docs.google.com/spreadsheets/d/1ikpJH6rVVHq8ES-IA8U4lkKH4XsTSpRyZewXwGTgits/edit?gid=1369069161#gid=1369069161 +import ( + "flag" + "fmt" + "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/xlabuilder" + benchmarks "github.com/janpfeifer/go-benchmarks" + "runtime" + "testing" + "time" + "unsafe" +) + +var ( + flagBenchDuration = flag.Duration("bench_duration", 1*time.Second, "Benchmark duration") + + // testShapes used during benchmarks executing small computation graphs. + testShapes = []xlabuilder.Shape{ + xlabuilder.MakeShape(dtypes.Float32, 1, 1), + xlabuilder.MakeShape(dtypes.Float32, 10, 10), + xlabuilder.MakeShape(dtypes.Float32, 100, 100), + xlabuilder.MakeShape(dtypes.Float32, 1000, 1000), + } +) + +// TestBenchCGO benchmarks a minimal CGO call. +func TestBenchCGO(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + const repeats = 1000 + repeatedCGO := func() { + for _ = range repeats { + dummyCGO(unsafe.Pointer(plugin.api)) + } + } + benchmarks.New(benchmarks.NamedFunction{"CGOCall", repeatedCGO}). + WithInnerRepeats(repeats). + Done() +} + +// Benchmark tests different methods to create temporary pointers to be passed to CGO. +func TestBenchArena(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + numAllocationsList := []int{1, 5, 10, 100} + allocations := make([]*int, 100) + testFns := make([]benchmarks.NamedFunction, 4*len(numAllocationsList)) + const repeats = 10 + idxFn := 0 + for _, allocType := range []string{"arena", "arenaPool", "malloc", "go+pinner"} { + for _, numAllocations := range numAllocationsList { + testFns[idxFn].Name = fmt.Sprintf("%s/%s/%d", t.Name(), allocType, numAllocations) + var fn func() + switch allocType { + case "arena": + fn = func() { + for _ = range repeats { + arena := newArena(1024) + for idx := range numAllocations { + allocations[idx] = arenaAlloc[int](arena) + } + dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) + arena.Free() + } + } + case "arenaPool": + fn = func() { + for _ = range repeats { + arena := getArenaFromPool() + for idx := range numAllocations { + allocations[idx] = arenaAlloc[int](arena) + } + dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) + returnArenaToPool(arena) + } + } + case "malloc": + fn = func() { + for _ = range repeats { + for idx := range numAllocations { + allocations[idx] = cMalloc[int]() + } + dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) + for idx := range numAllocations { + cFree(allocations[idx]) + } + } + } + case "go+pinner": + fn = func() { + for _ = range repeats { + var pinner runtime.Pinner + for idx := range numAllocations { + v := idx + allocations[idx] = &v + pinner.Pin(allocations[idx]) + } + dummyCGO(unsafe.Pointer(allocations[numAllocations-1])) + pinner.Unpin() + } + } + } + testFns[idxFn].Func = fn + idxFn++ + } + } + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + WithWarmUps(10). + Done() +} + +// TestBenchBufferFromHost benchmarks host->buffer transfer time. +func TestBenchBufferFromHost(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + const repeats = 10 + numShapes := len(testShapes) + inputData := make([][]float32, numShapes) + testFns := make([]benchmarks.NamedFunction, numShapes) + for shapeIdx, s := range testShapes { + inputData[shapeIdx] = make([]float32, s.Size()) + for i := 0; i < s.Size(); i++ { + inputData[shapeIdx][i] = float32(i) + } + testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) + testFns[shapeIdx].Func = func() { + for _ = range repeats { + x := inputData[shapeIdx] + s := testShapes[shapeIdx] + buf := must1(ArrayToBuffer(client, x, s.Dimensions...)) + must(buf.Destroy()) + } + } + } + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + WithDuration(*flagBenchDuration). + Done() +} + +// TestBenchBufferToHost benchmarks time to transfer data from device buffer to host. +// +// Results on CPU: +// +// Benchmarks: Median 5%-tile 99%-tile +// TestBenchBufferToHost/shape=(Float32)[1 1] 1.684µs 1.555µs 3.762µs +// TestBenchBufferToHost/shape=(Float32)[10 10] 1.651µs 1.534µs 3.699µs +// TestBenchBufferToHost/shape=(Float32)[100 100] 5.393µs 5.002µs 7.271µs +// TestBenchBufferToHost/shape=(Float32)[1000 1000] 131.826µs 131.498µs 139.316µs +func TestBenchBufferToHost(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + const repeats = 10 + numShapes := len(testShapes) + testFns := make([]benchmarks.NamedFunction, numShapes) + // Prepare output data (host destination array) and upload buffers to GPU + outputData := make([][]float32, numShapes) + buffers := make([]*Buffer, numShapes) + for shapeIdx, s := range testShapes { + outputData[shapeIdx] = make([]float32, s.Size()) + for i := 0; i < s.Size(); i++ { + outputData[shapeIdx][i] = float32(i) + } + buffers[shapeIdx] = must1(ArrayToBuffer(client, outputData[shapeIdx], s.Dimensions...)) + testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) + testFns[shapeIdx].Func = func() { + for _ = range repeats { + buf := buffers[shapeIdx] + rawData := unsafe.Slice((*byte)(unsafe.Pointer(&outputData[shapeIdx][0])), len(outputData[shapeIdx])*int(unsafe.Sizeof(outputData[shapeIdx][0]))) + must(buf.ToHost(rawData)) + } + } + } + defer func() { + for _, buf := range buffers { + must(buf.Destroy()) + } + }() + + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + Done() +} + +// BenchmarkAdd1Execution benchmarks the execution time for a minimal program. +func TestBenchAdd1Execution(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // Prepare input data, the uploaded buffers and the executables. + const repeats = 10 + numShapes := len(testShapes) + execs := make([]*LoadedExecutable, numShapes) + inputData := make([][]float32, numShapes) + buffers := make([]*Buffer, numShapes) + testFns := make([]benchmarks.NamedFunction, numShapes) + for shapeIdx, s := range testShapes { + inputData[shapeIdx] = make([]float32, s.Size()) + for i := 0; i < s.Size(); i++ { + inputData[shapeIdx][i] = float32(i) + } + buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) + + builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) + // f(x) = x + 1 + x := must1(xlabuilder.Parameter(builder, "x", 0, s)) + one := must1(xlabuilder.ScalarOne(builder, s.DType)) + add1 := must1(xlabuilder.Add(x, one)) + comp := must1(builder.Build(add1)) + execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) + testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) + testFns[shapeIdx].Func = func() { + for _ = range repeats { + buf := buffers[shapeIdx] + exec := execs[shapeIdx] + output := must1(exec.Execute(buf).Done())[0] + must(output.Destroy()) + } + } + } + defer func() { + // Clean up -- and don't wait for the GC. + for shapeIdx := range numShapes { + must(buffers[shapeIdx].Destroy()) + must(execs[shapeIdx].Destroy()) + } + }() + + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + WithWarmUps(100). + WithDuration(*flagBenchDuration). + Done() +} + +// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2. +// +// Runtimes for cpu: +// +// Benchmarks: Median 5%-tile 99%-tile +// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs +// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs +// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs +// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs +func TestBenchAdd1Div2Execution(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // Prepare input data, the uploaded buffers and the executables. + const repeats = 10 + numShapes := len(testShapes) + execs := make([]*LoadedExecutable, numShapes) + inputData := make([][]float32, numShapes) + buffers := make([]*Buffer, numShapes) + testFns := make([]benchmarks.NamedFunction, numShapes) + for shapeIdx, s := range testShapes { + inputData[shapeIdx] = make([]float32, s.Size()) + for i := 0; i < s.Size(); i++ { + inputData[shapeIdx][i] = float32(i) + } + buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) + + builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) + // f(x) = x + 1 + x := must1(xlabuilder.Parameter(builder, "x", 0, s)) + one := must1(xlabuilder.ScalarOne(builder, s.DType)) + add1 := must1(xlabuilder.Add(x, one)) + half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5)))) + div2 := must1(xlabuilder.Mul(add1, half)) + comp := must1(builder.Build(div2)) + execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) + testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) + testFns[shapeIdx].Func = func() { + for _ = range repeats { + buf := buffers[shapeIdx] + exec := execs[shapeIdx] + output := must1(exec.Execute(buf).Done())[0] + _ = output.Destroy() + } + } + } + defer func() { + // Clean up -- and don't wait for the GC. + for shapeIdx := range numShapes { + must(buffers[shapeIdx].Destroy()) + must(execs[shapeIdx].Destroy()) + } + }() + + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + WithWarmUps(100). + WithDuration(*flagBenchDuration). + Done() +} + +// TestBenchAdd1Div2Execution benchmarks the execution time for f(x) = (x+1)/2. +// +// Runtimes for cpu: +// +// Benchmarks: Median 5%-tile 99%-tile +// TestBenchAdd1Div2Execution/shape=(Float32)[1 1] 1.536µs 1.374µs 3.522µs +// TestBenchAdd1Div2Execution/shape=(Float32)[10 10] 1.536µs 1.333µs 3.449µs +// TestBenchAdd1Div2Execution/shape=(Float32)[100 100] 2.973µs 2.638µs 5.282µs +// TestBenchAdd1Div2Execution/shape=(Float32)[1000 1000] 38.513µs 36.434µs 86.827µs +func TestBenchMeanNormalizedExecution(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // Prepare input data, the uploaded buffers and the executables. + const repeats = 10 + numShapes := len(testShapes) + execs := make([]*LoadedExecutable, numShapes) + inputData := make([][]float32, numShapes) + buffers := make([]*Buffer, numShapes) + testFns := make([]benchmarks.NamedFunction, numShapes) + for shapeIdx, s := range testShapes { + inputData[shapeIdx] = make([]float32, s.Size()) + for i := 0; i < s.Size(); i++ { + inputData[shapeIdx][i] = float32(i) + } + buffers[shapeIdx] = must1(ArrayToBuffer(client, inputData[shapeIdx], s.Dimensions...)) + + builder := xlabuilder.New(fmt.Sprintf("Add1/%s", s)) + // f(x) = x + 1 + x := must1(xlabuilder.Parameter(builder, "x", 0, s)) + one := must1(xlabuilder.ScalarOne(builder, s.DType)) + add1 := must1(xlabuilder.Add(x, one)) + half := must1(xlabuilder.Constant(builder, xlabuilder.NewScalarLiteral(float32(0.5)))) + div2 := must1(xlabuilder.Mul(add1, half)) + mean := must1(xlabuilder.ReduceSum(div2)) + normalized := must1(xlabuilder.Sub(div2, mean)) + + comp := must1(builder.Build(normalized)) + execs[shapeIdx] = must1(client.Compile().WithComputation(comp).Done()) + testFns[shapeIdx].Name = fmt.Sprintf("%s/shape=%s", t.Name(), s) + testFns[shapeIdx].Func = func() { + for _ = range repeats { + buf := buffers[shapeIdx] + exec := execs[shapeIdx] + output := must1(exec.Execute(buf).Done())[0] + _ = output.Destroy() + } + } + } + defer func() { + // Clean up -- and don't wait for the GC. + for shapeIdx := range numShapes { + must(buffers[shapeIdx].Destroy()) + must(execs[shapeIdx].Destroy()) + } + }() + + benchmarks.New(testFns...). + WithInnerRepeats(repeats). + WithWarmUps(100). + WithDuration(*flagBenchDuration). + Done() +} diff --git a/pjrt/buffers.go b/pjrt/buffers.go index 39b798b..1490900 100644 --- a/pjrt/buffers.go +++ b/pjrt/buffers.go @@ -4,10 +4,6 @@ package pjrt #include "pjrt_c_api.h" #include "gen_api_calls.h" #include "gen_new_struct.h" - -PJRT_Buffer_MemoryLayout_Tiled *GetTiledLayoutUnion(PJRT_Buffer_MemoryLayout *layout) { - return &(layout->tiled); -} */ import "C" import ( @@ -17,6 +13,7 @@ import ( "reflect" "runtime" "slices" + "sync/atomic" "unsafe" ) @@ -24,9 +21,27 @@ import ( type Buffer struct { cBuffer *C.PJRT_Buffer client *Client + + // For "shared buffers", with a direct pointer to the underlying data. + // This is nil for non-shared-buffers. + isShared bool + sharedRawStorage unsafe.Pointer + + dimsSet bool // Whether dims is set. + dims []int + + dtypeSet bool // Whether dtype is set. + dtype dtypes.DType // DEBUG: creationStackTrace error } +var buffersAlive atomic.Int64 + +// BuffersAlive returns the number of PJRT Buffers in memory and currently tracked by gopjrt. +func BuffersAlive() int64 { + return buffersAlive.Load() +} + // newBuffer creates Buffer and registers it for freeing. func newBuffer(client *Client, cBuffer *C.PJRT_Buffer) *Buffer { b := &Buffer{ @@ -34,6 +49,8 @@ func newBuffer(client *Client, cBuffer *C.PJRT_Buffer) *Buffer { cBuffer: cBuffer, // DEBUG: creationStackTrace: errors.New("bufferCreation"), } + buffersAlive.Add(1) + runtime.SetFinalizer(b, func(b *Buffer) { /* DEBUG: if b != nil && cBuffer != nil && b.client != nil && b.client.plugin != nil { @@ -59,25 +76,41 @@ func (b *Buffer) Destroy() error { return nil } defer runtime.KeepAlive(b) - args := C.new_PJRT_Buffer_Destroy_Args() - defer cFree(args) + + arena := getArenaFromPool() + defer returnArenaToPool(arena) + args := arenaAlloc[C.PJRT_Buffer_Destroy_Args](arena) + args.struct_size = C.PJRT_Buffer_Destroy_Args_STRUCT_SIZE args.buffer = b.cBuffer err := toError(b.client.plugin, C.call_PJRT_Buffer_Destroy(b.client.plugin.api, args)) b.client = nil b.cBuffer = nil + buffersAlive.Add(-1) + + if b.sharedRawStorage != nil { + // Shared storage can only be freed after the buffer is destroyed. + AlignedFree(b.sharedRawStorage) + b.sharedRawStorage = nil + } return err } // Dimensions of the Buffer. +// Returned slice is owned by the buffer, to avoid creating a copy. Don't change it. func (b *Buffer) Dimensions() (dims []int, err error) { if b == nil || b.client == nil || b.client.plugin == nil || b.cBuffer == nil { err = errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?") return } + if b.dimsSet { + return b.dims, nil + } defer runtime.KeepAlive(b) - args := C.new_PJRT_Buffer_Dimensions_Args() - defer cFree(args) + arena := getArenaFromPool() + defer returnArenaToPool(arena) + args := arenaAlloc[C.PJRT_Buffer_Dimensions_Args](arena) + args.struct_size = C.PJRT_Buffer_Dimensions_Args_STRUCT_SIZE args.buffer = b.cBuffer err = toError(b.client.plugin, C.call_PJRT_Buffer_Dimensions(b.client.plugin.api, args)) if err != nil { @@ -86,8 +119,9 @@ func (b *Buffer) Dimensions() (dims []int, err error) { if args.num_dims == 0 { return // dims = nil } - dims = slices.Clone(cDataToSlice[int](unsafe.Pointer(args.dims), int(args.num_dims))) - return + b.dims = slices.Clone(cDataToSlice[int](unsafe.Pointer(args.dims), int(args.num_dims))) + b.dimsSet = true + return b.dims, nil } // DType of the Buffer (PJRT_Buffer_ElementType). @@ -98,15 +132,22 @@ func (b *Buffer) DType() (dtype dtypes.DType, err error) { return } defer runtime.KeepAlive(b) + if b.dtypeSet { + return b.dtype, nil + } - args := C.new_PJRT_Buffer_ElementType_Args() - defer cFree(args) + arena := getArenaFromPool() + defer returnArenaToPool(arena) + args := arenaAlloc[C.PJRT_Buffer_ElementType_Args](arena) + args.struct_size = C.PJRT_Buffer_ElementType_Args_STRUCT_SIZE args.buffer = b.cBuffer err = toError(b.client.plugin, C.call_PJRT_Buffer_ElementType(b.client.plugin.api, args)) if err != nil { return } dtype = dtypes.DType(args._type) + b.dtype = dtype + b.dtypeSet = true return } @@ -118,8 +159,10 @@ func (b *Buffer) Device() (device *Device, err error) { } defer runtime.KeepAlive(b) - args := C.new_PJRT_Buffer_Device_Args() - defer cFree(args) + arena := getArenaFromPool() + defer returnArenaToPool(arena) + args := arenaAlloc[C.PJRT_Buffer_Device_Args](arena) + args.struct_size = C.PJRT_Buffer_Device_Args_STRUCT_SIZE args.buffer = b.cBuffer err = toError(b.client.plugin, C.call_PJRT_Buffer_Device(b.client.plugin.api, args)) if err != nil { @@ -134,203 +177,6 @@ func (b *Buffer) Client() *Client { return b.client } -// BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is -// created with Client.BufferFromHost. -// -// The data to transfer from host can be set up with one of the following methods: -// -// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). -// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions). -// -// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum. -// -// At the end call BufferFromHostConfig.Done to actually initiate the transfer. -// -// TODO: Implement async transfers, arbitrary memory layout, etc. -type BufferFromHostConfig struct { - client *Client - data []byte - dtype dtypes.DType - dimensions []int - device *Device - - hostBufferSemantics PJRT_HostBufferSemantics - - // err stores the first error that happened during configuration. - // If it is not nil, it is immediately returned by the Done call. - err error -} - -// FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant) -// during the call. The parameters dtype and dimensions provide the shape of the array. -func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig { - if b.err != nil { - return b - } - b.data = data - b.dtype = dtype - b.dimensions = dimensions - return b -} - -// ToDevice configures which device to copy the host data to. -// -// If left un-configured, it will pick the first device returned by Client.AddressableDevices. -// -// You can also provide a device by their index in Client.AddressableDevices. -func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig { - if b.err != nil { - return b - } - if device == nil { - b.err = errors.New("BufferFromHost().ToDevice() given a nil device") - return b - } - addressable, err := device.IsAddressable() - if err != nil { - b.err = errors.WithMessagef(err, "BufferFromHost().ToDevice() failed to check whether device is addressable") - return b - } - if !addressable { - b.err = errors.New("BufferFromHost().ToDevice() given a non addressable device") - return b - } - b.device = device - return b -} - -// ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the -// list returned by Client.AddressableDevices. -// -// If left un-configured, it will pick the first device returned by Client.AddressableDevices. -// -// You can also provide a device by their index in Client.AddressableDevices. -func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig { - if b.err != nil { - return b - } - if deviceNum < 0 || deviceNum >= len(b.client.addressableDevices) { - b.err = errors.Errorf("BufferFromHost().ToDeviceNum() invalid deviceNum=%d, only %d addressable devices available", deviceNum, len(b.client.addressableDevices)) - return b - } - return b.ToDevice(b.client.addressableDevices[deviceNum]) -} - -// Done will use the configuration to start the transfer from host to device. -// It's synchronous: it awaits the transfer to finish and then returns. -func (b *BufferFromHostConfig) Done() (*Buffer, error) { - if b.err != nil { - // Return first error saved during configuration. - return nil, b.err - } - if len(b.data) == 0 { - return nil, errors.New("BufferFromHost requires one to configure the host data to transfer, none was configured.") - } - - // Makes sure program data is not moved around by the GC during the C/C++ call. - var pinner runtime.Pinner - dataPtr := unsafe.SliceData(b.data) - pinner.Pin(b) - pinner.Pin(dataPtr) - defer func() { - pinner.Unpin() - }() - - // Set default device. - if b.device == nil { - devices := b.client.AddressableDevices() - if len(devices) == 0 { - return nil, errors.New("BufferFromHost can't find addressable device to transfer to") - } - b.device = devices[0] - } - pinner.Pin(b.device) - - // Start the call. - args := C.new_PJRT_Client_BufferFromHostBuffer_Args() - defer cFree(args) - pinner.Pin(b.client) - args.client = b.client.client - args.data = unsafe.Pointer(dataPtr) - args._type = C.PJRT_Buffer_Type(b.dtype) - args.num_dims = C.size_t(len(b.dimensions)) - if len(b.dimensions) > 0 { - args.dims = cMallocArrayAndSet[C.int64_t](len(b.dimensions), func(i int) C.int64_t { - return C.int64_t(b.dimensions[i]) - }) - } - if args.dims != nil { - defer cFree(args.dims) - } - args.host_buffer_semantics = C.PJRT_HostBufferSemantics(b.hostBufferSemantics) - args.device = b.device.cDevice - pinner.Pin(b.client.plugin) - err := toError(b.client.plugin, C.call_PJRT_Client_BufferFromHostBuffer(b.client.plugin.api, args)) - if err != nil { - return nil, err - } - - // We get a PJRT_Buffer even before it's fully transferred. - buffer := newBuffer(b.client, args.buffer) - - // Await for transfer to finish. - doneEvent := newEvent(b.client.plugin, args.done_with_host_buffer) - defer func() { _ = doneEvent.Destroy() }() - err = doneEvent.Await() - if err != nil { - err2 := buffer.Destroy() - if err2 != nil { - klog.Errorf("Failed to destroy buffer that didn't finish to transfer from host: %+v", err2) - } - return nil, errors.WithMessage(err, "Failed to finish Client.BufferFromHost transfer") - } - return buffer, nil -} - -// FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying -// dimensions. -// The flat slice size must match the product of the dimension. -// If no dimensions are given, it is assumed to be a scalar, and flat should have length 1. -func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig { - if b.err != nil { - return b - } - // Checks dimensions. - expectedSize := 1 - for _, dim := range dimensions { - if dim <= 0 { - b.err = errors.Errorf("FromFlatDataWithDimensions cannot be given zero or negative dimensions, got %v", dimensions) - return b - } - expectedSize *= dim - } - - // Check the flat slice has the right shape. - flatV := reflect.ValueOf(flat) - if flatV.Kind() != reflect.Slice { - b.err = errors.Errorf("FromFlatDataWithDimensions was given a %s for flat, but it requires a slice", flatV.Kind()) - return b - } - if flatV.Len() != expectedSize { - b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions=%v) needs %d values to match dimensions, but got len(flat)=%d", dimensions, expectedSize, flatV.Len()) - return b - } - - // Check validity of the slice elements type. - element0 := flatV.Index(0) - element0Type := element0.Type() - dtype := dtypes.FromGoType(element0Type) - if dtype == dtypes.InvalidDType { - b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions%v) got flat=[]%s, expected a slice of a Go tyep that can be converted to a valid DType", dimensions, element0Type) - return b - } - - // Create slice of bytes and use b.FromRawData. - sizeBytes := uintptr(flatV.Len()) * element0Type.Size() - data := unsafe.Slice((*byte)(element0.Addr().UnsafePointer()), sizeBytes) - return b.FromRawData(data, dtype, dimensions) -} - // ScalarToRaw generates the raw values needed by BufferFromHostConfig.FromRawData to feed a simple scalar value. func ScalarToRaw[T dtypes.Supported](value T) ([]byte, dtypes.DType, []int) { dtype := dtypes.FromGenericsType[T]() @@ -346,8 +192,13 @@ func (b *Buffer) Size() (int, error) { return 0, errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?") } defer runtime.KeepAlive(b) - args := C.new_PJRT_Buffer_ToHostBuffer_Args() - defer cFree(args) + + arena := getArenaFromPool() + defer returnArenaToPool(arena) + + // It uses a PJRT_Buffer_ToHostBuffer_Args but it doesn't transfer, only inquire about size. + args := arenaAlloc[C.PJRT_Buffer_ToHostBuffer_Args](arena) + args.struct_size = C.PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE args.src = b.cBuffer args.dst = nil // Don't transfer, only inquire about size. err := toError(b.client.plugin, C.call_PJRT_Buffer_ToHostBuffer(b.client.plugin.api, args)) @@ -357,80 +208,8 @@ func (b *Buffer) Size() (int, error) { return int(args.dst_size), nil } -// ToHost transfers the contents of buffer stored on device to the host. -// The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned. -// -// This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to -// reorganize the layout. -func (b *Buffer) ToHost(dst []byte) error { - defer runtime.KeepAlive(b) - if b == nil || b.client.plugin == nil || b.cBuffer == nil { - // Already destroyed ? - return errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?") - } - - // Make sure garbage collection doesn't free or move data before they are used by C/C++. - var pinner runtime.Pinner - pinner.Pin(b) - pinner.Pin(unsafe.SliceData(dst)) - defer pinner.Unpin() - - // We'll need the buffer rank to set up the layout. - dims, err := b.Dimensions() - if err != nil { - return err - } - rank := len(dims) - - // Prepare arguments for the buffer-to-host call. - args := C.new_PJRT_Buffer_ToHostBuffer_Args() - defer cFree(args) - args.src = b.cBuffer - args.dst = unsafe.Pointer(unsafe.SliceData(dst)) - args.dst_size = C.size_t(len(dst)) - - // Layout argument. - layoutArgs := C.new_PJRT_Buffer_MemoryLayout() - defer cFree(layoutArgs) - args.host_layout = layoutArgs - - // Tiled layout must be present, even if there are no tiles (tileArgs.num_tiles==0). - layoutArgs._type = C.PJRT_Buffer_MemoryLayout_Type_Tiled - tileArgs := C.GetTiledLayoutUnion(layoutArgs) - - // Configure major-to-minor layout into tileArgs, if not scalar. - tileArgs.minor_to_major_size = C.size_t(rank) - if rank > 0 { - tileArgs.minor_to_major = cMallocArray[C.int64_t](rank) - minorToMajorMapping := unsafe.Slice(tileArgs.minor_to_major, rank) - defer cFree(tileArgs.minor_to_major) - for axisIdx := range len(dims) { - minorToMajorMapping[axisIdx] = C.int64_t(rank - axisIdx - 1) - } - } - - err = toError(b.client.plugin, C.call_PJRT_Buffer_ToHostBuffer(b.client.plugin.api, args)) - if err != nil { - return errors.WithMessage(err, "Failed to call PJRT_Buffer_ToHostBuffer to transfer the buffer to host") - } - - // Await for transfer to finish. - doneEvent := newEvent(b.client.plugin, args.event) - defer func() { _ = doneEvent.Destroy() }() - err = doneEvent.Await() - if err != nil { - return errors.WithMessage(err, "Failed to wait Buffer.ToHost transfer to finish") - } - return nil -} - // BufferToScalar is a generic function that transfer a Buffer back to host as a scalar of the given type. func BufferToScalar[T dtypes.Supported](b *Buffer) (value T, err error) { - var pinner runtime.Pinner - pinner.Pin(b) - pinner.Pin(&value) - defer pinner.Unpin() - dst := unsafe.Slice((*byte)(unsafe.Pointer(&value)), unsafe.Sizeof(value)) err = b.ToHost(dst) return @@ -441,11 +220,6 @@ func BufferToScalar[T dtypes.Supported](b *Buffer) (value T, err error) { // It is a shortcut to Client.BufferFromHost call with default parameters. // If you need more control where the value will be used you'll have to use Client.BufferFromHost instead. func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err error) { - var pinner runtime.Pinner - pinner.Pin(client) - pinner.Pin(&value) - defer pinner.Unpin() - dtype := dtypes.FromGenericsType[T]() src := unsafe.Slice((*byte)(unsafe.Pointer(&value)), unsafe.Sizeof(value)) return client.BufferFromHost().FromRawData(src, dtype, nil).Done() @@ -456,11 +230,6 @@ func ScalarToBuffer[T dtypes.Supported](client *Client, value T) (b *Buffer, err // It is a shortcut to Client.BufferFromHost call with default parameters. // If you need more control where the value will be used you'll have to use Client.BufferFromHost instead. func ScalarToBufferOnDeviceNum[T dtypes.Supported](client *Client, deviceNum int, value T) (b *Buffer, err error) { - var pinner runtime.Pinner - pinner.Pin(client) - pinner.Pin(&value) - defer pinner.Unpin() - dtype := dtypes.FromGenericsType[T]() src := unsafe.Slice((*byte)(unsafe.Pointer(&value)), unsafe.Sizeof(value)) return client.BufferFromHost().FromRawData(src, dtype, nil).ToDeviceNum(deviceNum).Done() @@ -503,12 +272,6 @@ func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensio } flatValues = make([]T, totalSize) flatValuesPtr := unsafe.SliceData(flatValues) - - var pinner runtime.Pinner - pinner.Pin(buffer) - pinner.Pin(flatValuesPtr) - defer pinner.Unpin() - dst := unsafe.Slice((*byte)( unsafe.Pointer(flatValuesPtr)), totalSize*int(unsafe.Sizeof(flatValues[0]))) @@ -520,7 +283,6 @@ func BufferToArray[T dtypes.Supported](buffer *Buffer) (flatValues []T, dimensio // // Similar to the generic BufferToArray[T], but this returns an anonymous typed (`any`) flat slice instead of using generics. func (b *Buffer) ToFlatDataAndDimensions() (flat any, dimensions []int, err error) { - defer runtime.KeepAlive(b) var dtype dtypes.DType dtype, err = b.DType() if err != nil { @@ -546,10 +308,6 @@ func (b *Buffer) ToFlatDataAndDimensions() (flat any, dimensions []int, err erro flatValuesPtr := element0.Addr().UnsafePointer() sizeBytes := uintptr(flatV.Len()) * element0.Type().Size() - var pinner runtime.Pinner - pinner.Pin(b) - pinner.Pin(flatValuesPtr) - defer pinner.Unpin() dst := unsafe.Slice((*byte)(flatValuesPtr), sizeBytes) err = b.ToHost(dst) flat = flatV.Interface() diff --git a/pjrt/buffers_from_host.go b/pjrt/buffers_from_host.go new file mode 100644 index 0000000..c4316cf --- /dev/null +++ b/pjrt/buffers_from_host.go @@ -0,0 +1,236 @@ +package pjrt + +/* +#include "pjrt_c_api.h" +#include "gen_api_calls.h" +#include "gen_new_struct.h" + +PJRT_Error* BufferFromHostAndWait(const PJRT_Api *api, PJRT_Client_BufferFromHostBuffer_Args *args) { + PJRT_Error* err = api->PJRT_Client_BufferFromHostBuffer(args); + if (err) { + return err; + } + PJRT_Event_Await_Args event_args = {0}; + event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + event_args.event = args->done_with_host_buffer; + err = api->PJRT_Event_Await(&event_args); + + PJRT_Event_Destroy_Args efree_args; + efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + efree_args.event = args->done_with_host_buffer; + api->PJRT_Event_Destroy(&efree_args); + + return err; +} + +PJRT_Error *dummy_error; +PJRT_Error *Dummy(void *api) { + if (api == NULL) { + return NULL; + } + return dummy_error; +} + +*/ +import "C" +import ( + "github.com/gomlx/gopjrt/dtypes" + "github.com/pkg/errors" + "reflect" + "runtime" + "slices" + "unsafe" +) + +// BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is +// created with Client.BufferFromHost. +// +// The data to transfer from host can be set up with one of the following methods: +// +// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). +// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions). +// +// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum. +// +// At the end call BufferFromHostConfig.Done to actually initiate the transfer. +// +// TODO: Implement async transfers, arbitrary memory layout, etc. +type BufferFromHostConfig struct { + client *Client + data []byte + dtype dtypes.DType + dimensions []int + device *Device + + hostBufferSemantics PJRT_HostBufferSemantics + + // err stores the first error that happened during configuration. + // If it is not nil, it is immediately returned by the Done call. + err error +} + +// FromRawData configures the data from host to copy: a pointer to bytes that must be kept alive (and constant) +// during the call. The parameters dtype and dimensions provide the shape of the array. +func (b *BufferFromHostConfig) FromRawData(data []byte, dtype dtypes.DType, dimensions []int) *BufferFromHostConfig { + if b.err != nil { + return b + } + b.data = data + b.dtype = dtype + b.dimensions = dimensions + return b +} + +// ToDevice configures which device to copy the host data to. +// +// If left un-configured, it will pick the first device returned by Client.AddressableDevices. +// +// You can also provide a device by their index in Client.AddressableDevices. +func (b *BufferFromHostConfig) ToDevice(device *Device) *BufferFromHostConfig { + if b.err != nil { + return b + } + if device == nil { + b.err = errors.New("BufferFromHost().ToDevice() given a nil device") + return b + } + addressable, err := device.IsAddressable() + if err != nil { + b.err = errors.WithMessagef(err, "BufferFromHost().ToDevice() failed to check whether device is addressable") + return b + } + if !addressable { + b.err = errors.New("BufferFromHost().ToDevice() given a non addressable device") + return b + } + b.device = device + return b +} + +// ToDeviceNum configures which device to copy the host data to, given a deviceNum pointing to the device in the +// list returned by Client.AddressableDevices. +// +// If left un-configured, it will pick the first device returned by Client.AddressableDevices. +// +// You can also provide a device by their index in Client.AddressableDevices. +func (b *BufferFromHostConfig) ToDeviceNum(deviceNum int) *BufferFromHostConfig { + if b.err != nil { + return b + } + if deviceNum < 0 || deviceNum >= len(b.client.addressableDevices) { + b.err = errors.Errorf("BufferFromHost().ToDeviceNum() invalid deviceNum=%d, only %d addressable devices available", deviceNum, len(b.client.addressableDevices)) + return b + } + return b.ToDevice(b.client.addressableDevices[deviceNum]) +} + +// FromFlatDataWithDimensions configures the data to come from a flat slice of the desired data type, and the underlying +// dimensions. +// The flat slice size must match the product of the dimension. +// If no dimensions are given, it is assumed to be a scalar, and flat should have length 1. +func (b *BufferFromHostConfig) FromFlatDataWithDimensions(flat any, dimensions []int) *BufferFromHostConfig { + if b.err != nil { + return b + } + // Checks dimensions. + expectedSize := 1 + for _, dim := range dimensions { + if dim <= 0 { + b.err = errors.Errorf("FromFlatDataWithDimensions cannot be given zero or negative dimensions, got %v", dimensions) + return b + } + expectedSize *= dim + } + + // Check the flat slice has the right shape. + flatV := reflect.ValueOf(flat) + if flatV.Kind() != reflect.Slice { + b.err = errors.Errorf("FromFlatDataWithDimensions was given a %s for flat, but it requires a slice", flatV.Kind()) + return b + } + if flatV.Len() != expectedSize { + b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions=%v) needs %d values to match dimensions, but got len(flat)=%d", dimensions, expectedSize, flatV.Len()) + return b + } + + // Check validity of the slice elements type. + element0 := flatV.Index(0) + element0Type := element0.Type() + dtype := dtypes.FromGoType(element0Type) + if dtype == dtypes.InvalidDType { + b.err = errors.Errorf("FromFlatDataWithDimensions(flat, dimensions%v) got flat=[]%s, expected a slice of a Go tyep that can be converted to a valid DType", dimensions, element0Type) + return b + } + + // Create slice of bytes and use b.FromRawData. + sizeBytes := uintptr(flatV.Len()) * element0Type.Size() + data := unsafe.Slice((*byte)(element0.Addr().UnsafePointer()), sizeBytes) + return b.FromRawData(data, dtype, dimensions) +} + +// Done will use the configuration to start the transfer from host to device. +// It's synchronous: it awaits the transfer to finish and then returns. +func (b *BufferFromHostConfig) Done() (*Buffer, error) { + if b.err != nil { + // Return first error saved during configuration. + return nil, b.err + } + if len(b.data) == 0 { + return nil, errors.New("BufferFromHost requires one to configure the host data to transfer, none was configured.") + } + defer runtime.KeepAlive(b) + + // Makes sure program data is not moved around by the GC during the C/C++ call. + var pinner runtime.Pinner + defer pinner.Unpin() + dataPtr := unsafe.SliceData(b.data) + pinner.Pin(dataPtr) + + // Set default device. + if b.device == nil { + devices := b.client.AddressableDevices() + if len(devices) == 0 { + return nil, errors.New("BufferFromHost can't find addressable device to transfer to") + } + b.device = devices[0] + } + + // Arena for memory allocations used by CGO. + arena := getArenaFromPool() + defer returnArenaToPool(arena) + + // Arguments to PJRT call. + var args *C.PJRT_Client_BufferFromHostBuffer_Args + args = arenaAlloc[C.PJRT_Client_BufferFromHostBuffer_Args](arena) + args.struct_size = C.PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE + args.client = b.client.client + args.data = unsafe.Pointer(dataPtr) + args._type = C.PJRT_Buffer_Type(b.dtype) + args.num_dims = C.size_t(len(b.dimensions)) + if len(b.dimensions) > 0 { + dims := arenaAllocSlice[C.int64_t](arena, len(b.dimensions)) + for ii, dim := range b.dimensions { + dims[ii] = C.int64_t(dim) + } + args.dims = unsafe.SliceData(dims) + } + args.host_buffer_semantics = C.PJRT_HostBufferSemantics(b.hostBufferSemantics) + args.device = b.device.cDevice + err := toError(b.client.plugin, C.BufferFromHostAndWait(b.client.plugin.api, args)) + if err != nil { + return nil, err + } + + buffer := newBuffer(b.client, args.buffer) + buffer.dims = slices.Clone(b.dimensions) + buffer.dimsSet = true + buffer.dtype = b.dtype + buffer.dtypeSet = true + return buffer, nil +} + +// dummyCGO calls a minimal C function and doesn't do anything. +// Here for the purpose of benchmarking CGO calls. +func dummyCGO(pointer unsafe.Pointer) { + _ = C.Dummy(pointer) +} diff --git a/pjrt/buffers_shared.go b/pjrt/buffers_shared.go new file mode 100644 index 0000000..e8eed1d --- /dev/null +++ b/pjrt/buffers_shared.go @@ -0,0 +1,201 @@ +package pjrt + +/* +#include "pjrt_c_api.h" +#include "gen_api_calls.h" +#include "gen_new_struct.h" + +extern void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg); + +// OnDeleteSharedBuffer is a no-op. +void OnDeleteSharedBuffer(void* device_buffer_ptr, void* user_arg) { + return; +} + +extern void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg); +void (*OnDeleteSharedBufferPtr)(void* device_buffer_ptr, void* user_arg) = &OnDeleteSharedBuffer; + +*/ +import "C" +import ( + "github.com/gomlx/gopjrt/dtypes" + "github.com/pkg/errors" + "reflect" + "slices" + "unsafe" +) + +// CreateViewOfDeviceBuffer creates a PJRT Buffer that is backed by storage on the same device given by the caller as flatData and shape. +// Consider using the simpler API NewSharedBuffer. +// +// Different PJRT may have different requirements on alignment, but for the CPU PJRT the library provide +// AlignedAlloc and AlignedFree, that can be used to allocate the aligned storage space. +// +// Example of how a typical usage where the same buffer is reused as input in loop: +// +// dtype := dtypes.Float32 +// dimensions := []int{batchSize, sequenceLength, 384} +// rawData := pjrt.AlignedAlloc(dtype.SizeForDimensions(dimensions...), pjrt.BufferAlignment) +// defer pjrt.AlignedFree(rawData) +// buf := client.CreateViewOfDeviceBuffer(rawData, dtype, dimensions) +// flat := unsafe.Slice((*float32)(storage), batchSize*sequenceLength*384) +// for _, batch := range batches { +// // ... set flat values +// // ... use buf as input when executing a PJRT program +// } +// +// If device is not given (at most one can be given), the first device available for the client is used. +// +// The naming comes from PJRT and is unfortunate, since it's the name from PJRT's perspective (PJRT view of a +// users device buffer). +// Probably, it should have been better named by "ShareDeviceBuffer" or something similar. +// +// This may not be implemented for all hardware (or all PJRT plugins). +// +// This can be useful to avoid the copy of values, by mutating directly in the memory shared with PJRT, to be used +// as input to a computation. +// +// See: dtypes.SizeForDimensions() to calculate the size for an arbitrary shape; AlignedAlloc, AlignedFree and +// BufferAlignment (a constant with the required alignment size) to allocate and free aligned storage. +func (c *Client) CreateViewOfDeviceBuffer(rawData unsafe.Pointer, dtype dtypes.DType, dimensions []int, device ...*Device) (*Buffer, error) { + var selectedDevice *Device + if len(device) > 1 { + return nil, errors.Errorf("only one device can be given to CreateViewOfDeviceBuffer, %d were given", len(device)) + } else if len(device) == 1 { + selectedDevice = device[0] + } else { + devices := c.AddressableDevices() + if len(devices) == 0 { + return nil, errors.New("CreateViewOfDeviceBuffer can't find addressable device to transfer to") + } + selectedDevice = devices[0] + } + + // Arena for memory allocations used by CGO. + arena := getArenaFromPool() + defer returnArenaToPool(arena) + + // Arguments to PJRT call. + var args *C.PJRT_Client_CreateViewOfDeviceBuffer_Args + args = arenaAlloc[C.PJRT_Client_CreateViewOfDeviceBuffer_Args](arena) + args.struct_size = C.PJRT_Client_CreateViewOfDeviceBuffer_Args_STRUCT_SIZE + args.client = c.client + args.device_buffer_ptr = rawData + args.element_type = C.PJRT_Buffer_Type(dtype) + args.num_dims = C.size_t(len(dimensions)) + if args.num_dims > 0 { + dims := arenaAllocSlice[C.int64_t](arena, int(args.num_dims)) + for ii, dim := range dimensions { + dims[ii] = C.int64_t(dim) + } + args.dims = unsafe.SliceData(dims) + } + args.device = selectedDevice.cDevice + args.on_delete_callback = (*[0]byte)(unsafe.Pointer(C.OnDeleteSharedBufferPtr)) + args.on_delete_callback_arg = nil + err := toError(c.plugin, C.call_PJRT_Client_CreateViewOfDeviceBuffer(c.plugin.api, args)) + if err != nil { + return nil, err + } + buffer := newBuffer(c, args.buffer) + buffer.isShared = true + buffer.dims = slices.Clone(dimensions) + buffer.dimsSet = true + buffer.dtype = dtype + buffer.dtypeSet = true + return buffer, nil +} + +// NewSharedBuffer returns a buffer that can be used for execution and share the underlying +// memory space with the host/local, which can be read and mutated directly. +// +// Shared buffers cannot be donated to executions. +// +// The buffer should not be mutated while it is used by an execution. +// +// When the buffer is finalized, the shared memory is also de-allocated. +// +// It returns a handle to the buffer and a slice of the corresponding data type pointing +// to the shared data. +func (c *Client) NewSharedBuffer(dtype dtypes.DType, dimensions []int, device ...*Device) (buffer *Buffer, flat any, err error) { + memorySize := uintptr(dtype.SizeForDimensions(dimensions...)) + rawStorage := AlignedAlloc(memorySize, BufferAlignment) + buffer, err = c.CreateViewOfDeviceBuffer(rawStorage, dtype, dimensions, device...) + if err != nil { + AlignedFree(rawStorage) + err = errors.WithMessagef(err, "NewSharedBuffer failed creating new buffer") + buffer = nil + return + } + buffer.sharedRawStorage = rawStorage + buffer.isShared = true + + goDType := dtype.GoType() + flat = reflect.SliceAt(dtype.GoType(), rawStorage, int(memorySize/goDType.Size())).Interface() + return +} + +// IsShared returns whether this buffer was created with Client.NewSharedBuffer. +// These buffers cannot be donated in execution. +func (b *Buffer) IsShared() bool { + return b.isShared +} + +// UnsafePointer returns platform-dependent address for the given buffer that is often but +// not guaranteed to be the physical/device address. +// Consider using the more convenient DirectAccess. +// +// Probably, this should only be used by CPU plugins. +// +// To be on the safe side, only use this if Client.HasSharedBuffers is true. +// It uses the undocumented PJRT_Buffer_UnsafePointer. +func (b *Buffer) UnsafePointer() (unsafe.Pointer, error) { + plugin := b.client.plugin + + // Arena for memory allocations used by CGO. + arena := getArenaFromPool() + defer returnArenaToPool(arena) + + // Arguments to PJRT call. + var args *C.PJRT_Buffer_UnsafePointer_Args + args = arenaAlloc[C.PJRT_Buffer_UnsafePointer_Args](arena) + args.struct_size = C.PJRT_Buffer_UnsafePointer_Args_STRUCT_SIZE + args.buffer = b.cBuffer + err := toError(plugin, C.call_PJRT_Buffer_UnsafePointer(plugin.api, args)) + if err != nil { + return nil, err + } + return unsafe.Pointer(uintptr(args.buffer_pointer)), nil +} + +// Data returns the flat slice pointing to the underlying storage data for the buffer. +// +// This is an undocumented feature of PJRT and likely only works for CPU platforms. +// The flat slice returned is only valid while the buffer is alive. +func (b *Buffer) Data() (flat any, err error) { + var rawStorage unsafe.Pointer + rawStorage, err = b.UnsafePointer() + if err != nil { + return nil, err + } + dims := b.dims + dtype := b.dtype + if !b.dimsSet { + dims, err = b.Dimensions() + if err != nil { + return nil, err + } + } + if !b.dtypeSet { + dtype, err = b.DType() + if err != nil { + return nil, err + } + } + + numElements := 1 + for _, dim := range dims { + numElements *= dim + } + return reflect.SliceAt(dtype.GoType(), rawStorage, numElements).Interface(), nil +} diff --git a/pjrt/buffers_test.go b/pjrt/buffers_test.go index c471831..3ebabc0 100644 --- a/pjrt/buffers_test.go +++ b/pjrt/buffers_test.go @@ -1,9 +1,12 @@ package pjrt import ( + "flag" "fmt" "github.com/gomlx/gopjrt/dtypes" + "github.com/gomlx/gopjrt/xlabuilder" "github.com/stretchr/testify/require" + "runtime" "testing" "unsafe" ) @@ -23,6 +26,7 @@ func testTransfersImpl[T interface { fmt.Printf("From %#v\n", input) buffer, err := ArrayToBuffer(client, input, 3, 1) require.NoError(t, err) + require.False(t, buffer.IsShared()) output, outputDims, err := BufferToArray[T](buffer) require.NoError(t, err) @@ -124,3 +128,171 @@ func TestBufferProperties(t *testing.T) { require.Equal(t, dtypes.Uint8, dtype) } } + +var flagForceSharedBuffer = flag.Bool( + "force_shared_buffer", false, "Force executing TestCreateViewOfDeviceBuffer and TestBufferUnsafePointer even if plugin is not \"cpu\".") + +func TestCreateViewOfDeviceBuffer(t *testing.T) { + if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + t.Skip("Skipping TestCreateViewOfDeviceBuffer because -plugin != \"cpu\". " + + "Set --force_create_view to force executing the test anyway") + } + + // Create plugin. + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // f(x) = x + 1 + dtype := dtypes.Float32 + shape := xlabuilder.MakeShape(dtype, 2, 3) + builder := xlabuilder.New("Add1") + x := must1(xlabuilder.Parameter(builder, "x", 0, shape)) + one := must1(xlabuilder.ScalarOne(builder, shape.DType)) + add1 := must1(xlabuilder.Add(x, one)) + comp := must1(builder.Build(add1)) + exec := must1(client.Compile().WithComputation(comp).Done()) + + // Input is created as a "Device Buffer View" + storage := AlignedAlloc(shape.Memory(), BufferAlignment) + defer AlignedFree(storage) + inputBuffer, err := client.CreateViewOfDeviceBuffer(storage, dtype, shape.Dimensions) + require.NoError(t, err) + defer func() { + err := inputBuffer.Destroy() + if err != nil { + t.Logf("Failed Buffer.Destroy(): %+v", err) + } + }() + + flatData := unsafe.Slice((*float32)(storage), shape.Size()) + for ii := range flatData { + flatData[ii] = float32(ii) + } + require.True(t, inputBuffer.IsShared()) + + results, err := exec.Execute(inputBuffer).DonateNone().Done() + require.NoError(t, err) + require.Len(t, results, 1) + + gotFlat, gotDims, err := BufferToArray[float32](results[0]) + require.NoError(t, err) + require.Equal(t, shape.Dimensions, gotDims) + require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, gotFlat) + + // Change the buffer directly, and see that we can reuse the buffer in PJRT, without the extra transfer. + flatData[1] = 11 + results, err = exec.Execute(inputBuffer).DonateNone().Done() + require.NoError(t, err) + require.Len(t, results, 1) + gotFlat, gotDims, err = BufferToArray[float32](results[0]) + require.NoError(t, err) + require.Equal(t, shape.Dimensions, gotDims) + require.Equal(t, []float32{1, 12, 3, 4, 5, 6}, gotFlat) + require.NoError(t, inputBuffer.Destroy()) +} + +func TestNewSharedBuffer(t *testing.T) { + if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " + + "Set --force_create_view to force executing the test anyway") + } + + // Create plugin. + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // f(x) = x + 1 + dtype := dtypes.Float32 + shape := xlabuilder.MakeShape(dtype, 2, 3) + builder := xlabuilder.New("Add1") + x := must1(xlabuilder.Parameter(builder, "x", 0, shape)) + one := must1(xlabuilder.ScalarOne(builder, shape.DType)) + add1 := must1(xlabuilder.Add(x, one)) + comp := must1(builder.Build(add1)) + exec := must1(client.Compile().WithComputation(comp).Done()) + + // Input is created as a "Device Buffer View" + inputBuffer, flatAny, err := client.NewSharedBuffer(dtype, shape.Dimensions) + require.NoError(t, err) + defer func() { + err := inputBuffer.Destroy() + if err != nil { + t.Logf("Failed to destroy shared buffer: %+v", err) + } + }() + + flatData := flatAny.([]float32) + for ii := range flatData { + flatData[ii] = float32(ii) + } + require.True(t, inputBuffer.IsShared()) + + results, err := exec.Execute(inputBuffer).DonateNone().Done() + require.NoError(t, err) + require.Len(t, results, 1) + + gotFlat, gotDims, err := BufferToArray[float32](results[0]) + require.NoError(t, err) + require.Equal(t, shape.Dimensions, gotDims) + require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, gotFlat) + + // Change the buffer directly, and see that we can reuse the buffer in PJRT, without the extra transfer. + flatData[1] = 11 + results, err = exec.Execute(inputBuffer).DonateNone().Done() + require.NoError(t, err) + require.Len(t, results, 1) + gotFlat, gotDims, err = BufferToArray[float32](results[0]) + require.NoError(t, err) + require.Equal(t, shape.Dimensions, gotDims) + require.Equal(t, []float32{1, 12, 3, 4, 5, 6}, gotFlat) + + require.NoError(t, inputBuffer.Destroy()) +} + +func TestBufferData(t *testing.T) { + if *flagPluginName != "cpu" && !*flagForceSharedBuffer { + t.Skip("Skipping TestNewSharedBuffer because -plugin != \"cpu\". " + + "Set --force_create_view to force executing the test anyway") + } + + // Create plugin. + plugin := must1(GetPlugin(*flagPluginName)) + client := must1(plugin.NewClient(nil)) + defer runtime.KeepAlive(client) + + // f(x) = x + 1 + dtype := dtypes.Float32 + shape := xlabuilder.MakeShape(dtype, 2, 3) + builder := xlabuilder.New("Add1") + x := must1(xlabuilder.Parameter(builder, "x", 0, shape)) + one := must1(xlabuilder.ScalarOne(builder, shape.DType)) + add1 := must1(xlabuilder.Add(x, one)) + comp := must1(builder.Build(add1)) + exec := must1(client.Compile().WithComputation(comp).Done()) + + // Input is created as a "Device Buffer View" + inputBuffer, flatAny, err := client.NewSharedBuffer(dtype, shape.Dimensions) + require.NoError(t, err) + defer func() { + err := inputBuffer.Destroy() + if err != nil { + t.Logf("Failed to destroy shared buffer: %+v", err) + } + }() + + flatData := flatAny.([]float32) + for ii := range flatData { + flatData[ii] = float32(ii) + } + require.True(t, inputBuffer.IsShared()) + + results, err := exec.Execute(inputBuffer).DonateNone().Done() + require.NoError(t, err) + require.Len(t, results, 1) + + flatOutput, err := results[0].Data() + require.NoError(t, err) + require.Equal(t, []float32{1, 2, 3, 4, 5, 6}, flatOutput.([]float32)) +} diff --git a/pjrt/buffers_to_host.go b/pjrt/buffers_to_host.go new file mode 100644 index 0000000..6869706 --- /dev/null +++ b/pjrt/buffers_to_host.go @@ -0,0 +1,83 @@ +package pjrt + +import ( + "github.com/pkg/errors" + "runtime" + "unsafe" +) + +/* +#include "pjrt_c_api.h" +#include "gen_api_calls.h" +#include "gen_new_struct.h" + +PJRT_Error* BufferToHost(const PJRT_Api *api, PJRT_Buffer *buffer, void *dst, int64_t dst_size, int rank) { + PJRT_Buffer_ToHostBuffer_Args args = {0}; + + args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE; + args.src = buffer; + args.dst = dst; + args.dst_size = dst_size; + PJRT_Buffer_MemoryLayout layout_args = {0}; + layout_args.struct_size = PJRT_Buffer_MemoryLayout_STRUCT_SIZE; + args.host_layout = &layout_args; + layout_args.type = PJRT_Buffer_MemoryLayout_Type_Tiled; + layout_args.tiled.minor_to_major_size = rank; + int64_t minor_to_major[rank > 0 ? rank : 1]; + if (rank > 0) { + for (int axisIdx = 0; axisIdx < rank; axisIdx++) { + minor_to_major[axisIdx] = rank - axisIdx - 1; + } + layout_args.tiled.minor_to_major = &minor_to_major[0]; + } + PJRT_Error* err = api->PJRT_Buffer_ToHostBuffer(&args); + if (err) { + return err; + } + + PJRT_Event_Await_Args event_args = {0}; + event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + event_args.event = args.event; + err = api->PJRT_Event_Await(&event_args); + PJRT_Event_Destroy_Args efree_args; + efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + efree_args.event = args.event; + api->PJRT_Event_Destroy(&efree_args); + + return err; +} + +*/ +import "C" + +// ToHost transfers the contents of buffer stored on device to the host. +// The space in dst has to hold enough space (see Buffer.Size) to hold the required data, or an error is returned. +// +// This always request a major-to-minor layout, the assumption of the layout in host memory -- TPUs are known to +// reorganize the layout. +func (b *Buffer) ToHost(dst []byte) error { + defer runtime.KeepAlive(b) + if b == nil || b.client.plugin == nil || b.cBuffer == nil { + // Already destroyed ? + return errors.New("Buffer is nil, or its plugin or wrapped C representation is nil -- has it been destroyed already?") + } + + // We'll need the buffer rank to set up the layout. + dims, err := b.Dimensions() + if err != nil { + return err + } + rank := len(dims) + + dstBytes := unsafe.Pointer(unsafe.SliceData(dst)) + var pinner runtime.Pinner + pinner.Pin(dstBytes) + defer pinner.Unpin() + + pErr := C.BufferToHost(b.client.plugin.api, b.cBuffer, dstBytes, C.int64_t(len(dst)), C.int(rank)) + err = toError(b.client.plugin, pErr) + if err != nil { + return errors.WithMessage(err, "Failed to call PJRT_Buffer_ToHostBuffer to transfer the buffer to host") + } + return nil +} diff --git a/pjrt/clients.go b/pjrt/clients.go index e509eaa..4689013 100644 --- a/pjrt/clients.go +++ b/pjrt/clients.go @@ -115,6 +115,7 @@ type Client struct { platform, platformVersion string processIndex int addressableDevices []*Device + allowBufferViews bool } // newClient is called by Plugin.NewClient to create a new PJRT_Client wrapper. @@ -276,6 +277,6 @@ func (c *Client) BufferFromHost() *BufferFromHostConfig { return &BufferFromHostConfig{ client: c, device: nil, - hostBufferSemantics: PJRT_HostBufferSemantics_kImmutableOnlyDuringCall, + hostBufferSemantics: PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes, } } diff --git a/pjrt/compile.go b/pjrt/compile.go index 19989f4..60fa5a3 100644 --- a/pjrt/compile.go +++ b/pjrt/compile.go @@ -3,14 +3,20 @@ package pjrt import ( "github.com/gomlx/gopjrt/cbuffer" "github.com/gomlx/gopjrt/protos/compile_options" + "github.com/gomlx/gopjrt/protos/xla" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "k8s.io/klog/v2" + "os" "runtime" "unsafe" ) +// EnvXlaDebugOptions is an environment variable that can be defined to set XLA DebugOptions proto when compiling +// a program. +const EnvXlaDebugOptions = "XLA_DEBUG_OPTIONS" + // CompileConfig is created with Client.Compile, and is a "builder pattern" to configure a compilation call. // // At a minimum one has to set the program to compile (use CompileConfig.WithHLO or CompileConfig.WithComputation). @@ -65,6 +71,27 @@ func newCompileConfig(client *Client) (cc *CompileConfig) { NumReplicas: 1, NumPartitions: 1, } + + debugOptionsStr := os.Getenv(EnvXlaDebugOptions) + if debugOptionsStr != "" { + debugOptions := &xla.DebugOptions{} + err := prototext.Unmarshal([]byte(debugOptionsStr), debugOptions) + if err != nil { + cc.err = errors.Wrapf(err, "Failed to parse xla.DebugOptions protobuf from $%s=%q", EnvXlaDebugOptions, debugOptionsStr) + return cc + } + // Print configuration. + textBytes, err := prototext.MarshalOptions{Multiline: true}.Marshal(debugOptions) + if err != nil { + klog.Infof("Failed to convert xla.DebugOptions proto to text: %v", err) + } else { + klog.Infof("This adds 10ms!! of time to the execution of the compiled graph!") + klog.Infof("%s=%s -> %s", EnvXlaDebugOptions, debugOptionsStr, textBytes) + } + + // Set parsed configuration. + cc.options.ExecutableBuildOptions.DebugOptions = debugOptions + } return cc } diff --git a/pjrt/gen_chelper.go b/pjrt/gen_chelper.go index d293a14..9d6e698 100644 --- a/pjrt/gen_chelper.go +++ b/pjrt/gen_chelper.go @@ -33,17 +33,15 @@ func cSizeOf[T any]() C.size_t { // It must be manually freed with cFree() by the user. func cMalloc[T any]() (ptr *T) { size := cSizeOf[T]() - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + cPtr := (*T)(C.calloc(1, size)) return cPtr } // cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero. // It must be manually freed with C.free() by the user. func cMallocArray[T any](n int) (ptr *T) { - size := cSizeOf[T]() * C.size_t(n) - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + size := cSizeOf[T]() + cPtr := (*T)(C.calloc(C.size_t(n), size)) return cPtr } diff --git a/pjrt/loadedexecutables.go b/pjrt/loadedexecutables.go index 6dcafe6..d566404 100644 --- a/pjrt/loadedexecutables.go +++ b/pjrt/loadedexecutables.go @@ -4,6 +4,33 @@ package pjrt #include "pjrt_c_api.h" #include "gen_api_calls.h" #include "gen_new_struct.h" + +PJRT_Error* ExecuteAndWait(const PJRT_Api *api, PJRT_LoadedExecutable_Execute_Args* args) { + PJRT_Error *err = api->PJRT_LoadedExecutable_Execute(args); + if (err) { + return err; + } + + if (args->device_complete_events) { + // Wait for devices to complete executions. + for (int ii = 0; ii < args->num_devices; ii++) { + PJRT_Event_Await_Args event_args = {0}; + event_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + event_args.event = args->device_complete_events[ii]; + err = api->PJRT_Event_Await(&event_args); + PJRT_Event_Destroy_Args efree_args; + efree_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE; + efree_args.event = args->device_complete_events[ii]; + api->PJRT_Event_Destroy(&efree_args); + if (err) { + return err; + } + } + } + return NULL; +} + + */ import "C" import ( @@ -11,6 +38,7 @@ import ( "k8s.io/klog/v2" "runtime" "slices" + "sync/atomic" "unsafe" ) @@ -32,6 +60,13 @@ type LoadedExecutable struct { NumOutputs int } +var numLoadedExecutables atomic.Int64 + +// LoadedExecutablesAlive returns a count of the numbers of LoadedExecutables currently in memory and tracked by gopjrt. +func LoadedExecutablesAlive() int64 { + return numLoadedExecutables.Load() +} + // newLoadedExecutable creates LoadedExecutable and registers it for freeing. func newLoadedExecutable(plugin *Plugin, client *Client, cLoadedExecutable *C.PJRT_LoadedExecutable) (*LoadedExecutable, error) { e := &LoadedExecutable{ @@ -39,6 +74,7 @@ func newLoadedExecutable(plugin *Plugin, client *Client, cLoadedExecutable *C.PJ client: client, cLoadedExecutable: cLoadedExecutable, } + numLoadedExecutables.Add(1) runtime.SetFinalizer(e, func(e *LoadedExecutable) { e.destroyOrLog() }) // Gather information about executable: @@ -78,6 +114,8 @@ func (e *LoadedExecutable) Destroy() error { err := toError(e.plugin, C.call_PJRT_LoadedExecutable_Destroy(e.plugin.api, args)) e.plugin = nil e.cLoadedExecutable = nil + + numLoadedExecutables.Add(-1) return err } @@ -277,82 +315,98 @@ func (c *ExecutionConfig) Done() ([]*Buffer, error) { c.devices = []*Device{devices[0]} } + // Dimensions of inputs/outputs. + numInputs := len(c.inputs) + numOutputs := e.NumOutputs + + // Allocations that will be used by CGO. + // Except if the number of inputs/outputs is very large, used the default arena size. + var arena *arenaContainer + minSize := (numInputs+numOutputs)*3*8 /*pointer size*/ + 1024 + if minSize > arenaDefaultSize { + arena = newArena(arenaDefaultSize + minSize) + defer arena.Free() + } else { + arena = getArenaFromPool() + defer returnArenaToPool(arena) + } + // Create arguments structures for call to Execute. - args := C.new_PJRT_LoadedExecutable_Execute_Args() - defer cFree(args) + var args *C.PJRT_LoadedExecutable_Execute_Args + args = arenaAlloc[C.PJRT_LoadedExecutable_Execute_Args](arena) + args.struct_size = C.PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE args.executable = e.cLoadedExecutable - options := C.new_PJRT_ExecuteOptions() // Like more args that for some reason(?) go on a separate struct. - defer cFree(options) + + var options *C.PJRT_ExecuteOptions + options = arenaAlloc[C.PJRT_ExecuteOptions](arena) // Extra args that for some reason(?) go on a separate struct. + options.struct_size = C.PJRT_ExecuteOptions_STRUCT_SIZE args.options = options // Configure (non-)donatable inputs. if len(c.nonDonatableInputs) > 0 { options.num_non_donatable_input_indices = C.size_t(len(c.nonDonatableInputs)) - options.non_donatable_input_indices = cMallocArrayAndSet[C.int64_t](len(c.nonDonatableInputs), func(ii int) C.int64_t { - return C.int64_t(c.nonDonatableInputs[ii]) - }) - defer cFree(options.non_donatable_input_indices) + nonDonatableIndices := arenaAllocSlice[C.int64_t](arena, len(c.nonDonatableInputs)) + for ii := range nonDonatableIndices { + nonDonatableIndices[ii] = C.int64_t(c.nonDonatableInputs[ii]) + } + options.non_donatable_input_indices = &nonDonatableIndices[0] } numDevices := 1 args.num_devices = C.size_t(numDevices) args.execute_device = c.devices[0].cDevice - args.num_args = C.size_t(len(c.inputs)) - args.argument_lists = allocatePerDeviceBufferList(numDevices, c.inputs) - defer freePerDeviceBufferList(args.argument_lists, numDevices) - args.output_lists = allocatePerDeviceBufferList(numDevices, make([]*Buffer, e.NumOutputs)) - defer freePerDeviceBufferList(args.output_lists, numDevices) + + args.num_args = C.size_t(numInputs) + if args.num_args > 0 { + args.argument_lists = allocatePerDeviceBufferListWithArena(arena, numDevices, numInputs, c.inputs) + } + + // For some reason the line below doesn't work. I think something is wrong with PJRT ... but I'm not sure. + if numOutputs > 0 { + args.output_lists = allocatePerDeviceBufferListWithArena(arena, numDevices, numOutputs, nil) + } + + // Create events to wait for the end of execution: leaving this as NULL is allowed, but what happens then + // (does it wait or not, and then what?) is not documented in PJRT. + perDeviceEvents := arenaAllocSlice[*C.PJRT_Event](arena, numDevices) + args.device_complete_events = (**C.PJRT_Event)(unsafe.SliceData(perDeviceEvents)) //args.device_complete_events = cMallocArray[*C.PJRT_Event](numDevices) //defer cFree(args.device_complete_events) - err := toError(e.plugin, C.call_PJRT_LoadedExecutable_Execute(e.plugin.api, args)) + err := toError(e.plugin, C.ExecuteAndWait(e.plugin.api, args)) if err != nil { return nil, err } - perDevice := gatherPerDeviceBufferList(e.client, args.output_lists, numDevices, e.NumOutputs) - return perDevice[0], nil + // We only support one device for now, so we return the results from the first device. + outputs := make([]*Buffer, numOutputs) + outputBuffers := unsafe.Slice(*args.output_lists, numOutputs) + for ii := range outputs { + outputs[ii] = newBuffer(e.client, outputBuffers[ii]) + } + return outputs, nil } // Allocate [numDevices][numBuffers]*Buffer C 2D-array to be used by PJRT C API, with the given Buffer pointers. -func allocatePerDeviceBufferList(numDevices int, buffers []*Buffer) ***C.PJRT_Buffer { +func allocatePerDeviceBufferListWithArena(arena *arenaContainer, numDevices int, numBuffers int, buffers []*Buffer) ***C.PJRT_Buffer { // Top level: - perDevice := make([]**C.PJRT_Buffer, numDevices) + perDevice := arenaAllocSlice[**C.PJRT_Buffer](arena, numDevices) for deviceIdx := range perDevice { - perDevice[deviceIdx] = cMallocArrayAndSet[*C.PJRT_Buffer](len(buffers), func(idxBuffer int) *C.PJRT_Buffer { - if buffers[idxBuffer] == nil { - // No buffer given for structure. - return nil - } - if buffers[idxBuffer].cBuffer == nil { - // Buffer given, but it's cBuffer is nil -> probably it has already been destroyed. - panicf("buffers[%d].cBuffer is nil, has it already been destroyed!?", idxBuffer) + deviceBuffers := arenaAllocSlice[*C.PJRT_Buffer](arena, numBuffers) + perDevice[deviceIdx] = &deviceBuffers[0] + if buffers != nil { + for bufferIdx := range deviceBuffers { + if buffers[bufferIdx] == nil { + deviceBuffers[bufferIdx] = nil + continue + } + if buffers[bufferIdx].cBuffer == nil { + // Buffer given, but it's cBuffer is nil -> probably it has already been destroyed. + panicf("buffers[%d].cBuffer is nil, has it already been destroyed!?", bufferIdx) + } + deviceBuffers[bufferIdx] = buffers[bufferIdx].cBuffer } - return buffers[idxBuffer].cBuffer - }) - } - return cMallocArrayFromSlice(perDevice) -} - -// freePerDeviceBufferList frees the intermediary array pointers, but it doesn't touch the buffers themselves. -func freePerDeviceBufferList(data ***C.PJRT_Buffer, numDevices int) { - perDevice := cDataToSlice[**C.PJRT_Buffer](unsafe.Pointer(data), numDevices) - for _, list := range perDevice { - cFree(list) - } - cFree(data) -} - -// gatherPerDeviceBufferList returns a [numDevices][numBuffers]*Buffer given the C 2D array. -func gatherPerDeviceBufferList(client *Client, data ***C.PJRT_Buffer, numDevices, numBuffers int) [][]*Buffer { - perDevice := make([][]*Buffer, numDevices) - cPerDevice := cDataToSlice[**C.PJRT_Buffer](unsafe.Pointer(data), numDevices) - for ii, cBufferListPtr := range cPerDevice { - perDevice[ii] = make([]*Buffer, numBuffers) - cBuffers := cDataToSlice[*C.PJRT_Buffer](unsafe.Pointer(cBufferListPtr), numBuffers) - for jj, cBuffer := range cBuffers { - perDevice[ii][jj] = newBuffer(client, cBuffer) } } - return perDevice + return &perDevice[0] } diff --git a/pjrt/minimal_test.go b/pjrt/minimal_test.go index 041b5f2..c1ec490 100644 --- a/pjrt/minimal_test.go +++ b/pjrt/minimal_test.go @@ -4,7 +4,6 @@ import ( "flag" "fmt" "github.com/gomlx/gopjrt/protos/hlo" - "github.com/janpfeifer/must" "github.com/stretchr/testify/require" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" @@ -52,32 +51,32 @@ func TestMinimal(t *testing.T) { var hloModule hlo.HloModuleProto if *flagLoadHLO != "" { fmt.Printf("Loading HLO program from %s...\n", *flagLoadHLO) - hloSerialized = must.M1(os.ReadFile(*flagLoadHLO)) - must.M(proto.Unmarshal(hloSerialized, &hloModule)) + hloSerialized = must1(os.ReadFile(*flagLoadHLO)) + must(proto.Unmarshal(hloSerialized, &hloModule)) } else { // Serialize HLO program from hloText: - must.M(prototext.Unmarshal([]byte(hloText), &hloModule)) - hloSerialized = must.M1(proto.Marshal(&hloModule)) + must(prototext.Unmarshal([]byte(hloText), &hloModule)) + hloSerialized = must1(proto.Marshal(&hloModule)) } fmt.Printf("HLO Program:\n%s\n\n", hloModule.String()) // `dlopen` PJRT plugin. - plugin := must.M1(GetPlugin(*flagPluginName)) + plugin := must1(GetPlugin(*flagPluginName)) defer runtime.KeepAlive(plugin) fmt.Printf("PJRT: %s\n", plugin.String()) // Create client. - client := must.M1(plugin.NewClient(nil)) + client := must1(plugin.NewClient(nil)) defer runtime.KeepAlive(client) devices := client.AddressableDevices() for ii, dev := range devices { - desc := must.M1(dev.GetDescription()) + desc := must1(dev.GetDescription()) fmt.Printf("\tDevice #%d: %s\n", ii, desc.DebugString()) } // Compile. defer runtime.KeepAlive(hloSerialized) - loadedExec := must.M1(client.Compile().WithHLO(hloSerialized).Done()) + loadedExec := must1(client.Compile().WithHLO(hloSerialized).Done()) defer runtime.KeepAlive(loadedExec) fmt.Printf("\t- program compiled successfully.\n") diff --git a/pjrt/pjrt_c_api.h b/pjrt/pjrt_c_api.h index 75b83df..b2a81c4 100644 --- a/pjrt/pjrt_c_api.h +++ b/pjrt/pjrt_c_api.h @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 57 +#define PJRT_API_MINOR 58 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -214,9 +214,9 @@ struct PJRT_Plugin_Attributes_Args { }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes); -// Returns an array of plugin attributes which are key-value pairs. One example -// attribute is the minimum supported StableHLO version. -// TODO(b/280349977): standardize the list of attributes. +// Returns an array of plugin attributes which are key-value pairs. Common keys +// include `xla_version`, `stablehlo_current_version`, and +// `stablehlo_minimum_version`. typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); // ---------------------------------- Events ----------------------------------- diff --git a/pjrt/pjrt_test.go b/pjrt/pjrt_test.go index a1937e8..10faef2 100644 --- a/pjrt/pjrt_test.go +++ b/pjrt/pjrt_test.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "github.com/gomlx/gopjrt/dtypes" + "github.com/pkg/errors" "github.com/stretchr/testify/require" "k8s.io/klog/v2" "testing" @@ -32,6 +33,17 @@ func (e errTester[T]) Test(t *testing.T) T { return e.value } +func must(err error) { + if err != nil { + panicf("Failed: %+v", errors.WithStack(err)) + } +} + +func must1[T any](t T, err error) T { + must(err) + return t +} + // getPJRTClient loads a PJRT plugin and create a client to run tests on. // It exits the test if anything goes wrong. func getPJRTClient(t *testing.T) *Client { diff --git a/xlabuilder/gen_chelper.go b/xlabuilder/gen_chelper.go index 890ad52..72a62ac 100644 --- a/xlabuilder/gen_chelper.go +++ b/xlabuilder/gen_chelper.go @@ -33,17 +33,15 @@ func cSizeOf[T any]() C.size_t { // It must be manually freed with cFree() by the user. func cMalloc[T any]() (ptr *T) { size := cSizeOf[T]() - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + cPtr := (*T)(C.calloc(1, size)) return cPtr } // cMallocArray allocates space to hold n copies of T in the C heap and initializes it to zero. // It must be manually freed with C.free() by the user. func cMallocArray[T any](n int) (ptr *T) { - size := cSizeOf[T]() * C.size_t(n) - cPtr := (*T)(C.malloc(size)) - C.memset(unsafe.Pointer(cPtr), 0, size) + size := cSizeOf[T]() + cPtr := (*T)(C.calloc(C.size_t(n), size)) return cPtr }