Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_infer::traits::DynCompatibilityViolation;
use rustc_macros::{TypeFoldable, TypeVisitable};
use rustc_middle::middle::stability::AllowUnstable;
use rustc_middle::mir::interpret::LitToConstInput;
use rustc_middle::mir::interpret::{LitToConstInput, const_lit_matches_ty};
use rustc_middle::ty::print::PrintPolyTraitRefExt as _;
use rustc_middle::ty::{
self, Const, GenericArgKind, GenericArgsRef, GenericParamDefKind, Ty, TyCtxt,
Expand Down Expand Up @@ -2803,8 +2803,17 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
span: Span,
) -> Const<'tcx> {
let tcx = self.tcx();
if let LitKind::Err(guar) = *kind {
return ty::Const::new_error(tcx, guar);
}
let input = LitToConstInput { lit: *kind, ty, neg };
tcx.at(span).lit_to_const(input)
match tcx.at(span).lit_to_const(input) {
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
None => {
let e = tcx.dcx().span_err(span, "type annotations needed for the literal");
ty::Const::new_error(tcx, e)
}
}
}

#[instrument(skip(self), level = "debug")]
Expand Down Expand Up @@ -2833,11 +2842,15 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
_ => None,
};

lit_input
// Allow the `ty` to be an alias type, though we cannot handle it here, we just go through
// the more expensive anon const code path.
.filter(|l| !l.ty.has_aliases())
.map(|l| tcx.at(expr.span).lit_to_const(l))
lit_input.and_then(|l| {
if const_lit_matches_ty(tcx, &l.lit, l.ty, l.neg) {
tcx.at(expr.span)
.lit_to_const(l)
.map(|value| ty::Const::new_value(tcx, value.valtree, value.ty))
} else {
None
}
})
}

fn require_type_const_attribute(
Expand Down
38 changes: 38 additions & 0 deletions compiler/rustc_middle/src/mir/interpret/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,44 @@ pub struct LitToConstInput<'tcx> {
pub neg: bool,
}

pub fn const_lit_matches_ty<'tcx>(
tcx: TyCtxt<'tcx>,
kind: &LitKind,
ty: Ty<'tcx>,
neg: bool,
) -> bool {
match (*kind, ty.kind()) {
(LitKind::Str(..), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => true,
(LitKind::Str(..), ty::Str) if tcx.features().deref_patterns() => true,
(LitKind::ByteStr(..), ty::Ref(_, inner_ty, _))
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
&& matches!(ty.kind(), ty::Uint(ty::UintTy::U8)) =>
{
true
}
(LitKind::ByteStr(..), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
if tcx.features().deref_patterns()
&& matches!(inner_ty.kind(), ty::Uint(ty::UintTy::U8)) =>
{
true
}
(LitKind::Byte(..), ty::Uint(ty::UintTy::U8)) => true,
(LitKind::CStr(..), ty::Ref(_, inner_ty, _))
if matches!(inner_ty.kind(), ty::Adt(def, _)
if tcx.is_lang_item(def.did(), rustc_hir::LangItem::CStr)) =>
{
true
}
(LitKind::Int(..), ty::Uint(_)) if !neg => true,
(LitKind::Int(..), ty::Int(_)) => true,
(LitKind::Bool(..), ty::Bool) => true,
(LitKind::Float(..), ty::Float(_)) => true,
(LitKind::Char(..), ty::Char) => true,
(LitKind::Err(..), _) => true,
_ => false,
}
}

#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct AllocId(pub NonZero<u64>);

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ rustc_queries! {
// FIXME get rid of this with valtrees
query lit_to_const(
key: LitToConstInput<'tcx>
) -> ty::Const<'tcx> {
) -> Option<ty::Value<'tcx>> {
desc { "converting literal to const" }
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/query/erase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ impl Erasable for Option<ty::EarlyBinder<'_, Ty<'_>>> {
type Storage = [u8; size_of::<Option<ty::EarlyBinder<'static, Ty<'static>>>>()];
}

impl Erasable for Option<ty::Value<'_>> {
type Storage = [u8; size_of::<Option<ty::Value<'static>>>()];
}

impl Erasable for rustc_hir::MaybeOwner<'_> {
type Storage = [u8; size_of::<rustc_hir::MaybeOwner<'static>>()];
}
Expand Down
89 changes: 55 additions & 34 deletions compiler/rustc_mir_build/src/thir/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ use rustc_ast::{self as ast, UintTy};
use rustc_hir::LangItem;
use rustc_middle::bug;
use rustc_middle::mir::interpret::LitToConstInput;
use rustc_middle::ty::{self, ScalarInt, TyCtxt, TypeVisitableExt as _};
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt, TypeVisitableExt as _};
use tracing::trace;

use crate::builder::parse_float_into_scalar;

pub(crate) fn lit_to_const<'tcx>(
tcx: TyCtxt<'tcx>,
lit_input: LitToConstInput<'tcx>,
) -> ty::Const<'tcx> {
let LitToConstInput { lit, ty, neg } = lit_input;
) -> Option<ty::Value<'tcx>> {
let LitToConstInput { lit, ty: expected_ty, neg } = lit_input;

if let Err(guar) = ty.error_reported() {
return ty::Const::new_error(tcx, guar);
if expected_ty.error_reported().is_err() {
return None;
}

let trunc = |n, width: ty::UintTy| {
Expand All @@ -32,63 +32,84 @@ pub(crate) fn lit_to_const<'tcx>(
.unwrap_or_else(|| bug!("expected to create ScalarInt from uint {:?}", result))
};

let valtree = match (lit, ty.kind()) {
(ast::LitKind::Str(s, _), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => {
let (valtree, valtree_ty) = match (lit, expected_ty.kind()) {
(ast::LitKind::Str(s, _), _) => {
let str_bytes = s.as_str().as_bytes();
ty::ValTree::from_raw_bytes(tcx, str_bytes)
}
(ast::LitKind::Str(s, _), ty::Str) if tcx.features().deref_patterns() => {
// String literal patterns may have type `str` if `deref_patterns` is enabled, in order
// to allow `deref!("..."): String`.
let str_bytes = s.as_str().as_bytes();
ty::ValTree::from_raw_bytes(tcx, str_bytes)
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, tcx.types.str_);
(ty::ValTree::from_raw_bytes(tcx, str_bytes), valtree_ty)
}
(ast::LitKind::ByteStr(byte_sym, _), ty::Ref(_, inner_ty, _))
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
&& let ty::Uint(UintTy::U8) = ty.kind() =>
{
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
}
(ast::LitKind::ByteStr(byte_sym, _), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
if tcx.features().deref_patterns()
&& let ty::Uint(UintTy::U8) = inner_ty.kind() =>
{
// Byte string literal patterns may have type `[u8]` or `[u8; N]` if `deref_patterns` is
// enabled, in order to allow, e.g., `deref!(b"..."): Vec<u8>`.
ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str())
(ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()), expected_ty)
}
(ast::LitKind::Byte(n), ty::Uint(ty::UintTy::U8)) => {
ty::ValTree::from_scalar_int(tcx, n.into())
(ast::LitKind::ByteStr(byte_sym, _), _) => {
let valtree = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
let valtree_ty = Ty::new_array(tcx, tcx.types.u8, byte_sym.as_byte_str().len() as u64);
(valtree, valtree_ty)
}
(ast::LitKind::CStr(byte_sym, _), ty::Ref(_, inner_ty, _)) if matches!(inner_ty.kind(), ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::CStr)) =>
(ast::LitKind::Byte(n), _) => (ty::ValTree::from_scalar_int(tcx, n.into()), tcx.types.u8),
(ast::LitKind::CStr(byte_sym, _), _)
if let Some(cstr_def_id) = tcx.lang_items().get(LangItem::CStr) =>
{
// A CStr is a newtype around a byte slice, so we create the inner slice here.
// We need a branch for each "level" of the data structure.
let cstr_ty = tcx.type_of(cstr_def_id).skip_binder();
let bytes = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str());
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, *inner_ty)])
let valtree =
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, cstr_ty)]);
let valtree_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_static, cstr_ty);
(valtree, valtree_ty)
}
(ast::LitKind::Int(n, _), ty::Uint(ui)) if !neg => {
(ast::LitKind::Int(n, ast::LitIntType::Unsigned(ui)), _) if !neg => {
let scalar_int = trunc(n.get(), ui);
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, ui))
}
(ast::LitKind::Int(_, ast::LitIntType::Unsigned(_)), _) if neg => return None,
(ast::LitKind::Int(n, ast::LitIntType::Signed(i)), _) => {
let scalar_int =
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, i))
}
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Uint(ui)) if !neg => {
let scalar_int = trunc(n.get(), *ui);
ty::ValTree::from_scalar_int(tcx, scalar_int)
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_uint(tcx, *ui))
}
(ast::LitKind::Int(n, _), ty::Int(i)) => {
(ast::LitKind::Int(n, ast::LitIntType::Unsuffixed), ty::Int(i)) => {
// Unsigned "negation" has the same bitwise effect as signed negation,
// which gets the result we want without additional casts.
let scalar_int =
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
ty::ValTree::from_scalar_int(tcx, scalar_int)
(ty::ValTree::from_scalar_int(tcx, scalar_int), Ty::new_int(tcx, *i))
}
(ast::LitKind::Bool(b), _) => (ty::ValTree::from_scalar_int(tcx, b.into()), tcx.types.bool),
(ast::LitKind::Float(n, ast::LitFloatType::Suffixed(fty)), _) => {
let fty = match fty {
ast::FloatTy::F16 => ty::FloatTy::F16,
ast::FloatTy::F32 => ty::FloatTy::F32,
ast::FloatTy::F64 => ty::FloatTy::F64,
ast::FloatTy::F128 => ty::FloatTy::F128,
};
let bits = parse_float_into_scalar(n, fty, neg)?;
(ty::ValTree::from_scalar_int(tcx, bits), Ty::new_float(tcx, fty))
}
(ast::LitKind::Bool(b), ty::Bool) => ty::ValTree::from_scalar_int(tcx, b.into()),
(ast::LitKind::Float(n, _), ty::Float(fty)) => {
let bits = parse_float_into_scalar(n, *fty, neg).unwrap_or_else(|| {
tcx.dcx().bug(format!("couldn't parse float literal: {:?}", lit_input.lit))
});
ty::ValTree::from_scalar_int(tcx, bits)
(ast::LitKind::Float(n, ast::LitFloatType::Unsuffixed), ty::Float(fty)) => {
let bits = parse_float_into_scalar(n, *fty, neg)?;
(ty::ValTree::from_scalar_int(tcx, bits), Ty::new_float(tcx, *fty))
}
(ast::LitKind::Char(c), ty::Char) => ty::ValTree::from_scalar_int(tcx, c.into()),
(ast::LitKind::Err(guar), _) => return ty::Const::new_error(tcx, guar),
_ => return ty::Const::new_misc_error(tcx),
(ast::LitKind::Char(c), _) => (ty::ValTree::from_scalar_int(tcx, c.into()), tcx.types.char),
(ast::LitKind::Err(_), _) => return None,
_ => return None,
};

ty::Const::new_value(tcx, valtree, ty)
Some(ty::Value { ty: valtree_ty, valtree })
}
25 changes: 21 additions & 4 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use std::cmp::Ordering;
use std::sync::Arc;

use rustc_abi::{FieldIdx, Integer};
use rustc_ast::LitKind;
use rustc_data_structures::assert_matches;
use rustc_errors::codes::*;
use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
use rustc_hir::{self as hir, RangeEnd};
use rustc_index::Idx;
use rustc_middle::mir::interpret::LitToConstInput;
use rustc_middle::mir::interpret::{LitToConstInput, const_lit_matches_ty};
use rustc_middle::thir::{
Ascription, DerefPatBorrowMode, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary,
};
Expand Down Expand Up @@ -197,8 +198,6 @@ impl<'tcx> PatCtxt<'tcx> {
expr: Option<&'tcx hir::PatExpr<'tcx>>,
ty: Ty<'tcx>,
) -> Result<(), ErrorGuaranteed> {
use rustc_ast::ast::LitKind;

let Some(expr) = expr else {
return Ok(());
};
Expand Down Expand Up @@ -696,7 +695,25 @@ impl<'tcx> PatCtxt<'tcx> {

let pat_ty = self.typeck_results.node_type(pat.hir_id);
let lit_input = LitToConstInput { lit: lit.node, ty: pat_ty, neg: *negated };
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
let error_const = || {
if let Some(guar) = self.typeck_results.tainted_by_errors {
ty::Const::new_error(self.tcx, guar)
} else {
ty::Const::new_error_with_message(
self.tcx,
expr.span,
"literal does not match expected type",
)
}
};
let constant = if const_lit_matches_ty(self.tcx, &lit.node, pat_ty, *negated) {
match self.tcx.at(expr.span).lit_to_const(lit_input) {
Some(value) => ty::Const::new_value(self.tcx, value.valtree, pat_ty),
None => error_const(),
}
} else {
error_const()
};
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span)
}
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_ty_utils/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ fn recurse_build<'tcx>(
}
&ExprKind::Literal { lit, neg } => {
let sp = node.span;
tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg })
match tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg }) {
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
None => ty::Const::new_misc_error(tcx),
}
}
&ExprKind::NonHirLiteral { lit, user_ty: _ } => {
let val = ty::ValTree::from_scalar_int(tcx, lit);
Expand Down
3 changes: 0 additions & 3 deletions tests/crashes/133966.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
struct ConstBytes<const T: &'static [*mut u8; 3]>
//~^ ERROR rustc_dump_predicates
//~| NOTE Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }

where
ConstBytes<b"AAA">: Sized;
//~^ ERROR mismatched types
//~| NOTE expected `&[*mut u8; 3]`, found `&[u8; 3]`
//~| NOTE expected reference `&'static [*mut u8; 3]`
//~| NOTE found reference `&'static [u8; 3]`

fn main() {}
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
error: rustc_dump_predicates
--> $DIR/byte-string-u8-validation.rs:8:1
|
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
= note: Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }

error[E0308]: mismatched types
--> $DIR/byte-string-u8-validation.rs:13:16
--> $DIR/byte-string-u8-validation.rs:14:16
|
LL | ConstBytes<b"AAA">: Sized;
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
|
= note: expected reference `&'static [*mut u8; 3]`
found reference `&'static [u8; 3]`

error: rustc_dump_predicates
--> $DIR/byte-string-u8-validation.rs:8:1
|
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
= note: Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0308`.
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ LL | struct ConstBytes<const T: &'static [*mut u8; 3]>;
= note: `[*mut u8; 3]` must implement `ConstParamTy_`, but it does not

error[E0308]: mismatched types
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
|
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
|
= note: expected reference `&'static [*mut u8; 3]`
found reference `&'static [u8; 3]`

error[E0308]: mismatched types
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
|
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
|
= note: expected reference `&'static [*mut u8; 3]`
found reference `&'static [u8; 3]`
Expand Down
Loading
Loading