From eecdedf4290a2ff0cf99b18c20743ce3bf7623aa Mon Sep 17 00:00:00 2001 From: kyule7 Date: Wed, 1 Mar 2023 20:34:38 -0800 Subject: [PATCH 1/2] Add support for converting DLTensor for NPU device type --- onnxruntime/core/dlpack/dlpack_converter.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index ace85615ee912..68c6aef6a6d72 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -86,6 +86,8 @@ OrtDevice GetOrtDevice(const DLDevice& device) { case DLDeviceType::kDLCUDA: case DLDeviceType::kDLROCM: return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); + case DLDeviceType::kDLExtDev: + return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); default: ORT_THROW("Unsupported device type"); } @@ -149,6 +151,10 @@ const char* GetOrtDeviceName(const OrtDevice& device) { return CPU; case OrtDevice::GPU: return CUDA; + case OrtDevice::FPGA: + return "fpga"; + case OrtDevice::NPU: + return "npu"; default: ORT_THROW("Unknown device type: ", device.Type()); } @@ -194,6 +200,9 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) { device.device_type = DLDeviceType::kDLCUDA; #endif break; + case OrtDevice::FPGA: + case OrtDevice::NPU: + device.device_type = DLDeviceType::kDLExtDev; default: ORT_THROW("Cannot pack tensors on this device."); } From 58c51e40e96d9bcd396426b6604ec3976a81da4f Mon Sep 17 00:00:00 2001 From: kyule7 Date: Thu, 22 Feb 2024 15:52:48 -0800 Subject: [PATCH 2/2] Updated DL device name for MAIA --- onnxruntime/core/dlpack/dlpack_converter.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index e1cf6fbbe4cd2..864b656140d23 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -86,7 +86,7 @@ OrtDevice GetOrtDevice(const DLDevice& device) { case DLDeviceType::kDLCUDA: case DLDeviceType::kDLROCM: return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); - case DLDeviceType::kDLExtDev: + case DLDeviceType::kDLMAIA: return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); default: ORT_THROW("Unsupported device type"); @@ -202,7 +202,7 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) { break; case OrtDevice::FPGA: case OrtDevice::NPU: - device.device_type = DLDeviceType::kDLExtDev; + device.device_type = DLDeviceType::kDLMAIA; default: ORT_THROW("Cannot pack tensors on this device."); }