Skip to content

[SYCLomatic] Refine the cub API migration by using group local memory kernel scope allocation #2849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3130,7 +3130,7 @@ MemVarInfo::MemVarInfo(unsigned Offset,
auto DS1 = getParentDeclStmt(Var);
auto DS2 = getParentDeclStmt(DeclOfVarType);
if (DS1 && DS2 && DS1 == DS2) {
IsAnonymousType = true;
IsAnonymousType = !DeclOfVarType->hasNameForLinkage();
DeclStmtOfVarType = DS2;
const auto LocInfo = DpctGlobalInfo::getLocInfo(
getDefinitionRange(DS2->getBeginLoc(), DS2->getEndLoc())
Expand Down Expand Up @@ -3195,7 +3195,9 @@ std::string MemVarInfo::getDeclarationReplacement(const VarDecl *VD) {
OS << "auto &" << getName() << " = "
<< "*" << MapNames::getClNamespace()
<< "ext::oneapi::group_local_memory_for_overwrite<"
<< getType()->getBaseName();
<< ((isAnonymousType() && isShared() && isLocal())
? LocalTypeName
: getType()->getBaseName());
for (auto &ArraySize : getType()->getRange()) {
OS << "[" << ArraySize.getSize() << "]";
}
Expand Down
142 changes: 122 additions & 20 deletions clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,28 @@ TYPE_REWRITE_ENTRY(
WARNING_FACTORY(Diagnostics::UNSUPPORT_SYCLCOMPAT, TYPESTR),
HEADER_INSERTION_FACTORY(
HeaderType::HT_DPCT_GROUP_Utils,
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))
TYPE_CONDITIONAL_FACTORY(
UseGroupLocalMemory(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(9, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(4)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(10, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(4), TEMPLATE_ARG(8)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(4), TEMPLATE_ARG(8),
TEMPLATE_ARG(9)))),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))

// cub::BlockExchange
TYPE_REWRITE_ENTRY(
Expand All @@ -128,9 +147,28 @@ TYPE_REWRITE_ENTRY(
WARNING_FACTORY(Diagnostics::UNSUPPORT_SYCLCOMPAT, TYPESTR),
HEADER_INSERTION_FACTORY(
HeaderType::HT_DPCT_GROUP_Utils,
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))
TYPE_CONDITIONAL_FACTORY(
UseGroupLocalMemory(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(1)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(6, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(1), TEMPLATE_ARG(4)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(1), TEMPLATE_ARG(4),
TEMPLATE_ARG(5)))),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))

// cub::BlockShuffle
TYPE_REWRITE_ENTRY(
Expand Down Expand Up @@ -165,13 +203,45 @@ TYPE_REWRITE_ENTRY(
HEADER_INSERTION_FACTORY(
HeaderType::HT_DPCT_GROUP_Utils,
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2), TEMPLATE_ARG(3)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))
UseGroupLocalMemory(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
STR(MapNames::getDpctNamespace() +
"group::group_load_algorithm::blocked"),
TEMPLATE_ARG(1)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(6, false,
std::less<unsigned>()),
TYPE_FACTORY(
STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1),
TEMPLATE_ARG(4)),
TYPE_FACTORY(
STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1),
TEMPLATE_ARG(4), TEMPLATE_ARG(5))))),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_load"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))))
// cub::BlockStore
TYPE_REWRITE_ENTRY(
"cub::BlockStore",
Expand All @@ -181,13 +251,45 @@ TYPE_REWRITE_ENTRY(
HEADER_INSERTION_FACTORY(
HeaderType::HT_DPCT_GROUP_Utils,
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2), TEMPLATE_ARG(3)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))))
UseGroupLocalMemory(),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
STR(MapNames::getDpctNamespace() +
"group::group_store_algorithm::blocked"),
TEMPLATE_ARG(1)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(5, false, std::less<unsigned>()),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1)),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(6, false,
std::less<unsigned>()),
TYPE_FACTORY(
STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1),
TEMPLATE_ARG(4)),
TYPE_FACTORY(
STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3), TEMPLATE_ARG(1),
TEMPLATE_ARG(4), TEMPLATE_ARG(5))))),
TYPE_CONDITIONAL_FACTORY(
CheckTemplateArgCount(4),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2),
TEMPLATE_ARG(3)),
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
"group::group_store"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))))

FEATURE_REQUEST_FACTORY(
HelperFeatureEnum::device_ext,
Expand Down
66 changes: 66 additions & 0 deletions clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,9 +1152,75 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE,
case TypeLoc::Typedef:
case TypeLoc::Builtin:
case TypeLoc::Using:
case TypeLoc::DependentName:
case TypeLoc::Elaborated:
case TypeLoc::Record: {
TyName = DpctGlobalInfo::getTypeName(TL.getType());
if (DpctGlobalInfo::useGroupLocalMemory() &&
(TyName.find("TempStorage") != std::string::npos) &&
isPreserveCubVar(TL.getType())) {
const RecordDecl *RD = nullptr;
const TemplateDecl *TD = nullptr;
const TypedefNameDecl *TND = nullptr;
if (auto ETL = TL.getAs<ElaboratedTypeLoc>()) {
if (auto RTL = ETL.getNamedTypeLoc().getAs<RecordTypeLoc>()) {
RD = RTL.getDecl();
}
} else if (auto RTL = TL.getAs<RecordTypeLoc>()) {
RD = RTL.getDecl();
} else if (auto TTL = TL.getAs<TypedefTypeLoc>()) {
TND = TTL.getTypedefNameDecl();
} else if (auto DTL = TL.getAs<DependentNameTypeLoc>()) {
const DependentNameType *DT = DTL.getTypePtr();
auto *QNNS = DT->getQualifier();
if (QNNS->getKind() == NestedNameSpecifier::TypeSpec) {
if (auto *SpecType =
dyn_cast<TemplateSpecializationType>(QNNS->getAsType())) {
TD = SpecType->getTemplateName().getAsTemplateDecl();
} else if (auto *TT = dyn_cast<TypedefType>(QNNS->getAsType())) {
if (auto D = TT->getDecl()) {
if (auto *SpecType = D->getUnderlyingType()
.getCanonicalType()
.getTypePtr()
->getAs<TemplateSpecializationType>()) {
TD = SpecType->getTemplateName().getAsTemplateDecl();
}
}
}
}
}
if (RD) {
auto DC = RD->getDeclContext();
if (DC->getDeclKind() == Decl::Kind::ClassTemplateSpecialization) {
if (auto CTS = dyn_cast<ClassTemplateSpecializationDecl>(DC)) {
if (dpct::DpctGlobalInfo::isInCudaPath(
CTS->getSpecializedTemplate()->getLocation())) {
addReplacement(TL.getBeginLoc(), TL.getEndLoc(), CSCE,
"TempLocalMemory");
return;
}
}
}
}
if (TND) {
auto DC = TND->getDeclContext();
if (DC && DC->isRecord()) {
auto *RD = dyn_cast<RecordDecl>(DC);
if (dpct::DpctGlobalInfo::isInCudaPath(RD->getLocation())) {
addReplacement(TL.getBeginLoc(), TL.getEndLoc(), CSCE,
"TempLocalMemory");
return;
}
}
}
if (TD) {
if (dpct::DpctGlobalInfo::isInCudaPath(TD->getLocation())) {
addReplacement(TL.getAs<DependentNameTypeLoc>().getNameLoc(),
TL.getEndLoc(), CSCE, "TempLocalMemory");
return;
}
}
}
RewriteType(TyName, TL);
break;
}
Expand Down
17 changes: 13 additions & 4 deletions clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ inline auto UseSYCLCompat() {
return [](const TypeLoc) -> bool { return DpctGlobalInfo::useSYCLCompat(); };
}

inline auto UseGroupLocalMemory() {
return [](const TypeLoc) -> bool {
return DpctGlobalInfo::useGroupLocalMemory();
};
}

TemplateArgumentInfo getTemplateArg(const TypeLoc &TL, unsigned Idx) {
if (auto TSTL = TL.getAs<TemplateSpecializationTypeLoc>()) {
if (TSTL.getNumArgs() > Idx) {
Expand Down Expand Up @@ -103,23 +109,26 @@ makeUserDefinedTypeStrCreator(MetaRuleObject &R,
class CheckTemplateArgCount {
unsigned Count;
bool IsIncludeDefault;
std::function<bool(unsigned, unsigned)> CmpFunc;

public:
CheckTemplateArgCount(unsigned I, bool D = true)
: Count(I), IsIncludeDefault(D) {}
CheckTemplateArgCount(
unsigned I, bool D = true,
std::function<bool(unsigned, unsigned)> F = std::equal_to<unsigned>())
: Count(I), IsIncludeDefault(D), CmpFunc(F) {}
bool operator()(const TypeLoc TL) {
if (auto TSTL = TL.getAs<TemplateSpecializationTypeLoc>()) {
size_t Num = TSTL.getNumArgs();
if (IsIncludeDefault) {
return Num == Count;
return CmpFunc(Num, Count);
}
size_t NoneDefaultNum = 0;
for (size_t i = 0; i < Num; i++) {
if (!TSTL.getArgLoc(i).getArgument().getIsDefaulted()) {
NoneDefaultNum++;
}
}
return NoneDefaultNum == Count;
return CmpFunc(NoneDefaultNum, Count);
}
return false;
}
Expand Down
12 changes: 8 additions & 4 deletions clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ void MemVarMigrationRule::processTypeDeclaredLocal(
std::string Ret;
llvm::raw_string_ostream OS(Ret);
OS << getNL(DS->getEndLoc().isMacroID()) << getIndent(InsertSL, SM);
if (DpctGlobalInfo::useGroupLocalMemory()) {
OS << Info->getDeclarationReplacement(MemVar);
return OS.str();
}
OS << TypeName << ' ';
if (IsReference)
OS << '&';
Expand Down Expand Up @@ -719,8 +723,7 @@ void MemVarMigrationRule::processTypeDeclaredLocal(
emplaceTransformation(new InsertText(InsertSL, GenDeclStmt(NewTypeName)));
} else if (DS) {
// remove var decl
emplaceTransformation(ReplaceVarDecl::getVarDeclReplacement(
MemVar, Info->getDeclarationReplacement(MemVar)));
emplaceTransformation(new ReplaceVarDecl(MemVar, ""));

Info->setLocalTypeName(Info->getType()->getBaseName());
emplaceTransformation(
Expand All @@ -731,7 +734,8 @@ void MemVarMigrationRule::processTypeDeclaredLocal(
void MemVarMigrationRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
if (auto MemVar = getAssistNodeAsType<VarDecl>(Result, "var")) {
if (isCubVar(MemVar) || MemVar->hasAttr<CUDAConstantAttr>()) {
if ((isCubVar(MemVar) && !isPreserveCubVar(MemVar->getType())) ||
MemVar->hasAttr<CUDAConstantAttr>()) {
return;
}
std::string CanonicalType =
Expand Down Expand Up @@ -787,7 +791,7 @@ void MemVarAnalysisRule::registerMatcher(MatchFinder &MF) {

void MemVarAnalysisRule::runRule(const MatchFinder::MatchResult &Result) {
if (auto MemVar = getAssistNodeAsType<VarDecl>(Result, "var")) {
if (isCubVar(MemVar)) {
if (isCubVar(MemVar) && !isPreserveCubVar(MemVar->getType())) {
return;
}
std::string CanonicalType =
Expand Down
Loading