@@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible {
40
40
var dependencies : [ LifetimeDependence ] { get set }
41
41
42
42
func getBoundsCheckedThunkBuilder(
43
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
43
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
44
44
) -> BoundsCheckedThunkBuilder
45
45
}
46
46
47
- func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
47
+ func tryGetParamName( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
48
48
switch expr {
49
49
case . param( let i) :
50
50
let funcParam = getParam ( funcDecl, i - 1 )
@@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To
55
55
}
56
56
}
57
57
58
- func getSwiftifyExprType( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TypeSyntax {
58
+ func getSwiftifyExprType( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TypeSyntax {
59
59
switch expr {
60
60
case . param( let i) :
61
61
let funcParam = getParam ( funcDecl, i - 1 )
@@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo {
79
79
}
80
80
81
81
func getBoundsCheckedThunkBuilder(
82
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
82
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
83
83
) -> BoundsCheckedThunkBuilder {
84
84
switch pointerIndex {
85
85
case . param( let i) :
@@ -115,7 +115,7 @@ struct CountedBy: ParamInfo {
115
115
}
116
116
117
117
func getBoundsCheckedThunkBuilder(
118
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
118
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
119
119
) -> BoundsCheckedThunkBuilder {
120
120
switch pointerIndex {
121
121
case . param( let i) :
@@ -424,14 +424,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
424
424
}
425
425
}
426
426
427
- func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
427
+ func getParam( _ funcDecl: FunctionParts , _ paramIndex: Int ) -> FunctionParameterSyntax {
428
428
return getParam ( funcDecl. signature, paramIndex)
429
429
}
430
430
431
431
struct FunctionCallBuilder : BoundsCheckedThunkBuilder {
432
- let base : FunctionDeclSyntax
432
+ let base : FunctionParts
433
433
434
- init ( _ function: FunctionDeclSyntax ) {
434
+ init ( _ function: FunctionParts ) {
435
435
base = function
436
436
}
437
437
@@ -491,14 +491,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
491
491
FunctionCallExprSyntax (
492
492
calledExpression: functionRef, leftParen: . leftParenToken( ) ,
493
493
arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
494
- return " unsafe \( call) "
494
+ if base. name. tokenKind == . keyword( . `init`) {
495
+ return " unsafe self. \( call) "
496
+ } else {
497
+ return " unsafe \( call) "
498
+ }
495
499
}
496
500
}
497
501
498
502
struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
499
503
public let base : BoundsCheckedThunkBuilder
500
504
public let index : Int
501
- public let funcDecl : FunctionDeclSyntax
505
+ public let funcDecl : FunctionParts
502
506
public let typeMappings : [ String : String ]
503
507
public let node : SyntaxProtocol
504
508
public let nonescaping : Bool
@@ -549,7 +553,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
549
553
550
554
struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
551
555
public let base : BoundsCheckedThunkBuilder
552
- public let funcDecl : FunctionDeclSyntax
556
+ public let funcDecl : FunctionParts
553
557
public let typeMappings : [ String : String ]
554
558
public let node : SyntaxProtocol
555
559
let isParameter : Bool = false
@@ -588,7 +592,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
588
592
protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
589
593
var oldType : TypeSyntax { get }
590
594
var newType : TypeSyntax { get throws }
591
- var funcDecl : FunctionDeclSyntax { get }
595
+ var funcDecl : FunctionParts { get }
592
596
}
593
597
594
598
extension BoundsThunkBuilder {
@@ -700,7 +704,7 @@ extension ParamBoundsThunkBuilder {
700
704
struct CountedOrSizedReturnPointerThunkBuilder : PointerBoundsThunkBuilder {
701
705
public let base : BoundsCheckedThunkBuilder
702
706
public let countExpr : ExprSyntax
703
- public let funcDecl : FunctionDeclSyntax
707
+ public let funcDecl : FunctionParts
704
708
public let nonescaping : Bool
705
709
public let isSizedBy : Bool
706
710
public let dependencies : [ LifetimeDependence ]
@@ -768,7 +772,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
768
772
public let base : BoundsCheckedThunkBuilder
769
773
public let index : Int
770
774
public let countExpr : ExprSyntax
771
- public let funcDecl : FunctionDeclSyntax
775
+ public let funcDecl : FunctionParts
772
776
public let nonescaping : Bool
773
777
public let isSizedBy : Bool
774
778
let isParameter : Bool = true
@@ -1267,22 +1271,22 @@ func parseMacroParam(
1267
1271
}
1268
1272
}
1269
1273
1270
- func checkArgs( _ args: [ ParamInfo ] , _ funcDecl : FunctionDeclSyntax ) throws {
1274
+ func checkArgs( _ args: [ ParamInfo ] , _ funcComponents : FunctionParts ) throws {
1271
1275
var argByIndex : [ Int : ParamInfo ] = [ : ]
1272
1276
var ret : ParamInfo ? = nil
1273
- let paramCount = funcDecl . signature. parameterClause. parameters. count
1277
+ let paramCount = funcComponents . signature. parameterClause. parameters. count
1274
1278
try args. forEach { pointerInfo in
1275
1279
switch pointerInfo. pointerIndex {
1276
1280
case . param( let i) :
1277
1281
if i < 1 || i > paramCount {
1278
1282
let noteMessage =
1279
1283
paramCount > 0
1280
- ? " function \( funcDecl . name) has parameter indices 1.. \( paramCount) "
1281
- : " function \( funcDecl . name) has no parameters "
1284
+ ? " function \( funcComponents . name) has parameter indices 1.. \( paramCount) "
1285
+ : " function \( funcComponents . name) has no parameters "
1282
1286
throw DiagnosticError (
1283
1287
" pointer index out of bounds " , node: pointerInfo. original,
1284
1288
notes: [
1285
- Note ( node: Syntax ( funcDecl . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1289
+ Note ( node: Syntax ( funcComponents . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1286
1290
] )
1287
1291
}
1288
1292
if argByIndex [ i] != nil {
@@ -1346,7 +1350,7 @@ func isInout(_ type: TypeSyntax) -> Bool {
1346
1350
}
1347
1351
1348
1352
func getReturnLifetimeAttribute(
1349
- _ funcDecl: FunctionDeclSyntax ,
1353
+ _ funcDecl: FunctionParts ,
1350
1354
_ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ]
1351
1355
) -> [ AttributeListSyntax . Element ] {
1352
1356
let returnDependencies = dependencies [ . `return`, default: [ ] ]
@@ -1503,9 +1507,9 @@ class CountExprRewriter: SyntaxRewriter {
1503
1507
}
1504
1508
}
1505
1509
1506
- func renameParameterNamesIfNeeded( _ funcDecl : FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1507
- let params = funcDecl . signature. parameterClause. parameters
1508
- let funcName = funcDecl . name. withoutBackticks. trimmed. text
1510
+ func renameParameterNamesIfNeeded( _ funcComponents : FunctionParts ) -> ( FunctionParts , CountExprRewriter ) {
1511
+ let params = funcComponents . signature. parameterClause. parameters
1512
+ let funcName = funcComponents . name. withoutBackticks. trimmed. text
1509
1513
let shouldRename = params. contains ( where: { param in
1510
1514
let paramName = param. name. trimmed. text
1511
1515
return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
@@ -1529,13 +1533,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe
1529
1533
}
1530
1534
return newParam
1531
1535
}
1532
- let newDecl = if renamedParams. count > 0 {
1533
- funcDecl . with ( \. signature . parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1536
+ let newSig = if renamedParams. count > 0 {
1537
+ funcComponents . signature . with ( \. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1534
1538
} else {
1535
1539
// Keeps source locations for diagnostics, in the common case where nothing was renamed
1536
- funcDecl
1540
+ funcComponents. signature
1541
+ }
1542
+ return ( FunctionParts ( signature: newSig, name: funcComponents. name, attributes: funcComponents. attributes) ,
1543
+ CountExprRewriter ( renamedParams) )
1544
+ }
1545
+
1546
+ struct FunctionParts {
1547
+ let signature : FunctionSignatureSyntax
1548
+ let name : TokenSyntax
1549
+ let attributes : AttributeListSyntax
1550
+ }
1551
+
1552
+ func deconstructFunction( _ declaration: some DeclSyntaxProtocol ) throws -> FunctionParts {
1553
+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1554
+ return FunctionParts ( signature: origFuncDecl. signature, name: origFuncDecl. name,
1555
+ attributes: origFuncDecl. attributes)
1556
+ }
1557
+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1558
+ return FunctionParts ( signature: origInitDecl. signature, name: origInitDecl. initKeyword,
1559
+ attributes: origInitDecl. attributes)
1537
1560
}
1538
- return ( newDecl , CountExprRewriter ( renamedParams ) )
1561
+ throw DiagnosticError ( " @_SwiftifyImport only works on functions and initializers " , node : declaration )
1539
1562
}
1540
1563
1541
1564
/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
@@ -1551,10 +1574,8 @@ public struct SwiftifyImportMacro: PeerMacro {
1551
1574
in context: some MacroExpansionContext
1552
1575
) throws -> [ DeclSyntax ] {
1553
1576
do {
1554
- guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
1555
- throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
1556
- }
1557
- let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
1577
+ let origFuncComponents = try deconstructFunction ( declaration)
1578
+ let ( funcComponents, rewriter) = renameParameterNamesIfNeeded ( origFuncComponents)
1558
1579
1559
1580
let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
1560
1581
var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1570,10 +1591,10 @@ public struct SwiftifyImportMacro: PeerMacro {
1570
1591
var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
1571
1592
var parsedArgs = try arguments. compactMap {
1572
1593
try parseMacroParam (
1573
- $0, funcDecl . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1594
+ $0, funcComponents . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1574
1595
lifetimeDependencies: & lifetimeDependencies)
1575
1596
}
1576
- parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl . signature, typeMappings) )
1597
+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcComponents . signature, typeMappings) )
1577
1598
setNonescapingPointers ( & parsedArgs, nonescapingPointers)
1578
1599
setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
1579
1600
// We only transform non-escaping spans.
@@ -1584,7 +1605,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1584
1605
return true
1585
1606
}
1586
1607
}
1587
- try checkArgs ( parsedArgs, funcDecl )
1608
+ try checkArgs ( parsedArgs, funcComponents )
1588
1609
parsedArgs. sort { a, b in
1589
1610
// make sure return value cast to Span happens last so that withUnsafeBufferPointer
1590
1611
// doesn't return a ~Escapable type
@@ -1596,12 +1617,12 @@ public struct SwiftifyImportMacro: PeerMacro {
1596
1617
}
1597
1618
return paramOrReturnIndex ( a. pointerIndex) < paramOrReturnIndex ( b. pointerIndex)
1598
1619
}
1599
- let baseBuilder = FunctionCallBuilder ( funcDecl )
1620
+ let baseBuilder = FunctionCallBuilder ( funcComponents )
1600
1621
1601
1622
let builder : BoundsCheckedThunkBuilder = parsedArgs. reduce (
1602
1623
baseBuilder,
1603
1624
{ ( prev, parsedArg) in
1604
- parsedArg. getBoundsCheckedThunkBuilder ( prev, funcDecl )
1625
+ parsedArg. getBoundsCheckedThunkBuilder ( prev, funcComponents )
1605
1626
} )
1606
1627
let newSignature = try builder. buildFunctionSignature ( [ : ] , nil )
1607
1628
var eliminatedArgs = Set < Int > ( )
@@ -1610,15 +1631,22 @@ public struct SwiftifyImportMacro: PeerMacro {
1610
1631
let checks = ( basicChecks + compoundChecks) . map { e in
1611
1632
CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
1612
1633
}
1613
- let call = CodeBlockItemSyntax (
1614
- item: CodeBlockItemSyntax . Item (
1615
- ReturnStmtSyntax (
1616
- returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1617
- expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1634
+ var call : CodeBlockItemSyntax
1635
+ if declaration. is ( InitializerDeclSyntax . self) {
1636
+ call = CodeBlockItemSyntax (
1637
+ item: CodeBlockItemSyntax . Item (
1638
+ try builder. buildFunctionCall ( [ : ] ) ) )
1639
+ } else {
1640
+ call = CodeBlockItemSyntax (
1641
+ item: CodeBlockItemSyntax . Item (
1642
+ ReturnStmtSyntax (
1643
+ returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1644
+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1645
+ }
1618
1646
let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1619
- let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl , lifetimeDependencies)
1647
+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcComponents , lifetimeDependencies)
1620
1648
let lifetimeAttrs =
1621
- returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl . attributes)
1649
+ returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcComponents . attributes)
1622
1650
let availabilityAttr = try getAvailability ( newSignature, spanAvailability)
1623
1651
let disfavoredOverload : [ AttributeListSyntax . Element ] =
1624
1652
[
@@ -1627,13 +1655,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1627
1655
atSign: . atSignToken( ) ,
1628
1656
attributeName: IdentifierTypeSyntax ( name: " _disfavoredOverload " ) ) )
1629
1657
]
1630
- let newFunc =
1631
- funcDecl
1632
- . with ( \. signature, newSignature)
1633
- . with ( \. body, body)
1634
- . with (
1635
- \. attributes,
1636
- funcDecl. attributes. filter { e in
1658
+ let attributes = funcComponents. attributes. filter { e in
1637
1659
switch e {
1638
1660
case . attribute( let attr) :
1639
1661
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
@@ -1649,9 +1671,23 @@ public struct SwiftifyImportMacro: PeerMacro {
1649
1671
]
1650
1672
+ availabilityAttr
1651
1673
+ lifetimeAttrs
1652
- + disfavoredOverload)
1653
- . with ( \. leadingTrivia, node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " ) )
1654
- return [ DeclSyntax ( newFunc) ]
1674
+ + disfavoredOverload
1675
+ let trivia = node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " )
1676
+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1677
+ return [ DeclSyntax ( origFuncDecl
1678
+ . with ( \. signature, newSignature)
1679
+ . with ( \. body, body)
1680
+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1681
+ . with ( \. leadingTrivia, trivia) ) ]
1682
+ }
1683
+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1684
+ return [ DeclSyntax ( origInitDecl
1685
+ . with ( \. signature, newSignature)
1686
+ . with ( \. body, body)
1687
+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1688
+ . with ( \. leadingTrivia, trivia) ) ]
1689
+ }
1690
+ return [ ]
1655
1691
} catch let error as DiagnosticError {
1656
1692
context. diagnose (
1657
1693
Diagnostic (
@@ -1716,6 +1752,9 @@ extension FunctionParameterSyntax {
1716
1752
1717
1753
extension TokenSyntax {
1718
1754
public var withoutBackticks : TokenSyntax {
1755
+ if self . identifier == nil {
1756
+ return self
1757
+ }
1719
1758
return . identifier( self . identifier!. name)
1720
1759
}
1721
1760
0 commit comments