Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions unified-runtime/source/adapters/offload/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ struct ur_context_handle_t_ : RefCounted {
~ur_context_handle_t_() { urDeviceRelease(Device); }

ur_device_handle_t Device;
std::unordered_map<void *, alloc_info_t> AllocTypeMap;

std::optional<alloc_info_t> getAllocType(const void *UsmPtr) {
for (auto &pair : AllocTypeMap) {
if (UsmPtr >= pair.first &&
reinterpret_cast<uintptr_t>(UsmPtr) <
reinterpret_cast<uintptr_t>(pair.first) + pair.second.Size) {
return pair.second;
}
ol_result_t getAllocType(const void *UsmPtr, ol_alloc_type_t &Type) {
auto Err = olGetMemInfo(Device->Platform->OffloadPlatform, UsmPtr,
OL_MEM_INFO_TYPE, sizeof(Type), &Type);
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
// Treat unknown allocations as host
Type = OL_ALLOC_TYPE_HOST;
return OL_SUCCESS;
}
return std::nullopt;
return Err;
}
};
23 changes: 12 additions & 11 deletions unified-runtime/source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,17 +440,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
size_t size, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
auto GetDevice = [&](const void *Ptr) {
auto Res = hQueue->UrContext->getAllocType(Ptr);
if (!Res)
return Adapter->HostDevice;
return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice
: hQueue->OffloadDevice;
};

return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc,
GetDevice(pSrc), size, blocking, numEventsInWaitList,
phEventWaitList, phEvent);
ol_alloc_type_t DstTy;
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pDst, DstTy));
ol_device_handle_t Dst =
DstTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;

ol_alloc_type_t SrcTy;
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pSrc, SrcTy));
ol_device_handle_t Src =
SrcTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;

return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, Dst, pSrc, Src, size,
blocking, numEventsInWaitList, phEventWaitList, phEvent);

return UR_RESULT_SUCCESS;
}
Expand Down
3 changes: 2 additions & 1 deletion unified-runtime/source/adapters/offload/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
// Subbuffers should not free their parents
if (!BufferImpl->Parent) {
// TODO: Handle registered host memory
OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr));
OL_RETURN_ON_ERR(olMemFree(
hMem->Context->Device->Platform->OffloadPlatform, BufferImpl->Ptr));
} else {
return urMemRelease(BufferImpl->Parent);
}
Expand Down
90 changes: 72 additions & 18 deletions unified-runtime/source/adapters/offload/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_HOST, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size});
return UR_RESULT_SUCCESS;
}

Expand All @@ -33,9 +30,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
ur_usm_pool_handle_t, size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_DEVICE, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size});
return UR_RESULT_SUCCESS;
}

Expand All @@ -44,23 +38,83 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
ur_usm_pool_handle_t, size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_MANAGED, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size});
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
void *pMem) {
hContext->AllocTypeMap.erase(pMem);
return offloadResultToUR(olMemFree(pMem));
return offloadResultToUR(
olMemFree(hContext->Device->Platform->OffloadPlatform, pMem));
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
[[maybe_unused]] ur_context_handle_t hContext,
[[maybe_unused]] const void *pMem,
[[maybe_unused]] ur_usm_alloc_info_t propName,
[[maybe_unused]] size_t propSize, [[maybe_unused]] void *pPropValue,
[[maybe_unused]] size_t *pPropSizeRet) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
UR_APIEXPORT ur_result_t UR_APICALL
urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
ur_usm_alloc_info_t propName, size_t propSize,
void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

auto Platform = hContext->Device->Platform->OffloadPlatform;
ol_mem_info_t olInfo;

switch (propName) {
case UR_USM_ALLOC_INFO_TYPE:
olInfo = OL_MEM_INFO_TYPE;
break;
case UR_USM_ALLOC_INFO_BASE_PTR:
olInfo = OL_MEM_INFO_BASE;
break;
case UR_USM_ALLOC_INFO_SIZE:
olInfo = OL_MEM_INFO_SIZE;
break;
case UR_USM_ALLOC_INFO_DEVICE:
// Contexts can only contain one device
return ReturnValue(hContext->Device);
case UR_USM_ALLOC_INFO_POOL:
default:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
break;
}

if (pPropSizeRet) {
OL_RETURN_ON_ERR(olGetMemInfoSize(Platform, pMem, olInfo, pPropSizeRet));
}

if (pPropValue) {
auto Err = olGetMemInfo(Platform, pMem, olInfo, propSize, pPropValue);
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
// If the device didn't allocate this object, return default values
switch (propName) {
case UR_USM_ALLOC_INFO_TYPE:
return ReturnValue(UR_USM_TYPE_UNKNOWN);
case UR_USM_ALLOC_INFO_BASE_PTR:
return ReturnValue(nullptr);
case UR_USM_ALLOC_INFO_SIZE:
return ReturnValue(0);
default:
return UR_RESULT_ERROR_UNKNOWN;
}
}
OL_RETURN_ON_ERR(Err);

if (propName == UR_USM_ALLOC_INFO_TYPE) {
auto *OlType = reinterpret_cast<ol_alloc_type_t *>(pPropValue);
auto *UrType = reinterpret_cast<ur_usm_type_t *>(pPropValue);
switch (*OlType) {
case OL_ALLOC_TYPE_HOST:
*UrType = UR_USM_TYPE_HOST;
break;
case OL_ALLOC_TYPE_DEVICE:
*UrType = UR_USM_TYPE_DEVICE;
break;
case OL_ALLOC_TYPE_MANAGED:
*UrType = UR_USM_TYPE_SHARED;
break;
default:
*UrType = UR_USM_TYPE_UNKNOWN;
break;
}
}
}

return UR_RESULT_SUCCESS;
}
Loading