Skip to content

[NFC][SYCL] Better "managed" ur_program_handle_t #19536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ kernel make_kernel(const context &TargetContext,
const device_image<bundle_state::executable> &DeviceImage =
*KernelBundle.begin();
device_image_impl &DeviceImageImpl = *getSyclObjImpl(DeviceImage);
UrProgram = DeviceImageImpl.get_ur_program_ref();
UrProgram = DeviceImageImpl.get_ur_program();
}

// Create UR kernel first.
Expand Down
50 changes: 50 additions & 0 deletions sycl/source/detail/adapter_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,56 @@ class adapter_impl {
UrFuncPtrMapT UrFuncPtrs;
}; // class adapter_impl

template <typename URResource> class Managed {
static constexpr auto Release = []() constexpr {
if constexpr (std::is_same_v<URResource, ur_program_handle_t>)
return UrApiKind::urProgramRelease;
}();

public:
Managed() = default;
Managed(URResource R, adapter_impl &Adapter) : R(R), Adapter(&Adapter) {}
Managed(adapter_impl &Adapter) : Adapter(&Adapter) {}
Managed(const Managed &) = delete;
Managed(Managed &&Other) : Adapter(Other.Adapter) {
R = Other.R;
Other.R = nullptr;
}
Managed &operator=(const Managed &) = delete;
Managed &operator=(Managed &&Other) {
if (R)
Adapter->call<Release>(R);
R = Other.R;
Other.R = nullptr;
Adapter = Other.Adapter;
return *this;
}

operator URResource() const { return R; }

URResource release() {
URResource Res = R;
R = nullptr;
return Res;
}

URResource *operator&() {
assert(!R && "Already initialized!");
assert(Adapter && "Adapter must be set for this API!");
return &R;
}

~Managed() {
if (!R)
return;

Adapter->call<Release>(R);
}

private:
URResource R = nullptr;
adapter_impl *Adapter = nullptr;
};
} // namespace detail
} // namespace _V1
} // namespace sycl
5 changes: 1 addition & 4 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ context_impl::~context_impl() {
DeviceGlobal);
DGEntry->removeAssociatedResources(this);
}
for (auto LibProg : MCachedLibPrograms) {
assert(LibProg.second && "Null program must not be kept in the cache");
getAdapter().call<UrApiKind::urProgramRelease>(LibProg.second);
}
MCachedLibPrograms.clear();
// TODO catch an exception and put it to list of asynchronous exceptions
getAdapter().call_nocheck<UrApiKind::urContextRelease>(MContext);
} catch (std::exception &e) {
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {

using CachedLibProgramsT =
std::map<std::pair<DeviceLibExt, ur_device_handle_t>,
ur_program_handle_t>;
Managed<ur_program_handle_t>>;

/// In contrast to user programs, which are compiled from user code, library
/// programs come from the SYCL runtime. They are identified by the
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/device_image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
if (!KID || !has_kernel(*KID))
continue;

auto UrProgram = get_ur_program_ref();
auto UrProgram = get_ur_program();
auto [UrKernel, CacheMutex, ArgMask] =
PM.getOrCreateKernel(Context, AdjustedName,
/*PropList=*/{}, UrProgram);
Expand All @@ -41,7 +41,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
return nullptr;
}

ur_program_handle_t UrProgram = get_ur_program_ref();
ur_program_handle_t UrProgram = get_ur_program();
detail::adapter_impl &Adapter = getSyclObjImpl(Context)->getAdapter();
ur_kernel_handle_t UrKernel = nullptr;
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
Expand Down
56 changes: 30 additions & 26 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ class device_image_impl
ur_program_handle_t Program, uint8_t Origins, private_tag)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
MProgram(Program), MKernelIDs(std::move(KernelIDs)),
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
MKernelIDs(std::move(KernelIDs)),
MSpecConstsDefValBlob(getSpecConstsDefValBlob()), MOrigins(Origins) {
updateSpecConstSymMap();
if (BinImage && (MOrigins & ImageOriginSYCLBIN)) {
Expand Down Expand Up @@ -294,8 +295,8 @@ class device_image_impl
std::unique_ptr<DynRTDeviceBinaryImage> &&MergedImageStorage, private_tag)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
MProgram(Program), MKernelIDs(std::move(KernelIDs)),
MKernelNames{std::move(KernelNames)},
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
MSpecConstsBlob(SpecConstsBlob),
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
Expand All @@ -311,7 +312,8 @@ class device_image_impl
private_tag)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
MProgram(Program), MKernelNames{std::move(KernelNames)},
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
MKernelNames{std::move(KernelNames)},
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
MOrigins(ImageOriginKernelCompiler),
Expand All @@ -329,8 +331,7 @@ class device_image_impl
private_tag)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
MProgram(nullptr), MKernelIDs(std::move(KernelIDs)),
MKernelNames{std::move(KernelNames)},
MKernelIDs(std::move(KernelIDs)), MKernelNames{std::move(KernelNames)},
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
MOrigins(ImageOriginKernelCompiler),
MRTCBinInfo(KernelCompilerBinaryInfo{
Expand All @@ -344,7 +345,7 @@ class device_image_impl
include_pairs_t &&IncludePairsVec, private_tag)
: MBinImage(Src), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()),
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
MState(bundle_state::ext_oneapi_source),
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
MOrigins(ImageOriginKernelCompiler),
MRTCBinInfo(
Expand All @@ -357,7 +358,7 @@ class device_image_impl
private_tag)
: MBinImage(Bytes), MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()),
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
MState(bundle_state::ext_oneapi_source),
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
MOrigins(ImageOriginKernelCompiler),
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
Expand All @@ -371,7 +372,8 @@ class device_image_impl
: MBinImage(static_cast<const RTDeviceBinaryImage *>(nullptr)),
MContext(std::move(Context)),
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
MProgram(Program), MKernelNames{std::move(KernelNames)},
MProgram(Program, getSyclObjImpl(MContext)->getAdapter()),
MKernelNames{std::move(KernelNames)},
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
MOrigins(ImageOriginKernelCompiler),
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
Expand Down Expand Up @@ -558,9 +560,7 @@ class device_image_impl
return get_devices().contains(Dev);
}

const ur_program_handle_t &get_ur_program_ref() const noexcept {
return MProgram;
}
ur_program_handle_t get_ur_program() const noexcept { return MProgram; }

const RTDeviceBinaryImage *const &get_bin_image_ref() const {
return std::get<const RTDeviceBinaryImage *>(MBinImage);
Expand Down Expand Up @@ -617,21 +617,25 @@ class device_image_impl
return NativeProgram;
}

~device_image_impl() {
try {
if (MProgram) {
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
Adapter.call<UrApiKind::urProgramRelease>(MProgram);
}
if (MSpecConstsBuffer) {
std::lock_guard<std::mutex> Lock{MSpecConstAccessMtx};
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
memReleaseHelper(Adapter, MSpecConstsBuffer);
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_image_impl", e);
#ifdef _MSC_VER
#pragma warning(push)
// https://developercommunity.visualstudio.com/t/False-C4297-warning-while-using-function/1130300
// https://godbolt.org/z/xsMvKf84f
#pragma warning(disable : 4297)
#endif
~device_image_impl() try {
if (MSpecConstsBuffer) {
std::lock_guard<std::mutex> Lock{MSpecConstAccessMtx};
adapter_impl &Adapter = getSyclObjImpl(MContext)->getAdapter();
memReleaseHelper(Adapter, MSpecConstsBuffer);
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_image_impl", e);
return; // Don't re-throw.
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif

std::string adjustKernelName(std::string_view Name) const {
if (MOrigins & ImageOriginSYCLBIN) {
Expand Down Expand Up @@ -1298,7 +1302,7 @@ class device_image_impl
std::vector<device_impl *> MDevices;
bundle_state MState;
// Native program handler which this device image represents
ur_program_handle_t MProgram = nullptr;
Managed<ur_program_handle_t> MProgram;

// List of kernel ids available in this image, elements should be sorted
// according to LessByNameComp. Shared between images for performance reasons
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1014,11 +1014,11 @@ class kernel_bundle_impl
auto [Kernel, CacheMutex, ArgMask] =
detail::ProgramManager::getInstance().getOrCreateKernel(
MContext, KernelID.get_name(), /*PropList=*/{},
SelectedImage->get_ur_program_ref());
SelectedImage->get_ur_program());

return std::make_shared<kernel_impl>(
Kernel, *detail::getSyclObjImpl(MContext), SelectedImage, *this,
ArgMask, SelectedImage->get_ur_program_ref(), CacheMutex);
ArgMask, SelectedImage->get_ur_program(), CacheMutex);
}

std::shared_ptr<kernel_impl>
Expand Down
25 changes: 10 additions & 15 deletions sycl/source/detail/kernel_name_based_cache_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,26 @@ struct FastKernelCacheVal {
caching is disabled, the pointer is
nullptr. */
const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */
ur_program_handle_t MProgramHandle; /* UR program handle corresponding to
this kernel. */
const adapter_impl &MAdapterPtr; /* We can keep reference to the adapter
because during 2-stage shutdown the kernel
cache is destroyed deliberately before the
adapter. */
Managed<ur_program_handle_t> MProgramHandle; /* UR program handle
corresponding to this kernel. */
adapter_impl &MAdapter; /* We can keep reference to the adapter
because during 2-stage shutdown the kernel
cache is destroyed deliberately before the
adapter. */

FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
const KernelArgMask *KernelArgMask,
ur_program_handle_t ProgramHandle,
const adapter_impl &AdapterPtr)
ur_program_handle_t ProgramHandle, adapter_impl &Adapter)
: MKernelHandle(KernelHandle), MMutex(Mutex),
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle),
MAdapterPtr(AdapterPtr) {}
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle, Adapter),
MAdapter(Adapter) {}

~FastKernelCacheVal() {
if (MKernelHandle)
MAdapterPtr.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
if (MProgramHandle)
MAdapterPtr.call<sycl::detail::UrApiKind::urProgramRelease>(
MProgramHandle);
MAdapter.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
MKernelHandle = nullptr;
MMutex = nullptr;
MKernelArgMask = nullptr;
MProgramHandle = nullptr;
}

FastKernelCacheVal(const FastKernelCacheVal &) = delete;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_program_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace sycl {
inline namespace _V1 {
namespace detail {
const adapter_impl &KernelProgramCache::getAdapter() {
adapter_impl &KernelProgramCache::getAdapter() {
return MParentContext->getAdapter();
}

Expand Down
44 changes: 23 additions & 21 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,28 +111,28 @@ class KernelProgramCache {
}
};

struct ProgramBuildResult : public BuildResult<ur_program_handle_t> {
const adapter_impl &MAdapter;
ProgramBuildResult(const adapter_impl &Adapter) : MAdapter(Adapter) {
Val = nullptr;
struct ProgramBuildResult : public BuildResult<Managed<ur_program_handle_t>> {
ProgramBuildResult(adapter_impl &Adapter) {
Val = Managed<ur_program_handle_t>{Adapter};
}
ProgramBuildResult(const adapter_impl &Adapter, BuildState InitialState)
: MAdapter(Adapter) {
Val = nullptr;
ProgramBuildResult(adapter_impl &Adapter, BuildState InitialState) {
Val = Managed<ur_program_handle_t>{Adapter};
this->State.store(InitialState);
}
~ProgramBuildResult() {
try {
if (Val) {
ur_result_t Err =
MAdapter.call_nocheck<UrApiKind::urProgramRelease>(Val);
__SYCL_CHECK_UR_CODE_NO_EXC(Err, MAdapter.getBackend());
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult",
e);
}
#ifdef _MSC_VER
#pragma warning(push)
// https://developercommunity.visualstudio.com/t/False-C4297-warning-while-using-function/1130300
// https://godbolt.org/z/xsMvKf84f
#pragma warning(disable : 4297)
#endif
~ProgramBuildResult() try {
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult", e);
return; // Don't re-throw.
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
};
using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;

Expand Down Expand Up @@ -434,7 +434,7 @@ class KernelProgramCache {
if (DidInsert) {
It->second = std::make_shared<ProgramBuildResult>(getAdapter(),
BuildState::BS_Done);
It->second->Val = Program;
It->second->Val = Managed<ur_program_handle_t>{Program, getAdapter()};
// Save reference between the common key and the full key.
CommonProgramKeyT CommonKey =
std::make_pair(CacheKey.first.second, CacheKey.second);
Expand Down Expand Up @@ -794,7 +794,9 @@ class KernelProgramCache {

// only the building thread will run this
try {
BuildResult->Val = Build();
// Remove `adapter_impl` from `ProgramBuildResult`'s ctors once `Build`
// returns `Managed<ur_platform_handle_t`:
*(&BuildResult->Val) = Build();

if constexpr (!std::is_same_v<EvictFT, void *>)
EvictFunc(BuildResult->Val, /*IsBuilt=*/true);
Expand Down Expand Up @@ -868,7 +870,7 @@ class KernelProgramCache {

friend class ::MockKernelProgramCache;

const adapter_impl &getAdapter();
adapter_impl &getAdapter();
ur_context_handle_t getURContext() const;
};
} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ getOrBuildProgramForDeviceGlobal(queue_impl &Queue,
PM.getDeviceImageFromBinaryImage(&Img, Context, Device);
device_image_plain BuiltImage =
PM.build(std::move(DeviceImage), {std::move(Device)}, {});
return getSyclObjImpl(BuiltImage)->get_ur_program_ref();
return getSyclObjImpl(BuiltImage)->get_ur_program();
}

static void
Expand Down
Loading