Skip to content

[IR2Vec] Simplifying creation of Embedder #143999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/06-12-_ir2vec_scale_vocab
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions llvm/docs/MLGO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.

// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
std::unique_ptr<ir2vec::Embedder> Emb =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);

if (auto Err = EmbOrErr.takeError()) {
// Handle error in embedder creation
return;
}
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);

3. **Compute and Access Embeddings**:
Call ``getFunctionVector()`` to get the embedding for the function.
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ class Embedder {
virtual ~Embedder() = default;

/// Factory method to create an Embedder object.
static Expected<std::unique_ptr<Embedder>>
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
const Vocab &Vocabulary);

/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
Expand Down
10 changes: 3 additions & 7 deletions llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
if (Error Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
EI.message());
});
if (!Embedder) {
BB.getContext().emitError("Error creating IR2Vec embeddings");
return;
}
auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.
Expand Down
17 changes: 7 additions & 10 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}

Expected<std::unique_ptr<Embedder>>
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
llvm_unreachable("Unknown IR2Vec kind");
return nullptr;
}

// FIXME: Currently lookups are string based. Use numeric Keys
Expand Down Expand Up @@ -384,17 +385,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,

auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
});
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
}

std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);

OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);
Expand Down
7 changes: 3 additions & 4 deletions llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
}

std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
auto EmbResult =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
EXPECT_TRUE(static_cast<bool>(EmbResult));
return std::move(*EmbResult);
auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
EXPECT_TRUE(static_cast<bool>(Emb));
return std::move(Emb);
}
};

Expand Down
44 changes: 17 additions & 27 deletions llvm/unittests/Analysis/IR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);

auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));

auto *Emb = Result->get();
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}

Expand All @@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);

// static_cast an invalid int to IR2VecKind
// static_cast an invalid int to IR2VecKind
#ifndef NDEBUG
#if GTEST_HAS_DEATH_TEST
EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
"Unknown IR2Vec kind");
#endif // GTEST_HAS_DEATH_TEST
#else
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));

std::string ErrMsg;
llvm::handleAllErrors(
Result.takeError(),
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
#endif // NDEBUG
}

TEST(IR2VecTest, LookupVocab) {
Expand Down Expand Up @@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;

float OriginalOpcWeight = ::OpcWeight;
float OriginalTypeWeight = ::TypeWeight;
float OriginalArgWeight = ::ArgWeight;

void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.25, 0.25}},
Expand All @@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
};

TEST_F(IR2VecTestFixture, GetInstVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));

const auto &InstMap = Emb->getInstVecMap();

Expand All @@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
}

TEST_F(IR2VecTestFixture, GetBBVecMap) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));

const auto &BBMap = Emb->getBBVecMap();

Expand All @@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
}

TEST_F(IR2VecTestFixture, GetBBVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));

const auto &BBVec = Emb->getBBVector(*BB);

Expand All @@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
}

TEST_F(IR2VecTestFixture, GetFunctionVector) {
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Result));
auto Emb = std::move(*Result);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));

const auto &FuncVec = Emb->getFunctionVector();

Expand Down
Loading