1use rustc_errors::{Applicability, Diag};
2use rustc_hir::def::{CtorOf, DefKind, Res};
3use rustc_hir::def_id::LocalDefId;
4use rustc_hir::{self as hir, ExprKind, HirId, PatKind};
5use rustc_hir_pretty::ty_to_string;
6use rustc_middle::ty::{self, Ty};
7use rustc_span::Span;
8use rustc_trait_selection::traits::{
9 MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
10};
11use tracing::{debug, instrument};
12
13use crate::coercion::{AsCoercionSite, CoerceMany};
14use crate::{Diverges, Expectation, FnCtxt, GatherLocalsVisitor, Needs};
15
16impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
17 #[instrument(skip(self), level = "debug", ret)]
18 pub(crate) fn check_expr_match(
19 &self,
20 expr: &'tcx hir::Expr<'tcx>,
21 scrut: &'tcx hir::Expr<'tcx>,
22 arms: &'tcx [hir::Arm<'tcx>],
23 orig_expected: Expectation<'tcx>,
24 match_src: hir::MatchSource,
25 ) -> Ty<'tcx> {
26 let tcx = self.tcx;
27
28 let acrb = arms_contain_ref_bindings(arms);
29 let scrutinee_ty = self.demand_scrutinee_type(scrut, acrb, arms.is_empty());
30 debug!(?scrutinee_ty);
31
32 if arms.is_empty() {
34 self.diverges.set(self.diverges.get() | Diverges::always(expr.span));
35 return tcx.types.never;
36 }
37
38 self.warn_arms_when_scrutinee_diverges(arms);
39
40 let scrut_diverges = self.diverges.replace(Diverges::Maybe);
42
43 let scrut_span = scrut.span.find_ancestor_inside(expr.span).unwrap_or(scrut.span);
45 for arm in arms {
46 GatherLocalsVisitor::gather_from_arm(self, arm);
47
48 self.check_pat_top(arm.pat, scrutinee_ty, Some(scrut_span), Some(scrut), None);
49 }
50
51 let mut all_arms_diverge = Diverges::WarnedAlways;
61
62 let expected =
63 orig_expected.try_structurally_resolve_and_adjust_for_branches(self, expr.span);
64 debug!(?expected);
65
66 let mut coercion = {
67 let coerce_first = match expected {
68 Expectation::ExpectHasType(ety) if ety != tcx.types.unit => ety,
74 _ => self.next_ty_var(expr.span),
75 };
76 CoerceMany::with_coercion_sites(coerce_first, arms)
77 };
78
79 let mut prior_non_diverging_arms = vec![]; let mut prior_arm = None;
81 for arm in arms {
82 self.diverges.set(Diverges::Maybe);
83
84 if let Some(e) = &arm.guard {
85 self.check_expr_has_type_or_error(e, tcx.types.bool, |_| {});
86
87 }
92
93 let arm_ty = self.check_expr_with_expectation(arm.body, expected);
98 all_arms_diverge &= self.diverges.get();
99 let tail_defines_return_position_impl_trait =
100 self.return_position_impl_trait_from_match_expectation(orig_expected);
101
102 let (arm_block_id, arm_span) = if let hir::ExprKind::Block(blk, _) = arm.body.kind {
103 (Some(blk.hir_id), self.find_block_span(blk))
104 } else {
105 (None, arm.body.span)
106 };
107
108 let code = match prior_arm {
109 None => ObligationCauseCode::BlockTailExpression(arm.body.hir_id, match_src),
112 Some((prior_arm_block_id, prior_arm_ty, prior_arm_span)) => {
113 ObligationCauseCode::MatchExpressionArm(Box::new(MatchExpressionArmCause {
114 arm_block_id,
115 arm_span,
116 arm_ty,
117 prior_arm_block_id,
118 prior_arm_ty,
119 prior_arm_span,
120 scrut_span: scrut.span,
121 expr_span: expr.span,
122 source: match_src,
123 prior_non_diverging_arms: prior_non_diverging_arms.clone(),
124 tail_defines_return_position_impl_trait,
125 }))
126 }
127 };
128 let cause = self.cause(arm_span, code);
129
130 coercion.coerce_inner(
135 self,
136 &cause,
137 Some(arm.body),
138 arm_ty,
139 |err| {
140 self.explain_never_type_coerced_to_unit(err, arm, arm_ty, prior_arm, expr);
141 },
142 false,
143 );
144
145 if !arm_ty.is_never() {
146 prior_arm = Some((arm_block_id, arm_ty, arm_span));
150
151 prior_non_diverging_arms.push(arm_span);
152 if prior_non_diverging_arms.len() > 5 {
153 prior_non_diverging_arms.remove(0);
154 }
155 }
156 }
157
158 if let (Diverges::Always { .. }, hir::MatchSource::Normal) = (all_arms_diverge, match_src) {
165 all_arms_diverge = Diverges::Always {
166 span: expr.span,
167 custom_note: Some(
168 "any code following this `match` expression is unreachable, as all arms diverge",
169 ),
170 };
171 }
172
173 self.diverges.set(scrut_diverges | all_arms_diverge);
175
176 coercion.complete(self)
177 }
178
179 fn explain_never_type_coerced_to_unit(
180 &self,
181 err: &mut Diag<'_>,
182 arm: &hir::Arm<'tcx>,
183 arm_ty: Ty<'tcx>,
184 prior_arm: Option<(Option<hir::HirId>, Ty<'tcx>, Span)>,
185 expr: &hir::Expr<'tcx>,
186 ) {
187 if let hir::ExprKind::Block(block, _) = arm.body.kind
188 && let Some(expr) = block.expr
189 && let arm_tail_ty = self.node_ty(expr.hir_id)
190 && arm_tail_ty.is_never()
191 && !arm_ty.is_never()
192 {
193 err.span_label(
194 expr.span,
195 format!(
196 "this expression is of type `!`, but it is coerced to `{arm_ty}` due to its \
197 surrounding expression",
198 ),
199 );
200 self.suggest_mismatched_types_on_tail(
201 err,
202 expr,
203 arm_ty,
204 prior_arm.map_or(arm_tail_ty, |(_, ty, _)| ty),
205 expr.hir_id,
206 );
207 }
208 self.suggest_removing_semicolon_for_coerce(err, expr, arm_ty, prior_arm)
209 }
210
211 fn suggest_removing_semicolon_for_coerce(
212 &self,
213 diag: &mut Diag<'_>,
214 expr: &hir::Expr<'tcx>,
215 arm_ty: Ty<'tcx>,
216 prior_arm: Option<(Option<hir::HirId>, Ty<'tcx>, Span)>,
217 ) {
218 let Some(body) = self.tcx.hir_maybe_body_owned_by(self.body_id) else {
220 return;
221 };
222 let hir::ExprKind::Block(block, _) = body.value.kind else {
223 return;
224 };
225 let Some(hir::Stmt { kind: hir::StmtKind::Semi(last_expr), span: semi_span, .. }) =
226 block.innermost_block().stmts.last()
227 else {
228 return;
229 };
230 if last_expr.hir_id != expr.hir_id {
231 return;
232 }
233
234 let Some(ret) =
236 self.tcx.hir_node_by_def_id(self.body_id).fn_decl().map(|decl| decl.output.span())
237 else {
238 return;
239 };
240
241 let can_coerce_to_return_ty = match self.ret_coercion.as_ref() {
242 Some(ret_coercion) => {
243 let ret_ty = ret_coercion.borrow().expected_ty();
244 let ret_ty = self.infcx.shallow_resolve(ret_ty);
245 self.may_coerce(arm_ty, ret_ty)
246 && prior_arm.is_none_or(|(_, ty, _)| self.may_coerce(ty, ret_ty))
247 && !matches!(ret_ty.kind(), ty::Alias(ty::Opaque, ..))
249 }
250 _ => false,
251 };
252 if !can_coerce_to_return_ty {
253 return;
254 }
255
256 let semi = expr.span.shrink_to_hi().with_hi(semi_span.hi());
257 let sugg = crate::errors::RemoveSemiForCoerce { expr: expr.span, ret, semi };
258 diag.subdiagnostic(sugg);
259 }
260
261 fn warn_arms_when_scrutinee_diverges(&self, arms: &'tcx [hir::Arm<'tcx>]) {
264 for arm in arms {
265 self.warn_if_unreachable(arm.body.hir_id, arm.body.span, "arm");
266 }
267 }
268
269 pub(super) fn if_fallback_coercion<T>(
273 &self,
274 if_span: Span,
275 cond_expr: &'tcx hir::Expr<'tcx>,
276 then_expr: &'tcx hir::Expr<'tcx>,
277 coercion: &mut CoerceMany<'tcx, '_, T>,
278 ) -> bool
279 where
280 T: AsCoercionSite,
281 {
282 let hir_id = self.tcx.parent_hir_id(self.tcx.parent_hir_id(then_expr.hir_id));
285 let ret_reason = self.maybe_get_coercion_reason(hir_id, if_span);
286 let cause = self.cause(if_span, ObligationCauseCode::IfExpressionWithNoElse);
287 let mut error = false;
288 coercion.coerce_forced_unit(
289 self,
290 &cause,
291 |err| self.explain_if_expr(err, ret_reason, if_span, cond_expr, then_expr, &mut error),
292 false,
293 );
294 error
295 }
296
297 fn explain_if_expr(
300 &self,
301 err: &mut Diag<'_>,
302 ret_reason: Option<(Span, String)>,
303 if_span: Span,
304 cond_expr: &'tcx hir::Expr<'tcx>,
305 then_expr: &'tcx hir::Expr<'tcx>,
306 error: &mut bool,
307 ) {
308 if let Some((if_span, msg)) = ret_reason {
309 err.span_label(if_span, msg);
310 } else if let ExprKind::Block(block, _) = then_expr.kind
311 && let Some(expr) = block.expr
312 {
313 err.span_label(expr.span, "found here");
314 }
315 err.note("`if` expressions without `else` evaluate to `()`");
316 err.help("consider adding an `else` block that evaluates to the expected type");
317 *error = true;
318 if let ExprKind::Let(hir::LetExpr { span, pat, init, .. }) = cond_expr.kind
319 && let ExprKind::Block(block, _) = then_expr.kind
320 && let PatKind::TupleStruct(qpath, ..) | PatKind::Struct(qpath, ..) = pat.kind
323 && let hir::QPath::Resolved(_, path) = qpath
324 {
325 match path.res {
326 Res::Def(DefKind::Ctor(CtorOf::Struct, _), _) => {
327 }
330 Res::Def(DefKind::Ctor(CtorOf::Variant, _), def_id)
331 if self
332 .tcx
333 .adt_def(self.tcx.parent(self.tcx.parent(def_id)))
334 .variants()
335 .len()
336 == 1 =>
337 {
338 }
341 _ => return,
342 }
343
344 let mut sugg = vec![
345 (if_span.until(*span), String::new()),
347 ];
348 match (block.stmts, block.expr) {
349 ([first, ..], Some(expr)) => {
350 let padding = self
351 .tcx
352 .sess
353 .source_map()
354 .indentation_before(first.span)
355 .unwrap_or_else(|| String::new());
356 sugg.extend([
357 (init.span.between(first.span), format!(";\n{padding}")),
358 (expr.span.shrink_to_hi().with_hi(block.span.hi()), String::new()),
359 ]);
360 }
361 ([], Some(expr)) => {
362 let padding = self
363 .tcx
364 .sess
365 .source_map()
366 .indentation_before(expr.span)
367 .unwrap_or_else(|| String::new());
368 sugg.extend([
369 (init.span.between(expr.span), format!(";\n{padding}")),
370 (expr.span.shrink_to_hi().with_hi(block.span.hi()), String::new()),
371 ]);
372 }
373 (_, None) => return,
376 }
377 err.multipart_suggestion(
378 "consider using an irrefutable `let` binding instead",
379 sugg,
380 Applicability::MaybeIncorrect,
381 );
382 }
383 }
384
385 pub(crate) fn maybe_get_coercion_reason(
386 &self,
387 hir_id: hir::HirId,
388 sp: Span,
389 ) -> Option<(Span, String)> {
390 let node = self.tcx.hir_node(hir_id);
391 if let hir::Node::Block(block) = node {
392 let parent = self.tcx.parent_hir_node(self.tcx.parent_hir_id(block.hir_id));
394 if let (Some(expr), hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn { .. }, .. })) =
395 (&block.expr, parent)
396 {
397 if expr.span == sp {
399 return self.get_fn_decl(hir_id).map(|(_, fn_decl)| {
400 let (ty, span) = match fn_decl.output {
401 hir::FnRetTy::DefaultReturn(span) => ("()".to_string(), span),
402 hir::FnRetTy::Return(ty) => (ty_to_string(&self.tcx, ty), ty.span),
403 };
404 (span, format!("expected `{ty}` because of this return type"))
405 });
406 }
407 }
408 }
409 if let hir::Node::LetStmt(hir::LetStmt { ty: Some(_), pat, .. }) = node {
410 return Some((pat.span, "expected because of this assignment".to_string()));
411 }
412 None
413 }
414
415 pub(crate) fn if_cause(
416 &self,
417 expr_id: HirId,
418 else_expr: &'tcx hir::Expr<'tcx>,
419 tail_defines_return_position_impl_trait: Option<LocalDefId>,
420 ) -> ObligationCause<'tcx> {
421 let error_sp = self.find_block_span_from_hir_id(else_expr.hir_id);
422
423 self.cause(
425 error_sp,
426 ObligationCauseCode::IfExpression { expr_id, tail_defines_return_position_impl_trait },
427 )
428 }
429
430 pub(super) fn demand_scrutinee_type(
431 &self,
432 scrut: &'tcx hir::Expr<'tcx>,
433 contains_ref_bindings: Option<hir::Mutability>,
434 no_arms: bool,
435 ) -> Ty<'tcx> {
436 if let Some(m) = contains_ref_bindings {
489 self.check_expr_with_needs(scrut, Needs::maybe_mut_place(m))
490 } else if no_arms {
491 self.check_expr(scrut)
492 } else {
493 let scrut_ty = self.next_ty_var(scrut.span);
497 self.check_expr_has_type_or_error(scrut, scrut_ty, |_| {});
498 scrut_ty
499 }
500 }
501
502 pub(crate) fn return_position_impl_trait_from_match_expectation(
507 &self,
508 expectation: Expectation<'tcx>,
509 ) -> Option<LocalDefId> {
510 let expected_ty = expectation.to_option(self)?;
511 let (def_id, args) = match *expected_ty.kind() {
512 ty::Alias(ty::Opaque, alias_ty) => (alias_ty.def_id.as_local()?, alias_ty.args),
514 ty::Infer(ty::TyVar(_)) => self
516 .inner
517 .borrow_mut()
518 .opaque_types()
519 .iter_opaque_types()
520 .find(|(_, v)| v.ty == expected_ty)
521 .map(|(k, _)| (k.def_id, k.args))?,
522 _ => return None,
523 };
524 let hir::OpaqueTyOrigin::FnReturn { parent: parent_def_id, .. } =
525 self.tcx.local_opaque_ty_origin(def_id)
526 else {
527 return None;
528 };
529 if &args[0..self.tcx.generics_of(parent_def_id).count()]
530 != ty::GenericArgs::identity_for_item(self.tcx, parent_def_id).as_slice()
531 {
532 return None;
533 }
534 Some(def_id)
535 }
536}
537
538fn arms_contain_ref_bindings<'tcx>(arms: &'tcx [hir::Arm<'tcx>]) -> Option<hir::Mutability> {
539 arms.iter().filter_map(|a| a.pat.contains_explicit_ref_binding()).max()
540}