Skip to content

[ASTGen] Several improvements to generalize node handling #69894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 62 additions & 127 deletions lib/ASTGen/Sources/ASTGen/ASTGen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ enum ASTNode {
case decl(BridgedDecl)
case stmt(BridgedStmt)
case expr(BridgedExpr)
case type(BridgedTypeRepr)

var castToExpr: BridgedExpr {
guard case .expr(let bridged) = self else {
Expand All @@ -51,13 +50,6 @@ enum ASTNode {
return bridged
}

var castToType: BridgedTypeRepr {
guard case .type(let bridged) = self else {
fatalError("Expected a type")
}
return bridged
}

var bridged: BridgedASTNode {
switch self {
case .expr(let e):
Expand All @@ -66,21 +58,6 @@ enum ASTNode {
return BridgedASTNode(raw: s.raw, kind: .stmt)
case .decl(let d):
return BridgedASTNode(raw: d.raw, kind: .decl)
default:
fatalError("Must be expr, stmt, or decl.")
}
}

var raw: UnsafeMutableRawPointer {
switch self {
case .expr(let e):
return e.raw
case .stmt(let s):
return s.raw
case .decl(let d):
return d.raw
case .type(let t):
return t.raw
}
}
}
Expand All @@ -97,8 +74,6 @@ class Boxed<Value> {
}

struct ASTGenVisitor {
typealias ResultType = ASTNode

fileprivate let diagnosticEngine: BridgedDiagnosticEngine

let base: UnsafeBufferPointer<UInt8>
Expand Down Expand Up @@ -187,47 +162,9 @@ extension ASTGenVisitor {
}
}

extension ASTGenVisitor {
/// Generate ASTNode from a Syntax node. The node must be a decl, stmt, expr, or
/// type.
func generate(_ node: Syntax) -> ASTNode {
if let decl = node.as(DeclSyntax.self) {
return .decl(self.generate(decl: decl))
}
if let stmt = node.as(StmtSyntax.self) {
return .stmt(self.generate(stmt: stmt))
}
if let expr = node.as(ExprSyntax.self) {
return .expr(self.generate(expr: expr))
}
if let type = node.as(TypeSyntax.self) {
return .type(self.generate(type: type))
}

// --- Special cases where `node` doesn't belong to one of the base kinds.

// CodeBlockSyntax -> BraceStmt.
if let node = node.as(CodeBlockSyntax.self) {
return .stmt(self.generate(codeBlock: node).asStmt)
}
// CodeBlockItemSyntax -> ASTNode.
if let node = node.as(CodeBlockItemSyntax.self) {
return self.generate(codeBlockItem: node)
}

fatalError("node does not correspond to an ASTNode \(node.kind)")
}
}

// Misc visits.
// TODO: Some of these are called within a single file/method; we may want to move them to the respective files.
extension ASTGenVisitor {

/// Do NOT introduce another usage of this. Not all choices can produce 'ASTNode'.
func generate(choices node: some SyntaxChildChoices) -> ASTNode {
return self.generate(Syntax(node))
}

public func generate(memberBlockItem node: MemberBlockItemSyntax) -> BridgedDecl {
generate(decl: node.decl)
}
Expand All @@ -237,11 +174,29 @@ extension ASTGenVisitor {
}

public func generate(conditionElement node: ConditionElementSyntax) -> ASTNode {
generate(choices: node.condition)
// FIXME: returning ASTNode is wrong, non-expression conditions are not ASTNode.
switch node.condition {
case .availability(_):
break
case .expression(let node):
return .expr(self.generate(expr: node))
case .matchingPattern(_):
break
case .optionalBinding(_):
break
}
fatalError("unimplemented")
}

public func generate(codeBlockItem node: CodeBlockItemSyntax) -> ASTNode {
generate(choices: node.item)
switch node.item {
case .decl(let node):
return .decl(self.generate(decl: node))
case .stmt(let node):
return .stmt(self.generate(stmt: node))
case .expr(let node):
return .expr(self.generate(expr: node))
}
}

public func generate(arrayElement node: ArrayElementSyntax) -> BridgedExpr {
Expand All @@ -255,79 +210,59 @@ extension ASTGenVisitor {
}

// Forwarding overloads that take optional syntax nodes. These are defined on demand to achieve a consistent
// 'self.visit(<expr>)' recursion pattern between optional and non-optional inputs.
// 'self.generate(foo: FooSyntax)' recursion pattern between optional and non-optional inputs.
extension ASTGenVisitor {
@inline(__always)
func generate(optional node: TypeSyntax?) -> BridgedTypeRepr? {
guard let node else {
return nil
}

return self.generate(type: node)
func generate(type node: TypeSyntax?) -> BridgedNullableTypeRepr {
self.map(node, generate(type:))
}

@inline(__always)
func generate(optional node: ExprSyntax?) -> BridgedExpr? {
guard let node else {
return nil
}

return self.generate(expr: node)
func generate(expr node: ExprSyntax?) -> BridgedNullableExpr {
self.map(node, generate(expr:))
}

/// DO NOT introduce another usage of this. Not all choices can produce 'ASTNode'.
@inline(__always)
func generate(optional node: (some SyntaxChildChoices)?) -> ASTNode? {
guard let node else {
return nil
}

return self.generate(choices: node)
func generate(genericParameterClause node: GenericParameterClauseSyntax?) -> BridgedNullableGenericParamList {
self.map(node, generate(genericParameterClause:))
}

@inline(__always)
func generate(optional node: GenericParameterClauseSyntax?) -> BridgedGenericParamList? {
guard let node else {
return nil
}

return self.generate(genericParameterClause: node)
func generate(genericWhereClause node: GenericWhereClauseSyntax?) -> BridgedNullableTrailingWhereClause {
self.map(node, generate(genericWhereClause:))
}

@inline(__always)
func generate(optional node: GenericWhereClauseSyntax?) -> BridgedTrailingWhereClause? {
guard let node else {
return nil
}

return self.generate(genericWhereClause: node)
func generate(enumCaseParameterClause node: EnumCaseParameterClauseSyntax?) -> BridgedNullableParameterList {
self.map(node, generate(enumCaseParameterClause:))
}

@inline(__always)
func generate(optional node: EnumCaseParameterClauseSyntax?) -> BridgedParameterList? {
guard let node else {
return nil
}

return self.generate(enumCaseParameterClause: node)
func generate(inheritedTypeList node: InheritedTypeListSyntax?) -> BridgedArrayRef {
self.map(node, generate(inheritedTypeList:))
}

@inline(__always)
func generate(optional node: InheritedTypeListSyntax?) -> BridgedArrayRef {
guard let node else {
return .init()
}

return self.generate(inheritedTypeList: node)
func generate(precedenceGroupNameList node: PrecedenceGroupNameListSyntax?) -> BridgedArrayRef {
self.map(node, generate(precedenceGroupNameList:))
}

// Helper function for `generate(foo: FooSyntax?)` methods.
@inline(__always)
func generate(optional node: PrecedenceGroupNameListSyntax?) -> BridgedArrayRef {
guard let node else {
return .init()
}
private func map<Node: SyntaxProtocol, Result: HasNullable>(
_ node: Node?,
_ body: (Node) -> Result
) -> Result.Nullable {
return Result.asNullable(node.map(body))
}

return self.generate(precedenceGroupNameList: node)
// Helper function for `generate(barList: BarListSyntax?)` methods for collection nodes.
@inline(__always)
private func map<Node: SyntaxCollection>(
_ node: Node?,
_ body: (Node) -> BridgedArrayRef
) -> BridgedArrayRef {
return node.map(body) ?? .init()
}
}

Expand Down Expand Up @@ -422,16 +357,16 @@ public func buildTopLevelASTNodes(

/// Generate an AST node at the given source location. Returns the generated
/// ASTNode and mutate the pointee of `endLocPtr` to the end of the node.
private func _build<Node: SyntaxProtocol>(
kind: Node.Type,
private func _build<Node: SyntaxProtocol, Result>(
generator: (ASTGenVisitor) -> (Node) -> Result,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would name this generateFunction. generator sounds to much like an object to me. Also, I think a doc comment for this parameter wouldn’t hurt.

diagEngine: BridgedDiagnosticEngine,
sourceFilePtr: UnsafeRawPointer,
sourceLoc: BridgedSourceLoc,
declContext: BridgedDeclContext,
astContext: BridgedASTContext,
legacyParser: BridgedLegacyParser,
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
) -> UnsafeMutableRawPointer? {
) -> Result? {
let sourceFile = sourceFilePtr.assumingMemoryBound(to: ExportedSourceFile.self)

// Find the type syntax node.
Expand All @@ -452,13 +387,13 @@ private func _build<Node: SyntaxProtocol>(
endLocPtr.pointee = sourceLoc.advanced(by: node.totalLength.utf8Length)

// Convert the syntax node.
return ASTGenVisitor(
return generator(ASTGenVisitor(
diagnosticEngine: diagEngine,
sourceBuffer: sourceFile.pointee.buffer,
declContext: declContext,
astContext: astContext,
legacyParser: legacyParser
).generate(Syntax(node)).raw
))(node)
}

@_cdecl("swift_ASTGen_buildTypeRepr")
Expand All @@ -473,15 +408,15 @@ func buildTypeRepr(
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
) -> UnsafeMutableRawPointer? {
return _build(
kind: TypeSyntax.self,
generator: ASTGenVisitor.generate(type:),
diagEngine: diagEngine,
sourceFilePtr: sourceFilePtr,
sourceLoc: sourceLoc,
declContext: declContext,
astContext: astContext,
legacyParser: legacyParser,
endLocPtr: endLocPtr
)
)?.raw
}

@_cdecl("swift_ASTGen_buildDecl")
Expand All @@ -496,15 +431,15 @@ func buildDecl(
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
) -> UnsafeMutableRawPointer? {
return _build(
kind: DeclSyntax.self,
generator: ASTGenVisitor.generate(decl:),
diagEngine: diagEngine,
sourceFilePtr: sourceFilePtr,
sourceLoc: sourceLoc,
declContext: declContext,
astContext: astContext,
legacyParser: legacyParser,
endLocPtr: endLocPtr
)
)?.raw
}

@_cdecl("swift_ASTGen_buildExpr")
Expand All @@ -519,15 +454,15 @@ func buildExpr(
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
) -> UnsafeMutableRawPointer? {
return _build(
kind: ExprSyntax.self,
generator: ASTGenVisitor.generate(expr:),
diagEngine: diagEngine,
sourceFilePtr: sourceFilePtr,
sourceLoc: sourceLoc,
declContext: declContext,
astContext: astContext,
legacyParser: legacyParser,
endLocPtr: endLocPtr
)
)?.raw
}

@_cdecl("swift_ASTGen_buildStmt")
Expand All @@ -542,13 +477,13 @@ func buildStmt(
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
) -> UnsafeMutableRawPointer? {
return _build(
kind: StmtSyntax.self,
generator: ASTGenVisitor.generate(stmt:),
diagEngine: diagEngine,
sourceFilePtr: sourceFilePtr,
sourceLoc: sourceLoc,
declContext: declContext,
astContext: astContext,
legacyParser: legacyParser,
endLocPtr: endLocPtr
)
)?.raw
}
Loading