Skip to content

Commit 88c9a18

Browse files
committed
[Offload] Cache symbols in program
When creating a new symbol, check that it already exists. If it does, return that pointer rather than building a new symbol structure.
1 parent 584ef94 commit 88c9a18

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,18 @@ struct ol_program_impl_t {
8585
plugin::DeviceImageTy *Image;
8686
std::unique_ptr<llvm::MemoryBuffer> ImageData;
8787
std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
88+
std::mutex SymbolListMutex;
8889
__tgt_device_image DeviceImage;
8990
};
9091

9192
struct ol_symbol_impl_t {
92-
ol_symbol_impl_t(GenericKernelTy *Kernel)
93-
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
94-
ol_symbol_impl_t(GlobalTy &&Global)
95-
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
93+
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
94+
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
95+
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
96+
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
9697
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
9798
ol_symbol_kind_t Kind;
99+
const char *Name;
98100
};
99101

100102
namespace llvm {
@@ -714,6 +716,18 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
714716
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
715717
auto &Device = Program->Image->getDevice();
716718

719+
std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
720+
721+
// If it already exists, return an existing handle
722+
auto Check = std::find_if(
723+
Program->Symbols.begin(), Program->Symbols.end(), [&](auto &Sym) {
724+
return Sym->Kind == Kind && !std::strcmp(Sym->Name, Name);
725+
});
726+
if (Check != Program->Symbols.end()) {
727+
*Symbol = Check->get();
728+
return Error::success();
729+
}
730+
717731
switch (Kind) {
718732
case OL_SYMBOL_KIND_KERNEL: {
719733
auto KernelImpl = Device.constructKernel(Name);
@@ -723,10 +737,10 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
723737
if (auto Err = KernelImpl->init(Device, *Program->Image))
724738
return Err;
725739

726-
*Symbol =
727-
Program->Symbols
728-
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
729-
.get();
740+
*Symbol = Program->Symbols
741+
.emplace_back(std::make_unique<ol_symbol_impl_t>(
742+
KernelImpl->getName(), &*KernelImpl))
743+
.get();
730744
return Error::success();
731745
}
732746
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
@@ -736,8 +750,8 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
736750
return Res;
737751

738752
*Symbol = Program->Symbols
739-
.emplace_back(
740-
std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
753+
.emplace_back(std::make_unique<ol_symbol_impl_t>(
754+
GlobalObj.getName().c_str(), std::move(GlobalObj)))
741755
.get();
742756

743757
return Error::success();

offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ TEST_P(olGetSymbolKernelTest, Success) {
4141
ASSERT_NE(Kernel, nullptr);
4242
}
4343

44+
TEST_P(olGetSymbolKernelTest, SuccessSamePtr) {
45+
ol_symbol_handle_t KernelA = nullptr;
46+
ol_symbol_handle_t KernelB = nullptr;
47+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelA));
48+
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelB));
49+
ASSERT_EQ(KernelA, KernelB);
50+
}
51+
4452
TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
4553
ol_symbol_handle_t Kernel = nullptr;
4654
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
@@ -72,6 +80,16 @@ TEST_P(olGetSymbolGlobalTest, Success) {
7280
ASSERT_NE(Global, nullptr);
7381
}
7482

83+
TEST_P(olGetSymbolGlobalTest, SuccessSamePtr) {
84+
ol_symbol_handle_t GlobalA = nullptr;
85+
ol_symbol_handle_t GlobalB = nullptr;
86+
ASSERT_SUCCESS(
87+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalA));
88+
ASSERT_SUCCESS(
89+
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalB));
90+
ASSERT_EQ(GlobalA, GlobalB);
91+
}
92+
7593
TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
7694
ol_symbol_handle_t Global = nullptr;
7795
ASSERT_ERROR(

0 commit comments

Comments
 (0)