diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bad7a6ca5..c9c24188be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,7 +150,7 @@ if(UR_FORMAT_CPP_STYLE) message(STATUS "Found clang-format: ${CLANG_FORMAT} (version: ${CLANG_FORMAT_VERSION})") set(CLANG_FORMAT_REQUIRED "15.0") - if(NOT (CLANG_FORMAT_VERSION VERSION_EQUAL CLANG_FORMAT_REQUIRED)) + if(NOT (CLANG_FORMAT_VERSION VERSION_GREATER_EQUAL CLANG_FORMAT_REQUIRED)) message(FATAL_ERROR "required clang-format version is ${CLANG_FORMAT_REQUIRED}") endif() else() diff --git a/include/ur.py b/include/ur.py index 09b7955e07..085cff4e6c 100644 --- a/include/ur.py +++ b/include/ur.py @@ -206,6 +206,7 @@ class ur_function_v(IntEnum): COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 213 ## Enumerator for ::urCommandBufferAppendUSMAdviseExp ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 214 ## Enumerator for ::urEnqueueCooperativeKernelLaunchExp KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 215## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp + DEVICE_GET_SELECTED = 300 ## Enumerator for ::urDeviceGetSelected class ur_function_t(c_int): def __str__(self): diff --git a/include/ur_api.h b/include/ur_api.h index 63f5fc8083..a751673942 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -215,6 +215,7 @@ typedef enum ur_function_t { UR_FUNCTION_COMMAND_BUFFER_APPEND_USM_ADVISE_EXP = 213, ///< Enumerator for ::urCommandBufferAppendUSMAdviseExp UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 214, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 215, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp + UR_FUNCTION_DEVICE_GET_SELECTED = 300, ///< Enumerator for ::urDeviceGetSelected /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -1376,6 +1377,46 @@ urDeviceGet( ///< pNumDevices will be updated with the total number of devices available. ); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +UR_APIEXPORT ur_result_t UR_APICALL +urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t *phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. +); + /////////////////////////////////////////////////////////////////////////////// /// @brief Supported device info typedef enum ur_device_info_t { @@ -10814,6 +10855,18 @@ typedef struct ur_device_get_params_t { uint32_t **ppNumDevices; } ur_device_get_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urDeviceGetSelected +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_device_get_selected_params_t { + ur_platform_handle_t *phPlatform; + ur_device_type_t *pDeviceType; + uint32_t *pNumEntries; + ur_device_handle_t **pphDevices; + uint32_t **ppNumDevices; +} ur_device_get_selected_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urDeviceGetInfo /// @details Each entry is a pointer to the parameter passed to the function; diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 9a0ce9e657..730a90134b 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -879,6 +879,9 @@ inline std::ostream &operator<<(std::ostream &os, ur_function_t value) { case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP: os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP"; break; + case UR_FUNCTION_DEVICE_GET_SELECTED: + os << "UR_FUNCTION_DEVICE_GET_SELECTED"; + break; default: os << "unknown enumerator"; break; @@ -15623,6 +15626,46 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_device_get_selected_params_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_device_get_selected_params_t *params) { + + os << ".hPlatform = "; + + ur::details::printPtr(os, *(params->phPlatform)); + + os << ", "; + os << ".DeviceType = "; + + os << *(params->pDeviceType); + + os << ", "; + os << ".NumEntries = "; + + os << *(params->pNumEntries); + + os << ", "; + os << ".phDevices = {"; + for (size_t i = 0; + *(params->pphDevices) != NULL && i < *params->pNumEntries; ++i) { + if (i != 0) { + os << ", "; + } + + ur::details::printPtr(os, (*(params->pphDevices))[i]); + } + os << "}"; + + os << ", "; + os << ".pNumDevices = "; + + ur::details::printPtr(os, *(params->ppNumDevices)); + + return os; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_device_get_info_params_t type /// @returns @@ -16406,6 +16449,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_ case UR_FUNCTION_DEVICE_GET: { os << (const struct ur_device_get_params_t *)params; } break; + case UR_FUNCTION_DEVICE_GET_SELECTED: { + os << (const struct ur_device_get_selected_params_t *)params; + } break; case UR_FUNCTION_DEVICE_GET_INFO: { os << (const struct ur_device_get_info_params_t *)params; } break; diff --git a/scripts/core/device.yml b/scripts/core/device.yml index 3999fa70f2..0e747041fc 100644 --- a/scripts/core/device.yml +++ b/scripts/core/device.yml @@ -150,6 +150,45 @@ returns: - "`NumEntries > 0 && phDevices == NULL`" - $X_RESULT_ERROR_INVALID_VALUE --- #-------------------------------------------------------------------------- +type: function +desc: "Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR" +class: $xDevice +loader_only: True +name: GetSelected +decl: static +ordinal: "0" +details: + - "Multiple calls to this function will return identical device handles, in the same order." + - "The number and order of handles returned from this function will be affected by environment variables that filter or select which devices are exposed through this API." + - "A reference is taken for each returned device and must be released with a subsequent call to $xDeviceRelease." + - "The application may call this function from simultaneous threads, the implementation must be thread-safe." +params: + - type: $x_platform_handle_t + name: hPlatform + desc: "[in] handle of the platform instance" + - type: "$x_device_type_t" + name: DeviceType + desc: | + [in] the type of the devices. + - type: "uint32_t" + name: NumEntries + desc: | + [in] the number of devices to be added to phDevices. + If phDevices in not NULL then NumEntries should be greater than zero, otherwise $X_RESULT_ERROR_INVALID_VALUE, + will be returned. + - type: "$x_device_handle_t*" + name: phDevices + desc: | + [out][optional][range(0, NumEntries)] array of handle of devices. + If NumEntries is less than the number of devices available, then only that number of devices will be retrieved. + - type: "uint32_t*" + name: pNumDevices + desc: | + [out][optional] pointer to the number of devices. + pNumDevices will be updated with the total number of selected devices available for the given platform. +returns: + - $X_RESULT_ERROR_INVALID_VALUE +--- #-------------------------------------------------------------------------- type: enum desc: "Supported device info" class: $xDevice diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index deb5ee9604..1d2acf7dd7 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -559,6 +559,9 @@ etors: - name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCountExp value: '215' +- name: DEVICE_GET_SELECTED + desc: Enumerator for $xDeviceGetSelected + value: '300' --- type: enum desc: Defines structure types diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index 34531ca8b1..79079ca85b 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -9,11 +9,17 @@ * @file ur_lib.cpp * */ + +// avoids windows.h from defining macros for min and max +// which avoids playing havoc with std::min and std::max +// (not quite sure why windows.h is being included here) +#define NOMINMAX + #include "ur_lib.hpp" #include "logger/ur_logger.hpp" #include "ur_loader.hpp" -#include +#include namespace ur_lib { /////////////////////////////////////////////////////////////////////////////// @@ -206,4 +212,579 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, return UR_RESULT_SUCCESS; } +ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, + ur_device_type_t DeviceType, + uint32_t NumEntries, + ur_device_handle_t *phDevices, + uint32_t *pNumDevices) { + + if (!hPlatform) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + // NumEntries is max number of devices wanted by the caller (max usable length of phDevices) + if (NumEntries < 0) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + if (NumEntries > 0 && !phDevices) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + // pNumDevices is the actual number of device handles added to phDevices by this function + if (NumEntries == 0 && !pNumDevices) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + // plan: + // 0. basic validation of argument values (see code above) + // 1. conversion of argument values into useful data items + // 2. retrieval and parsing of environment variable string + // 3. conversion of term map to accept and discard filters + // 4. inserting a default "*:*" accept filter, if required + // 5. symbolic consolidation of accept and discard filters + // 6. querying the platform handles for all 'root' devices + // 7. partioning via platform root devices into subdevices + // 8. partioning via platform subdevices into subsubdevices + // 9. short-listing devices to accept using accept filters + // A. de-listing devices to discard using discard filters + + // possible symbolic short-circuit special cases exist: + // * if there are no terms, select all root devices + // * if any discard is "*", select no root devices + // * if any discard is "*.*", select no sub-devices + // * if any discard is "*.*.*", select no sub-sub-devices + // * + // + // detail for step 5 of above plan: + // * combine all accept filters into a single accept list + // * combine all discard filters into single discard list + // then invert it to make the initial/default accept list + // (needs knowledge of the valid range from the platform) + // "!level_zero:1,2" -> "level_zero:0,3,...,max" + // * finally subtract the discard set from the accept set + + // accept "2,*" != "*,2" + // because "2,*" == "2,0,1,3" + // whereas "*,2" == "0,1,2,3" + // however + // discard "2,*" == "*,2" + + ur_platform_backend_t platformBackend; + if (UR_RESULT_SUCCESS != + urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND, + sizeof(ur_platform_backend_t), &platformBackend, 0)) { + return UR_RESULT_ERROR_INVALID_PLATFORM; + } + const std::string platformBackendName = // hPlatform->get_backend_name(); + [&platformBackend]() constexpr { + switch (platformBackend) { + case UR_PLATFORM_BACKEND_UNKNOWN: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_LEVEL_ZERO: + return "level_zero"; + break; + case UR_PLATFORM_BACKEND_OPENCL: + return "opencl"; + break; + case UR_PLATFORM_BACKEND_CUDA: + return "cuda"; + break; + case UR_PLATFORM_BACKEND_HIP: + return "hip"; + break; + case UR_PLATFORM_BACKEND_NATIVE_CPU: + return "*"; // the only ODS string that matches + break; + case UR_PLATFORM_BACKEND_FORCE_UINT32: + return ""; // no ODS string matches this + break; + default: + return ""; // no ODS string matches this + break; + } + }(); + + // The std::map is sorted by its key, so this method of parsing the ODS env var + // alters the ordering of the terms, which makes it impossible to check whether + // all discard terms appear after all accept terms and to preserve the ordering + // of backends as specified in the ODS string. + // However, for single-platform requests, we are only interested in exactly one + // backend, and we know that discard filter terms always override accept filter + // terms, so the ordering of terms can be safely ignored -- in the special case + // where the whole ODS string contains at most one accept term, and at most one + // discard term, for that backend. + // (If we wished to preserve the ordering of terms, we could replace `std::map` + // with `std::queue>` or something similar.) + auto &maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true); + + // if the ODS env var is not set at all, then pretend it was set to the default + using EnvVarMap = std::map>; + EnvVarMap mapODS = maybeEnvVarMap.has_value() ? maybeEnvVarMap.value() + : EnvVarMap{{"*", {"*"}}}; + + // the full BNF grammar can be found here: + // https://github.com/intel/llvm/blob/sycl/sycl/doc/EnvironmentVariables.md#oneapi_device_selector + + // discardFilter = "!acceptFilter" + // acceptFilter = "backend:filterStrings" + // filterStrings = "filterString[,filterString[,...]]" + // filterString = "root[.sub[.subsub]]" + // root = "*|int|cpu|gpu|fpga" + // sub = "*|int" + // subsub = "*|int" + + // validation regex for filterString (not used in this code) + std::regex validation_pattern("^(" + "\\*" // C++ escape for \, regex escape for literal '*' + "|" + "cpu" // ensure case-insenitive, when using + "|" + "gpu" // ensure case-insenitive, when using + "|" + "fpga" // ensure case-insenitive, when using + "|" + "[[:digit:]]+" // '' + "|" + "[[:digit:]]+\\.[[:digit:]]+" // '.' + "|" + "[[:digit:]]+\\.\\*" // '.*.*' + "|" + "\\*\\.\\*" // C++ and regex escapes, literal '*.*' + "|" + "[[:digit:]]+\\.[[:digit:]]+\\.[[:digit:]]+" // '..' + "|" + "[[:digit:]]+\\.[[:digit:]]+\\.\\*" // '..*' + "|" + "[[:digit:]]+\\.\\*\\.\\*" // '.*.*' + "|" + "\\*\\.\\*\\.\\*" // C++ and regex escapes, literal '*.*.*' + ")$", std::regex_constants::icase); + + using DeviceHardwareType = ur_device_type_t; + + enum class DevicePartLevel { + ROOT, + SUB, + SUBSUB + }; + + using DeviceIdType = unsigned long; + constexpr DeviceIdType DeviceIdTypeALL = -1; // ULONG_MAX but without #include + + struct DeviceSpec { + DevicePartLevel level; + DeviceHardwareType hwType = ::UR_DEVICE_TYPE_ALL; + DeviceIdType rootId = DeviceIdTypeALL; + DeviceIdType subId = DeviceIdTypeALL; + DeviceIdType subsubId = DeviceIdTypeALL; + ur_device_handle_t urDeviceHandle; + }; + + auto getRootHardwareType = [](const std::string &input) -> DeviceHardwareType { + std::string lowerInput(input); + std::transform(lowerInput.cbegin(), lowerInput.cend(), + lowerInput.begin(), ::tolower); + if (lowerInput == "cpu") { + return ::UR_DEVICE_TYPE_CPU; + } + if (lowerInput == "gpu") { + return ::UR_DEVICE_TYPE_GPU; + } + if (lowerInput == "fpga") { + return ::UR_DEVICE_TYPE_FPGA; + } + return ::UR_DEVICE_TYPE_ALL; + }; + + auto getDeviceId = + [DeviceIdTypeALL](const std::string &input) -> DeviceIdType { + if (input.find_first_not_of("0123456789") == std::string::npos) { + return std::stoul(input); + } + return DeviceIdTypeALL; + }; + + std::vector acceptDeviceList; + std::vector discardDeviceList; + + std::vector acceptFilters; + std::vector discardFilters; + for (auto &termPair : mapODS) { + std::string backend = termPair.first; + if (backend.empty()) { + // malformed term: missing backend -- output ERROR, then continue + // TODO: replace std::cout with URT message output mechanism + std::cout << "ERROR: missing backend, format of filter = " + "'[!]backend:filterStrings'"; + continue; + } + enum FilterType { + AcceptFilter, + DiscardFilter, + } termType = (backend.front() != '!') ? AcceptFilter : DiscardFilter; + auto &deviceList = acceptDeviceList; + if (termType != AcceptFilter) { + backend.erase(backend.cbegin()); + deviceList = discardDeviceList; + } + // Note the hPlatform -> platformBackend -> platformBackendName conversion above + // guarantees minimal sanity for the comparison with backend from the ODS string + if (backend != "*" && + std::equal(platformBackendName.cbegin(), platformBackendName.cend(), + backend.cbegin(), backend.cend(), + [](const auto &a, const auto &b) { + // case-insensitive comparison by converting both tolower + return std::tolower(static_cast(a)) == + std::tolower(static_cast(b)); + })) { + // irrelevant term for current request: different backend -- silently ignore + continue; + } + if (termPair.second.size() == 0) { + // malformed term: missing filterStrings -- output ERROR, then continue + // TODO: replace std::cout with URT message output mechanism + std::cout << "ERROR missing filterStrings, format of filter = " + "'[!]backend:filterStrings'"; + continue; + } + if (std::find_if(termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { return s.empty(); }) != + termPair.second.cend()) { + // malformed term: missing filterString -- output warning, then continue + // TODO: replace std::cout with URT message output mechanism + std::cout << "WARNING: empty filterString, format of filterStrings " + "= 'filterString[,filterString[,...]]'"; + continue; + } + if (std::find_if(termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { + return std::count(s.cbegin(), s.cend(), '.') > 2; + }) != termPair.second.cend()) { + // malformed term: too many dots in filterString -- output warning, then continue + // TODO: replace std::cout with URT message output mechanism + std::cout << "WARNING: too many dots in filterString, format of " + "filterString = 'root[.sub[.subsub]]'"; + continue; + } + if (std::find_if( + termPair.second.cbegin(), termPair.second.cend(), + [](const auto &s) { + // GOOD: "*.*", "1.*.*", "*.*.*" + // BAD: "*.1", "*.", "1.*.2", "*.gpu" + std::string prefix = "*."; // every "*." pattern ... + std::string whole = "*.*"; // ... must be start of "*.*" + std::string::size_type pos = 0; + while ((pos = s.find(prefix, pos)) != std::string::npos) { + if (s.substr(pos, whole.size()) != whole) { + return true; // found a BAD thing, either "\*\.$" or "\*\.[^*]" + } + pos += prefix.size(); + } + return false; // no BAD things, so must be okay + }) != termPair.second.cend()) { + // malformed term: star dot no-star in filterString -- output warning, then continue + // TODO: replace std::cout with URT message output mechanism + std::cout + << "WARNING: invalid wildcard in filterString, '*.' => '*.*'"; + continue; + } + + // TODO -- use regex validation_pattern to catch all other syntax errors in the ODS string + + for (auto &filterString : termPair.second) { + std::string::size_type locationDot1 = filterString.find('.'); + if (locationDot1 != std::string::npos) { + std::string firstPart = filterString.substr(0, locationDot1); + const auto hardwareType = getRootHardwareType(firstPart); + const auto firstDeviceId = getDeviceId(firstPart); + // first dot found, look for another + std::string::size_type locationDot2 = + filterString.find('.', locationDot1); + std::string secondPart = filterString.substr( + locationDot1 + 1, locationDot2 == std::string::npos + ? std::string::npos + : locationDot2 - locationDot1); + const auto secondDeviceId = getDeviceId(secondPart); + if (locationDot2 != std::string::npos) { + // second dot found, this is a subsubdevice + std::string thirdPart = + filterString.substr(locationDot2 + 1); + const auto thirdDeviceId = getDeviceId(thirdPart); + deviceList.push_back(DeviceSpec{ + DevicePartLevel::SUBSUB, hardwareType, + firstDeviceId, secondDeviceId, + thirdDeviceId}); + } else { + // second dot not found, this is a subdevice + deviceList.push_back(DeviceSpec{DevicePartLevel::SUB, + hardwareType, firstDeviceId, + secondDeviceId}); + } + } else { + // first dot not found, this is a root device + const auto hardwareType = getRootHardwareType(filterString); + const auto firstDeviceId = getDeviceId(filterString); + deviceList.push_back(DeviceSpec{DevicePartLevel::ROOT, + hardwareType, firstDeviceId}); + } + } + + if (termType != AcceptFilter) { + discardFilters.insert(discardFilters.end(), + termPair.second.cbegin(), + termPair.second.cend()); + } else { + acceptFilters.insert(acceptFilters.end(), termPair.second.cbegin(), + termPair.second.cend()); + } + } + + // if no accept filters are specified by the user, we must add a default "all root devices" + if (acceptFilters.size() == 0) { + acceptFilters.insert(acceptFilters.end(), 1, "*"); + } + if (acceptDeviceList.size() == 0) { + acceptDeviceList.push_back( + DeviceSpec{DevicePartLevel::ROOT, ::UR_DEVICE_TYPE_ALL, + DeviceIdTypeALL}); + } + + std::vector rootDevices; + std::vector subDevices; + std::vector subSubDevices; + + // To support root device terms: + { + uint32_t platformNumRootDevicesAll = 0; + if (UR_RESULT_SUCCESS != urDeviceGet(hPlatform, UR_DEVICE_TYPE_ALL, 0, + nullptr, + &platformNumRootDevicesAll)) { + return UR_RESULT_ERROR_DEVICE_NOT_FOUND; + } + std::vector rootDeviceHandles( + platformNumRootDevicesAll); + auto pRootDevices = rootDeviceHandles.data(); + if (UR_RESULT_SUCCESS != urDeviceGet(hPlatform, UR_DEVICE_TYPE_ALL, + platformNumRootDevicesAll, + pRootDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_NOT_FOUND; + } + + DeviceIdType deviceCount = 0; + std::transform( + rootDeviceHandles.cbegin(), rootDeviceHandles.cend(), + std::back_inserter(rootDevices), + [&](ur_device_handle_t urDeviceHandle) { + // obtain and record device type from platform (squash errors) + ur_device_type_t hardwareType = ::UR_DEVICE_TYPE_DEFAULT; + urDeviceGetInfo(urDeviceHandle, UR_DEVICE_INFO_TYPE, + sizeof(ur_device_type_t), &hardwareType, 0); + return DeviceSpec{DevicePartLevel::ROOT, hardwareType, + deviceCount++, DeviceIdTypeALL, + DeviceIdTypeALL, urDeviceHandle}; + }); + + // apply the function parameter: ur_device_type_t DeviceType + // remove_if(..., urDeviceHandle->deviceType == DeviceType) + rootDevices.erase( + std::remove_if( + rootDevices.begin(), rootDevices.end(), + [DeviceType](DeviceSpec &device) { + const bool keep = + (DeviceType == + DeviceHardwareType::UR_DEVICE_TYPE_ALL) || + (DeviceType == + DeviceHardwareType::UR_DEVICE_TYPE_DEFAULT) || + (DeviceType == device.hwType); + return !keep; + }), + rootDevices.end()); + } + + + // To support sub-device terms: + std::for_each( + rootDevices.cbegin(), rootDevices.cend(), + [&](DeviceSpec device) { + ur_device_partition_property_t propNextPart{ + UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN, + {UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE}}; + ur_device_partition_properties_t partitionProperties{ + UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, nullptr, + &propNextPart, 1}; + uint32_t numSubdevices = 0; + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + 0, nullptr, &numSubdevices)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + std::vector subDeviceHandles(numSubdevices); + auto pSubDevices = subDeviceHandles.data(); + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + numSubdevices, pSubDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + DeviceIdType subDeviceCount = 0; + std::transform(subDeviceHandles.cbegin(), subDeviceHandles.cend(), + std::back_inserter(subDevices), + [&](ur_device_handle_t urDeviceHandle) { + return DeviceSpec{ + DevicePartLevel::SUB, device.hwType, + device.rootId, subDeviceCount++, + DeviceIdTypeALL, urDeviceHandle}; + }); + return UR_RESULT_SUCCESS; + }); + + // To support sub-sub-device terms: + std::for_each( + subDevices.cbegin(), subDevices.cend(), + [&](DeviceSpec device) { + ur_device_partition_property_t propNextPart{ + UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN, + {UR_DEVICE_AFFINITY_DOMAIN_FLAG_NEXT_PARTITIONABLE}}; + ur_device_partition_properties_t partitionProperties{ + UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, nullptr, + &propNextPart, 1}; + uint32_t numSubSubdevices = 0; + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + 0, nullptr, &numSubSubdevices)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + std::vector subSubDeviceHandles( + numSubSubdevices); + auto pSubSubDevices = subSubDeviceHandles.data(); + if (UR_RESULT_SUCCESS != + urDevicePartition(device.urDeviceHandle, &partitionProperties, + numSubSubdevices, pSubSubDevices, 0)) { + return UR_RESULT_ERROR_DEVICE_PARTITION_FAILED; + } + DeviceIdType subSubDeviceCount = 0; + std::transform(subSubDeviceHandles.cbegin(), + subSubDeviceHandles.cend(), + std::back_inserter(subSubDevices), + [&](ur_device_handle_t urDeviceHandle) { + return DeviceSpec{DevicePartLevel::SUBSUB, device.hwType, + device.rootId, device.subId, + subSubDeviceCount++, urDeviceHandle}; + }); + return UR_RESULT_SUCCESS; + }); + + auto ApplyFilter = [&](DeviceSpec &filter, DeviceSpec &device) -> bool { + bool matches = false; + if (filter.rootId == DeviceIdTypeALL) { + // if this is a root device filter, then it must be '*' or 'cpu' or 'gpu' or 'fpga' + // if this is a subdevice filter, then it must be '*.*' + // if this is a subsubdevice filter, then it must be '*.*.*' + matches = (filter.hwType == device.hwType) || + (filter.hwType == DeviceHardwareType::UR_DEVICE_TYPE_ALL); + } else if (filter.rootId != device.rootId) { + // root part in filter is a number but does not match the number in the root part of device + matches = false; + } else if (filter.level == DevicePartLevel::ROOT) { + // this is a root device filter with a number that matches + matches = true; + } else if (filter.subId == DeviceIdTypeALL) { + // sub type of star always matches (when root part matches, which we already know here) + // if this is a subdevice filter, then it must be 'matches.*' + // if this is a subsubdevice filter, then it must be 'matches.*.*' + matches = true; + } else if (filter.subId != device.subId) { + // sub part in filter is a number but does not match the number in the sub part of device + matches = false; + } else if (filter.level == DevicePartLevel::SUB) { + // this is a sub device number filter, numbers match in both parts + matches = true; + } else if (filter.subsubId == DeviceIdTypeALL) { + // subsub type of star always matches (when other parts match, which we already know here) + // this is a subsub device filter, it must be 'matches.matches.*' + matches = true; + } else { + // this is a subsub device filter, numbers in all three parts match + matches = (filter.subsubId == device.subsubId); + } + return matches; + }; + + + // apply each discard filter in turn by removing all matching elements + // from the appropriate device handle vector returned by the platform; + // no side-effect: the matching devices are just removed and discarded + for (auto &discard: discardDeviceList) { + auto ApplyDiscardFilter = + [&](auto &device) -> bool { + return ApplyFilter(discard, device); + }; + if (discard.level == DevicePartLevel::ROOT) { + rootDevices.erase(std::remove_if(rootDevices.begin(), + rootDevices.end(), + ApplyDiscardFilter), + rootDevices.end()); + } + if (discard.level == DevicePartLevel::SUB) { + subDevices.erase(std::remove_if(subDevices.begin(), + subDevices.end(), + ApplyDiscardFilter), + subDevices.end()); + } + if (discard.level == DevicePartLevel::SUBSUB) { + subSubDevices.erase(std::remove_if(subSubDevices.begin(), + subSubDevices.end(), + ApplyDiscardFilter), + subSubDevices.end()); + } + } + + std::vector selectedDevices; + + // apply each accept filter in turn by removing all matching elements + // from the appropriate device handle vector returned by the platform + // but using a predicate with a side-effect that takes a copy of each + // of the accepted device handles just before they are removed + // removing each item as it is selected prevents us taking duplicates + // without needing O(n^2) de-duplicatation or symbolic simplification + for (auto &accept : acceptDeviceList) { + auto ApplyAcceptFilter = + [&](auto &device) -> bool { + const bool matches = ApplyFilter(accept, device); + if (matches) { + selectedDevices.push_back(device.urDeviceHandle); + } + return matches; + }; + if (accept.level == DevicePartLevel::ROOT) { + rootDevices.erase(std::remove_if(rootDevices.begin(), + rootDevices.end(), + ApplyAcceptFilter), + rootDevices.end()); + } + if (accept.level == DevicePartLevel::SUB) { + subDevices.erase(std::remove_if(subDevices.begin(), + subDevices.end(), + ApplyAcceptFilter), + subDevices.end()); + } + if (accept.level == DevicePartLevel::SUBSUB) { + subSubDevices.erase(std::remove_if(subSubDevices.begin(), + subSubDevices.end(), + ApplyAcceptFilter), + subSubDevices.end()); + } + } + + // selectedDevices is now a vector containing all the right device handles + + // should we return the size of the vector or the content of the vector? + if (NumEntries == 0) { + *pNumDevices = static_cast(selectedDevices.size()); + } else if (NumEntries > 0) { + *pNumDevices = static_cast( + std::min((size_t)NumEntries, selectedDevices.size())); + std::copy_n(selectedDevices.cbegin(), *pNumDevices, phDevices); + } + + return UR_RESULT_SUCCESS; +} } // namespace ur_lib diff --git a/source/loader/ur_lib.hpp b/source/loader/ur_lib.hpp index 9d1e02a67e..4a9733b421 100644 --- a/source/loader/ur_lib.hpp +++ b/source/loader/ur_lib.hpp @@ -99,5 +99,10 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, ur_code_location_callback_t pfnCodeloc, void *pUserData); +ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, + ur_device_type_t DeviceType, + uint32_t NumEntries, + ur_device_handle_t *phDevices, + uint32_t *pNumDevices); } // namespace ur_lib #endif /* UR_LOADER_LIB_H */ diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 80d1bc3fb6..de3a655c9b 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -777,6 +777,52 @@ ur_result_t UR_APICALL urDeviceGet( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +ur_result_t UR_APICALL urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t + NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t * + phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. + ) try { + return ur_lib::urDeviceGetSelected(hPlatform, DeviceType, NumEntries, + phDevices, pNumDevices); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Retrieves various information about device /// diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 3a7cebca8c..731a6fb574 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -673,6 +673,50 @@ ur_result_t UR_APICALL urDeviceGet( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR +/// +/// @details +/// - Multiple calls to this function will return identical device handles, +/// in the same order. +/// - The number and order of handles returned from this function will be +/// affected by environment variables that filter or select which devices +/// are exposed through this API. +/// - A reference is taken for each returned device and must be released +/// with a subsequent call to ::urDeviceRelease. +/// - The application may call this function from simultaneous threads, the +/// implementation must be thread-safe. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hPlatform` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_DEVICE_TYPE_VPU < DeviceType` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +ur_result_t UR_APICALL urDeviceGetSelected( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t + NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices in not NULL then NumEntries should be greater than zero, + ///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE, + ///< will be returned. + ur_device_handle_t * + phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then only + ///< that number of devices will be retrieved. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of selected devices + ///< available for the given platform. +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Retrieves various information about device /// diff --git a/test/conformance/device/CMakeLists.txt b/test/conformance/device/CMakeLists.txt index 23ff5b4ebc..0f7da3d80c 100644 --- a/test/conformance/device/CMakeLists.txt +++ b/test/conformance/device/CMakeLists.txt @@ -9,6 +9,7 @@ add_conformance_test_with_platform_environment(device urDeviceGetGlobalTimestamps.cpp urDeviceGetInfo.cpp urDeviceGetNativeHandle.cpp + urDeviceGetSelected.cpp urDevicePartition.cpp urDeviceRelease.cpp urDeviceRetain.cpp diff --git a/test/conformance/device/urDeviceGetSelected.cpp b/test/conformance/device/urDeviceGetSelected.cpp new file mode 100644 index 0000000000..0cb2738cf3 --- /dev/null +++ b/test/conformance/device/urDeviceGetSelected.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2022-2023 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +using urDeviceGetSelectedTest = uur::urPlatformTest; + +TEST_F(urDeviceGetSelectedTest, Success) { + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count, + devices.data(), nullptr)); + for (auto device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, SuccessSubsetOfDevices) { + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + if (count < 2) { + GTEST_SKIP(); + } + std::vector devices(count - 1); + ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count - 1, + devices.data(), nullptr)); + for (auto device : devices) { + ASSERT_NE(nullptr, device); + } +} + +TEST_F(urDeviceGetSelectedTest, InvalidNullHandlePlatform) { + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_NULL_HANDLE, + urDeviceGetSelected(nullptr, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); +} + +TEST_F(urDeviceGetSelectedTest, InvalidEnumerationDevicesType) { + uint32_t count = 0; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_ENUMERATION, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_FORCE_UINT32, + 0, nullptr, &count)); +} + +TEST_F(urDeviceGetSelectedTest, InvalidValueNumEntries) { + uint32_t count = 0; + ASSERT_SUCCESS( + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count)); + ASSERT_NE(count, 0); + std::vector devices(count); + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_SIZE, + urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, + devices.data(), nullptr)); +}