diff --git a/include/phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h b/include/phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h index fbb2693c95..b14ca90f75 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h +++ b/include/phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h @@ -26,13 +26,11 @@ #include "phasar/PhasarLLVM/ControlFlow/LLVMVFTableProvider.h" #include "phasar/PhasarLLVM/Pointer/LLVMAliasInfo.h" #include "phasar/PhasarLLVM/Utils/LLVMBasedContainerConfig.h" -#include "phasar/Utils/MaybeUniquePtr.h" #include "phasar/Utils/Soundness.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Value.h" #include "llvm/Support/raw_ostream.h" diff --git a/include/phasar/PhasarLLVM/ControlFlow/LLVMVFTableProvider.h b/include/phasar/PhasarLLVM/ControlFlow/LLVMVFTableProvider.h index 646ff71322..accb0dfcb9 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/LLVMVFTableProvider.h +++ b/include/phasar/PhasarLLVM/ControlFlow/LLVMVFTableProvider.h @@ -11,7 +11,13 @@ #define PHASAR_PHASARLLVM_CONTROLFLOW_LLVMVFTABLEPROVIDER_H #include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h" +#include "phasar/Utils/HashUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/IR/DebugInfoMetadata.h" + +#include #include namespace llvm { @@ -35,11 +41,35 @@ class LLVMVFTableProvider { explicit LLVMVFTableProvider(const LLVMProjectIRDB &IRDB); [[nodiscard]] bool hasVFTable(const llvm::DIType *Type) const; - [[nodiscard]] const LLVMVFTable * - getVFTableOrNull(const llvm::DIType *Type) const; + [[nodiscard]] const LLVMVFTable *getVFTableOrNull(const llvm::DIType *Type, + uint32_t Index = 0) const; + + [[nodiscard]] const llvm::GlobalVariable * + getVFTableGlobal(const llvm::DIType *Type) const; + + [[nodiscard]] const llvm::GlobalVariable * + getVFTableGlobal(llvm::StringRef ClearTypeName) const; + + [[nodiscard]] const llvm::SmallDenseSet & + getVTableIndexInHierarchy(const llvm::DIType *DerivedType, + const llvm::DIType *BaseType) const; + + /// Supercedes DIBasedTypeHierarchy::removeVTablePrefix + [[nodiscard]] static llvm::StringRef + removeVTablePrefix(llvm::StringRef GlobName) noexcept; + + /// Supercedes DIBasedTypeHierarchy::isVTable + [[nodiscard]] static bool isVTable(llvm::StringRef MangledVarName); private: - std::unordered_map TypeVFTMap; + llvm::StringMap ClearNameTVMap; + std::unordered_map, LLVMVFTable, + PairHash> + TypeVFTMap; + std::unordered_map< + const llvm::DIType *, + llvm::SmallDenseMap>> + BasesOfVirt; }; } // namespace psr diff --git a/include/phasar/PhasarLLVM/ControlFlow/Resolver/CHAResolver.h b/include/phasar/PhasarLLVM/ControlFlow/Resolver/CHAResolver.h index dc6f7c8ff1..8186bdba16 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/Resolver/CHAResolver.h +++ b/include/phasar/PhasarLLVM/ControlFlow/Resolver/CHAResolver.h @@ -38,7 +38,8 @@ class CHAResolver : public Resolver { // dtor in CHAResolver.cpp ~CHAResolver() override; - FunctionSetTy resolveVirtualCall(const llvm::CallBase *CallSite) override; + void resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; [[nodiscard]] std::string str() const override; diff --git a/include/phasar/PhasarLLVM/ControlFlow/Resolver/NOResolver.h b/include/phasar/PhasarLLVM/ControlFlow/Resolver/NOResolver.h index 88afa796e5..044815e287 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/Resolver/NOResolver.h +++ b/include/phasar/PhasarLLVM/ControlFlow/Resolver/NOResolver.h @@ -25,9 +25,11 @@ class NOResolver final : public Resolver { ~NOResolver() override = default; - FunctionSetTy resolveVirtualCall(const llvm::CallBase *CallSite) override; + void resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; - FunctionSetTy resolveFunctionPointer(const llvm::CallBase *CallSite) override; + void resolveFunctionPointer(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; [[nodiscard]] std::string str() const override; diff --git a/include/phasar/PhasarLLVM/ControlFlow/Resolver/OTFResolver.h b/include/phasar/PhasarLLVM/ControlFlow/Resolver/OTFResolver.h index eca760ae77..41b3f6e878 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/Resolver/OTFResolver.h +++ b/include/phasar/PhasarLLVM/ControlFlow/Resolver/OTFResolver.h @@ -48,9 +48,11 @@ class OTFResolver : public Resolver { void handlePossibleTargets(const llvm::CallBase *CallSite, FunctionSetTy &CalleeTargets) override; - FunctionSetTy resolveVirtualCall(const llvm::CallBase *CallSite) override; + void resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; - FunctionSetTy resolveFunctionPointer(const llvm::CallBase *CallSite) override; + void resolveFunctionPointer(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; static std::set getReachableTypes(const LLVMAliasInfo::AliasSetTy &Values); diff --git a/include/phasar/PhasarLLVM/ControlFlow/Resolver/RTAResolver.h b/include/phasar/PhasarLLVM/ControlFlow/Resolver/RTAResolver.h index c6e003211f..6e89063a9d 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/Resolver/RTAResolver.h +++ b/include/phasar/PhasarLLVM/ControlFlow/Resolver/RTAResolver.h @@ -38,7 +38,8 @@ class RTAResolver : public CHAResolver { ~RTAResolver() override = default; - FunctionSetTy resolveVirtualCall(const llvm::CallBase *CallSite) override; + void resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) override; [[nodiscard]] std::string str() const override; diff --git a/include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h b/include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h index c59717c25f..09b8147424 100644 --- a/include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h +++ b/include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h @@ -52,10 +52,9 @@ getReceiverType(const llvm::CallBase *CallSite); /// Assuming that `CallSite` is a virtual call, where `Idx` is retrieved through /// `getVFTIndex()` and `T` through `getReceiverType()` -[[nodiscard]] const llvm::Function * -getNonPureVirtualVFTEntry(const llvm::DIType *T, unsigned Idx, - const llvm::CallBase *CallSite, - const psr::LLVMVFTableProvider &VTP); +[[nodiscard]] const llvm::Function *getNonPureVirtualVFTEntry( + const llvm::DIType *T, unsigned Idx, const llvm::CallBase *CallSite, + const psr::LLVMVFTableProvider &VTP, const llvm::DIType *ReceiverType); [[nodiscard]] std::string getReceiverTypeName(const llvm::CallBase *CallSite); @@ -79,11 +78,12 @@ class Resolver { const llvm::Function * getNonPureVirtualVFTEntry(const llvm::DIType *T, unsigned Idx, - const llvm::CallBase *CallSite) { + const llvm::CallBase *CallSite, + const llvm::DIType *ReceiverType) { if (!VTP) { return nullptr; } - return psr::getNonPureVirtualVFTEntry(T, Idx, CallSite, *VTP); + return psr::getNonPureVirtualVFTEntry(T, Idx, CallSite, *VTP, ReceiverType); } public: @@ -103,16 +103,14 @@ class Resolver { [[nodiscard]] FunctionSetTy resolveIndirectCall(const llvm::CallBase *CallSite); - [[nodiscard]] virtual FunctionSetTy - resolveVirtualCall(const llvm::CallBase *CallSite) = 0; - - [[nodiscard]] virtual FunctionSetTy - resolveFunctionPointer(const llvm::CallBase *CallSite); - virtual void otherInst(const llvm::Instruction *Inst); [[nodiscard]] virtual std::string str() const = 0; + /// Whether the ICFG needs to reconsider all dynamic call-sites once there + /// have been changes through handlePossibleTargets(). + /// + /// Make false for performance (may be less sound then) [[nodiscard]] virtual bool mutatesHelperAnalysisInformation() const noexcept { // Conservatively returns true. Override if possible return true; @@ -122,6 +120,13 @@ class Resolver { const LLVMVFTableProvider *VTP, const DIBasedTypeHierarchy *TH, LLVMAliasInfoRef PT = nullptr); + +protected: + virtual void resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) = 0; + + virtual void resolveFunctionPointer(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite); }; } // namespace psr diff --git a/include/phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h b/include/phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h index d4554958a8..fccd023f85 100644 --- a/include/phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h +++ b/include/phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h @@ -58,7 +58,9 @@ class DIBasedTypeHierarchy const DIBasedTypeHierarchyData &SerializedData); ~DIBasedTypeHierarchy() override = default; + [[deprecated("Use LLVMVFTableProvider::isVTable() instead")]] static bool isVTable(llvm::StringRef VarName); + [[deprecated("Use LLVMVFTableProvider::removeVTablePrefix() instead")]] static std::string removeVTablePrefix(llvm::StringRef VarName); [[nodiscard]] bool hasType(ClassType Type) const override { diff --git a/include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h b/include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h index a68c1b341a..e4955ab05b 100644 --- a/include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h +++ b/include/phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h @@ -73,26 +73,26 @@ class LLVMVFTable : public VFTable { void printAsJson(llvm::raw_ostream &OS) const override; - [[nodiscard]] std::vector::iterator begin() { + [[nodiscard]] std::vector::iterator begin() noexcept { return VFT.begin(); } [[nodiscard]] std::vector::const_iterator - begin() const { + begin() const noexcept { return VFT.begin(); }; - [[nodiscard]] std::vector::iterator end() { + [[nodiscard]] std::vector::iterator end() noexcept { return VFT.end(); }; [[nodiscard]] std::vector::const_iterator - end() const { + end() const noexcept { return VFT.end(); }; [[nodiscard]] static std::vector - getVFVectorFromIRVTable(const llvm::ConstantStruct &); + getVFVectorFromIRVTable(const llvm::ConstantStruct &VT, uint32_t Index = 0); }; } // namespace psr diff --git a/include/phasar/PhasarLLVM/Utils/LLVMIRToSrc.h b/include/phasar/PhasarLLVM/Utils/LLVMIRToSrc.h index fd57c81e0d..915787ff87 100644 --- a/include/phasar/PhasarLLVM/Utils/LLVMIRToSrc.h +++ b/include/phasar/PhasarLLVM/Utils/LLVMIRToSrc.h @@ -33,7 +33,6 @@ class Value; class GlobalVariable; class Module; class DIFile; -class DIType; class DILocation; } // namespace llvm diff --git a/include/phasar/Utils/DefaultValue.h b/include/phasar/Utils/DefaultValue.h index 3dde71aeb3..e63b10018a 100644 --- a/include/phasar/Utils/DefaultValue.h +++ b/include/phasar/Utils/DefaultValue.h @@ -40,6 +40,26 @@ getDefaultValue() noexcept(std::is_nothrow_default_constructible_v) { } } +namespace detail { +struct DefaultCast { + template >> + operator To() && { + return psr::getDefaultValue(); + } + + template >> + operator const To &() && { + return psr::getDefaultValue(); + } +}; +} // namespace detail + +/// Provides a value that automatically converts to (a const-ref to) the +/// default-constructed object of the expected receiver type. +static constexpr detail::DefaultCast default_value() noexcept { return {}; } + } // namespace psr #endif // PHASAR_UTILS_DEFAULTVALUE_H diff --git a/include/phasar/Utils/HashUtils.h b/include/phasar/Utils/HashUtils.h new file mode 100644 index 0000000000..3815d29ff7 --- /dev/null +++ b/include/phasar/Utils/HashUtils.h @@ -0,0 +1,27 @@ +/****************************************************************************** + * Copyright (c) 2025 Fabian Schiebel. + * All rights reserved. This program and the accompanying materials are made + * available under the terms of LICENSE.txt. + * + * Contributors: + * Fabian Schiebel and others + *****************************************************************************/ + +#ifndef PHASAR_UTILS_HASHUTILS_H +#define PHASAR_UTILS_HASHUTILS_H + +#include "llvm/ADT/DenseMapInfo.h" + +#include +#include + +namespace psr { +struct PairHash { + template + size_t operator()(const std::pair &Pair) const noexcept { + return llvm::DenseMapInfo>::getHashValue(Pair); + } +}; +} // namespace psr + +#endif // PHASAR_UTILS_HASHUTILS_H diff --git a/include/phasar/Utils/MapUtils.h b/include/phasar/Utils/MapUtils.h new file mode 100644 index 0000000000..c087aaa0ee --- /dev/null +++ b/include/phasar/Utils/MapUtils.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2025 Fabian Schiebel. + * All rights reserved. This program and the accompanying materials are made + * available under the terms of LICENSE.txt. + * + * Contributors: + * Fabian Schiebel and others + *****************************************************************************/ + +#ifndef PHASAR_UTILS_MAPUTILS_H +#define PHASAR_UTILS_MAPUTILS_H + +#include "phasar/Utils/ByRef.h" +#include "phasar/Utils/DefaultValue.h" +#include "phasar/Utils/Macros.h" + +#include "llvm/ADT/STLForwardCompat.h" + +#include +#include + +namespace psr { + +template >> +static auto getOrDefault(MapT &&Map, KeyT &&Key) -> ByConstRef< + llvm::remove_cvref_tsecond)>> { + auto It = Map.find(PSR_FWD(Key)); + if (It == Map.end()) { + return default_value(); + } + + return It->second; +} + +template < + typename MapT, typename KeyT, + typename = std::enable_if_t>, + std::enable_if_t< + !psr::CanEfficientlyPassByValue>, int> = 0> +static auto getOrNull(MapT &&Map, KeyT &&Key) + -> decltype(&Map.find(PSR_FWD(Key))->second) { + auto It = Map.find(PSR_FWD(Key)); + decltype(&It->second) Ret = nullptr; + if (It != Map.end()) { + Ret = &It->second; + } + + return Ret; +} + +template < + typename MapT, typename KeyT, + typename = std::enable_if_t>, + std::enable_if_t>, + int> = 0> +static auto getOrNull(MapT &&Map, KeyT Key) + -> decltype(&Map.find(Key)->second) { + auto It = Map.find(Key); + decltype(&It->second) Ret = nullptr; + if (It != Map.end()) { + Ret = &It->second; + } + + return Ret; +} +} // namespace psr + +#endif // PHASAR_UTILS_MAPUTILS_H diff --git a/lib/PhasarLLVM/ControlFlow/LLVMBasedCallGraphBuilder.cpp b/lib/PhasarLLVM/ControlFlow/LLVMBasedCallGraphBuilder.cpp index 10ea6c257c..e3f65c00ec 100644 --- a/lib/PhasarLLVM/ControlFlow/LLVMBasedCallGraphBuilder.cpp +++ b/lib/PhasarLLVM/ControlFlow/LLVMBasedCallGraphBuilder.cpp @@ -6,7 +6,7 @@ #include "phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h" #include "phasar/PhasarLLVM/DB/LLVMProjectIRDB.h" #include "phasar/PhasarLLVM/Pointer/LLVMAliasSet.h" -#include "phasar/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.h" +#include "phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h" #include "phasar/PhasarLLVM/Utils/LLVMShorthands.h" #include "phasar/Utils/PAMMMacros.h" #include "phasar/Utils/Soundness.h" @@ -111,8 +111,8 @@ static bool fillPossibleTargets( PossibleTargets.insert(StaticCallee); PHASAR_LOG_LEVEL_CAT(DEBUG, "LLVMBasedICFG", - "Found static call-site: " - << " " << llvmIRToString(CS)); + "Found static call-site: " << " " + << llvmIRToString(CS)); return true; } @@ -122,8 +122,8 @@ static bool fillPossibleTargets( // the function call must be resolved dynamically PHASAR_LOG_LEVEL_CAT(DEBUG, "LLVMBasedICFG", - "Found dynamic call-site: " - << " " << llvmIRToString(CS)); + "Found dynamic call-site: " << " " + << llvmIRToString(CS)); PossibleTargets = Res.resolveIndirectCall(CS); diff --git a/lib/PhasarLLVM/ControlFlow/LLVMBasedICFG.cpp b/lib/PhasarLLVM/ControlFlow/LLVMBasedICFG.cpp index addd88c957..7d6bd38fd3 100644 --- a/lib/PhasarLLVM/ControlFlow/LLVMBasedICFG.cpp +++ b/lib/PhasarLLVM/ControlFlow/LLVMBasedICFG.cpp @@ -27,6 +27,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include diff --git a/lib/PhasarLLVM/ControlFlow/LLVMVFTableProvider.cpp b/lib/PhasarLLVM/ControlFlow/LLVMVFTableProvider.cpp index 45820e2793..67246c938e 100644 --- a/lib/PhasarLLVM/ControlFlow/LLVMVFTableProvider.cpp +++ b/lib/PhasarLLVM/ControlFlow/LLVMVFTableProvider.cpp @@ -4,7 +4,7 @@ #include "phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h" #include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h" #include "phasar/PhasarLLVM/Utils/LLVMIRToSrc.h" -#include "phasar/Utils/Logger.h" +#include "phasar/Utils/MapUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/BinaryFormat/Dwarf.h" @@ -17,48 +17,82 @@ using namespace psr; +static constexpr llvm::StringLiteral TSPrefixDemang = "typeinfo name for "; +static constexpr llvm::StringLiteral VTablePrefixDemang = "vtable for "; +static constexpr llvm::StringLiteral VTablePrefix = "_ZTV"; + static std::string getTypeName(const llvm::DIType *DITy) { - if (const auto *CompTy = llvm::dyn_cast(DITy)) { - auto Ident = CompTy->getIdentifier(); - return Ident.empty() ? llvm::demangle(CompTy->getName().str()) - : llvm::demangle(Ident.str()); + auto TypeName = [DITy] { + if (const auto *CompTy = llvm::dyn_cast(DITy)) { + if (auto Ident = CompTy->getIdentifier(); !Ident.empty()) { + return Ident; + } + } + return DITy->getName(); + }(); + + // In LLVM 17 demangle() takes a StringRef + auto Ret = llvm::demangle(TypeName.str()); + + if (llvm::StringRef(Ret).startswith(TSPrefixDemang)) { + Ret.erase(0, TSPrefixDemang.size()); } - return llvm::demangle(DITy->getName().str()); -} -static std::vector getVirtualFunctions( - const llvm::StringMap &ClearNameTVMap, - const llvm::DIType *Type) { - auto ClearName = getTypeName(Type); + return Ret; +} - static constexpr llvm::StringLiteral TIPrefix = "typeinfo name for "; - if (llvm::StringRef(ClearName).startswith(TIPrefix)) { - ClearName = ClearName.substr(TIPrefix.size()); +static void insertVirtualFunctions( + std::unordered_map, LLVMVFTable, + PairHash> &Into, + const llvm::DIType *Type, const llvm::GlobalVariable *VTableGlobal) { + if (!VTableGlobal) { + return; } - auto It = ClearNameTVMap.find(ClearName); + if (const auto *VT = llvm::dyn_cast( + VTableGlobal->getInitializer())) { + auto NumElems = VT->getNumOperands(); - if (It != ClearNameTVMap.end()) { - if (!It->second->hasInitializer()) { - PHASAR_LOG_LEVEL_CAT(DEBUG, "DIBasedTypeHierarchy", - ClearName << " does not have initializer"); - return {}; + // llvm::errs() << "[insertVirtualFunctions]: VT: " << *VT << '\n'; + // llvm::errs() << "[insertVirtualFunctions]: > NumElems: " << NumElems + // << '\n'; + for (uint32_t I = 0; I != NumElems; ++I) { + Into[{Type, I}] = LLVMVFTable::getVFVectorFromIRVTable(*VT, I); } - if (const auto *I = llvm::dyn_cast( - It->second->getInitializer())) { - return LLVMVFTable::getVFVectorFromIRVTable(*I); + } +} + +static void getBasesOfVirt( + llvm::SmallDenseMap> + &Into, + const llvm::DICompositeType *VirtTy, uint32_t CurrIdx = 0) { + Into[VirtTy].insert(CurrIdx); + for (const auto *Elem : VirtTy->getElements()) { + const auto *Inher = llvm::dyn_cast(Elem); + if (!Inher) { + // Inheritance is always at the front of the member-list + break; + } + if (Inher->getTag() != llvm::dwarf::DW_TAG_inheritance) { + continue; + } + + const auto *BaseClass = + llvm::dyn_cast(Inher->getBaseType()); + if (!BaseClass || !BaseClass->getVTableHolder()) { + continue; } + getBasesOfVirt(Into, BaseClass, CurrIdx); + CurrIdx++; } - return {}; } LLVMVFTableProvider::LLVMVFTableProvider(const llvm::Module &Mod) { - llvm::StringMap ClearNameTVMap; - for (const auto &Glob : Mod.globals()) { - if (DIBasedTypeHierarchy::isVTable(Glob.getName())) { + if (isVTable(Glob.getName())) { auto Demang = llvm::demangle(Glob.getName().str()); - auto ClearName = DIBasedTypeHierarchy::removeVTablePrefix(Demang); + auto ClearName = removeVTablePrefix(Demang); + // llvm::errs() << "> ClearName: " << ClearName << '\n'; ClearNameTVMap.try_emplace(ClearName, &Glob); } } @@ -69,8 +103,14 @@ LLVMVFTableProvider::LLVMVFTableProvider(const llvm::Module &Mod) { if (const auto *CompTy = llvm::dyn_cast(Ty)) { if (CompTy->getTag() == llvm::dwarf::DW_TAG_class_type || CompTy->getTag() == llvm::dwarf::DW_TAG_structure_type) { - TypeVFTMap.try_emplace(CompTy, - getVirtualFunctions(ClearNameTVMap, CompTy)); + insertVirtualFunctions( + TypeVFTMap, CompTy, + getOrDefault(ClearNameTVMap, getTypeName(CompTy))); + + if (CompTy->getVTableHolder()) { + auto &BaseTys = BasesOfVirt[CompTy]; + getBasesOfVirt(BaseTys, CompTy); + } } } } @@ -80,11 +120,65 @@ LLVMVFTableProvider::LLVMVFTableProvider(const LLVMProjectIRDB &IRDB) : LLVMVFTableProvider(*IRDB.getModule()) {} bool LLVMVFTableProvider::hasVFTable(const llvm::DIType *Type) const { - return TypeVFTMap.count(Type); + return TypeVFTMap.count({Type, 0}); } const LLVMVFTable * -LLVMVFTableProvider::getVFTableOrNull(const llvm::DIType *Type) const { - auto It = TypeVFTMap.find(Type); +LLVMVFTableProvider::getVFTableOrNull(const llvm::DIType *Type, + uint32_t Index) const { + auto It = TypeVFTMap.find({Type, Index}); return It != TypeVFTMap.end() ? &It->second : nullptr; } + +const llvm::GlobalVariable * +LLVMVFTableProvider::getVFTableGlobal(const llvm::DIType *Type) const { + auto Name = getTypeName(Type); + return getVFTableGlobal(Name); +} + +const llvm::GlobalVariable * +LLVMVFTableProvider::getVFTableGlobal(llvm::StringRef ClearTypeName) const { + // llvm::errs() << "[getVFTableGlobal]: " << ClearTypeName << '\n'; + if (auto It = ClearNameTVMap.find(ClearTypeName); + It != ClearNameTVMap.end()) { + return It->second; + } + return nullptr; +} + +static const auto &getDefaultIndices() { + static const llvm::SmallDenseSet DefaultIndices = {0}; + return DefaultIndices; +} + +const llvm::SmallDenseSet & +LLVMVFTableProvider::getVTableIndexInHierarchy( + const llvm::DIType *DerivedType, const llvm::DIType *BaseType) const { + auto OuterIt = BasesOfVirt.find(DerivedType); + if (OuterIt == BasesOfVirt.end()) { + return getDefaultIndices(); + } + + auto InnerIt = OuterIt->second.find(BaseType); + if (InnerIt == OuterIt->second.end()) { + return getDefaultIndices(); + } + + return InnerIt->second; +} + +llvm::StringRef +LLVMVFTableProvider::removeVTablePrefix(llvm::StringRef GlobName) noexcept { + if (GlobName.startswith(VTablePrefixDemang)) { + return GlobName.drop_front(VTablePrefixDemang.size()); + } + if (GlobName.startswith(VTablePrefix)) { + return GlobName.drop_front(VTablePrefix.size()); + } + return GlobName; +} + +/// Supercedes DIBasedTypeHierarchy::isVTable() + removeVTablePrefix +bool LLVMVFTableProvider::isVTable(llvm::StringRef MangledVarName) { + return MangledVarName.startswith(VTablePrefix); +} diff --git a/lib/PhasarLLVM/ControlFlow/Resolver/CHAResolver.cpp b/lib/PhasarLLVM/ControlFlow/Resolver/CHAResolver.cpp index ba464cd0a6..25f5c0fd02 100644 --- a/lib/PhasarLLVM/ControlFlow/Resolver/CHAResolver.cpp +++ b/lib/PhasarLLVM/ControlFlow/Resolver/CHAResolver.cpp @@ -43,8 +43,8 @@ CHAResolver::CHAResolver(const LLVMProjectIRDB *IRDB, CHAResolver::~CHAResolver() = default; -auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) - -> FunctionSetTy { +void CHAResolver::resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) { PHASAR_LOG_LEVEL(DEBUG, "Call virtual function: "); // Leading to SEGFAULT in Unittests. Error only when run in Debug mode // << llvmIRToString(CallSite)); @@ -59,7 +59,7 @@ auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) // run in Debug mode // << llvmIRToString(CallSite) << "\n"); - return {}; + return; } auto VtableIndex = RetrievedVtableIndex.value(); @@ -71,16 +71,13 @@ auto CHAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) // also insert all possible subtypes vtable entries auto FallbackTys = TH->getSubTypes(ReceiverTy); - FunctionSetTy PossibleCallees; - for (const auto &FallbackTy : FallbackTys) { - const auto *Target = - getNonPureVirtualVFTEntry(FallbackTy, VtableIndex, CallSite); + const auto *Target = getNonPureVirtualVFTEntry(FallbackTy, VtableIndex, + CallSite, ReceiverTy); if (Target) { - PossibleCallees.insert(Target); + PossibleTargets.insert(Target); } } - return PossibleCallees; } std::string CHAResolver::str() const { return "CHA"; } diff --git a/lib/PhasarLLVM/ControlFlow/Resolver/NOResolver.cpp b/lib/PhasarLLVM/ControlFlow/Resolver/NOResolver.cpp index f825f52549..9dc6a56c28 100644 --- a/lib/PhasarLLVM/ControlFlow/Resolver/NOResolver.cpp +++ b/lib/PhasarLLVM/ControlFlow/Resolver/NOResolver.cpp @@ -26,15 +26,11 @@ NOResolver::NOResolver(const LLVMProjectIRDB *IRDB, const LLVMVFTableProvider *VTP) : Resolver(IRDB, VTP) {} -auto NOResolver::resolveVirtualCall(const llvm::CallBase * /*CallSite*/) - -> FunctionSetTy { - return {}; -} - -auto NOResolver::resolveFunctionPointer(const llvm::CallBase * /*CallSite*/) - -> FunctionSetTy { - return {}; -} +void NOResolver::resolveVirtualCall(FunctionSetTy & /*PossibleTargets*/, + const llvm::CallBase * /*CallSite*/) {} + +void NOResolver::resolveFunctionPointer(FunctionSetTy & /*PossibleTargets*/, + const llvm::CallBase * /*CallSite*/) {} std::string NOResolver::str() const { return "NOResolver"; } diff --git a/lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp b/lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp index 6e70e7de01..9a206c9e21 100644 --- a/lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp +++ b/lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp @@ -65,9 +65,8 @@ void OTFResolver::handlePossibleTargets(const llvm::CallBase *CallSite, } } -auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite) - -> FunctionSetTy { - FunctionSetTy PossibleCallTargets; +void OTFResolver::resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) { PHASAR_LOG_LEVEL(DEBUG, "Call virtual function: " << llvmIRToString(CallSite)); @@ -79,7 +78,7 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite) "Error with resolveVirtualCall : impossible to retrieve " "the vtable index\n" << llvmIRToString(CallSite) << "\n"); - return {}; + return; } auto VtableIndex = RetrievedVtableIndex.value(); @@ -104,23 +103,19 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite) !isConsistentCall(CallSite, Callee)) { continue; } - PossibleCallTargets.insert(Callee); + PossibleTargets.insert(Callee); } } } } - - return PossibleCallTargets; } -auto OTFResolver::resolveFunctionPointer(const llvm::CallBase *CallSite) - -> FunctionSetTy { +void OTFResolver::resolveFunctionPointer(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) { if (!CallSite->getCalledOperand()) { - return {}; + return; } - FunctionSetTy Callees; - auto PTS = PT.getAliasSet(CallSite->getCalledOperand(), CallSite); llvm::SmallVector GlobalVariableWL; @@ -138,7 +133,7 @@ auto OTFResolver::resolveFunctionPointer(const llvm::CallBase *CallSite) if (const auto *F = llvm::dyn_cast(P)) { if (isConsistentCall(CallSite, F)) { - Callees.insert(F); + PossibleTargets.insert(F); } } @@ -181,14 +176,14 @@ auto OTFResolver::resolveFunctionPointer(const llvm::CallBase *CallSite) if (const auto *F = llvm::dyn_cast(CE->getOperand(0)); F && isConsistentCall(CallSite, F)) { - Callees.insert(F); + PossibleTargets.insert(F); } } } if (const auto *F = llvm::dyn_cast(Op)) { if (isConsistentCall(CallSite, F)) { - Callees.insert(F); + PossibleTargets.insert(F); } } else if (auto *CA = llvm::dyn_cast(Op)) { ConstantAggregateWL.push_back(CA); @@ -204,8 +199,6 @@ auto OTFResolver::resolveFunctionPointer(const llvm::CallBase *CallSite) } } } - - return Callees; } std::set diff --git a/lib/PhasarLLVM/ControlFlow/Resolver/RTAResolver.cpp b/lib/PhasarLLVM/ControlFlow/Resolver/RTAResolver.cpp index 05342b05ac..216118a1c8 100644 --- a/lib/PhasarLLVM/ControlFlow/Resolver/RTAResolver.cpp +++ b/lib/PhasarLLVM/ControlFlow/Resolver/RTAResolver.cpp @@ -18,22 +18,19 @@ #include "phasar/PhasarLLVM/DB/LLVMProjectIRDB.h" #include "phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h" +#include "phasar/PhasarLLVM/Utils/LLVMIRToSrc.h" #include "phasar/PhasarLLVM/Utils/LLVMShorthands.h" #include "phasar/Utils/Logger.h" -#include "phasar/Utils/Utilities.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" -using namespace std; using namespace psr; RTAResolver::RTAResolver(const LLVMProjectIRDB *IRDB, @@ -43,10 +40,8 @@ RTAResolver::RTAResolver(const LLVMProjectIRDB *IRDB, resolveAllocatedCompositeTypes(); } -auto RTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) - -> FunctionSetTy { - - FunctionSetTy PossibleCallTargets; +void RTAResolver::resolveVirtualCall(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) { PHASAR_LOG_LEVEL(DEBUG, "Call virtual function: " << llvmIRToString(CallSite)); @@ -58,7 +53,7 @@ auto RTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) "Error with resolveVirtualCall : impossible to retrieve " "the vtable index\n" << llvmIRToString(CallSite) << "\n"); - return {}; + return; } auto VtableIndex = RetrievedVtableIndex.value(); @@ -74,35 +69,63 @@ auto RTAResolver::resolveVirtualCall(const llvm::CallBase *CallSite) auto EndIt = ReachableTypes.end(); for (const auto *PossibleType : AllocatedCompositeTypes) { if (ReachableTypes.find(PossibleType) != EndIt) { - const auto *Target = - getNonPureVirtualVFTEntry(PossibleType, VtableIndex, CallSite); + + const auto *Target = getNonPureVirtualVFTEntry(PossibleType, VtableIndex, + CallSite, ReceiverType); if (Target) { - PossibleCallTargets.insert(Target); + PossibleTargets.insert(Target); } } } - if (PossibleCallTargets.empty()) { - return CHAResolver::resolveVirtualCall(CallSite); + if (PossibleTargets.empty()) { + CHAResolver::resolveVirtualCall(PossibleTargets, CallSite); } - - return PossibleCallTargets; } std::string RTAResolver::str() const { return "RTA"; } -/// More or less copied from GeneralStatisticsAnalysis +static const llvm::DICompositeType * +isCompositeStructType(const llvm::DIType *Ty) { + if (const auto *CompTy = llvm::dyn_cast_if_present(Ty); + CompTy && (CompTy->getTag() == llvm::dwarf::DW_TAG_structure_type || + CompTy->getTag() == llvm::dwarf::DW_TAG_class_type)) { + + return CompTy; + } + + return nullptr; +} + void RTAResolver::resolveAllocatedCompositeTypes() { if (!AllocatedCompositeTypes.empty()) { return; } - llvm::DebugInfoFinder DIF; - DIF.processModule(*IRDB->getModule()); + llvm::DenseSet AllocatedTypes; - for (const auto *Ty : DIF.types()) { - if (const auto *CompTy = llvm::dyn_cast(Ty)) { - AllocatedCompositeTypes.push_back(CompTy); + for (const auto *Inst : IRDB->getAllInstructions()) { + if (const auto *Alloca = llvm::dyn_cast(Inst)) { + if (const auto *Ty = isCompositeStructType(getVarTypeFromIR(Alloca))) { + AllocatedTypes.insert(Ty); + } + } else if (const auto *Call = llvm::dyn_cast(Inst)) { + if (const auto *Callee = llvm::dyn_cast( + Call->getCalledOperand()->stripPointerCastsAndAliases())) { + if (psr::isHeapAllocatingFunction(Callee)) { + const auto *MDNode = Call->getMetadata("heapallocsite"); + if (const auto *CompTy = + llvm::dyn_cast_if_present(MDNode); + isCompositeStructType(CompTy)) { + + AllocatedTypes.insert(CompTy); + } + } + } } } + + AllocatedCompositeTypes.reserve(AllocatedTypes.size()); + AllocatedCompositeTypes.insert(AllocatedCompositeTypes.end(), + AllocatedTypes.begin(), AllocatedTypes.end()); } diff --git a/lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp b/lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp index 9065a43415..fb9eaa55dd 100644 --- a/lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp +++ b/lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp @@ -61,6 +61,15 @@ std::optional psr::getVFTIndex(const llvm::CallBase *CallSite) { return std::nullopt; } +static const llvm::DIType *stripPointerTypes(const llvm::DIType *DITy) { + while (const auto *DerivedTy = + llvm::dyn_cast_if_present(DITy)) { + // get rid of the pointer + DITy = DerivedTy->getBaseType(); + } + return DITy; +} + const llvm::DIType *psr::getReceiverType(const llvm::CallBase *CallSite) { if (CallSite->arg_empty() || (CallSite->hasStructRetAttr() && CallSite->arg_size() < 2)) { @@ -75,23 +84,23 @@ const llvm::DIType *psr::getReceiverType(const llvm::CallBase *CallSite) { } if (const auto *DITy = getVarTypeFromIR(Receiver)) { - while (const auto *DerivedTy = - llvm::dyn_cast_if_present(DITy)) { - // get rid of the pointer - DITy = DerivedTy->getBaseType(); - } - return DITy; + return stripPointerTypes(DITy); + } + + if (const auto *Var = + getDILocalVariable(Receiver->stripPointerCastsAndAliases())) { + return stripPointerTypes(Var->getType()); } return nullptr; } -const llvm::Function * -psr::getNonPureVirtualVFTEntry(const llvm::DIType *T, unsigned Idx, - const llvm::CallBase *CallSite, - const LLVMVFTableProvider &VTP) { +const llvm::Function *psr::getNonPureVirtualVFTEntry( + const llvm::DIType *T, unsigned Idx, const llvm::CallBase *CallSite, + const LLVMVFTableProvider &VTP, const llvm::DIType *ReceiverType) { + auto VTIndex = *VTP.getVTableIndexInHierarchy(T, ReceiverType).begin(); - if (const auto *VT = VTP.getVFTableOrNull(T)) { + if (const auto *VT = VTP.getVFTableOrNull(T, VTIndex)) { const auto *Target = VT->getFunction(Idx); if (Target && Target->getName() != DIBasedTypeHierarchy::PureVirtualCallName && @@ -135,6 +144,8 @@ bool psr::isVirtualCall(const llvm::Instruction *Inst, // check potential receiver type const auto *RecType = getReceiverType(CallSite); if (!RecType) { + llvm::errs() << "No receiver type found for call at " + << llvmIRToString(Inst) << '\n'; return false; } @@ -149,7 +160,6 @@ namespace psr { Resolver::Resolver(const LLVMProjectIRDB *IRDB, const LLVMVFTableProvider *VTP) : IRDB(IRDB), VTP(VTP) { assert(IRDB != nullptr); - assert(VTP != nullptr); } void Resolver::preCall(const llvm::Instruction *Inst) {} @@ -161,28 +171,31 @@ void Resolver::postCall(const llvm::Instruction *Inst) {} auto Resolver::resolveIndirectCall(const llvm::CallBase *CallSite) -> FunctionSetTy { + FunctionSetTy PossibleTargets; if (VTP && isVirtualCall(CallSite, *VTP)) { - return resolveVirtualCall(CallSite); + resolveVirtualCall(PossibleTargets, CallSite); + } + + if (PossibleTargets.empty()) { + resolveFunctionPointer(PossibleTargets, CallSite); } - return resolveFunctionPointer(CallSite); + + return PossibleTargets; } -auto Resolver::resolveFunctionPointer(const llvm::CallBase *CallSite) - -> FunctionSetTy { +void Resolver::resolveFunctionPointer(FunctionSetTy &PossibleTargets, + const llvm::CallBase *CallSite) { // we may wish to optimise this function // naive implementation that considers every function whose signature // matches the call-site's signature as a callee target PHASAR_LOG_LEVEL(DEBUG, "Call function pointer: " << llvmIRToString(CallSite)); - FunctionSetTy CalleeTargets; for (const auto *F : IRDB->getAllFunctions()) { if (F->hasAddressTaken() && isConsistentCall(CallSite, F)) { - CalleeTargets.insert(F); + PossibleTargets.insert(F); } } - - return CalleeTargets; } void Resolver::otherInst(const llvm::Instruction *Inst) {} diff --git a/lib/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.cpp b/lib/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.cpp index 36319dc05a..4967c9faf0 100644 --- a/lib/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.cpp +++ b/lib/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.cpp @@ -23,9 +23,7 @@ #include "llvm/Demangle/Demangle.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -317,7 +315,7 @@ bool DIBasedTypeHierarchy::isVTable(llvm::StringRef VarName) { if (VarName.startswith(VTablePrefix)) { return true; } - // In LLVM 16 demangle() takes a StringRef + // In LLVM 17 demangle() takes a StringRef auto Demang = llvm::demangle(VarName.str()); return llvm::StringRef(Demang).startswith(VTablePrefixDemang); } diff --git a/lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp b/lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp index 028abe8659..716f759af0 100644 --- a/lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp +++ b/lib/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.cpp @@ -20,9 +20,6 @@ #include "phasar/PhasarLLVM/DB/LLVMProjectIRDB.h" #include "phasar/PhasarLLVM/Utils/LLVMShorthands.h" #include "phasar/Utils/Logger.h" -#include "phasar/Utils/NlohmannLogging.h" -#include "phasar/Utils/PAMMMacros.h" -#include "phasar/Utils/Utilities.h" #include "llvm/ADT/StringMap.h" #include "llvm/Demangle/Demangle.h" @@ -30,18 +27,12 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" -#include "llvm/Support/Format.h" #include "boost/graph/graphviz.hpp" #include "boost/graph/transitive_closure.hpp" -#include #include -#include #include using namespace std; @@ -346,8 +337,7 @@ LLVMTypeHierarchy::getType(llvm::StringRef TypeName) const { } // Sometimes, clang adds a .base suffix - std::string TN = TypeName.str() + ".base"; - return getTypeImpl(TypeGraph, TN); + return getTypeImpl(TypeGraph, (TypeName + ".base").str()); } std::vector LLVMTypeHierarchy::getAllTypes() const { diff --git a/lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp b/lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp index f2a5a849bd..7a7e487b94 100644 --- a/lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp +++ b/lib/PhasarLLVM/TypeHierarchy/LLVMVFTable.cpp @@ -10,15 +10,12 @@ #include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTable.h" #include "phasar/PhasarLLVM/TypeHierarchy/LLVMVFTableData.h" -#include "phasar/Utils/NlohmannLogging.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalAlias.h" -#include "llvm/IR/Operator.h" #include "llvm/Support/raw_ostream.h" #include -#include using namespace psr; @@ -51,7 +48,7 @@ void LLVMVFTable::print(llvm::raw_ostream &OS) const { [[nodiscard]] LLVMVFTableData LLVMVFTable::getVFTableData() const { LLVMVFTableData Data; - for (const auto &Curr : VFT) { + for (const auto *Curr : VFT) { if (Curr) { Data.VFT.push_back(Curr->getName().str()); continue; @@ -69,22 +66,26 @@ void LLVMVFTable::printAsJson(llvm::raw_ostream &OS) const { } std::vector -LLVMVFTable::getVFVectorFromIRVTable(const llvm::ConstantStruct &VT) { +LLVMVFTable::getVFVectorFromIRVTable(const llvm::ConstantStruct &VT, + uint32_t Index) { std::vector VFS; - for (const auto &Op : VT.operands()) { - if (const auto *CA = llvm::dyn_cast(Op)) { - // Start iterating at offset 2, because offset 0 is vbase offset, offset 1 - // is RTTI - for (const auto *It = std::next(CA->operands().begin(), 2); - It != CA->operands().end(); ++It) { - const auto *Entry = It->get()->stripPointerCastsAndAliases(); - - const auto *F = llvm::dyn_cast(Entry); - VFS.push_back(F); - } + if (Index >= VT.getNumOperands()) { + return VFS; + } + + const auto *Op = VT.getOperand(Index); + + if (const auto *CA = llvm::dyn_cast(Op)) { + // Start iterating at offset 2, because offset 0 is vbase offset, offset 1 + // is RTTI + for (const auto *It = std::next(CA->operands().begin(), 2); + It != CA->operands().end(); ++It) { + const auto *Entry = It->get()->stripPointerCastsAndAliases(); + + const auto *F = llvm::dyn_cast(Entry); + VFS.push_back(F); } } return VFS; } - } // namespace psr diff --git a/lib/PhasarLLVM/Utils/LLVMIRToSrc.cpp b/lib/PhasarLLVM/Utils/LLVMIRToSrc.cpp index ac36114d33..5493b9dcc9 100644 --- a/lib/PhasarLLVM/Utils/LLVMIRToSrc.cpp +++ b/lib/PhasarLLVM/Utils/LLVMIRToSrc.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -190,7 +191,7 @@ static llvm::DIType *getStructElementType(llvm::DIType *BaseTy, size_t Offset) { if (const auto *CompositeTy = llvm::dyn_cast(StructTy)) { - if (Offset > CompositeTy->getElements().size()) { + if (Offset >= CompositeTy->getElements().size()) { return nullptr; } auto Elems = CompositeTy->getElements(); diff --git a/test/llvm_test_code/call_graphs/CMakeLists.txt b/test/llvm_test_code/call_graphs/CMakeLists.txt index 9a4557b178..833ff39d2e 100644 --- a/test/llvm_test_code/call_graphs/CMakeLists.txt +++ b/test/llvm_test_code/call_graphs/CMakeLists.txt @@ -27,6 +27,8 @@ set(NoMem2regSources virtual_call_7.cpp virtual_call_8.cpp virtual_call_9.cpp + virtual_call_10.cpp + virtual_call_11.cpp global_ctor_dtor_1.cpp global_ctor_dtor_2.cpp global_ctor_dtor_3.cpp diff --git a/test/llvm_test_code/call_graphs/virtual_call_10.cpp b/test/llvm_test_code/call_graphs/virtual_call_10.cpp new file mode 100644 index 0000000000..121eee4d41 --- /dev/null +++ b/test/llvm_test_code/call_graphs/virtual_call_10.cpp @@ -0,0 +1,22 @@ +// handle virtual function call on a pointer to an interface implementation + +struct A { + virtual ~A() = default; + virtual void foo() = 0; +}; + +struct B { + virtual ~B() = default; + virtual void bar() = 0; +}; + +struct ABImpl : A, B { + void foo() override {} + void bar() override {} +}; + +int main() { + B *ABptr = new ABImpl; + ABptr->bar(); + delete ABptr; +} diff --git a/test/llvm_test_code/call_graphs/virtual_call_11.cpp b/test/llvm_test_code/call_graphs/virtual_call_11.cpp new file mode 100644 index 0000000000..e98782e23f --- /dev/null +++ b/test/llvm_test_code/call_graphs/virtual_call_11.cpp @@ -0,0 +1,45 @@ +// handle virtual function call on a pointer to an interface implementation + +struct A { + virtual void foo() = 0; +}; + +struct B { + virtual void bar() = 0; +}; + +struct ABImpl : A, B { + void foo() override {} + void bar() override {} +}; + +struct C { + + virtual void baz() {} +}; + +struct ABCImpl : C, ABImpl { + void foo() override {} + void bar() override {} + void baz() override {} +}; + +void callFoo(A &a) { // + a.foo(); +} + +void callBar(B &b) { // + b.bar(); +} + +void callBaz(C &c) { // + c.baz(); +} + +int main() { + ABCImpl abc; + + callFoo(abc); + callBar(abc); + callBaz(abc); +} diff --git a/test/llvm_test_code/virtual_callsites/CMakeLists.txt b/test/llvm_test_code/virtual_callsites/CMakeLists.txt index 4a329e510e..a4f9d3e0c8 100644 --- a/test/llvm_test_code/virtual_callsites/CMakeLists.txt +++ b/test/llvm_test_code/virtual_callsites/CMakeLists.txt @@ -11,4 +11,5 @@ set(NoMem2regSources foreach(TEST_SRC ${NoMem2regSources}) generate_ll_file(FILE ${TEST_SRC}) + generate_ll_file(FILE ${TEST_SRC} DEBUG) endforeach(TEST_SRC) diff --git a/unittests/PhasarLLVM/ControlFlow/CMakeLists.txt b/unittests/PhasarLLVM/ControlFlow/CMakeLists.txt index 11a79b8389..95d3b00e67 100644 --- a/unittests/PhasarLLVM/ControlFlow/CMakeLists.txt +++ b/unittests/PhasarLLVM/ControlFlow/CMakeLists.txt @@ -4,6 +4,7 @@ set(ControlFlowSources LLVMBasedICFG_CHATest.cpp LLVMBasedICFG_OTFTest.cpp LLVMBasedICFG_RTATest.cpp + LLVMBasedICFG_RTA_MultipleInheritanceTest.cpp LLVMBasedBackwardCFGTest.cpp LLVMBasedBackwardICFGTest.cpp LLVMBasedICFGExportTest.cpp diff --git a/unittests/PhasarLLVM/ControlFlow/LLVMBasedICFG_RTA_MultipleInheritanceTest.cpp b/unittests/PhasarLLVM/ControlFlow/LLVMBasedICFG_RTA_MultipleInheritanceTest.cpp new file mode 100644 index 0000000000..594aa828cc --- /dev/null +++ b/unittests/PhasarLLVM/ControlFlow/LLVMBasedICFG_RTA_MultipleInheritanceTest.cpp @@ -0,0 +1,122 @@ +#include "phasar/Config/Configuration.h" +#include "phasar/ControlFlow/CallGraphAnalysisType.h" +#include "phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h" +#include "phasar/PhasarLLVM/DB/LLVMProjectIRDB.h" +#include "phasar/PhasarLLVM/Pointer/LLVMAliasSet.h" +#include "phasar/PhasarLLVM/TypeHierarchy/DIBasedTypeHierarchy.h" +#include "phasar/PhasarLLVM/Utils/LLVMIRToSrc.h" +#include "phasar/PhasarLLVM/Utils/LLVMShorthands.h" + +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/Support/Casting.h" + +#include "TestConfig.h" +#include "gtest/gtest.h" + +using namespace psr; + +static const llvm::CallBase *getCallInLine(const llvm::Function &F, + uint32_t Line) { + for (const auto &I : llvm::instructions(F)) { + const auto *CB = llvm::dyn_cast(&I); + if (!CB) { + continue; + } + + auto CBLine = getLineFromIR(CB); + if (CBLine == Line) { + return CB; + } + } + return nullptr; +} + +TEST(LLVMBasedICFG_RTATest, VirtualCallSite_10) { + // TODO: test if getNonPureVirtualVFTEntry gets the correct inheritance and + // not just the first one every time + + LLVMProjectIRDB IRDB(unittest::PathToLLTestFiles + + "call_graphs/virtual_call_10_cpp_dbg.ll"); + DIBasedTypeHierarchy TH(IRDB); + LLVMBasedICFG ICFG(&IRDB, CallGraphAnalysisType::RTA, {"main"}, &TH); + const llvm::Function *MainF = IRDB.getFunctionDefinition("main"); + ASSERT_TRUE(MainF); + + // --- At Line 20: ABptr->bar(); + + const auto *CallToBar = getCallInLine(*MainF, 20); + ASSERT_TRUE(CallToBar); + const auto *BarF = IRDB.getFunction("_ZThn8_N6ABImpl3barEv"); + ASSERT_TRUE(BarF); + + auto BarCallees = ICFG.getCalleesOfCallAt(CallToBar); + // non-virtual thunk to ABImpl::bar() + EXPECT_EQ(llvm::ArrayRef{BarF}, BarCallees); + + // --- At Line 21: delete ABptr; + + const auto *CallToDtor = getCallInLine(*MainF, 21); + ASSERT_TRUE(CallToDtor); + const auto *DtorF = IRDB.getFunction("_ZThn8_N6ABImplD0Ev"); + ASSERT_TRUE(DtorF); + + auto DtorCallees = ICFG.getCalleesOfCallAt(CallToDtor); + + // non-virtual thunk to ABImpl::~ABImpl() + EXPECT_EQ(llvm::ArrayRef{DtorF}, DtorCallees); +} + +TEST(LLVMBasedICFG_RTATest, VirtualCallSite_11) { + // TODO: test if getNonPureVirtualVFTEntry gets the correct inheritance and + // not just the first one every time + + LLVMProjectIRDB IRDB(unittest::PathToLLTestFiles + + "call_graphs/virtual_call_11_cpp_dbg.ll"); + DIBasedTypeHierarchy TH(IRDB); + LLVMBasedICFG ICFG(&IRDB, CallGraphAnalysisType::RTA, {"main"}, &TH); + const llvm::Function *CallFooF = IRDB.getFunctionDefinition("_Z7callFooR1A"); + const llvm::Function *CallBarF = IRDB.getFunctionDefinition("_Z7callBarR1B"); + const llvm::Function *CallBazF = IRDB.getFunctionDefinition("_Z7callBazR1C"); + ASSERT_TRUE(CallFooF); + ASSERT_TRUE(CallBarF); + ASSERT_TRUE(CallBazF); + + // --- At Line 28: a.foo(); + + const auto *CallToFoo = getCallInLine(*CallFooF, 28); + ASSERT_TRUE(CallToFoo); + const auto *FooF = IRDB.getFunction("_ZThn8_N7ABCImpl3fooEv"); + ASSERT_TRUE(FooF); + + auto FooCallees = ICFG.getCalleesOfCallAt(CallToFoo); + // non-virtual thunk to ABCImpl::foo() + EXPECT_EQ(llvm::ArrayRef{FooF}, FooCallees); + + // --- At Line 32: b.bar(); + + const auto *CallToBar = getCallInLine(*CallBarF, 32); + ASSERT_TRUE(CallToBar); + const auto *BarF = IRDB.getFunction("_ZThn16_N7ABCImpl3barEv"); + ASSERT_TRUE(BarF); + + auto BarCallees = ICFG.getCalleesOfCallAt(CallToBar); + // non-virtual thunk to ABCImpl::bar() + EXPECT_EQ(llvm::ArrayRef{BarF}, BarCallees); + + // --- At Line 36: c.baz(); + + const auto *CallToBaz = getCallInLine(*CallBazF, 36); + ASSERT_TRUE(CallToBaz); + const auto *BazF = IRDB.getFunction("_ZN7ABCImpl3bazEv"); + ASSERT_TRUE(BazF); + + auto BazCallees = ICFG.getCalleesOfCallAt(CallToBaz); + // ABCImpl::baz() + EXPECT_EQ(llvm::ArrayRef{BazF}, BazCallees); +} + +int main(int Argc, char **Argv) { + ::testing::InitGoogleTest(&Argc, Argv); + return RUN_ALL_TESTS(); +}