Skip to content

Commit 58c85ec

Browse files
committed
[MLIniner][IR2Vec] Integrating IR2Vec with MLInliner
1 parent 79060df commit 58c85ec

File tree

10 files changed

+361
-38
lines changed

10 files changed

+361
-38
lines changed

llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
1616

1717
#include "llvm/ADT/DenseSet.h"
18+
#include "llvm/Analysis/IR2Vec.h"
1819
#include "llvm/IR/Dominators.h"
1920
#include "llvm/IR/PassManager.h"
2021
#include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
3233
void updateAggregateStats(const Function &F, const LoopInfo &LI);
3334
void reIncludeBB(const BasicBlock &BB);
3435

36+
ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
37+
std::optional<ir2vec::Vocab> IR2VecVocab;
38+
3539
public:
3640
LLVM_ABI static FunctionPropertiesInfo
3741
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
38-
const LoopInfo &LI);
42+
const LoopInfo &LI,
43+
const IR2VecVocabResult *VocabResult);
3944

4045
LLVM_ABI static FunctionPropertiesInfo
4146
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
4247

43-
bool operator==(const FunctionPropertiesInfo &FPI) const {
44-
return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
45-
}
48+
bool operator==(const FunctionPropertiesInfo &FPI) const;
4649

4750
bool operator!=(const FunctionPropertiesInfo &FPI) const {
4851
return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
137140
int64_t CallReturnsVectorPointerCount = 0;
138141
int64_t CallWithManyArgumentsCount = 0;
139142
int64_t CallWithPointerArgumentCount = 0;
143+
144+
const ir2vec::Embedding &getFunctionEmbedding() const {
145+
return FunctionEmbedding;
146+
}
147+
148+
const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
149+
return IR2VecVocab;
150+
}
151+
152+
// Helper intended to be useful for unittests
153+
void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
154+
FunctionEmbedding = Embedding;
155+
}
140156
};
141157

142158
// Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
192208

193209
DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
194210

195-
DenseSet<const BasicBlock *> Successors;
211+
DenseSet<const BasicBlock *> Successors, CallUsers;
196212

197213
// Edges we might potentially need to remove from the dominator tree.
198214
SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
239239
public:
240240
static AnalysisKey Key;
241241
IR2VecVocabAnalysis() = default;
242-
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
242+
explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab);
243243
using Result = IR2VecVocabResult;
244244
Result run(Module &M, ModuleAnalysisManager &MAM);
245245
};

llvm/include/llvm/Analysis/InlineAdvisor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
331331
};
332332

333333
Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
334+
335+
private:
336+
static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM);
334337
};
335338

336339
/// Printer pass for the InlineAdvisorAnalysis results.

llvm/include/llvm/Analysis/InlineModelFeatureMaps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ enum class FeatureIndex : size_t {
142142
INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
143143
#undef POPULATE_INDICES
144144

145+
// IR2Vec embeddings
146+
callee_embedding,
147+
caller_embedding,
148+
145149
NumberOfFeatures
146150
};
147151
// clang-format on
@@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
154158
constexpr size_t NumberOfFeatures =
155159
static_cast<size_t>(FeatureIndex::NumberOfFeatures);
156160

157-
LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
161+
LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
158162

159163
LLVM_ABI extern const char *const DecisionName;
160164
LLVM_ABI extern const TensorSpec InlineDecisionSpec;

llvm/include/llvm/Analysis/MLInlineAdvisor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
8282
int64_t NodeCount = 0;
8383
int64_t EdgeCount = 0;
8484
int64_t EdgesOfLastSeenNodes = 0;
85+
bool UseIR2Vec = false;
8586

8687
std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
8788
const int32_t InitialIRSize = 0;

llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
199199
#undef CHECK_OPERAND
200200
}
201201
}
202+
203+
if (IR2VecVocab) {
204+
// We instantiate the IR2Vec embedder each time, as having an unique
205+
// pointer to the embedder as member of the class would make it
206+
// non-copyable. Instantiating the embedder in itself is not costly.
207+
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
208+
*BB.getParent(), *IR2VecVocab);
209+
if (Error Err = EmbOrErr.takeError()) {
210+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
211+
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
212+
EI.message());
213+
});
214+
return;
215+
}
216+
auto Embedder = std::move(*EmbOrErr);
217+
const auto &BBEmbedding = Embedder->getBBVector(BB);
218+
// Subtract BBEmbedding from Function embedding if the direction is -1,
219+
// and add it if the direction is +1.
220+
if (Direction == -1)
221+
FunctionEmbedding -= BBEmbedding;
222+
else
223+
FunctionEmbedding += BBEmbedding;
224+
}
202225
}
203226

204227
void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,21 +243,91 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
220243

221244
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
222245
Function &F, FunctionAnalysisManager &FAM) {
246+
// We use the cached result of the IR2VecVocabAnalysis run by
247+
// InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
248+
// use IR2Vec embeddings.
249+
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
250+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
223251
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
224-
FAM.getResult<LoopAnalysis>(F));
252+
FAM.getResult<LoopAnalysis>(F), VocabResult);
225253
}
226254

227255
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
228-
const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
256+
const Function &F, const DominatorTree &DT, const LoopInfo &LI,
257+
const IR2VecVocabResult *VocabResult) {
229258

230259
FunctionPropertiesInfo FPI;
260+
if (VocabResult && VocabResult->isValid()) {
261+
FPI.IR2VecVocab = VocabResult->getVocabulary();
262+
FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
263+
}
231264
for (const auto &BB : F)
232265
if (DT.isReachableFromEntry(&BB))
233266
FPI.reIncludeBB(BB);
234267
FPI.updateAggregateStats(F, LI);
235268
return FPI;
236269
}
237270

271+
bool FunctionPropertiesInfo::operator==(
272+
const FunctionPropertiesInfo &FPI) const {
273+
if (BasicBlockCount != FPI.BasicBlockCount ||
274+
BlocksReachedFromConditionalInstruction !=
275+
FPI.BlocksReachedFromConditionalInstruction ||
276+
Uses != FPI.Uses ||
277+
DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
278+
LoadInstCount != FPI.LoadInstCount ||
279+
StoreInstCount != FPI.StoreInstCount ||
280+
MaxLoopDepth != FPI.MaxLoopDepth ||
281+
TopLevelLoopCount != FPI.TopLevelLoopCount ||
282+
TotalInstructionCount != FPI.TotalInstructionCount ||
283+
BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
284+
BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
285+
BasicBlocksWithMoreThanTwoSuccessors !=
286+
FPI.BasicBlocksWithMoreThanTwoSuccessors ||
287+
BasicBlocksWithSinglePredecessor !=
288+
FPI.BasicBlocksWithSinglePredecessor ||
289+
BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
290+
BasicBlocksWithMoreThanTwoPredecessors !=
291+
FPI.BasicBlocksWithMoreThanTwoPredecessors ||
292+
BigBasicBlocks != FPI.BigBasicBlocks ||
293+
MediumBasicBlocks != FPI.MediumBasicBlocks ||
294+
SmallBasicBlocks != FPI.SmallBasicBlocks ||
295+
CastInstructionCount != FPI.CastInstructionCount ||
296+
FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
297+
IntegerInstructionCount != FPI.IntegerInstructionCount ||
298+
ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
299+
ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
300+
ConstantOperandCount != FPI.ConstantOperandCount ||
301+
InstructionOperandCount != FPI.InstructionOperandCount ||
302+
BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
303+
GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
304+
InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
305+
ArgumentOperandCount != FPI.ArgumentOperandCount ||
306+
UnknownOperandCount != FPI.UnknownOperandCount ||
307+
CriticalEdgeCount != FPI.CriticalEdgeCount ||
308+
ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
309+
UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
310+
IntrinsicCount != FPI.IntrinsicCount ||
311+
DirectCallCount != FPI.DirectCallCount ||
312+
IndirectCallCount != FPI.IndirectCallCount ||
313+
CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
314+
CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
315+
CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
316+
CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
317+
CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
318+
CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
319+
CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
320+
CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
321+
return false;
322+
}
323+
// Check the equality of the function embeddings. We don't check the equality
324+
// of Vocabulary as it remains the same.
325+
if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
326+
return false;
327+
328+
return true;
329+
}
330+
238331
void FunctionPropertiesInfo::print(raw_ostream &OS) const {
239332
#define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
240333

@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
322415
// The caller's entry BB may change due to new alloca instructions.
323416
LikelyToChangeBBs.insert(&*Caller.begin());
324417

418+
// The users of the value returned by call instruction can change
419+
// leading to the change in embeddings being computed, when used.
420+
// We conservatively add the BBs with such uses to LikelyToChangeBBs.
421+
for (const auto *User : CB.users())
422+
CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
423+
// CallSiteBB can be removed from CallUsers if present, it's taken care
424+
// separately.
425+
CallUsers.erase(&CallSiteBB);
426+
LikelyToChangeBBs.insert_range(CallUsers);
427+
325428
// The successors may become unreachable in the case of `invoke` inlining.
326429
// We track successors separately, too, because they form a boundary, together
327430
// with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
435538
if (&CallSiteBB != &*Caller.begin())
436539
Reinclude.insert(&*Caller.begin());
437540

541+
// Reinclude the BBs which use the values returned by call instruction
542+
Reinclude.insert_range(CallUsers);
543+
438544
// Distribute the successors to the 2 buckets.
439545
for (const auto *Succ : Successors)
440546
if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
486592
return false;
487593
DominatorTree DT(F);
488594
LoopInfo LI(DT);
489-
auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
595+
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
596+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
597+
auto Fresh =
598+
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
490599
return FPI == Fresh;
491600
}

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ Error IR2VecVocabAnalysis::readVocabulary() {
294294
return Error::success();
295295
}
296296

297-
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
298-
: Vocabulary(std::move(Vocabulary)) {}
297+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
298+
: Vocabulary(Vocabulary) {}
299299

300300
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
301301
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {

llvm/lib/Analysis/InlineAdvisor.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/StringExtras.h"
1717
#include "llvm/Analysis/AssumptionCache.h"
1818
#include "llvm/Analysis/EphemeralValuesCache.h"
19+
#include "llvm/Analysis/IR2Vec.h"
1920
#include "llvm/Analysis/InlineCost.h"
2021
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
2122
#include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
6465
cl::desc("If true, annotate inline advisor remarks "
6566
"with LTO and pass information."));
6667

68+
// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
69+
// with ML inliner. The vocab file is used to initialize the embeddings.
70+
static cl::opt<std::string> IR2VecVocabFile(
71+
"ml-inliner-ir2vec-vocab-file", cl::Hidden,
72+
cl::desc("Vocab file for IR2Vec; Setting this enables "
73+
"configuring the model to use IR2Vec embeddings."));
74+
6775
namespace llvm {
6876
extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
6977
} // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
206214
AnalysisKey InlineAdvisorAnalysis::Key;
207215
AnalysisKey PluginInlineAdvisorAnalysis::Key;
208216

217+
bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M,
218+
ModuleAnalysisManager &MAM) {
219+
if (!IR2VecVocabFile.empty()) {
220+
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
221+
if (!IR2VecVocabResult.isValid()) {
222+
M.getContext().emitError("Failed to load IR2Vec vocabulary");
223+
return false;
224+
}
225+
}
226+
// No vocab file specified is OK; We just don't use IR2Vec
227+
// embeddings.
228+
return true;
229+
}
230+
209231
bool InlineAdvisorAnalysis::Result::tryCreate(
210232
InlineParams Params, InliningAdvisorMode Mode,
211233
const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
231253
/* EmitRemarks =*/true, IC);
232254
}
233255
break;
256+
// Run IR2VecVocabAnalysis once per module to get the vocabulary.
257+
// We run it here because it is immutable and we want to avoid running it
258+
// multiple times.
234259
case InliningAdvisorMode::Development:
235260
#ifdef LLVM_HAVE_TFLITE
236261
LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
262+
if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
263+
return false;
237264
Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
238265
#endif
239266
break;
240267
case InliningAdvisorMode::Release:
241268
LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
269+
if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
270+
return false;
242271
Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
243272
break;
244273
}

0 commit comments

Comments
 (0)