Skip to content

[MLGO][IR2Vec] Integrating IR2Vec with MLInliner #143479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H

#include "llvm/ADT/DenseSet.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/Compiler.h"
Expand All @@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
void updateAggregateStats(const Function &F, const LoopInfo &LI);
void reIncludeBB(const BasicBlock &BB);

ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to plumb this through FunctionPropertiesAnalysis instead of directly depending on the IR2Vec Analysis pass in MLInliner? I think the separation would keep things a bit cleaner unless there's a constraint compelling this design.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the patch-updating of IR2Vec embeddings uses the same contours as FPA. Also, eventually, I'd hope we can remove at least the expensive manually-computed features from FPA (those depending on LoopInfo for instance)

std::optional<ir2vec::Vocab> IR2VecVocab;

public:
LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
const LoopInfo &LI);
const LoopInfo &LI,
const IR2VecVocabResult *VocabResult);

LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);

bool operator==(const FunctionPropertiesInfo &FPI) const {
return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
}
bool operator==(const FunctionPropertiesInfo &FPI) const;

bool operator!=(const FunctionPropertiesInfo &FPI) const {
return !(*this == FPI);
Expand Down Expand Up @@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
int64_t CallReturnsVectorPointerCount = 0;
int64_t CallWithManyArgumentsCount = 0;
int64_t CallWithPointerArgumentCount = 0;

const ir2vec::Embedding &getFunctionEmbedding() const {
return FunctionEmbedding;
}

const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
return IR2VecVocab;
}

// Helper intended to be useful for unittests
void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
FunctionEmbedding = Embedding;
}
};

// Analysis pass
Expand Down Expand Up @@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {

DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;

DenseSet<const BasicBlock *> Successors;
DenseSet<const BasicBlock *> Successors, CallUsers;

// Edges we might potentially need to remove from the dominator tree.
SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/InlineAdvisor.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
};

Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }

private:
static bool initializeIR2VecVocabIfRequested(Module &M,
ModuleAnalysisManager &MAM);
};

/// Printer pass for the InlineAdvisorAnalysis results.
Expand Down
8 changes: 7 additions & 1 deletion llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ enum class FeatureIndex : size_t {
INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
#undef POPULATE_INDICES

// IR2Vec embeddings
// Dimensions of embeddings are not known in the compile time (until vocab is
// read). Hence macros cannot be used here.
callee_embedding,
caller_embedding,

NumberOfFeatures
};
// clang-format on
Expand All @@ -154,7 +160,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
constexpr size_t NumberOfFeatures =
static_cast<size_t>(FeatureIndex::NumberOfFeatures);

LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
LLVM_ABI extern std::vector<TensorSpec> FeatureMap;

LLVM_ABI extern const char *const DecisionName;
LLVM_ABI extern const TensorSpec InlineDecisionSpec;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/MLInlineAdvisor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
int64_t NodeCount = 0;
int64_t EdgeCount = 0;
int64_t EdgesOfLastSeenNodes = 0;
const bool UseIR2Vec;

std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
const int32_t InitialIRSize = 0;
Expand Down
115 changes: 112 additions & 3 deletions llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
#undef CHECK_OPERAND
}
}

if (IR2VecVocab) {
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
if (Error Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
EI.message());
});
return;
}
auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.
if (Direction == -1)
FunctionEmbedding -= BBEmbedding;
else
FunctionEmbedding += BBEmbedding;
}
}

void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
Expand All @@ -220,21 +243,91 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,

FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
Function &F, FunctionAnalysisManager &FAM) {
// We use the cached result of the IR2VecVocabAnalysis run by
// InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
// use IR2Vec embeddings.
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
FAM.getResult<LoopAnalysis>(F));
FAM.getResult<LoopAnalysis>(F), VocabResult);
}

FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
const Function &F, const DominatorTree &DT, const LoopInfo &LI,
const IR2VecVocabResult *VocabResult) {

FunctionPropertiesInfo FPI;
if (VocabResult && VocabResult->isValid()) {
FPI.IR2VecVocab = VocabResult->getVocabulary();
FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
}
for (const auto &BB : F)
if (DT.isReachableFromEntry(&BB))
FPI.reIncludeBB(BB);
FPI.updateAggregateStats(F, LI);
return FPI;
}

bool FunctionPropertiesInfo::operator==(
const FunctionPropertiesInfo &FPI) const {
if (BasicBlockCount != FPI.BasicBlockCount ||
BlocksReachedFromConditionalInstruction !=
FPI.BlocksReachedFromConditionalInstruction ||
Uses != FPI.Uses ||
DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
LoadInstCount != FPI.LoadInstCount ||
StoreInstCount != FPI.StoreInstCount ||
MaxLoopDepth != FPI.MaxLoopDepth ||
TopLevelLoopCount != FPI.TopLevelLoopCount ||
TotalInstructionCount != FPI.TotalInstructionCount ||
BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
BasicBlocksWithMoreThanTwoSuccessors !=
FPI.BasicBlocksWithMoreThanTwoSuccessors ||
BasicBlocksWithSinglePredecessor !=
FPI.BasicBlocksWithSinglePredecessor ||
BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
BasicBlocksWithMoreThanTwoPredecessors !=
FPI.BasicBlocksWithMoreThanTwoPredecessors ||
BigBasicBlocks != FPI.BigBasicBlocks ||
MediumBasicBlocks != FPI.MediumBasicBlocks ||
SmallBasicBlocks != FPI.SmallBasicBlocks ||
CastInstructionCount != FPI.CastInstructionCount ||
FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
IntegerInstructionCount != FPI.IntegerInstructionCount ||
ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
ConstantOperandCount != FPI.ConstantOperandCount ||
InstructionOperandCount != FPI.InstructionOperandCount ||
BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
ArgumentOperandCount != FPI.ArgumentOperandCount ||
UnknownOperandCount != FPI.UnknownOperandCount ||
CriticalEdgeCount != FPI.CriticalEdgeCount ||
ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
IntrinsicCount != FPI.IntrinsicCount ||
DirectCallCount != FPI.DirectCallCount ||
IndirectCallCount != FPI.IndirectCallCount ||
CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
return false;
}
// Check the equality of the function embeddings. We don't check the equality
// of Vocabulary as it remains the same.
if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
return false;

return true;
}

void FunctionPropertiesInfo::print(raw_ostream &OS) const {
#define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";

Expand Down Expand Up @@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
// The caller's entry BB may change due to new alloca instructions.
LikelyToChangeBBs.insert(&*Caller.begin());

// The users of the value returned by call instruction can change
// leading to the change in embeddings being computed, when used.
// We conservatively add the BBs with such uses to LikelyToChangeBBs.
for (const auto *User : CB.users())
CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
// CallSiteBB can be removed from CallUsers if present, it's taken care
// separately.
CallUsers.erase(&CallSiteBB);
LikelyToChangeBBs.insert_range(CallUsers);

// The successors may become unreachable in the case of `invoke` inlining.
// We track successors separately, too, because they form a boundary, together
// with the CB BB ('Entry') between which the inlined callee will be pasted.
Expand Down Expand Up @@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
if (&CallSiteBB != &*Caller.begin())
Reinclude.insert(&*Caller.begin());

// Reinclude the BBs which use the values returned by call instruction
Reinclude.insert_range(CallUsers);

// Distribute the successors to the 2 buckets.
for (const auto *Succ : Successors)
if (DT.isReachableFromEntry(Succ))
Expand Down Expand Up @@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
return false;
DominatorTree DT(F);
LoopInfo LI(DT);
auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
auto Fresh =
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
return FPI == Fresh;
}
29 changes: 29 additions & 0 deletions llvm/lib/Analysis/InlineAdvisor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/EphemeralValuesCache.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
Expand Down Expand Up @@ -64,6 +65,13 @@ static cl::opt<bool>
cl::desc("If true, annotate inline advisor remarks "
"with LTO and pass information."));

// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
// with ML inliner. The vocab file is used to initialize the embeddings.
static cl::opt<std::string> IR2VecVocabFile(
"ml-inliner-ir2vec-vocab-file", cl::Hidden,
cl::desc("Vocab file for IR2Vec; Setting this enables "
"configuring the model to use IR2Vec embeddings."));

namespace llvm {
extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
} // namespace llvm
Expand Down Expand Up @@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
AnalysisKey InlineAdvisorAnalysis::Key;
AnalysisKey PluginInlineAdvisorAnalysis::Key;

bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
Module &M, ModuleAnalysisManager &MAM) {
if (!IR2VecVocabFile.empty()) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
if (!IR2VecVocabResult.isValid()) {
M.getContext().emitError("Failed to load IR2Vec vocabulary");
return false;
}
}
// No vocab file specified is OK; We just don't use IR2Vec
// embeddings.
return true;
}

bool InlineAdvisorAnalysis::Result::tryCreate(
InlineParams Params, InliningAdvisorMode Mode,
const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
Expand All @@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
/* EmitRemarks =*/true, IC);
}
break;
// Run IR2VecVocabAnalysis once per module to get the vocabulary.
// We run it here because it is immutable and we want to avoid running it
// multiple times.
case InliningAdvisorMode::Development:
#ifdef LLVM_HAVE_TFLITE
LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
return false;
Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
#endif
break;
case InliningAdvisorMode::Release:
LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
return false;
Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
break;
}
Expand Down
34 changes: 33 additions & 1 deletion llvm/lib/Analysis/MLInlineAdvisor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
cl::init(false));

// clang-format off
const std::vector<TensorSpec> llvm::FeatureMap{
std::vector<TensorSpec> llvm::FeatureMap{
#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
// InlineCost features - these must come first
INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
Expand Down Expand Up @@ -144,6 +144,7 @@ MLInlineAdvisor::MLInlineAdvisor(
M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(M) != nullptr),
InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
assert(ModelRunner);
Expand Down Expand Up @@ -186,6 +187,19 @@ MLInlineAdvisor::MLInlineAdvisor(
EdgeCount += getLocalCalls(KVP.first->getFunction());
}
NodeCount = AllNodes.size();

if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
if (!IR2VecVocabResult->isValid()) {
M.getContext().emitError("IR2VecVocabAnalysis is not valid");
return;
}
// Add the IR2Vec features to the feature map
auto IR2VecDim = IR2VecVocabResult->getDimension();
FeatureMap.push_back(
TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
FeatureMap.push_back(
TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
}
}

unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
Expand Down Expand Up @@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
*ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
Caller.hasAvailableExternallyLinkage();

if (UseIR2Vec) {
// Python side expects float embeddings. The IR2Vec embeddings are doubles
// as of now due to the restriction of fromJSON method used by the
// readVocabulary method in ir2vec::Embeddings.
auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
FeatureIndex Index) {
auto Embedding_float =
std::vector<float>(Embedding.begin(), Embedding.end());
std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
Embedding.size() * sizeof(float));
};

setEmbedding(CalleeBefore.getFunctionEmbedding(),
FeatureIndex::callee_embedding);
setEmbedding(CallerBefore.getFunctionEmbedding(),
FeatureIndex::caller_embedding);
}

// Add the cost features
for (size_t I = 0;
I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
Expand Down
Loading
Loading