Skip to content

[6.2][cxx-interop] Support Swiftifying C++ constructors #82061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions lib/ClangImporter/ImportDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ createFuncOrAccessor(ClangImporter::Implementation &impl, SourceLoc funcLoc,
genericParams, dc, clangNode);
}
impl.importSwiftAttrAttributes(decl);
impl.swiftify(decl);

return decl;
}
Expand Down Expand Up @@ -4361,6 +4360,7 @@ namespace {
}

recordObjCOverride(result);
Impl.swiftify(result);
}

static bool hasComputedPropertyAttr(const clang::Decl *decl) {
Expand Down Expand Up @@ -9131,7 +9131,7 @@ static bool SwiftifiablePointerType(Type swiftType) {
(nonnullType->getAnyPointerElementType(PTK) && PTK != PTK_AutoreleasingUnsafeMutablePointer);
}

void ClangImporter::Implementation::swiftify(FuncDecl *MappedDecl) {
void ClangImporter::Implementation::swiftify(AbstractFunctionDecl *MappedDecl) {
if (!SwiftContext.LangOpts.hasFeature(Feature::SafeInteropWrappers))
return;
if (importSymbolicCXXDecls)
Expand Down Expand Up @@ -9166,7 +9166,15 @@ void ClangImporter::Implementation::swiftify(FuncDecl *MappedDecl) {
return false;
};
SwiftifyInfoPrinter printer(getClangASTContext(), SwiftContext, out);
Type swiftReturnTy = MappedDecl->getResultInterfaceType();
Type swiftReturnTy;
if (const auto *funcDecl = dyn_cast<FuncDecl>(MappedDecl))
swiftReturnTy = funcDecl->getResultInterfaceType();
else if (const auto *ctorDecl = dyn_cast<ConstructorDecl>(MappedDecl))
swiftReturnTy = ctorDecl->getResultInterfaceType();
else {
assert(false && "Unexpected AbstractFunctionDecl subclass.");
return;
}
bool returnIsStdSpan = registerStdSpanTypeMapping(
swiftReturnTy, ClangDecl->getReturnType());
auto *CAT = ClangDecl->getReturnType()->getAs<clang::CountAttributedType>();
Expand Down
2 changes: 1 addition & 1 deletion lib/ClangImporter/ImporterImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1753,7 +1753,7 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
}

void importSwiftAttrAttributes(Decl *decl);
void swiftify(FuncDecl *MappedDecl);
void swiftify(AbstractFunctionDecl *MappedDecl);

/// Find the lookup table that corresponds to the given Clang module.
///
Expand Down
145 changes: 92 additions & 53 deletions lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible {
var dependencies: [LifetimeDependence] { get set }

func getBoundsCheckedThunkBuilder(
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts
) -> BoundsCheckedThunkBuilder
}

func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TokenSyntax? {
func tryGetParamName(_ funcDecl: FunctionParts, _ expr: SwiftifyExpr) -> TokenSyntax? {
switch expr {
case .param(let i):
let funcParam = getParam(funcDecl, i - 1)
Expand All @@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To
}
}

func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TypeSyntax {
func getSwiftifyExprType(_ funcDecl: FunctionParts, _ expr: SwiftifyExpr) -> TypeSyntax {
switch expr {
case .param(let i):
let funcParam = getParam(funcDecl, i - 1)
Expand All @@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo {
}

func getBoundsCheckedThunkBuilder(
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts
) -> BoundsCheckedThunkBuilder {
switch pointerIndex {
case .param(let i):
Expand Down Expand Up @@ -115,7 +115,7 @@ struct CountedBy: ParamInfo {
}

func getBoundsCheckedThunkBuilder(
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts
) -> BoundsCheckedThunkBuilder {
switch pointerIndex {
case .param(let i):
Expand Down Expand Up @@ -400,14 +400,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
}
}

func getParam(_ funcDecl: FunctionDeclSyntax, _ paramIndex: Int) -> FunctionParameterSyntax {
func getParam(_ funcDecl: FunctionParts, _ paramIndex: Int) -> FunctionParameterSyntax {
return getParam(funcDecl.signature, paramIndex)
}

struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
let base: FunctionDeclSyntax
let base: FunctionParts

init(_ function: FunctionDeclSyntax) {
init(_ function: FunctionParts) {
base = function
}

Expand Down Expand Up @@ -467,14 +467,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
FunctionCallExprSyntax(
calledExpression: functionRef, leftParen: .leftParenToken(),
arguments: LabeledExprListSyntax(labeledArgs), rightParen: .rightParenToken()))
return "unsafe \(call)"
if base.name.tokenKind == .keyword(.`init`) {
return "unsafe self.\(call)"
} else {
return "unsafe \(call)"
}
}
}

struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let index: Int
public let funcDecl: FunctionDeclSyntax
public let funcDecl: FunctionParts
public let typeMappings: [String: String]
public let node: SyntaxProtocol
public let nonescaping: Bool
Expand Down Expand Up @@ -525,7 +529,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {

struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let funcDecl: FunctionDeclSyntax
public let funcDecl: FunctionParts
public let typeMappings: [String: String]
public let node: SyntaxProtocol
let isParameter: Bool = false
Expand Down Expand Up @@ -564,7 +568,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
var oldType: TypeSyntax { get }
var newType: TypeSyntax { get throws }
var funcDecl: FunctionDeclSyntax { get }
var funcDecl: FunctionParts { get }
}

extension BoundsThunkBuilder {
Expand Down Expand Up @@ -675,7 +679,7 @@ extension ParamBoundsThunkBuilder {
struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder
public let countExpr: ExprSyntax
public let funcDecl: FunctionDeclSyntax
public let funcDecl: FunctionParts
public let nonescaping: Bool
public let isSizedBy: Bool
public let dependencies: [LifetimeDependence]
Expand Down Expand Up @@ -743,7 +747,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
public let base: BoundsCheckedThunkBuilder
public let index: Int
public let countExpr: ExprSyntax
public let funcDecl: FunctionDeclSyntax
public let funcDecl: FunctionParts
public let nonescaping: Bool
public let isSizedBy: Bool
let isParameter: Bool = true
Expand Down Expand Up @@ -1237,22 +1241,22 @@ func parseMacroParam(
}
}

func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws {
func checkArgs(_ args: [ParamInfo], _ funcComponents: FunctionParts) throws {
var argByIndex: [Int: ParamInfo] = [:]
var ret: ParamInfo? = nil
let paramCount = funcDecl.signature.parameterClause.parameters.count
let paramCount = funcComponents.signature.parameterClause.parameters.count
try args.forEach { pointerInfo in
switch pointerInfo.pointerIndex {
case .param(let i):
if i < 1 || i > paramCount {
let noteMessage =
paramCount > 0
? "function \(funcDecl.name) has parameter indices 1..\(paramCount)"
: "function \(funcDecl.name) has no parameters"
? "function \(funcComponents.name) has parameter indices 1..\(paramCount)"
: "function \(funcComponents.name) has no parameters"
throw DiagnosticError(
"pointer index out of bounds", node: pointerInfo.original,
notes: [
Note(node: Syntax(funcDecl.name), message: MacroExpansionNoteMessage(noteMessage))
Note(node: Syntax(funcComponents.name), message: MacroExpansionNoteMessage(noteMessage))
])
}
if argByIndex[i] != nil {
Expand Down Expand Up @@ -1316,7 +1320,7 @@ func isInout(_ type: TypeSyntax) -> Bool {
}

func getReturnLifetimeAttribute(
_ funcDecl: FunctionDeclSyntax,
_ funcDecl: FunctionParts,
_ dependencies: [SwiftifyExpr: [LifetimeDependence]]
) -> [AttributeListSyntax.Element] {
let returnDependencies = dependencies[.`return`, default: []]
Expand Down Expand Up @@ -1473,9 +1477,9 @@ class CountExprRewriter: SyntaxRewriter {
}
}

func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDeclSyntax, CountExprRewriter) {
let params = funcDecl.signature.parameterClause.parameters
let funcName = funcDecl.name.withoutBackticks.trimmed.text
func renameParameterNamesIfNeeded(_ funcComponents: FunctionParts) -> (FunctionParts, CountExprRewriter) {
let params = funcComponents.signature.parameterClause.parameters
let funcName = funcComponents.name.withoutBackticks.trimmed.text
let shouldRename = params.contains(where: { param in
let paramName = param.name.trimmed.text
return paramName == "_" || paramName == funcName || "`\(paramName)`" == funcName
Expand All @@ -1499,13 +1503,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe
}
return newParam
}
let newDecl = if renamedParams.count > 0 {
funcDecl.with(\.signature.parameterClause.parameters, FunctionParameterListSyntax(newParams))
let newSig = if renamedParams.count > 0 {
funcComponents.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams))
} else {
// Keeps source locations for diagnostics, in the common case where nothing was renamed
funcDecl
funcComponents.signature
}
return (FunctionParts(signature: newSig, name: funcComponents.name, attributes: funcComponents.attributes),
CountExprRewriter(renamedParams))
}

struct FunctionParts {
let signature: FunctionSignatureSyntax
let name: TokenSyntax
let attributes: AttributeListSyntax
}

func deconstructFunction(_ declaration: some DeclSyntaxProtocol) throws -> FunctionParts {
if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) {
return FunctionParts(signature: origFuncDecl.signature, name: origFuncDecl.name,
attributes: origFuncDecl.attributes)
}
if let origInitDecl = declaration.as(InitializerDeclSyntax.self) {
return FunctionParts(signature: origInitDecl.signature, name: origInitDecl.initKeyword,
attributes: origInitDecl.attributes)
}
return (newDecl, CountExprRewriter(renamedParams))
throw DiagnosticError("@_SwiftifyImport only works on functions and initializers", node: declaration)
}

/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
Expand All @@ -1521,10 +1544,8 @@ public struct SwiftifyImportMacro: PeerMacro {
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
do {
guard let origFuncDecl = declaration.as(FunctionDeclSyntax.self) else {
throw DiagnosticError("@_SwiftifyImport only works on functions", node: declaration)
}
let (funcDecl, rewriter) = renameParameterNamesIfNeeded(origFuncDecl)
let origFuncComponents = try deconstructFunction(declaration)
let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents)

let argumentList = node.arguments!.as(LabeledExprListSyntax.self)!
var arguments = [LabeledExprSyntax](argumentList)
Expand All @@ -1540,10 +1561,10 @@ public struct SwiftifyImportMacro: PeerMacro {
var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:]
var parsedArgs = try arguments.compactMap {
try parseMacroParam(
$0, funcDecl.signature, rewriter, nonescapingPointers: &nonescapingPointers,
$0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers,
lifetimeDependencies: &lifetimeDependencies)
}
parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcDecl.signature, typeMappings))
parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings))
setNonescapingPointers(&parsedArgs, nonescapingPointers)
setLifetimeDependencies(&parsedArgs, lifetimeDependencies)
// We only transform non-escaping spans.
Expand All @@ -1554,7 +1575,7 @@ public struct SwiftifyImportMacro: PeerMacro {
return true
}
}
try checkArgs(parsedArgs, funcDecl)
try checkArgs(parsedArgs, funcComponents)
parsedArgs.sort { a, b in
// make sure return value cast to Span happens last so that withUnsafeBufferPointer
// doesn't return a ~Escapable type
Expand All @@ -1566,12 +1587,12 @@ public struct SwiftifyImportMacro: PeerMacro {
}
return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex)
}
let baseBuilder = FunctionCallBuilder(funcDecl)
let baseBuilder = FunctionCallBuilder(funcComponents)

let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce(
baseBuilder,
{ (prev, parsedArg) in
parsedArg.getBoundsCheckedThunkBuilder(prev, funcDecl)
parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents)
})
let newSignature = try builder.buildFunctionSignature([:], nil)
var eliminatedArgs = Set<Int>()
Expand All @@ -1580,15 +1601,22 @@ public struct SwiftifyImportMacro: PeerMacro {
let checks = (basicChecks + compoundChecks).map { e in
CodeBlockItemSyntax(leadingTrivia: "\n", item: e)
}
let call = CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
ReturnStmtSyntax(
returnKeyword: .keyword(.return, trailingTrivia: " "),
expression: try builder.buildFunctionCall([:]))))
var call : CodeBlockItemSyntax
if declaration.is(InitializerDeclSyntax.self) {
call = CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
try builder.buildFunctionCall([:])))
} else {
call = CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
ReturnStmtSyntax(
returnKeyword: .keyword(.return, trailingTrivia: " "),
expression: try builder.buildFunctionCall([:]))))
}
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcDecl, lifetimeDependencies)
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies)
let lifetimeAttrs =
returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes)
returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes)
let availabilityAttr = try getAvailability(newSignature, spanAvailability)
let disfavoredOverload: [AttributeListSyntax.Element] =
[
Expand All @@ -1597,13 +1625,7 @@ public struct SwiftifyImportMacro: PeerMacro {
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload")))
]
let newFunc =
funcDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(
\.attributes,
funcDecl.attributes.filter { e in
let attributes = funcComponents.attributes.filter { e in
switch e {
case .attribute(let attr):
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
Expand All @@ -1619,9 +1641,23 @@ public struct SwiftifyImportMacro: PeerMacro {
]
+ availabilityAttr
+ lifetimeAttrs
+ disfavoredOverload)
.with(\.leadingTrivia, node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n"))
return [DeclSyntax(newFunc)]
+ disfavoredOverload
let trivia = node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n")
if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) {
return [DeclSyntax(origFuncDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))]
}
if let origInitDecl = declaration.as(InitializerDeclSyntax.self) {
return [DeclSyntax(origInitDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))]
}
return []
} catch let error as DiagnosticError {
context.diagnose(
Diagnostic(
Expand Down Expand Up @@ -1686,6 +1722,9 @@ extension FunctionParameterSyntax {

extension TokenSyntax {
public var withoutBackticks: TokenSyntax {
if self.identifier == nil {
return self
}
return .identifier(self.identifier!.name)
}

Expand Down
2 changes: 1 addition & 1 deletion test/Interop/Cxx/swiftify-import/counted-by-method.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ import Method
func test(s: UnsafeMutableBufferPointer<Float>) {
var foo = Foo()
foo.bar(s)
}
}
Loading