diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 9c797a0ec3..1b7d19fa67 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -273,11 +273,17 @@ namespace ur_loader %endif %endif - ## Before we can re-enable the releases we will need ref-counted object_t. - ## See unified-runtime github issue #1784 - ##%if item['release']: - ##// release loader handle - ##${item['factory']}.release( ${item['name']} ); + ## Possibly handle release/retain ref counting - there are no ur_exp-image factories + %if 'factory' in item and '_exp_image_' not in item['factory']: + %if item['release']: + // release loader handle + context->factories.${item['factory']}.release( ${item['name']} ); + %endif + %if item['retain']: + // increment refcount of handle + context->factories.${item['factory']}.retain( ${item['name']} ); + %endif + %endif %if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { diff --git a/source/common/ur_singleton.hpp b/source/common/ur_singleton.hpp index b469c8b8a7..057d58c067 100644 --- a/source/common/ur_singleton.hpp +++ b/source/common/ur_singleton.hpp @@ -11,6 +11,7 @@ #ifndef UR_SINGLETON_H #define UR_SINGLETON_H 1 +#include #include #include #include @@ -18,13 +19,18 @@ ////////////////////////////////////////////////////////////////////////// /// a abstract factory for creation of singleton objects template class singleton_factory_t { + struct entry_t { + std::unique_ptr ptr; + size_t ref_count; + }; + protected: using singleton_t = singleton_tn; using key_t = typename std::conditional::value, size_t, key_tn>::type; using ptr_t = std::unique_ptr; - using map_t = std::unordered_map; + using map_t = std::unordered_map; std::mutex mut; ///< lock for thread-safety map_t map; ///< single instance of singleton for each unique key @@ -60,16 +66,31 @@ template class singleton_factory_t { if (map.end() == iter) { auto ptr = std::make_unique(std::forward(params)...); - iter = map.emplace(key, std::move(ptr)).first; + iter = map.emplace(key, entry_t{std::move(ptr), 0}).first; + } else { + iter->second.ref_count++; } - return iter->second.get(); + return iter->second.ptr.get(); + } + + void retain(key_tn key) { + std::lock_guard lk(mut); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + iter->second.ref_count++; } ////////////////////////////////////////////////////////////////////////// /// once the key is no longer valid, release the singleton void release(key_tn key) { std::lock_guard lk(mut); - map.erase(getKey(key)); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + if (iter->second.ref_count == 0) { + map.erase(iter); + } else { + iter->second.ref_count--; + } } void clear() { diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 7696489d97..08e3a1b2a7 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( // forward to device-platform result = pfnAdapterRelease(hAdapter); + // release loader handle + context->factories.ur_adapter_factory.release(hAdapter); + return result; } @@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); + // increment refcount of handle + context->factories.ur_adapter_factory.retain(hAdapter); + return result; } @@ -647,6 +653,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); + // increment refcount of handle + context->factories.ur_device_factory.retain(hDevice); + return result; } @@ -673,6 +682,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( // forward to device-platform result = pfnRelease(hDevice); + // release loader handle + context->factories.ur_device_factory.release(hDevice); + return result; } @@ -943,6 +955,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); + // increment refcount of handle + context->factories.ur_context_factory.retain(hContext); + return result; } @@ -969,6 +984,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( // forward to device-platform result = pfnRelease(hContext); + // release loader handle + context->factories.ur_context_factory.release(hContext); + return result; } @@ -1271,6 +1289,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); + // increment refcount of handle + context->factories.ur_mem_factory.retain(hMem); + return result; } @@ -1297,6 +1318,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( // forward to device-platform result = pfnRelease(hMem); + // release loader handle + context->factories.ur_mem_factory.release(hMem); + return result; } @@ -1648,6 +1672,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); + // increment refcount of handle + context->factories.ur_sampler_factory.retain(hSampler); + return result; } @@ -1674,6 +1701,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( // forward to device-platform result = pfnRelease(hSampler); + // release loader handle + context->factories.ur_sampler_factory.release(hSampler); + return result; } @@ -2107,6 +2137,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); + // increment refcount of handle + context->factories.ur_usm_pool_factory.retain(pPool); + return result; } @@ -2132,6 +2165,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( // forward to device-platform result = pfnPoolRelease(pPool); + // release loader handle + context->factories.ur_usm_pool_factory.release(pPool); + return result; } @@ -2517,6 +2553,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); + // increment refcount of handle + context->factories.ur_physical_mem_factory.retain(hPhysicalMem); + return result; } @@ -2545,6 +2584,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( // forward to device-platform result = pfnRelease(hPhysicalMem); + // release loader handle + context->factories.ur_physical_mem_factory.release(hPhysicalMem); + return result; } @@ -2876,6 +2918,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); + // increment refcount of handle + context->factories.ur_program_factory.retain(hProgram); + return result; } @@ -2902,6 +2947,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( // forward to device-platform result = pfnRelease(hProgram); + // release loader handle + context->factories.ur_program_factory.release(hProgram); + return result; } @@ -3499,6 +3547,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); + // increment refcount of handle + context->factories.ur_kernel_factory.retain(hKernel); + return result; } @@ -3525,6 +3576,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( // forward to device-platform result = pfnRelease(hKernel); + // release loader handle + context->factories.ur_kernel_factory.release(hKernel); + return result; } @@ -3975,6 +4029,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); + // increment refcount of handle + context->factories.ur_queue_factory.retain(hQueue); + return result; } @@ -4001,6 +4058,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( // forward to device-platform result = pfnRelease(hQueue); + // release loader handle + context->factories.ur_queue_factory.release(hQueue); + return result; } @@ -4305,6 +4365,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); + // increment refcount of handle + context->factories.ur_event_factory.retain(hEvent); + return result; } @@ -4330,6 +4393,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( // forward to device-platform result = pfnRelease(hEvent); + // release loader handle + context->factories.ur_event_factory.release(hEvent); + return result; } @@ -6862,6 +6928,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp( // forward to device-platform result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem); + // release loader handle + context->factories.ur_exp_external_mem_factory.release(hExternalMem); + return result; } @@ -6952,6 +7021,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp( result = pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore); + // release loader handle + context->factories.ur_exp_external_semaphore_factory.release( + hExternalSemaphore); + return result; } @@ -7179,6 +7252,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); + // increment refcount of handle + context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer); + return result; } @@ -7209,6 +7285,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( // forward to device-platform result = pfnReleaseExp(hCommandBuffer); + // release loader handle + context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer); + return result; } @@ -8525,6 +8604,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( // forward to device-platform result = pfnRetainCommandExp(hCommand); + // increment refcount of handle + context->factories.ur_exp_command_buffer_command_factory.retain(hCommand); + return result; } @@ -8556,6 +8638,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( // forward to device-platform result = pfnReleaseCommandExp(hCommand); + // release loader handle + context->factories.ur_exp_command_buffer_command_factory.release(hCommand); + return result; }