-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLGO][IR2Vec] Integrating IR2Vec with MLInliner #143479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLGO][IR2Vec] Integrating IR2Vec with MLInliner #143479
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesChanges to use Symbolic embeddings in MLInliner. (Fixes #141836, Tracking issue - #141817) Patch is 36.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143479.diff 10 Files Affected:
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index babb6d9d6cf0c..06dbfc35a5294 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -15,6 +15,7 @@
#define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
void updateAggregateStats(const Function &F, const LoopInfo &LI);
void reIncludeBB(const BasicBlock &BB);
+ ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
+ std::optional<ir2vec::Vocab> IR2VecVocab;
+
public:
LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
- const LoopInfo &LI);
+ const LoopInfo &LI,
+ const IR2VecVocabResult *VocabResult);
LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
- bool operator==(const FunctionPropertiesInfo &FPI) const {
- return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
- }
+ bool operator==(const FunctionPropertiesInfo &FPI) const;
bool operator!=(const FunctionPropertiesInfo &FPI) const {
return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
int64_t CallReturnsVectorPointerCount = 0;
int64_t CallWithManyArgumentsCount = 0;
int64_t CallWithPointerArgumentCount = 0;
+
+ const ir2vec::Embedding &getFunctionEmbedding() const {
+ return FunctionEmbedding;
+ }
+
+ const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
+ return IR2VecVocab;
+ }
+
+ // Helper intended to be useful for unittests
+ void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
+ FunctionEmbedding = Embedding;
+ }
};
// Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
- DenseSet<const BasicBlock *> Successors;
+ DenseSet<const BasicBlock *> Successors, CallUsers;
// Edges we might potentially need to remove from the dominator tree.
SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 7976cc7470d5b..3f44c650e640d 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
public:
static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
- explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
+ explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab);
using Result = IR2VecVocabResult;
Result run(Module &M, ModuleAnalysisManager &MAM);
};
diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h
index 9d15136e81d10..d2cad4717cbdb 100644
--- a/llvm/include/llvm/Analysis/InlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/InlineAdvisor.h
@@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
};
Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
+
+private:
+ static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM);
};
/// Printer pass for the InlineAdvisorAnalysis results.
diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 961d5091bf9f3..91d3378565fc5 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -142,6 +142,10 @@ enum class FeatureIndex : size_t {
INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
#undef POPULATE_INDICES
+// IR2Vec embeddings
+ callee_embedding,
+ caller_embedding,
+
NumberOfFeatures
};
// clang-format on
@@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
constexpr size_t NumberOfFeatures =
static_cast<size_t>(FeatureIndex::NumberOfFeatures);
-LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
+LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
LLVM_ABI extern const char *const DecisionName;
LLVM_ABI extern const TensorSpec InlineDecisionSpec;
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
index 580dd5e95d760..935e4c56dfce6 100644
--- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
int64_t NodeCount = 0;
int64_t EdgeCount = 0;
int64_t EdgesOfLastSeenNodes = 0;
+ bool UseIR2Vec = false;
std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
const int32_t InitialIRSize = 0;
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 9d044c8a35910..29d3aaf46dc06 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
#undef CHECK_OPERAND
}
}
+
+ if (IR2VecVocab) {
+ // We instantiate the IR2Vec embedder each time, as having an unique
+ // pointer to the embedder as member of the class would make it
+ // non-copyable. Instantiating the embedder in itself is not costly.
+ auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+ *BB.getParent(), *IR2VecVocab);
+ if (Error Err = EmbOrErr.takeError()) {
+ handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+ BB.getContext().emitError("Error creating IR2Vec embeddings: " +
+ EI.message());
+ });
+ return;
+ }
+ auto Embedder = std::move(*EmbOrErr);
+ const auto &BBEmbedding = Embedder->getBBVector(BB);
+ // Subtract BBEmbedding from Function embedding if the direction is -1,
+ // and add it if the direction is +1.
+ if (Direction == -1)
+ FunctionEmbedding -= BBEmbedding;
+ else
+ FunctionEmbedding += BBEmbedding;
+ }
}
void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
Function &F, FunctionAnalysisManager &FAM) {
+ // We use the cached result of the IR2VecVocabAnalysis run by
+ // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
+ // use IR2Vec embeddings.
+ auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+ .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
- FAM.getResult<LoopAnalysis>(F));
+ FAM.getResult<LoopAnalysis>(F), VocabResult);
}
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
- const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
+ const Function &F, const DominatorTree &DT, const LoopInfo &LI,
+ const IR2VecVocabResult *VocabResult) {
FunctionPropertiesInfo FPI;
+ if (VocabResult && VocabResult->isValid()) {
+ FPI.IR2VecVocab = VocabResult->getVocabulary();
+ FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+ }
for (const auto &BB : F)
if (DT.isReachableFromEntry(&BB))
FPI.reIncludeBB(BB);
@@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
return FPI;
}
+bool FunctionPropertiesInfo::operator==(
+ const FunctionPropertiesInfo &FPI) const {
+ if (BasicBlockCount != FPI.BasicBlockCount ||
+ BlocksReachedFromConditionalInstruction !=
+ FPI.BlocksReachedFromConditionalInstruction ||
+ Uses != FPI.Uses ||
+ DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
+ LoadInstCount != FPI.LoadInstCount ||
+ StoreInstCount != FPI.StoreInstCount ||
+ MaxLoopDepth != FPI.MaxLoopDepth ||
+ TopLevelLoopCount != FPI.TopLevelLoopCount ||
+ TotalInstructionCount != FPI.TotalInstructionCount ||
+ BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
+ BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
+ BasicBlocksWithMoreThanTwoSuccessors !=
+ FPI.BasicBlocksWithMoreThanTwoSuccessors ||
+ BasicBlocksWithSinglePredecessor !=
+ FPI.BasicBlocksWithSinglePredecessor ||
+ BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
+ BasicBlocksWithMoreThanTwoPredecessors !=
+ FPI.BasicBlocksWithMoreThanTwoPredecessors ||
+ BigBasicBlocks != FPI.BigBasicBlocks ||
+ MediumBasicBlocks != FPI.MediumBasicBlocks ||
+ SmallBasicBlocks != FPI.SmallBasicBlocks ||
+ CastInstructionCount != FPI.CastInstructionCount ||
+ FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
+ IntegerInstructionCount != FPI.IntegerInstructionCount ||
+ ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
+ ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
+ ConstantOperandCount != FPI.ConstantOperandCount ||
+ InstructionOperandCount != FPI.InstructionOperandCount ||
+ BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
+ GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
+ InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
+ ArgumentOperandCount != FPI.ArgumentOperandCount ||
+ UnknownOperandCount != FPI.UnknownOperandCount ||
+ CriticalEdgeCount != FPI.CriticalEdgeCount ||
+ ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
+ UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
+ IntrinsicCount != FPI.IntrinsicCount ||
+ DirectCallCount != FPI.DirectCallCount ||
+ IndirectCallCount != FPI.IndirectCallCount ||
+ CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
+ CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
+ CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
+ CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
+ CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
+ CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
+ CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
+ CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
+ return false;
+ }
+ // Check the equality of the function embeddings. We don't check the equality
+ // of Vocabulary as it remains the same.
+ if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
+ return false;
+
+ return true;
+}
+
void FunctionPropertiesInfo::print(raw_ostream &OS) const {
#define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
// The caller's entry BB may change due to new alloca instructions.
LikelyToChangeBBs.insert(&*Caller.begin());
+ // The users of the value returned by call instruction can change
+ // leading to the change in embeddings being computed, when used.
+ // We conservatively add the BBs with such uses to LikelyToChangeBBs.
+ for (const auto *User : CB.users())
+ CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
+ // CallSiteBB can be removed from CallUsers if present, it's taken care
+ // separately.
+ CallUsers.erase(&CallSiteBB);
+ LikelyToChangeBBs.insert_range(CallUsers);
+
// The successors may become unreachable in the case of `invoke` inlining.
// We track successors separately, too, because they form a boundary, together
// with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
if (&CallSiteBB != &*Caller.begin())
Reinclude.insert(&*Caller.begin());
+ // Reinclude the BBs which use the values returned by call instruction
+ Reinclude.insert_range(CallUsers);
+
// Distribute the successors to the 2 buckets.
for (const auto *Succ : Successors)
if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
return false;
DominatorTree DT(F);
LoopInfo LI(DT);
- auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
+ auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+ .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+ auto Fresh =
+ FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
return FPI == Fresh;
}
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 8a392e0709c7f..5f2245ad6aafb 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -294,8 +294,8 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return Error::success();
}
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
- : Vocabulary(std::move(Vocabulary)) {}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
+ : Vocabulary(Vocabulary) {}
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 3d30f3d10a9d0..2e869dfd91713 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/EphemeralValuesCache.h"
+#include "llvm/Analysis/IR2Vec.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
cl::desc("If true, annotate inline advisor remarks "
"with LTO and pass information."));
+// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
+// with ML inliner. The vocab file is used to initialize the embeddings.
+static cl::opt<std::string> IR2VecVocabFile(
+ "ml-inliner-ir2vec-vocab-file", cl::Hidden,
+ cl::desc("Vocab file for IR2Vec; Setting this enables "
+ "configuring the model to use IR2Vec embeddings."));
+
namespace llvm {
extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
} // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
AnalysisKey InlineAdvisorAnalysis::Key;
AnalysisKey PluginInlineAdvisorAnalysis::Key;
+bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M,
+ ModuleAnalysisManager &MAM) {
+ if (!IR2VecVocabFile.empty()) {
+ auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+ if (!IR2VecVocabResult.isValid()) {
+ M.getContext().emitError("Failed to load IR2Vec vocabulary");
+ return false;
+ }
+ }
+ // No vocab file specified is OK; We just don't use IR2Vec
+ // embeddings.
+ return true;
+}
+
bool InlineAdvisorAnalysis::Result::tryCreate(
InlineParams Params, InliningAdvisorMode Mode,
const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
/* EmitRemarks =*/true, IC);
}
break;
+ // Run IR2VecVocabAnalysis once per module to get the vocabulary.
+ // We run it here because it is immutable and we want to avoid running it
+ // multiple times.
case InliningAdvisorMode::Development:
#ifdef LLVM_HAVE_TFLITE
LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
+ if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+ return false;
Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
#endif
break;
case InliningAdvisorMode::Release:
LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
+ if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+ return false;
Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
break;
}
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 81a3bc94a6ad8..3a9a68670e852 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
cl::init(false));
// clang-format off
-const std::vector<TensorSpec> llvm::FeatureMap{
+std::vector<TensorSpec> llvm::FeatureMap{
#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
// InlineCost features - these must come first
INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -186,6 +186,20 @@ MLInlineAdvisor::MLInlineAdvisor(
EdgeCount += getLocalCalls(KVP.first->getFunction());
}
NodeCount = AllNodes.size();
+
+ if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
+ if (!IR2VecVocabResult->isValid()) {
+ M.getContext().emitError("IR2VecVocabAnalysis is not valid");
+ return;
+ }
+ // Add the IR2Vec features to the feature map
+ auto IR2VecDim = IR2VecVocabResult->getDimension();
+ FeatureMap.push_back(
+ TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
+ FeatureMap.push_back(
+ TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
+ UseIR2Vec = true;
+ }
}
unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
*ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
Caller.hasAvailableExternallyLinkage();
+ if (UseIR2Vec) {
+ // Python side expects float embeddings. The IR2Vec embeddings are doubles
+ // as of now due to the restriction of fromJSON method used by the
+ // readVocabulary method in ir2vec::Embeddings.
+ auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
+ FeatureIndex Index) {
+ auto Embedding_float =
+ std::vector<float>(Embedding.begin(), Embedding.end());
+ std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
+ Embedding.size() * sizeof(float));
+ };
+
+ setEmbedding(CalleeBefore.getFunctionEmbedding(),
+ FeatureIndex::callee_embedding);
+ setEmbedding(CallerBefore.getFunctionEmbedding(),
+ FeatureIndex::caller_embedding);
+ }
+
// Add the cost features
for (size_t I = 0;
I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index 0720d935b0362..3ef2964f2d170 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -8,6 +8,7 @@
#include "llvm/Analysis/FunctionPropertiesAnalysis.h"
#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/IR2Vec.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Dominators.h"
@@ -20,15 +21,20 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/Utils/Cloning.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <cstring>
using namespace llvm;
+using namespace testing;
namespace llvm {
LLVM_ABI extern cl::opt<bool> EnableDetailedFunctionProperties;
LLVM_ABI extern cl::opt<bool> BigBasicBlockInstructionThreshold;
LLVM_ABI extern cl::opt<bool> MediumBasicBlockInstrutionThreshold;
+LLVM_ABI extern cl::opt<float> ir2vec::OpcWeight;
+LLVM_ABI extern cl::opt<float> ir2vec::TypeWeight;
+LLVM_ABI...
[truncated]
|
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo { | |||
void updateAggregateStats(const Function &F, const LoopInfo &LI); | |||
void reIncludeBB(const BasicBlock &BB); | |||
|
|||
ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to plumb this through FunctionPropertiesAnalysis
instead of directly depending on the IR2Vec Analysis pass in MLInliner? I think the separation would keep things a bit cleaner unless there's a constraint compelling this design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because the patch-updating of IR2Vec embeddings uses the same contours as FPA. Also, eventually, I'd hope we can remove at least the expensive manually-computed features from FPA (those depending on LoopInfo for instance)
float OriginalArgWeight = ir2vec::ArgWeight; | ||
|
||
void createTestVocabulary(unsigned Dim) { | ||
Vocabulary["add"] = ir2vec::Embedding(Dim, 0.1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you load this from one of the test JSON files rather than explicitly creating it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the constructor of vocab analysis could only take in Vocabulary map and not JSON. It has to be via flags if we need to use JSON. This design of constructor would help in later patches where vocab map would be autogenerated during the build time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could populate it with a macro, easier to read. Like
#define ENTRY(NAME, DEFAULT_VALUE) Vocabulary[#NAME] = ir2vec::Embedding(Dim, DEFAULT_VALUE);
ENTRY(add, 0.1)
ENTRY(sub, 0.2)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified it using lambda. Please let me know if this is better. (Just thought this would reduce the number of lines)
llvm/lib/Analysis/IR2Vec.cpp
Outdated
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary) | ||
: Vocabulary(std::move(Vocabulary)) {} | ||
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary) | ||
: Vocabulary(Vocabulary) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
can this change be made in a precursor PR?
or can you pass by reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, moved it to earlier PR (#143200)
79060df
to
abee4cd
Compare
58c85ec
to
6085399
Compare
abee4cd
to
cd2cdf5
Compare
6085399
to
579c3c0
Compare
cd2cdf5
to
716f3d2
Compare
579c3c0
to
b40003f
Compare
716f3d2
to
173c3b1
Compare
b40003f
to
5d7a2b0
Compare
5d7a2b0
to
b7ec652
Compare
173c3b1
to
6e44bb0
Compare
6e44bb0
to
40402d2
Compare
b7ec652
to
d151083
Compare
40402d2
to
ece3e21
Compare
d151083
to
a2bec77
Compare
a2bec77
to
9124e83
Compare
9124e83
to
c157323
Compare
Changes to use Symbolic embeddings in MLInliner.
(Fixes #141836, Tracking issue - #141817)