diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index af07a6786cfea..88f163961ec1e 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -84,17 +84,20 @@ struct ol_program_impl_t { DeviceImage(DeviceImage) {} plugin::DeviceImageTy *Image; std::unique_ptr ImageData; - std::vector> Symbols; + std::mutex SymbolListMutex; __tgt_device_image DeviceImage; + llvm::StringMap> KernelSymbols; + llvm::StringMap> GlobalSymbols; }; struct ol_symbol_impl_t { - ol_symbol_impl_t(GenericKernelTy *Kernel) - : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {} - ol_symbol_impl_t(GlobalTy &&Global) - : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {} + ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel) + : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {} + ol_symbol_impl_t(const char *Name, GlobalTy &&Global) + : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {} std::variant PluginImpl; ol_symbol_kind_t Kind; + llvm::StringRef Name; }; namespace llvm { @@ -714,32 +717,40 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { auto &Device = Program->Image->getDevice(); + std::lock_guard Lock{Program->SymbolListMutex}; + switch (Kind) { case OL_SYMBOL_KIND_KERNEL: { - auto KernelImpl = Device.constructKernel(Name); - if (!KernelImpl) - return KernelImpl.takeError(); + auto &Kernel = Program->KernelSymbols[Name]; + if (!Kernel) { + auto KernelImpl = Device.constructKernel(Name); + if (!KernelImpl) + return KernelImpl.takeError(); - if (auto Err = KernelImpl->init(Device, *Program->Image)) - return Err; + if (auto Err = KernelImpl->init(Device, *Program->Image)) + return Err; + + Kernel = std::make_unique(KernelImpl->getName(), + &*KernelImpl); + } - *Symbol = - Program->Symbols - .emplace_back(std::make_unique(&*KernelImpl)) - .get(); + *Symbol = Kernel.get(); return Error::success(); } case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { - GlobalTy GlobalObj{Name}; - if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( - Device, *Program->Image, GlobalObj)) - return Res; - - *Symbol = Program->Symbols - .emplace_back( - std::make_unique(std::move(GlobalObj))) - .get(); + auto &Global = Program->KernelSymbols[Name]; + if (!Global) { + GlobalTy GlobalObj{Name}; + if (auto Res = + Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( + Device, *Program->Image, GlobalObj)) + return Res; + + Global = std::make_unique(GlobalObj.getName().c_str(), + std::move(GlobalObj)); + } + *Symbol = Global.get(); return Error::success(); } default: diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp index 5e87ab5b29621..1f496b9c6e1ae 100644 --- a/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp +++ b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp @@ -41,6 +41,14 @@ TEST_P(olGetSymbolKernelTest, Success) { ASSERT_NE(Kernel, nullptr); } +TEST_P(olGetSymbolKernelTest, SuccessSamePtr) { + ol_symbol_handle_t KernelA = nullptr; + ol_symbol_handle_t KernelB = nullptr; + ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelA)); + ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelB)); + ASSERT_EQ(KernelA, KernelB); +} + TEST_P(olGetSymbolKernelTest, InvalidNullProgram) { ol_symbol_handle_t Kernel = nullptr; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, @@ -72,6 +80,16 @@ TEST_P(olGetSymbolGlobalTest, Success) { ASSERT_NE(Global, nullptr); } +TEST_P(olGetSymbolGlobalTest, SuccessSamePtr) { + ol_symbol_handle_t GlobalA = nullptr; + ol_symbol_handle_t GlobalB = nullptr; + ASSERT_SUCCESS( + olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalA)); + ASSERT_SUCCESS( + olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalB)); + ASSERT_EQ(GlobalA, GlobalB); +} + TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) { ol_symbol_handle_t Global = nullptr; ASSERT_ERROR(