Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 152 additions & 77 deletions compiler/rustc_ast_lowering/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// into:
//
// let __postcond = if contract_checks {
// CONTRACT_DECLARATIONS;
// contract_check_requires(PRECOND);
// Some(|ret_val| POSTCOND)
// let __ensures_builder = || {
// CONTRACT_DECLARATIONS;
// contract_check_requires(|| PRECOND);
// build_check_ensures(Some(|| { POSTCOND_DECLS; |ret_val| POSTCOND }))
// };
// contract_check_requires_and_build_ensures(__ensures_builder)
// } else {
// None
// };
Expand All @@ -50,7 +53,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// }
// }

let precond = self.lower_precond(req);
let lowered_req = self.lower_expr_mut(&req);
let precond = self.lower_precond(lowered_req);
let postcond_checker = self.lower_postcond_checker(ens);

let contract_check = self.lower_contract_check_with_postcond(
Expand All @@ -71,15 +75,18 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// into:
//
// let __postcond = if contract_checks {
// Some(|ret_val| POSTCOND)
// let __ensures_builder = || {
// CONTRACT_DECLARATIONS;
// build_check_ensures(Some(|| { POSTCOND_DECLS; |ret_val| POSTCOND }))
// };
// contract_check_requires_and_build_ensures(__ensures_builder)
// } else {
// None
// };
// {
// let ret = { body };
//
// if contract_checks {
// CONTRACT_DECLARATIONS;
// contract_check_ensures(__postcond, ret)
// } else {
// ret
Expand All @@ -102,13 +109,13 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
//
// {
// if contracts_checks {
// CONTRACT_DECLARATIONS;
// contract_requires(PRECOND);
// contract_requires(|| { CONTRACT_DECLARATIONS; PRECOND });
// }
// body
// }
let precond = self.lower_precond(req);
let precond_check = self.lower_contract_check_just_precond(contract_decls, precond);
let lowered_req = self.lower_expr(&req);
let precond = self.block_decls_with_precond(contract_decls, lowered_req);
let precond_check = self.lower_contract_check_just_precond(precond);

let body = self.arena.alloc(body(self));

Expand Down Expand Up @@ -137,17 +144,17 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}

/// Lower the precondition check intrinsic.
fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
let lowered_req = self.lower_expr_mut(&req);
fn lower_precond(&mut self, req: rustc_hir::Expr<'hir>) -> rustc_hir::Stmt<'hir> {
let req_span = self.mark_span_with_reason(
rustc_span::DesugaringKind::Contract,
lowered_req.span,
req.span,
Some(Arc::clone(&self.allow_contracts)),
);
let req_closure = self.expr_closure(req_span, req);
let precond = self.expr_call_lang_item_fn_mut(
req_span,
rustc_hir::LangItem::ContractCheckRequires,
&*arena_vec![self; lowered_req],
&*arena_vec![self; req_closure],
);
self.stmt_expr(req.span, precond)
}
Expand All @@ -163,33 +170,42 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
Some(Arc::clone(&self.allow_contracts)),
);
let lowered_ens = self.lower_expr_mut(&ens);
let ens_closure = self.expr_closure(ens_span, lowered_ens);
self.expr_call_lang_item_fn(
ens_span,
rustc_hir::LangItem::ContractBuildCheckEnsures,
&*arena_vec![self; lowered_ens],
&*arena_vec![self; ens_closure],
)
}

fn block_decls_with_precond(
&mut self,
contract_decls: &'hir [rustc_hir::Stmt<'_>],
lowered_req: &'hir rustc_hir::Expr<'_>,
) -> rustc_hir::Stmt<'hir> {
let req_span = span_of_stmts(contract_decls, lowered_req.span);

let precond_stmts = self.block_all(req_span, contract_decls, Some(lowered_req));
let precond_stmts = self.expr_block(precond_stmts);
self.lower_precond(precond_stmts)
}

fn lower_contract_check_just_precond(
&mut self,
contract_decls: &'hir [rustc_hir::Stmt<'hir>],
precond: rustc_hir::Stmt<'hir>,
) -> rustc_hir::Stmt<'hir> {
let stmts = self
.arena
.alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain([precond].into_iter()));

let then_block_stmts = self.block_all(precond.span, stmts, None);
let span = precond.span;
let then_block_stmts = self.block_all(span, &*arena_vec![self; precond], None);
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));

let precond_check = rustc_hir::ExprKind::If(
self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
then_block,
None,
);

let precond_check = self.expr(precond.span, precond_check);
self.stmt_expr(precond.span, precond_check)
let precond_check = self.expr(span, precond_check);
self.stmt_expr(span, precond_check)
}

fn lower_contract_check_with_postcond(
Expand All @@ -201,26 +217,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let stmts = self
.arena
.alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain(precond.into_iter()));
let span = match precond {
Some(precond) => precond.span,
None => postcond_checker.span,
};

let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
postcond_checker.span,
rustc_hir::lang_items::LangItem::OptionSome,
&*arena_vec![self; *postcond_checker],
));
let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));

let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
postcond_checker.span,
rustc_hir::lang_items::LangItem::OptionNone,
Default::default(),
));
let else_block = self.block_expr(none_expr);
let else_block = self.arena.alloc(self.expr_block(else_block));
let span = self.contract_check_with_postcond_span(stmts, postcond_checker);

let then_block = self.contract_check_with_postcond_block(stmts, postcond_checker, span);
let else_block = self.option_none_block(span);

let contract_check = rustc_hir::ExprKind::If(
self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
Expand All @@ -230,32 +231,89 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.arena.alloc(self.expr(span, contract_check))
}

fn contract_check_with_postcond_span(
&mut self,
stmts: &mut [rustc_hir::Stmt<'hir>],
postcond_checker: &rustc_hir::Expr<'_>,
) -> rustc_span::Span {
// For error diagnostics, span is set to decls + precondition, because
// those will determine the well-typedness of the __ensures_builder
// closure. postcond_checker is already type-checked as part of the
// call to build_check_ensures.
let span =
span_of_stmts(stmts, stmts.last().map(|s| s.span).unwrap_or(postcond_checker.span));
self.mark_span_with_reason(
rustc_span::DesugaringKind::Contract,
span,
Some(Arc::clone(&self.allow_contracts)),
)
}

fn contract_check_with_postcond_block(
&mut self,
stmts: &'hir mut [rustc_hir::Stmt<'hir>],
postcond_checker: &'hir rustc_hir::Expr<'_>,
span: rustc_span::Span,
) -> &'hir mut rustc_hir::Expr<'hir> {
let (builder_decl, builder_ident_expr) =
self.contract_check_with_postcond_builder(stmts, postcond_checker, span);

let build_postcond_call = self.expr_call_lang_item_fn(
span,
rustc_hir::LangItem::ContractCheckRequiresAndBuildEnsures,
&*arena_vec![self; *builder_ident_expr],
);
let block_stmts =
self.block_all(span, arena_vec![self; builder_decl], Some(build_postcond_call));
self.arena.alloc(self.expr_block(block_stmts))
}

fn contract_check_with_postcond_builder(
&mut self,
stmts: &'hir mut [rustc_hir::Stmt<'hir>],
postcond_checker: &'hir rustc_hir::Expr<'_>,
span: rustc_span::Span,
) -> (rustc_hir::Stmt<'hir>, &'hir rustc_hir::Expr<'hir>) {
let block_closure =
self.contract_check_with_postcond_builder_closure(stmts, postcond_checker, span);

let (builder_ident, builder_hir_id, builder_decl) =
self.bind_expression(block_closure, span, "__ensures_builder");
let builder_ident_expr = self.expr_ident(span, builder_ident, builder_hir_id);

(builder_decl, builder_ident_expr)
}

fn contract_check_with_postcond_builder_closure(
&mut self,
stmts: &'hir mut [rustc_hir::Stmt<'hir>],
postcond_checker: &'hir rustc_hir::Expr<'_>,
span: rustc_span::Span,
) -> &'hir mut rustc_hir::Expr<'hir> {
let stmts = self.block_all(span, stmts, Some(postcond_checker));
let stmts = self.expr_block(stmts);
let closure = self.expr_closure(span, stmts);
self.arena.alloc(closure)
}

fn option_none_block(&mut self, span: rustc_span::Span) -> &'hir mut rustc_hir::Expr<'hir> {
let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
span,
rustc_hir::lang_items::LangItem::OptionNone,
Default::default(),
));
let else_block = self.block_expr(none_expr);
self.arena.alloc(self.expr_block(else_block))
}

fn wrap_body_with_contract_check(
&mut self,
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
contract_check: &'hir rustc_hir::Expr<'hir>,
postcond_span: rustc_span::Span,
) -> &'hir rustc_hir::Block<'hir> {
let check_ident: rustc_span::Ident =
rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
let (check_hir_id, postcond_decl) = {
// Set up the postcondition `let` statement.
let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
postcond_span,
check_ident,
rustc_hir::BindingMode::NONE,
);
(
check_hir_id,
self.stmt_let_pat(
None,
postcond_span,
Some(contract_check),
self.arena.alloc(checker_pat),
rustc_hir::LocalSource::Contract,
),
)
};
let (check_ident, check_hir_id, postcond_decl) =
self.bind_expression(contract_check, postcond_span, "__ensures_checker");

// Install contract_ensures so we will intercept `return` statements,
// then lower the body.
Expand All @@ -274,6 +332,26 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
wrapped_body
}

fn bind_expression(
&mut self,
expr: &'hir rustc_hir::Expr<'hir>,
span: rustc_span::Span,
var_name: &str,
) -> (rustc_span::Ident, rustc_hir::HirId, rustc_hir::Stmt<'hir>) {
let ident = rustc_span::Ident::from_str_and_span(var_name, span);
let (pat, hir_id) =
self.pat_ident_binding_mode_mut(span, ident, rustc_hir::BindingMode::NONE);

let decl = self.stmt_let_pat(
None,
span,
Some(expr),
self.arena.alloc(pat),
rustc_hir::LocalSource::Contract,
);
(ident, hir_id, decl)
}

/// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
/// a contract ensures clause, if it exists.
pub(super) fn checked_return(
Expand Down Expand Up @@ -307,20 +385,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// ret
// }
// }
let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);

// Set up the return `let` statement.
let (ret_pat, ret_hir_id) =
self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);

let ret_stmt = self.stmt_let_pat(
None,
span,
Some(expr),
self.arena.alloc(ret_pat),
rustc_hir::LocalSource::Contract,
);

let (ret_ident, ret_hir_id, ret_stmt) = self.bind_expression(expr, span, "__ret");
let ret = self.expr_ident(span, ret_ident, ret_hir_id);

let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
Expand Down Expand Up @@ -355,3 +420,13 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
}
}

fn span_of_stmts<'hir>(
stmts: &'hir [rustc_hir::Stmt<'_>],
default_span: rustc_span::Span,
) -> rustc_span::Span {
match stmts {
[] => default_span,
[first, ..] => first.span.to(default_span),
}
}
42 changes: 42 additions & 0 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,48 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.expr(span, hir::ExprKind::Lit(Spanned { node: LitKind::Bool(val), span }))
}

pub(super) fn expr_closure(
&mut self,
span: Span,
body_expr: hir::Expr<'hir>,
) -> hir::Expr<'hir> {
let closure_node_id = self.next_node_id();

let closure_def_id = self.create_def(
closure_node_id,
None,
hir::def::DefKind::Closure,
hir::definitions::DefPathData::LateClosure,
span,
);

let hir_id = self.lower_node_id(closure_node_id);
let body_id = self.lower_body(|_| (Default::default(), body_expr));

let fn_decl = self.arena.alloc(hir::FnDecl {
inputs: &[],
output: hir::FnRetTy::DefaultReturn(span),
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
lifetime_elision_allowed: true,
});

let closure = self.arena.alloc(hir::Closure {
def_id: closure_def_id,
binder: hir::ClosureBinder::Default,
constness: hir::Constness::NotConst,
capture_clause: hir::CaptureBy::Ref,
bound_generic_params: &[],
fn_decl,
body: body_id,
fn_decl_span: span,
fn_arg_span: None,
kind: hir::ClosureKind::Closure,
});

hir::Expr { hir_id, kind: hir::ExprKind::Closure(closure), span }
}

pub(super) fn expr(&mut self, span: Span, kind: hir::ExprKind<'hir>) -> hir::Expr<'hir> {
let hir_id = self.next_id();
hir::Expr { hir_id, kind, span: self.lower_span(span) }
Expand Down
Loading
Loading