Skip to content

[MLGO][IR2Vec] Integrating IR2Vec with MLInliner #143479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 10, 2025

Changes to use Symbolic embeddings in MLInliner.

(Fixes #141836, Tracking issue - #141817)

Copy link
Contributor Author

svkeerthy commented Jun 10, 2025

@svkeerthy svkeerthy marked this pull request as ready for review June 10, 2025 05:50
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

Changes to use Symbolic embeddings in MLInliner.

(Fixes #141836, Tracking issue - #141817)


Patch is 36.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143479.diff

10 Files Affected:

  • (modified) llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h (+21-5)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+1-1)
  • (modified) llvm/include/llvm/Analysis/InlineAdvisor.h (+3)
  • (modified) llvm/include/llvm/Analysis/InlineModelFeatureMaps.h (+5-1)
  • (modified) llvm/include/llvm/Analysis/MLInlineAdvisor.h (+1)
  • (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+112-3)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+2-2)
  • (modified) llvm/lib/Analysis/InlineAdvisor.cpp (+29)
  • (modified) llvm/lib/Analysis/MLInlineAdvisor.cpp (+33-1)
  • (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+154-25)
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index babb6d9d6cf0c..06dbfc35a5294 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
   void updateAggregateStats(const Function &F, const LoopInfo &LI);
   void reIncludeBB(const BasicBlock &BB);
 
+  ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
+  std::optional<ir2vec::Vocab> IR2VecVocab;
+
 public:
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
-                            const LoopInfo &LI);
+                            const LoopInfo &LI,
+                            const IR2VecVocabResult *VocabResult);
 
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
 
-  bool operator==(const FunctionPropertiesInfo &FPI) const {
-    return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
-  }
+  bool operator==(const FunctionPropertiesInfo &FPI) const;
 
   bool operator!=(const FunctionPropertiesInfo &FPI) const {
     return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
   int64_t CallReturnsVectorPointerCount = 0;
   int64_t CallWithManyArgumentsCount = 0;
   int64_t CallWithPointerArgumentCount = 0;
+
+  const ir2vec::Embedding &getFunctionEmbedding() const {
+    return FunctionEmbedding;
+  }
+
+  const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
+    return IR2VecVocab;
+  }
+
+  // Helper intended to be useful for unittests
+  void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
+    FunctionEmbedding = Embedding;
+  }
 };
 
 // Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
 
   DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
 
-  DenseSet<const BasicBlock *> Successors;
+  DenseSet<const BasicBlock *> Successors, CallUsers;
 
   // Edges we might potentially need to remove from the dominator tree.
   SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 7976cc7470d5b..3f44c650e640d 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
-  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
+  explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab);
   using Result = IR2VecVocabResult;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h
index 9d15136e81d10..d2cad4717cbdb 100644
--- a/llvm/include/llvm/Analysis/InlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/InlineAdvisor.h
@@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
   };
 
   Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
+
+private:
+  static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM);
 };
 
 /// Printer pass for the InlineAdvisorAnalysis results.
diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 961d5091bf9f3..91d3378565fc5 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -142,6 +142,10 @@ enum class FeatureIndex : size_t {
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 
+// IR2Vec embeddings
+  callee_embedding,
+  caller_embedding,
+
   NumberOfFeatures
 };
 // clang-format on
@@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
 constexpr size_t NumberOfFeatures =
     static_cast<size_t>(FeatureIndex::NumberOfFeatures);
 
-LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
+LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
 
 LLVM_ABI extern const char *const DecisionName;
 LLVM_ABI extern const TensorSpec InlineDecisionSpec;
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
index 580dd5e95d760..935e4c56dfce6 100644
--- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
   int64_t NodeCount = 0;
   int64_t EdgeCount = 0;
   int64_t EdgesOfLastSeenNodes = 0;
+  bool UseIR2Vec = false;
 
   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
   const int32_t InitialIRSize = 0;
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 9d044c8a35910..29d3aaf46dc06 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
 #undef CHECK_OPERAND
     }
   }
+
+  if (IR2VecVocab) {
+    // We instantiate the IR2Vec embedder each time, as having an unique
+    // pointer to the embedder as member of the class would make it
+    // non-copyable. Instantiating the embedder in itself is not costly.
+    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+                                             *BB.getParent(), *IR2VecVocab);
+    if (Error Err = EmbOrErr.takeError()) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
+                                  EI.message());
+      });
+      return;
+    }
+    auto Embedder = std::move(*EmbOrErr);
+    const auto &BBEmbedding = Embedder->getBBVector(BB);
+    // Subtract BBEmbedding from Function embedding if the direction is -1,
+    // and add it if the direction is +1.
+    if (Direction == -1)
+      FunctionEmbedding -= BBEmbedding;
+    else
+      FunctionEmbedding += BBEmbedding;
+  }
 }
 
 void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
     Function &F, FunctionAnalysisManager &FAM) {
+  // We use the cached result of the IR2VecVocabAnalysis run by
+  // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
+  // use IR2Vec embeddings.
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
-                                   FAM.getResult<LoopAnalysis>(F));
+                                   FAM.getResult<LoopAnalysis>(F), VocabResult);
 }
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
-    const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
+    const Function &F, const DominatorTree &DT, const LoopInfo &LI,
+    const IR2VecVocabResult *VocabResult) {
 
   FunctionPropertiesInfo FPI;
+  if (VocabResult && VocabResult->isValid()) {
+    FPI.IR2VecVocab = VocabResult->getVocabulary();
+    FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+  }
   for (const auto &BB : F)
     if (DT.isReachableFromEntry(&BB))
       FPI.reIncludeBB(BB);
@@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
   return FPI;
 }
 
+bool FunctionPropertiesInfo::operator==(
+    const FunctionPropertiesInfo &FPI) const {
+  if (BasicBlockCount != FPI.BasicBlockCount ||
+      BlocksReachedFromConditionalInstruction !=
+          FPI.BlocksReachedFromConditionalInstruction ||
+      Uses != FPI.Uses ||
+      DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
+      LoadInstCount != FPI.LoadInstCount ||
+      StoreInstCount != FPI.StoreInstCount ||
+      MaxLoopDepth != FPI.MaxLoopDepth ||
+      TopLevelLoopCount != FPI.TopLevelLoopCount ||
+      TotalInstructionCount != FPI.TotalInstructionCount ||
+      BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
+      BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
+      BasicBlocksWithMoreThanTwoSuccessors !=
+          FPI.BasicBlocksWithMoreThanTwoSuccessors ||
+      BasicBlocksWithSinglePredecessor !=
+          FPI.BasicBlocksWithSinglePredecessor ||
+      BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
+      BasicBlocksWithMoreThanTwoPredecessors !=
+          FPI.BasicBlocksWithMoreThanTwoPredecessors ||
+      BigBasicBlocks != FPI.BigBasicBlocks ||
+      MediumBasicBlocks != FPI.MediumBasicBlocks ||
+      SmallBasicBlocks != FPI.SmallBasicBlocks ||
+      CastInstructionCount != FPI.CastInstructionCount ||
+      FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
+      IntegerInstructionCount != FPI.IntegerInstructionCount ||
+      ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
+      ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
+      ConstantOperandCount != FPI.ConstantOperandCount ||
+      InstructionOperandCount != FPI.InstructionOperandCount ||
+      BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
+      GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
+      InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
+      ArgumentOperandCount != FPI.ArgumentOperandCount ||
+      UnknownOperandCount != FPI.UnknownOperandCount ||
+      CriticalEdgeCount != FPI.CriticalEdgeCount ||
+      ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
+      UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
+      IntrinsicCount != FPI.IntrinsicCount ||
+      DirectCallCount != FPI.DirectCallCount ||
+      IndirectCallCount != FPI.IndirectCallCount ||
+      CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
+      CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
+      CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
+      CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
+      CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
+      CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
+      CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
+      CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
+    return false;
+  }
+  // Check the equality of the function embeddings. We don't check the equality
+  // of Vocabulary as it remains the same.
+  if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
+    return false;
+
+  return true;
+}
+
 void FunctionPropertiesInfo::print(raw_ostream &OS) const {
 #define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
 
@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
   // The caller's entry BB may change due to new alloca instructions.
   LikelyToChangeBBs.insert(&*Caller.begin());
 
+  // The users of the value returned by call instruction can change
+  // leading to the change in embeddings being computed, when used.
+  // We conservatively add the BBs with such uses to LikelyToChangeBBs.
+  for (const auto *User : CB.users())
+    CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
+  // CallSiteBB can be removed from CallUsers if present, it's taken care
+  // separately.
+  CallUsers.erase(&CallSiteBB);
+  LikelyToChangeBBs.insert_range(CallUsers);
+
   // The successors may become unreachable in the case of `invoke` inlining.
   // We track successors separately, too, because they form a boundary, together
   // with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
   if (&CallSiteBB != &*Caller.begin())
     Reinclude.insert(&*Caller.begin());
 
+  // Reinclude the BBs which use the values returned by call instruction
+  Reinclude.insert_range(CallUsers);
+
   // Distribute the successors to the 2 buckets.
   for (const auto *Succ : Successors)
     if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
     return false;
   DominatorTree DT(F);
   LoopInfo LI(DT);
-  auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Fresh =
+      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
   return FPI == Fresh;
 }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 8a392e0709c7f..5f2245ad6aafb 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -294,8 +294,8 @@ Error IR2VecVocabAnalysis::readVocabulary() {
   return Error::success();
 }
 
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)) {}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
+    : Vocabulary(Vocabulary) {}
 
 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
   handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 3d30f3d10a9d0..2e869dfd91713 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/EphemeralValuesCache.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
                         cl::desc("If true, annotate inline advisor remarks "
                                  "with LTO and pass information."));
 
+// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
+// with ML inliner. The vocab file is used to initialize the embeddings.
+static cl::opt<std::string> IR2VecVocabFile(
+    "ml-inliner-ir2vec-vocab-file", cl::Hidden,
+    cl::desc("Vocab file for IR2Vec; Setting this enables "
+             "configuring the model to use IR2Vec embeddings."));
+
 namespace llvm {
 extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
 } // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
 AnalysisKey InlineAdvisorAnalysis::Key;
 AnalysisKey PluginInlineAdvisorAnalysis::Key;
 
+bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M,
+                                                  ModuleAnalysisManager &MAM) {
+  if (!IR2VecVocabFile.empty()) {
+    auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+    if (!IR2VecVocabResult.isValid()) {
+      M.getContext().emitError("Failed to load IR2Vec vocabulary");
+      return false;
+    }
+  }
+  // No vocab file specified is OK; We just don't use IR2Vec
+  // embeddings.
+  return true;
+}
+
 bool InlineAdvisorAnalysis::Result::tryCreate(
     InlineParams Params, InliningAdvisorMode Mode,
     const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
                                              /* EmitRemarks =*/true, IC);
     }
     break;
+    // Run IR2VecVocabAnalysis once per module to get the vocabulary.
+    // We run it here because it is immutable and we want to avoid running it
+    // multiple times.
   case InliningAdvisorMode::Development:
 #ifdef LLVM_HAVE_TFLITE
     LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
 #endif
     break;
   case InliningAdvisorMode::Release:
     LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
     break;
   }
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 81a3bc94a6ad8..3a9a68670e852 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
     cl::init(false));
 
 // clang-format off
-const std::vector<TensorSpec> llvm::FeatureMap{
+std::vector<TensorSpec> llvm::FeatureMap{
 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
 // InlineCost features - these must come first
   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -186,6 +186,20 @@ MLInlineAdvisor::MLInlineAdvisor(
     EdgeCount += getLocalCalls(KVP.first->getFunction());
   }
   NodeCount = AllNodes.size();
+
+  if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
+    if (!IR2VecVocabResult->isValid()) {
+      M.getContext().emitError("IR2VecVocabAnalysis is not valid");
+      return;
+    }
+    // Add the IR2Vec features to the feature map
+    auto IR2VecDim = IR2VecVocabResult->getDimension();
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
+    UseIR2Vec = true;
+  }
 }
 
 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
       Caller.hasAvailableExternallyLinkage();
 
+  if (UseIR2Vec) {
+    // Python side expects float embeddings. The IR2Vec embeddings are doubles
+    // as of now due to the restriction of fromJSON method used by the
+    // readVocabulary method in ir2vec::Embeddings.
+    auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
+                            FeatureIndex Index) {
+      auto Embedding_float =
+          std::vector<float>(Embedding.begin(), Embedding.end());
+      std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
+                  Embedding.size() * sizeof(float));
+    };
+
+    setEmbedding(CalleeBefore.getFunctionEmbedding(),
+                 FeatureIndex::callee_embedding);
+    setEmbedding(CallerBefore.getFunctionEmbedding(),
+                 FeatureIndex::caller_embedding);
+  }
+
   // Add the cost features
   for (size_t I = 0;
        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index 0720d935b0362..3ef2964f2d170 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
 #include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Dominators.h"
@@ -20,15 +21,20 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Transforms/Utils/Cloning.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <cstring>
 
 using namespace llvm;
+using namespace testing;
 
 namespace llvm {
 LLVM_ABI extern cl::opt<bool> EnableDetailedFunctionProperties;
 LLVM_ABI extern cl::opt<bool> BigBasicBlockInstructionThreshold;
 LLVM_ABI extern cl::opt<bool> MediumBasicBlockInstrutionThreshold;
+LLVM_ABI extern cl::opt<float> ir2vec::OpcWeight;
+LLVM_ABI extern cl::opt<float> ir2vec::TypeWeight;
+LLVM_ABI...
[truncated]

@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
void updateAggregateStats(const Function &F, const LoopInfo &LI);
void reIncludeBB(const BasicBlock &BB);

ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to plumb this through FunctionPropertiesAnalysis instead of directly depending on the IR2Vec Analysis pass in MLInliner? I think the separation would keep things a bit cleaner unless there's a constraint compelling this design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the patch-updating of IR2Vec embeddings uses the same contours as FPA. Also, eventually, I'd hope we can remove at least the expensive manually-computed features from FPA (those depending on LoopInfo for instance)

float OriginalArgWeight = ir2vec::ArgWeight;

void createTestVocabulary(unsigned Dim) {
Vocabulary["add"] = ir2vec::Embedding(Dim, 0.1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you load this from one of the test JSON files rather than explicitly creating it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the constructor of vocab analysis could only take in Vocabulary map and not JSON. It has to be via flags if we need to use JSON. This design of constructor would help in later patches where vocab map would be autogenerated during the build time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could populate it with a macro, easier to read. Like

#define ENTRY(NAME, DEFAULT_VALUE) Vocabulary[#NAME] = ir2vec::Embedding(Dim, DEFAULT_VALUE);

ENTRY(add, 0.1)
ENTRY(sub, 0.2)
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified it using lambda. Please let me know if this is better. (Just thought this would reduce the number of lines)

@svkeerthy svkeerthy changed the title [MLIniner][IR2Vec] Integrating IR2Vec with MLInliner [MLInliner][IR2Vec] Integrating IR2Vec with MLInliner Jun 10, 2025
@svkeerthy svkeerthy changed the title [MLInliner][IR2Vec] Integrating IR2Vec with MLInliner [MLGO][IR2Vec] Integrating IR2Vec with MLInliner Jun 10, 2025
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
: Vocabulary(std::move(Vocabulary)) {}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
: Vocabulary(Vocabulary) {}
Copy link
Member

@mtrofin mtrofin Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

can this change be made in a precursor PR?

or can you pass by reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, moved it to earlier PR (#143200)

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from 79060df to abee4cd Compare June 10, 2025 18:14
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from 58c85ec to 6085399 Compare June 10, 2025 18:14
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from abee4cd to cd2cdf5 Compare June 10, 2025 19:54
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from 6085399 to 579c3c0 Compare June 10, 2025 19:54
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from cd2cdf5 to 716f3d2 Compare June 10, 2025 21:22
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from 579c3c0 to b40003f Compare June 10, 2025 21:22
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from 716f3d2 to 173c3b1 Compare June 10, 2025 22:14
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from b40003f to 5d7a2b0 Compare June 10, 2025 22:15
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from 5d7a2b0 to b7ec652 Compare June 12, 2025 21:48
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from 173c3b1 to 6e44bb0 Compare June 12, 2025 21:48
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from 6e44bb0 to 40402d2 Compare June 13, 2025 17:44
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from b7ec652 to d151083 Compare June 13, 2025 17:45
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-reachable_bb branch from 40402d2 to ece3e21 Compare June 13, 2025 18:18
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from d151083 to a2bec77 Compare June 13, 2025 18:18
Base automatically changed from users/svkeerthy/06-10-reachable_bb to main June 17, 2025 17:57
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-10-_mlininer_ir2vec_integrating_ir2vec_with_mlinliner branch from a2bec77 to 9124e83 Compare June 17, 2025 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLGO][IR2Vec] IR2Vec embeddings for MLInliner
4 participants