diff --git a/unified-runtime/source/adapters/offload/context.hpp b/unified-runtime/source/adapters/offload/context.hpp index b40d17ad3ae9c..36ddf1141a4ea 100644 --- a/unified-runtime/source/adapters/offload/context.hpp +++ b/unified-runtime/source/adapters/offload/context.hpp @@ -29,16 +29,15 @@ struct ur_context_handle_t_ : RefCounted { ~ur_context_handle_t_() { urDeviceRelease(Device); } ur_device_handle_t Device; - std::unordered_map AllocTypeMap; - std::optional getAllocType(const void *UsmPtr) { - for (auto &pair : AllocTypeMap) { - if (UsmPtr >= pair.first && - reinterpret_cast(UsmPtr) < - reinterpret_cast(pair.first) + pair.second.Size) { - return pair.second; - } + ol_result_t getAllocType(const void *UsmPtr, ol_alloc_type_t &Type) { + auto Err = olGetMemInfo(Device->Platform->OffloadPlatform, UsmPtr, + OL_MEM_INFO_TYPE, sizeof(Type), &Type); + if (Err && Err->Code == OL_ERRC_NOT_FOUND) { + // Treat unknown allocations as host + Type = OL_ALLOC_TYPE_HOST; + return OL_SUCCESS; } - return std::nullopt; + return Err; } }; diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 3d419a30dc500..0e995b83eebab 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -440,17 +440,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc, size_t size, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - auto GetDevice = [&](const void *Ptr) { - auto Res = hQueue->UrContext->getAllocType(Ptr); - if (!Res) - return Adapter->HostDevice; - return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice - : hQueue->OffloadDevice; - }; - - return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc, - GetDevice(pSrc), size, blocking, numEventsInWaitList, - phEventWaitList, phEvent); + ol_alloc_type_t DstTy; + OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pDst, DstTy)); + ol_device_handle_t Dst = + DstTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice; + + ol_alloc_type_t SrcTy; + OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pSrc, SrcTy)); + ol_device_handle_t Src = + SrcTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice; + + return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, Dst, pSrc, Src, size, + blocking, numEventsInWaitList, phEventWaitList, phEvent); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/offload/memory.cpp b/unified-runtime/source/adapters/offload/memory.cpp index e27a032a61451..99db441983da8 100644 --- a/unified-runtime/source/adapters/offload/memory.cpp +++ b/unified-runtime/source/adapters/offload/memory.cpp @@ -80,7 +80,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { // Subbuffers should not free their parents if (!BufferImpl->Parent) { // TODO: Handle registered host memory - OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr)); + OL_RETURN_ON_ERR(olMemFree( + hMem->Context->Device->Platform->OffloadPlatform, BufferImpl->Ptr)); } else { return urMemRelease(BufferImpl->Parent); } diff --git a/unified-runtime/source/adapters/offload/usm.cpp b/unified-runtime/source/adapters/offload/usm.cpp index e457c59896d53..06c1c87a513a4 100644 --- a/unified-runtime/source/adapters/offload/usm.cpp +++ b/unified-runtime/source/adapters/offload/usm.cpp @@ -22,9 +22,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext, size_t size, void **ppMem) { OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_HOST, size, ppMem)); - - hContext->AllocTypeMap.insert_or_assign( - *ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size}); return UR_RESULT_SUCCESS; } @@ -33,9 +30,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( ur_usm_pool_handle_t, size_t size, void **ppMem) { OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, ppMem)); - - hContext->AllocTypeMap.insert_or_assign( - *ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size}); return UR_RESULT_SUCCESS; } @@ -44,23 +38,83 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( ur_usm_pool_handle_t, size_t size, void **ppMem) { OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_MANAGED, size, ppMem)); - - hContext->AllocTypeMap.insert_or_assign( - *ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size}); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, void *pMem) { - hContext->AllocTypeMap.erase(pMem); - return offloadResultToUR(olMemFree(pMem)); + return offloadResultToUR( + olMemFree(hContext->Device->Platform->OffloadPlatform, pMem)); } -UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo( - [[maybe_unused]] ur_context_handle_t hContext, - [[maybe_unused]] const void *pMem, - [[maybe_unused]] ur_usm_alloc_info_t propName, - [[maybe_unused]] size_t propSize, [[maybe_unused]] void *pPropValue, - [[maybe_unused]] size_t *pPropSizeRet) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL +urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, + ur_usm_alloc_info_t propName, size_t propSize, + void *pPropValue, size_t *pPropSizeRet) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + auto Platform = hContext->Device->Platform->OffloadPlatform; + ol_mem_info_t olInfo; + + switch (propName) { + case UR_USM_ALLOC_INFO_TYPE: + olInfo = OL_MEM_INFO_TYPE; + break; + case UR_USM_ALLOC_INFO_BASE_PTR: + olInfo = OL_MEM_INFO_BASE; + break; + case UR_USM_ALLOC_INFO_SIZE: + olInfo = OL_MEM_INFO_SIZE; + break; + case UR_USM_ALLOC_INFO_DEVICE: + // Contexts can only contain one device + return ReturnValue(hContext->Device); + case UR_USM_ALLOC_INFO_POOL: + default: + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + break; + } + + if (pPropSizeRet) { + OL_RETURN_ON_ERR(olGetMemInfoSize(Platform, pMem, olInfo, pPropSizeRet)); + } + + if (pPropValue) { + auto Err = olGetMemInfo(Platform, pMem, olInfo, propSize, pPropValue); + if (Err && Err->Code == OL_ERRC_NOT_FOUND) { + // If the device didn't allocate this object, return default values + switch (propName) { + case UR_USM_ALLOC_INFO_TYPE: + return ReturnValue(UR_USM_TYPE_UNKNOWN); + case UR_USM_ALLOC_INFO_BASE_PTR: + return ReturnValue(nullptr); + case UR_USM_ALLOC_INFO_SIZE: + return ReturnValue(0); + default: + return UR_RESULT_ERROR_UNKNOWN; + } + } + OL_RETURN_ON_ERR(Err); + + if (propName == UR_USM_ALLOC_INFO_TYPE) { + auto *OlType = reinterpret_cast(pPropValue); + auto *UrType = reinterpret_cast(pPropValue); + switch (*OlType) { + case OL_ALLOC_TYPE_HOST: + *UrType = UR_USM_TYPE_HOST; + break; + case OL_ALLOC_TYPE_DEVICE: + *UrType = UR_USM_TYPE_DEVICE; + break; + case OL_ALLOC_TYPE_MANAGED: + *UrType = UR_USM_TYPE_SHARED; + break; + default: + *UrType = UR_USM_TYPE_UNKNOWN; + break; + } + } + } + + return UR_RESULT_SUCCESS; }