Skip to content

fix: Simplify macro statement expansion handling #12668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions crates/hir-def/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use la_arena::{Arena, ArenaMap};
use limit::Limit;
use profile::Count;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use syntax::{ast, AstNode, AstPtr, SyntaxNodePtr};

use crate::{
Expand Down Expand Up @@ -294,10 +293,6 @@ pub struct BodySourceMap {
field_map: FxHashMap<InFile<AstPtr<ast::RecordExprField>>, ExprId>,
field_map_back: FxHashMap<ExprId, InFile<AstPtr<ast::RecordExprField>>>,

/// Maps a macro call to its lowered expressions, a single one if it expands to an expression,
/// or multiple if it expands to MacroStmts.
macro_call_to_exprs: FxHashMap<InFile<AstPtr<ast::MacroCall>>, SmallVec<[ExprId; 1]>>,

expansions: FxHashMap<InFile<AstPtr<ast::MacroCall>>, HirFileId>,

/// Diagnostics accumulated during body lowering. These contain `AstPtr`s and so are stored in
Expand Down Expand Up @@ -466,9 +461,9 @@ impl BodySourceMap {
self.field_map.get(&src).cloned()
}

pub fn macro_expansion_expr(&self, node: InFile<&ast::MacroCall>) -> Option<&[ExprId]> {
let src = node.map(AstPtr::new);
self.macro_call_to_exprs.get(&src).map(|it| &**it)
pub fn macro_expansion_expr(&self, node: InFile<&ast::MacroExpr>) -> Option<ExprId> {
let src = node.map(AstPtr::new).map(AstPtr::upcast::<ast::MacroExpr>).map(AstPtr::upcast);
self.expr_map.get(&src).copied()
}

/// Get a reference to the body source map's diagnostics.
Expand Down
154 changes: 72 additions & 82 deletions crates/hir-def/src/body/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use hir_expand::{
use la_arena::Arena;
use profile::Count;
use rustc_hash::FxHashMap;
use smallvec::smallvec;
use syntax::{
ast::{
self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind,
Expand Down Expand Up @@ -97,7 +96,6 @@ pub(super) fn lower(
or_pats: Default::default(),
},
expander,
statements_in_scope: Vec::new(),
name_to_pat_grouping: Default::default(),
is_lowering_inside_or_pat: false,
}
Expand All @@ -109,7 +107,6 @@ struct ExprCollector<'a> {
expander: Expander,
body: Body,
source_map: BodySourceMap,
statements_in_scope: Vec<Statement>,
// a poor-mans union-find?
name_to_pat_grouping: FxHashMap<Name, Vec<PatId>>,
is_lowering_inside_or_pat: bool,
Expand Down Expand Up @@ -514,27 +511,25 @@ impl ExprCollector<'_> {
ast::Expr::MacroExpr(e) => {
let e = e.macro_call()?;
let macro_ptr = AstPtr::new(&e);
let id = self.collect_macro_call(e, macro_ptr.clone(), true, |this, expansion| {
let id = self.collect_macro_call(e, macro_ptr, true, |this, expansion| {
expansion.map(|it| this.collect_expr(it))
});
match id {
Some(id) => {
self.source_map
.macro_call_to_exprs
.insert(self.expander.to_source(macro_ptr), smallvec![id]);
// Make the macro-call point to its expanded expression so we can query
// semantics on syntax pointers to the macro
let src = self.expander.to_source(syntax_ptr);
self.source_map.expr_map.insert(src, id);
id
}
None => self.alloc_expr(Expr::Missing, syntax_ptr.clone()),
None => self.alloc_expr(Expr::Missing, syntax_ptr),
}
}
ast::Expr::MacroStmts(e) => {
e.statements().for_each(|s| self.collect_stmt(s));
let tail = e
.expr()
.map(|e| self.collect_expr(e))
.unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone()));
let statements = e.statements().filter_map(|s| self.collect_stmt(s)).collect();
let tail = e.expr().map(|e| self.collect_expr(e));

self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr)
self.alloc_expr(Expr::MacroStmts { tail, statements }, syntax_ptr)
}
ast::Expr::UnderscoreExpr(_) => self.alloc_expr(Expr::Underscore, syntax_ptr),
})
Expand Down Expand Up @@ -607,11 +602,11 @@ impl ExprCollector<'_> {
}
}

fn collect_stmt(&mut self, s: ast::Stmt) {
fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Statement> {
match s {
ast::Stmt::LetStmt(stmt) => {
if self.check_cfg(&stmt).is_none() {
return;
return None;
}
let pat = self.collect_pat_opt(stmt.pat());
let type_ref =
Expand All @@ -621,70 +616,61 @@ impl ExprCollector<'_> {
.let_else()
.and_then(|let_else| let_else.block_expr())
.map(|block| self.collect_block(block));
self.statements_in_scope.push(Statement::Let {
pat,
type_ref,
initializer,
else_branch,
});
Some(Statement::Let { pat, type_ref, initializer, else_branch })
}
ast::Stmt::ExprStmt(stmt) => {
if let Some(expr) = stmt.expr() {
if self.check_cfg(&expr).is_none() {
return;
let expr = stmt.expr();
if let Some(expr) = &expr {
if self.check_cfg(expr).is_none() {
return None;
}
}
let has_semi = stmt.semicolon_token().is_some();
// Note that macro could be expended to multiple statements
if let Some(ast::Expr::MacroExpr(e)) = stmt.expr() {
let m = match e.macro_call() {
Some(it) => it,
None => return,
};
let macro_ptr = AstPtr::new(&m);
let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());

let prev_stmt = self.statements_in_scope.len();
self.collect_macro_call(m, macro_ptr.clone(), false, |this, expansion| {
match expansion {
// Note that macro could be expanded to multiple statements
if let Some(expr @ ast::Expr::MacroExpr(mac)) = &expr {
let mac_call = mac.macro_call()?;
let syntax_ptr = AstPtr::new(expr);
let macro_ptr = AstPtr::new(&mac_call);
let stmt = self.collect_macro_call(
mac_call,
macro_ptr,
false,
|this, expansion: Option<ast::MacroStmts>| match expansion {
Some(expansion) => {
let statements: ast::MacroStmts = expansion;

statements.statements().for_each(|stmt| this.collect_stmt(stmt));
if let Some(expr) = statements.expr() {
let expr = this.collect_expr(expr);
this.statements_in_scope
.push(Statement::Expr { expr, has_semi });
}
let statements = expansion
.statements()
.filter_map(|stmt| this.collect_stmt(stmt))
.collect();
let tail = expansion.expr().map(|expr| this.collect_expr(expr));

let mac_stmts = this.alloc_expr(
Expr::MacroStmts { tail, statements },
AstPtr::new(&ast::Expr::MacroStmts(expansion)),
);

Some(mac_stmts)
}
None => {
let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
this.statements_in_scope.push(Statement::Expr { expr, has_semi });
}
}
});
None => None,
},
);

let mut macro_exprs = smallvec![];
for stmt in &self.statements_in_scope[prev_stmt..] {
match *stmt {
Statement::Let { initializer, else_branch, .. } => {
macro_exprs.extend(initializer);
macro_exprs.extend(else_branch);
}
Statement::Expr { expr, .. } => macro_exprs.push(expr),
let expr = match stmt {
Some(expr) => {
// Make the macro-call point to its expanded expression so we can query
// semantics on syntax pointers to the macro
let src = self.expander.to_source(syntax_ptr);
self.source_map.expr_map.insert(src, expr);
expr
}
}
if !macro_exprs.is_empty() {
self.source_map
.macro_call_to_exprs
.insert(self.expander.to_source(macro_ptr), macro_exprs);
}
None => self.alloc_expr(Expr::Missing, syntax_ptr),
};
Some(Statement::Expr { expr, has_semi })
} else {
let expr = self.collect_expr_opt(stmt.expr());
self.statements_in_scope.push(Statement::Expr { expr, has_semi });
let expr = self.collect_expr_opt(expr);
Some(Statement::Expr { expr, has_semi })
}
}
ast::Stmt::Item(_item) => {}
ast::Stmt::Item(_item) => None,
}
}

Expand All @@ -703,25 +689,27 @@ impl ExprCollector<'_> {
};
let prev_def_map = mem::replace(&mut self.expander.def_map, def_map);
let prev_local_module = mem::replace(&mut self.expander.module, module);
let prev_statements = std::mem::take(&mut self.statements_in_scope);

block.statements().for_each(|s| self.collect_stmt(s));
block.tail_expr().and_then(|e| {
let expr = self.maybe_collect_expr(e)?;
self.statements_in_scope.push(Statement::Expr { expr, has_semi: false });
Some(())
let mut statements: Vec<_> =
block.statements().filter_map(|s| self.collect_stmt(s)).collect();
let tail = block.tail_expr().and_then(|e| self.maybe_collect_expr(e));
let tail = tail.or_else(|| {
let stmt = statements.pop()?;
if let Statement::Expr { expr, has_semi: false } = stmt {
return Some(expr);
}
statements.push(stmt);
None
});

let mut tail = None;
if let Some(Statement::Expr { expr, has_semi: false }) = self.statements_in_scope.last() {
tail = Some(*expr);
self.statements_in_scope.pop();
}
let tail = tail;
let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements).into();
let syntax_node_ptr = AstPtr::new(&block.into());
let expr_id = self.alloc_expr(
Expr::Block { id: block_id, statements, tail, label: None },
Expr::Block {
id: block_id,
statements: statements.into_boxed_slice(),
tail,
label: None,
},
syntax_node_ptr,
);

Expand Down Expand Up @@ -903,10 +891,12 @@ impl ExprCollector<'_> {
ast::Pat::MacroPat(mac) => match mac.macro_call() {
Some(call) => {
let macro_ptr = AstPtr::new(&call);
let src = self.expander.to_source(Either::Left(AstPtr::new(&pat)));
let pat =
self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
this.collect_pat_opt_(expanded_pat)
});
self.source_map.pat_map.insert(src, pat);
return pat;
}
None => Pat::Missing,
Expand Down
22 changes: 13 additions & 9 deletions crates/hir-def/src/body/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,27 +145,28 @@ fn compute_block_scopes(
tail: Option<ExprId>,
body: &Body,
scopes: &mut ExprScopes,
mut scope: ScopeId,
scope: &mut ScopeId,
) {
for stmt in statements {
match stmt {
Statement::Let { pat, initializer, else_branch, .. } => {
if let Some(expr) = initializer {
compute_expr_scopes(*expr, body, scopes, &mut scope);
compute_expr_scopes(*expr, body, scopes, scope);
}
if let Some(expr) = else_branch {
compute_expr_scopes(*expr, body, scopes, &mut scope);
compute_expr_scopes(*expr, body, scopes, scope);
}
scope = scopes.new_scope(scope);
scopes.add_bindings(body, scope, *pat);

*scope = scopes.new_scope(*scope);
scopes.add_bindings(body, *scope, *pat);
}
Statement::Expr { expr, .. } => {
compute_expr_scopes(*expr, body, scopes, &mut scope);
compute_expr_scopes(*expr, body, scopes, scope);
}
}
}
if let Some(expr) = tail {
compute_expr_scopes(expr, body, scopes, &mut scope);
compute_expr_scopes(expr, body, scopes, scope);
}
}

Expand All @@ -175,12 +176,15 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope

scopes.set_scope(expr, *scope);
match &body[expr] {
Expr::MacroStmts { statements, tail } => {
compute_block_scopes(statements, *tail, body, scopes, scope);
}
Expr::Block { statements, tail, id, label } => {
let scope = scopes.new_block_scope(*scope, *id, make_label(label));
let mut scope = scopes.new_block_scope(*scope, *id, make_label(label));
// Overwrite the old scope for the block expr, so that every block scope can be found
// via the block itself (important for blocks that only contain items, no expressions).
scopes.set_scope(expr, scope);
compute_block_scopes(statements, *tail, body, scopes, scope);
compute_block_scopes(statements, *tail, body, scopes, &mut scope);
}
Expr::For { iterable, pat, body: body_expr, label } => {
compute_expr_scopes(*iterable, body, scopes, scope);
Expand Down
6 changes: 3 additions & 3 deletions crates/hir-def/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ pub enum Expr {
body: ExprId,
},
MacroStmts {
tail: ExprId,
statements: Box<[Statement]>,
tail: Option<ExprId>,
},
Array(Array),
Literal(Literal),
Expand Down Expand Up @@ -254,7 +255,7 @@ impl Expr {
Expr::Let { expr, .. } => {
f(*expr);
}
Expr::Block { statements, tail, .. } => {
Expr::MacroStmts { tail, statements } | Expr::Block { statements, tail, .. } => {
for stmt in statements.iter() {
match stmt {
Statement::Let { initializer, .. } => {
Expand Down Expand Up @@ -344,7 +345,6 @@ impl Expr {
f(*repeat)
}
},
Expr::MacroStmts { tail } => f(*tail),
Expr::Literal(_) => {}
Expr::Underscore => {}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ impl<'a> InferenceContext<'a> {
None => self.table.new_float_var(),
},
},
Expr::MacroStmts { tail } => self.infer_expr_inner(*tail, expected),
Expr::MacroStmts { tail, statements } => {
self.infer_block(tgt_expr, statements, *tail, expected)
}
Expr::Underscore => {
// Underscore expressions may only appear in assignee expressions,
// which are handled by `infer_assignee_expr()`, so any underscore
Expand Down
Loading