diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index 5307bb32de7d0..864b656140d23 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -86,6 +86,8 @@ OrtDevice GetOrtDevice(const DLDevice& device) { case DLDeviceType::kDLCUDA: case DLDeviceType::kDLROCM: return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); + case DLDeviceType::kDLMAIA: + return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); default: ORT_THROW("Unsupported device type"); } @@ -149,6 +151,10 @@ const char* GetOrtDeviceName(const OrtDevice& device) { return CPU; case OrtDevice::GPU: return CUDA; + case OrtDevice::FPGA: + return "fpga"; + case OrtDevice::NPU: + return "npu"; default: ORT_THROW("Unknown device type: ", device.Type()); } @@ -194,6 +200,9 @@ DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) { device.device_type = DLDeviceType::kDLCUDA; #endif break; + case OrtDevice::FPGA: + case OrtDevice::NPU: + device.device_type = DLDeviceType::kDLMAIA; default: ORT_THROW("Cannot pack tensors on this device."); }