diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp index 8942556abce0a..b68f0090c3583 100644 --- a/lib/ClangImporter/ImportDecl.cpp +++ b/lib/ClangImporter/ImportDecl.cpp @@ -135,7 +135,6 @@ createFuncOrAccessor(ClangImporter::Implementation &impl, SourceLoc funcLoc, genericParams, dc, clangNode); } impl.importSwiftAttrAttributes(decl); - impl.swiftify(decl); return decl; } @@ -4361,6 +4360,7 @@ namespace { } recordObjCOverride(result); + Impl.swiftify(result); } static bool hasComputedPropertyAttr(const clang::Decl *decl) { @@ -9131,7 +9131,7 @@ static bool SwiftifiablePointerType(Type swiftType) { (nonnullType->getAnyPointerElementType(PTK) && PTK != PTK_AutoreleasingUnsafeMutablePointer); } -void ClangImporter::Implementation::swiftify(FuncDecl *MappedDecl) { +void ClangImporter::Implementation::swiftify(AbstractFunctionDecl *MappedDecl) { if (!SwiftContext.LangOpts.hasFeature(Feature::SafeInteropWrappers)) return; if (importSymbolicCXXDecls) @@ -9166,7 +9166,15 @@ void ClangImporter::Implementation::swiftify(FuncDecl *MappedDecl) { return false; }; SwiftifyInfoPrinter printer(getClangASTContext(), SwiftContext, out); - Type swiftReturnTy = MappedDecl->getResultInterfaceType(); + Type swiftReturnTy; + if (const auto *funcDecl = dyn_cast(MappedDecl)) + swiftReturnTy = funcDecl->getResultInterfaceType(); + else if (const auto *ctorDecl = dyn_cast(MappedDecl)) + swiftReturnTy = ctorDecl->getResultInterfaceType(); + else { + assert(false && "Unexpected AbstractFunctionDecl subclass."); + return; + } bool returnIsStdSpan = registerStdSpanTypeMapping( swiftReturnTy, ClangDecl->getReturnType()); auto *CAT = ClangDecl->getReturnType()->getAs(); diff --git a/lib/ClangImporter/ImporterImpl.h b/lib/ClangImporter/ImporterImpl.h index b718cc679b0e2..67b5348d88e82 100644 --- a/lib/ClangImporter/ImporterImpl.h +++ b/lib/ClangImporter/ImporterImpl.h @@ -1753,7 +1753,7 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation } void importSwiftAttrAttributes(Decl *decl); - void swiftify(FuncDecl *MappedDecl); + void swiftify(AbstractFunctionDecl *MappedDecl); /// Find the lookup table that corresponds to the given Clang module. /// diff --git a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift index 5be714860742a..c7e0f41dbcfa8 100644 --- a/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift +++ b/lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift @@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible { var dependencies: [LifetimeDependence] { get set } func getBoundsCheckedThunkBuilder( - _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax + _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts ) -> BoundsCheckedThunkBuilder } -func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TokenSyntax? { +func tryGetParamName(_ funcDecl: FunctionParts, _ expr: SwiftifyExpr) -> TokenSyntax? { switch expr { case .param(let i): let funcParam = getParam(funcDecl, i - 1) @@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To } } -func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TypeSyntax { +func getSwiftifyExprType(_ funcDecl: FunctionParts, _ expr: SwiftifyExpr) -> TypeSyntax { switch expr { case .param(let i): let funcParam = getParam(funcDecl, i - 1) @@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo { } func getBoundsCheckedThunkBuilder( - _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax + _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts ) -> BoundsCheckedThunkBuilder { switch pointerIndex { case .param(let i): @@ -115,7 +115,7 @@ struct CountedBy: ParamInfo { } func getBoundsCheckedThunkBuilder( - _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax + _ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionParts ) -> BoundsCheckedThunkBuilder { switch pointerIndex { case .param(let i): @@ -400,14 +400,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi } } -func getParam(_ funcDecl: FunctionDeclSyntax, _ paramIndex: Int) -> FunctionParameterSyntax { +func getParam(_ funcDecl: FunctionParts, _ paramIndex: Int) -> FunctionParameterSyntax { return getParam(funcDecl.signature, paramIndex) } struct FunctionCallBuilder: BoundsCheckedThunkBuilder { - let base: FunctionDeclSyntax + let base: FunctionParts - init(_ function: FunctionDeclSyntax) { + init(_ function: FunctionParts) { base = function } @@ -467,14 +467,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder { FunctionCallExprSyntax( calledExpression: functionRef, leftParen: .leftParenToken(), arguments: LabeledExprListSyntax(labeledArgs), rightParen: .rightParenToken())) - return "unsafe \(call)" + if base.name.tokenKind == .keyword(.`init`) { + return "unsafe self.\(call)" + } else { + return "unsafe \(call)" + } } } struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder { public let base: BoundsCheckedThunkBuilder public let index: Int - public let funcDecl: FunctionDeclSyntax + public let funcDecl: FunctionParts public let typeMappings: [String: String] public let node: SyntaxProtocol public let nonescaping: Bool @@ -525,7 +529,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder { struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder { public let base: BoundsCheckedThunkBuilder - public let funcDecl: FunctionDeclSyntax + public let funcDecl: FunctionParts public let typeMappings: [String: String] public let node: SyntaxProtocol let isParameter: Bool = false @@ -564,7 +568,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder { protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder { var oldType: TypeSyntax { get } var newType: TypeSyntax { get throws } - var funcDecl: FunctionDeclSyntax { get } + var funcDecl: FunctionParts { get } } extension BoundsThunkBuilder { @@ -675,7 +679,7 @@ extension ParamBoundsThunkBuilder { struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder { public let base: BoundsCheckedThunkBuilder public let countExpr: ExprSyntax - public let funcDecl: FunctionDeclSyntax + public let funcDecl: FunctionParts public let nonescaping: Bool public let isSizedBy: Bool public let dependencies: [LifetimeDependence] @@ -743,7 +747,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds public let base: BoundsCheckedThunkBuilder public let index: Int public let countExpr: ExprSyntax - public let funcDecl: FunctionDeclSyntax + public let funcDecl: FunctionParts public let nonescaping: Bool public let isSizedBy: Bool let isParameter: Bool = true @@ -1237,22 +1241,22 @@ func parseMacroParam( } } -func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws { +func checkArgs(_ args: [ParamInfo], _ funcComponents: FunctionParts) throws { var argByIndex: [Int: ParamInfo] = [:] var ret: ParamInfo? = nil - let paramCount = funcDecl.signature.parameterClause.parameters.count + let paramCount = funcComponents.signature.parameterClause.parameters.count try args.forEach { pointerInfo in switch pointerInfo.pointerIndex { case .param(let i): if i < 1 || i > paramCount { let noteMessage = paramCount > 0 - ? "function \(funcDecl.name) has parameter indices 1..\(paramCount)" - : "function \(funcDecl.name) has no parameters" + ? "function \(funcComponents.name) has parameter indices 1..\(paramCount)" + : "function \(funcComponents.name) has no parameters" throw DiagnosticError( "pointer index out of bounds", node: pointerInfo.original, notes: [ - Note(node: Syntax(funcDecl.name), message: MacroExpansionNoteMessage(noteMessage)) + Note(node: Syntax(funcComponents.name), message: MacroExpansionNoteMessage(noteMessage)) ]) } if argByIndex[i] != nil { @@ -1316,7 +1320,7 @@ func isInout(_ type: TypeSyntax) -> Bool { } func getReturnLifetimeAttribute( - _ funcDecl: FunctionDeclSyntax, + _ funcDecl: FunctionParts, _ dependencies: [SwiftifyExpr: [LifetimeDependence]] ) -> [AttributeListSyntax.Element] { let returnDependencies = dependencies[.`return`, default: []] @@ -1473,9 +1477,9 @@ class CountExprRewriter: SyntaxRewriter { } } -func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDeclSyntax, CountExprRewriter) { - let params = funcDecl.signature.parameterClause.parameters - let funcName = funcDecl.name.withoutBackticks.trimmed.text +func renameParameterNamesIfNeeded(_ funcComponents: FunctionParts) -> (FunctionParts, CountExprRewriter) { + let params = funcComponents.signature.parameterClause.parameters + let funcName = funcComponents.name.withoutBackticks.trimmed.text let shouldRename = params.contains(where: { param in let paramName = param.name.trimmed.text return paramName == "_" || paramName == funcName || "`\(paramName)`" == funcName @@ -1499,13 +1503,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe } return newParam } - let newDecl = if renamedParams.count > 0 { - funcDecl.with(\.signature.parameterClause.parameters, FunctionParameterListSyntax(newParams)) + let newSig = if renamedParams.count > 0 { + funcComponents.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams)) } else { // Keeps source locations for diagnostics, in the common case where nothing was renamed - funcDecl + funcComponents.signature + } + return (FunctionParts(signature: newSig, name: funcComponents.name, attributes: funcComponents.attributes), + CountExprRewriter(renamedParams)) +} + +struct FunctionParts { + let signature: FunctionSignatureSyntax + let name: TokenSyntax + let attributes: AttributeListSyntax +} + +func deconstructFunction(_ declaration: some DeclSyntaxProtocol) throws -> FunctionParts { + if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) { + return FunctionParts(signature: origFuncDecl.signature, name: origFuncDecl.name, + attributes: origFuncDecl.attributes) + } + if let origInitDecl = declaration.as(InitializerDeclSyntax.self) { + return FunctionParts(signature: origInitDecl.signature, name: origInitDecl.initKeyword, + attributes: origInitDecl.attributes) } - return (newDecl, CountExprRewriter(renamedParams)) + throw DiagnosticError("@_SwiftifyImport only works on functions and initializers", node: declaration) } /// A macro that adds safe(r) wrappers for functions with unsafe pointer types. @@ -1521,10 +1544,8 @@ public struct SwiftifyImportMacro: PeerMacro { in context: some MacroExpansionContext ) throws -> [DeclSyntax] { do { - guard let origFuncDecl = declaration.as(FunctionDeclSyntax.self) else { - throw DiagnosticError("@_SwiftifyImport only works on functions", node: declaration) - } - let (funcDecl, rewriter) = renameParameterNamesIfNeeded(origFuncDecl) + let origFuncComponents = try deconstructFunction(declaration) + let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents) let argumentList = node.arguments!.as(LabeledExprListSyntax.self)! var arguments = [LabeledExprSyntax](argumentList) @@ -1540,10 +1561,10 @@ public struct SwiftifyImportMacro: PeerMacro { var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:] var parsedArgs = try arguments.compactMap { try parseMacroParam( - $0, funcDecl.signature, rewriter, nonescapingPointers: &nonescapingPointers, + $0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers, lifetimeDependencies: &lifetimeDependencies) } - parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcDecl.signature, typeMappings)) + parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings)) setNonescapingPointers(&parsedArgs, nonescapingPointers) setLifetimeDependencies(&parsedArgs, lifetimeDependencies) // We only transform non-escaping spans. @@ -1554,7 +1575,7 @@ public struct SwiftifyImportMacro: PeerMacro { return true } } - try checkArgs(parsedArgs, funcDecl) + try checkArgs(parsedArgs, funcComponents) parsedArgs.sort { a, b in // make sure return value cast to Span happens last so that withUnsafeBufferPointer // doesn't return a ~Escapable type @@ -1566,12 +1587,12 @@ public struct SwiftifyImportMacro: PeerMacro { } return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex) } - let baseBuilder = FunctionCallBuilder(funcDecl) + let baseBuilder = FunctionCallBuilder(funcComponents) let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce( baseBuilder, { (prev, parsedArg) in - parsedArg.getBoundsCheckedThunkBuilder(prev, funcDecl) + parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents) }) let newSignature = try builder.buildFunctionSignature([:], nil) var eliminatedArgs = Set() @@ -1580,15 +1601,22 @@ public struct SwiftifyImportMacro: PeerMacro { let checks = (basicChecks + compoundChecks).map { e in CodeBlockItemSyntax(leadingTrivia: "\n", item: e) } - let call = CodeBlockItemSyntax( - item: CodeBlockItemSyntax.Item( - ReturnStmtSyntax( - returnKeyword: .keyword(.return, trailingTrivia: " "), - expression: try builder.buildFunctionCall([:])))) + var call : CodeBlockItemSyntax + if declaration.is(InitializerDeclSyntax.self) { + call = CodeBlockItemSyntax( + item: CodeBlockItemSyntax.Item( + try builder.buildFunctionCall([:]))) + } else { + call = CodeBlockItemSyntax( + item: CodeBlockItemSyntax.Item( + ReturnStmtSyntax( + returnKeyword: .keyword(.return, trailingTrivia: " "), + expression: try builder.buildFunctionCall([:])))) + } let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call])) - let returnLifetimeAttribute = getReturnLifetimeAttribute(funcDecl, lifetimeDependencies) + let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies) let lifetimeAttrs = - returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes) + returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes) let availabilityAttr = try getAvailability(newSignature, spanAvailability) let disfavoredOverload: [AttributeListSyntax.Element] = [ @@ -1597,13 +1625,7 @@ public struct SwiftifyImportMacro: PeerMacro { atSign: .atSignToken(), attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload"))) ] - let newFunc = - funcDecl - .with(\.signature, newSignature) - .with(\.body, body) - .with( - \.attributes, - funcDecl.attributes.filter { e in + let attributes = funcComponents.attributes.filter { e in switch e { case .attribute(let attr): // don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient @@ -1619,9 +1641,23 @@ public struct SwiftifyImportMacro: PeerMacro { ] + availabilityAttr + lifetimeAttrs - + disfavoredOverload) - .with(\.leadingTrivia, node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n")) - return [DeclSyntax(newFunc)] + + disfavoredOverload + let trivia = node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n") + if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) { + return [DeclSyntax(origFuncDecl + .with(\.signature, newSignature) + .with(\.body, body) + .with(\.attributes, AttributeListSyntax(attributes)) + .with(\.leadingTrivia, trivia))] + } + if let origInitDecl = declaration.as(InitializerDeclSyntax.self) { + return [DeclSyntax(origInitDecl + .with(\.signature, newSignature) + .with(\.body, body) + .with(\.attributes, AttributeListSyntax(attributes)) + .with(\.leadingTrivia, trivia))] + } + return [] } catch let error as DiagnosticError { context.diagnose( Diagnostic( @@ -1686,6 +1722,9 @@ extension FunctionParameterSyntax { extension TokenSyntax { public var withoutBackticks: TokenSyntax { + if self.identifier == nil { + return self + } return .identifier(self.identifier!.name) } diff --git a/test/Interop/Cxx/swiftify-import/counted-by-method.swift b/test/Interop/Cxx/swiftify-import/counted-by-method.swift index c15cd8564bc88..625b225ee30b9 100644 --- a/test/Interop/Cxx/swiftify-import/counted-by-method.swift +++ b/test/Interop/Cxx/swiftify-import/counted-by-method.swift @@ -27,4 +27,4 @@ import Method func test(s: UnsafeMutableBufferPointer) { var foo = Foo() foo.bar(s) -} +} diff --git a/test/Interop/Cxx/swiftify-import/span-in-ctor.swift b/test/Interop/Cxx/swiftify-import/span-in-ctor.swift new file mode 100644 index 0000000000000..7e824cd0ea9ec --- /dev/null +++ b/test/Interop/Cxx/swiftify-import/span-in-ctor.swift @@ -0,0 +1,40 @@ +// REQUIRES: swift_feature_SafeInteropWrappers + +// FIXME swift-ci linux tests do not support std::span +// UNSUPPORTED: OS=linux-gnu, OS=linux-android, OS=linux-androideabi + +// RUN: rm -rf %t +// RUN: split-file %s %t +// RUN: %target-swift-frontend -c -plugin-path %swift-plugin-dir -I %t/Inputs -Xcc -std=c++20 -cxx-interoperability-mode=default -enable-experimental-feature SafeInteropWrappers %t/method.swift -dump-macro-expansions -verify 2>&1 | %FileCheck %s + +// CHECK: @_alwaysEmitIntoClient +// CHECK: public init(_ sp: Span) { +// CHECK: unsafe self.init(IntSpan(sp)) +// CHECK: } + + +//--- Inputs/module.modulemap +module Method { + header "method.h" + requires cplusplus +} + +//--- Inputs/method.h + +#include + +using IntSpan = std::span; + +class Foo { +public: + Foo(); + Foo(IntSpan sp [[clang::noescape]]); +}; + +//--- method.swift +import CxxStdlib +import Method + +func test(s: Span) { + var _ = Foo(s) +}