diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs index 242a0b0b4ff9..d999613cf190 100644 --- a/crates/hir-def/src/expr_store/lower.rs +++ b/crates/hir-def/src/expr_store/lower.rs @@ -204,7 +204,7 @@ pub(crate) fn lower_type_ref( (store, source_map, type_ref) } -pub(crate) fn lower_generic_params( +pub fn lower_generic_params( db: &dyn DefDatabase, module: ModuleId, def: GenericDefId, diff --git a/crates/hir-ty/src/consteval.rs b/crates/hir-ty/src/consteval.rs index 2c43feeb3b1a..d6580d3752f6 100644 --- a/crates/hir-ty/src/consteval.rs +++ b/crates/hir-ty/src/consteval.rs @@ -23,6 +23,7 @@ use crate::{ db::{AnonConstId, AnonConstLoc, GeneralConstId, HirDatabase}, display::DisplayTarget, generics::Generics, + lower::LoweringMode, mir::{MirEvalError, MirLowerError, pad16}, next_solver::{ Allocation, Const, ConstKind, Consts, DbInterner, DefaultAny, GenericArgs, ParamConst, @@ -305,6 +306,7 @@ pub(crate) enum CreateConstError<'db> { DoesNotResolve, ConstHasGenerics, UnderscoreExpr, + AnonConstInterningDisabled, TypeMismatch { #[expect(unused, reason = "will need this for diagnostics")] actual: Ty<'db>, @@ -355,6 +357,7 @@ pub(crate) fn create_anon_const<'a, 'db>( expected_ty: Ty<'db>, generics: &dyn Fn() -> &'a Generics<'db>, create_var: Option<&mut dyn FnMut(Span) -> Const<'db>>, + lowering_mode: LoweringMode, forbid_params_after: Option, ) -> Result, CreateConstError<'db>> { match &store[expr] { @@ -374,6 +377,10 @@ pub(crate) fn create_anon_const<'a, 'db>( konst } _ => { + let Some(token) = lowering_mode.allow_tracked_structs() else { + return Err(CreateConstError::AnonConstInterningDisabled); + }; + let allow_using_generic_params = forbid_params_after.is_none(); let konst = AnonConstId::new( interner.db, @@ -383,6 +390,7 @@ pub(crate) fn create_anon_const<'a, 'db>( ty: StoredEarlyBinder::bind(expected_ty.store()), allow_using_generic_params, }, + token, ); let args = if allow_using_generic_params { GenericArgs::identity_for_item(interner, owner.generic_def(interner.db).into()) diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index 511ab856107f..99a8bfe7f0ef 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -27,7 +27,7 @@ use crate::{ consteval::ConstEvalError, dyn_compatibility::DynCompatibilityViolation, layout::{Layout, LayoutError}, - lower::{GenericDefaults, TypeAliasBounds}, + lower::{GenericDefaults, TrackedStructToken, TypeAliasBounds}, mir::{BorrowckResult, MirBody, MirLowerError}, next_solver::{ Allocation, Clause, EarlyBinder, GenericArgs, ParamEnv, PolyFnSig, StoredClauses, @@ -421,13 +421,20 @@ pub struct AnonConstLoc { pub(crate) allow_using_generic_params: bool, } -#[salsa_macros::interned(debug, no_lifetime, revisions = usize::MAX)] +#[salsa_macros::interned(debug, no_lifetime, revisions = usize::MAX, constructor = new_)] #[derive(PartialOrd, Ord)] pub struct AnonConstId { #[returns(ref)] pub loc: AnonConstLoc, } +impl AnonConstId { + pub(crate) fn new(db: &dyn DefDatabase, loc: AnonConstLoc, token: TrackedStructToken) -> Self { + _ = token; + AnonConstId::new_(db, loc) + } +} + impl HasModule for AnonConstId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { self.loc(db).owner.module(db) diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 2df2789a2eee..4f2ad2e3de6c 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -91,7 +91,8 @@ use crate::{ unify::resolve_completely::WriteBackCtxt, }, lower::{ - ImplTraitIdx, ImplTraitLoweringMode, LifetimeElisionKind, diagnostics::TyLoweringDiagnostic, + ImplTraitIdx, ImplTraitLoweringMode, LifetimeElisionKind, LoweringMode, + diagnostics::TyLoweringDiagnostic, }, method_resolution::CandidateId, next_solver::{ @@ -116,13 +117,14 @@ use cast::{CastCheck, CastError}; /// The entry point of type inference. fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { - infer_query_with_inspect(db, def, None) + infer_query_with_inspect(db, def, None, LoweringMode::Analysis) } pub fn infer_query_with_inspect<'db>( db: &'db dyn HirDatabase, def: DefWithBodyId, inspect: Option>, + lowering_mode: LoweringMode, ) -> InferenceResult { let _p = tracing::info_span!("infer_query").entered(); let resolver = def.resolver(db); @@ -135,6 +137,7 @@ pub fn infer_query_with_inspect<'db>( &body.store, resolver, true, + lowering_mode, ); if let Some(inspect) = inspect { @@ -202,6 +205,7 @@ fn infer_anon_const_query(db: &dyn HirDatabase, def: AnonConstId) -> InferenceRe store, resolver, loc.allow_using_generic_params, + LoweringMode::Analysis, ); ctx.infer_expr( @@ -1236,6 +1240,7 @@ pub(crate) struct InferenceContext<'body, 'db> { pub(crate) store_owner: ExpressionStoreOwnerId, pub(crate) generic_def: GenericDefId, pub(crate) store: &'body ExpressionStore, + pub(crate) lowering_mode: LoweringMode, /// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext /// and resolve the path via its methods. This will ensure proper error reporting. pub(crate) resolver: Resolver<'db>, @@ -1335,6 +1340,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { store: &'body ExpressionStore, resolver: Resolver<'db>, allow_using_generic_params: bool, + lowering_mode: LoweringMode, ) -> Self { let trait_env = db.trait_environment(generic_def); let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), store_owner); @@ -1369,6 +1375,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { vars_emitted_type_must_be_known_for: FxHashSet::default(), deferred_call_resolutions: FxHashMap::default(), defined_anon_consts: RefCell::new(ThinVec::new()), + lowering_mode, } } @@ -1969,6 +1976,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { expected_ty, &|| self.generics(), Some(&mut |span| self.table.next_const_var(span)), + self.lowering_mode, (!(allow_using_generic_params && self.allow_using_generic_params)).then_some(0), ); diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index f612bdc26697..f1ee91f3ee82 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -111,8 +111,8 @@ pub use infer::{ }; pub use lower::{ GenericDefaults, GenericDefaultsRef, GenericPredicates, ImplTraits, LifetimeElisionKind, - TyDefId, TyLoweringContext, TyLoweringInferVarsCtx, TyLoweringResult, ValueTyDefId, - diagnostics::*, + LoweringMode, TyDefId, TyLoweringContext, TyLoweringInferVarsCtx, TyLoweringResult, + ValueTyDefId, diagnostics::*, }; pub use next_solver::interner::{attach_db, attach_db_allow_change, with_attached_db}; pub use target_feature::TargetFeatures; diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index df83b2abb870..96e1ef117f0b 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -199,6 +199,33 @@ pub trait TyLoweringInferVarsCtx<'db> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LoweringMode { + Analysis, + Ide, +} + +pub(crate) use self::tracked_struct_token::TrackedStructToken; +mod tracked_struct_token { + use super::LoweringMode; + + /// A token that is required to construct tracked structs. + /// This exists to prevent one from accidentally creating a tracked struct outside of a query which may happen for some codepaths. + pub(crate) struct TrackedStructToken { + // #[non_exhaustive] doesn't work for us here, we want it module focused. + _private: (), + } + + impl LoweringMode { + pub(crate) fn allow_tracked_structs(self) -> Option { + match self { + LoweringMode::Analysis => Some(TrackedStructToken { _private: () }), + LoweringMode::Ide => None, + } + } + } +} + pub struct TyLoweringContext<'db, 'a> { pub db: &'db dyn HirDatabase, pub(crate) interner: DbInterner<'db>, @@ -211,6 +238,7 @@ pub struct TyLoweringContext<'db, 'a> { generics: &'a OnceCell>, in_binders: DebruijnIndex, impl_trait_mode: ImplTraitLoweringState, + interning_mode: LoweringMode, /// Tracks types with explicit `?Sized` bounds. pub(crate) unsized_types: FxHashSet>, pub(crate) diagnostics: ThinVec, @@ -247,6 +275,7 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { store, in_binders, impl_trait_mode, + interning_mode: LoweringMode::Analysis, unsized_types: FxHashSet::default(), diagnostics: ThinVec::new(), lifetime_elision, @@ -261,6 +290,11 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { self.lifetime_elision = lifetime_elision; } + pub(crate) fn with_interning_mode(mut self, interning_mode: LoweringMode) -> Self { + self.interning_mode = interning_mode; + self + } + pub(crate) fn with_debruijn( &mut self, debruijn: DebruijnIndex, @@ -384,6 +418,7 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { const_type, &|| self.generics.get_or_init(|| generics(self.db, self.generic_def)), create_var, + self.interning_mode, self.forbid_params_after, ); diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index f6b5adfb6fff..4c76ae901da8 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -1,12 +1,15 @@ //! Trait solving using next trait solver. -use std::hash::Hash; +use std::{cell::OnceCell, hash::Hash}; use base_db::Crate; use hir_def::{ - AdtId, AssocItemId, HasModule, ImplId, Lookup, TraitId, + AdtId, AssocItemId, ExpressionStoreOwnerId, GenericDefId, HasModule, ImplId, Lookup, TraitId, + expr_store::ExpressionStore, + hir::generics::WherePredicate, lang_item::LangItems, nameres::DefMap, + resolver::Resolver, signatures::{ ConstFlags, ConstSignature, EnumFlags, EnumSignature, FnFlags, FunctionSignature, StructFlags, StructSignature, TraitFlags, TraitSignature, TypeAliasFlags, @@ -16,17 +19,20 @@ use hir_def::{ use hir_expand::name::Name; use intern::sym; use rustc_type_ir::{ - TypingMode, + TypeVisitableExt, TypingMode, inherent::{BoundExistentialPredicates, IntoKind}, }; use crate::{ - Span, + LifetimeElisionKind, Span, TyLoweringContext, db::HirDatabase, + generics::Generics, + lower::LoweringMode, next_solver::{ DbInterner, GenericArgs, ParamEnv, StoredClauses, Ty, TyKind, infer::{ DbInternerInferExt, InferCtxt, + select::EvaluationResult, traits::{Obligation, ObligationCause}, }, obligation_ctxt::ObligationCtxt, @@ -153,6 +159,90 @@ pub fn implements_trait_unique_with_infcx<'db>( infcx.predicate_must_hold_modulo_regions(&obligation) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WherePredicateEvaluation { + Holds, + NotProven, + HasErrors, + NoObligations, +} + +/// This should not be used in `hir-ty`, only in `hir`. +/// This is exposed to allow the IDE to evaluate arbitrary predicates. +pub fn where_predicate_must_hold<'db>( + db: &'db dyn HirDatabase, + resolver: &Resolver<'db>, + store: &ExpressionStore, + def: ExpressionStoreOwnerId, + generic_def: GenericDefId, + env: ParamEnvAndCrate<'db>, + predicate: &WherePredicate, +) -> WherePredicateEvaluation { + let interner = DbInterner::new_with(db, env.krate); + let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis); + let generics = OnceCell::>::new(); + let mut ctx = TyLoweringContext::new( + db, + resolver, + store, + def, + generic_def, + &generics, + LifetimeElisionKind::Infer, + ) + .with_interning_mode(LoweringMode::Ide); + let clauses = + ctx.lower_where_predicate(predicate, false).map(|(clause, _)| clause).collect::>(); + + if !ctx.diagnostics.is_empty() + || clauses.iter().any(|clause| clause.as_predicate().references_error()) + { + return WherePredicateEvaluation::HasErrors; + } + + if clauses.is_empty() { + return if ctx.unsized_types.is_empty() { + WherePredicateEvaluation::HasErrors + } else { + WherePredicateEvaluation::NoObligations + }; + } + + let result = infcx.probe(|snapshot| { + let mut ocx = ObligationCtxt::new(&infcx); + for clause in clauses { + let obligation = Obligation::new( + interner, + ObligationCause::dummy(), + env.param_env, + clause.as_predicate(), + ); + ocx.register_obligation(obligation); + } + + let mut result = EvaluationResult::EvaluatedToOk; + for error in ocx.evaluate_obligations_error_on_ambiguity() { + if error.is_true_error() { + return EvaluationResult::EvaluatedToErr; + } + result = result.max(EvaluationResult::EvaluatedToAmbig); + } + if infcx.opaque_types_added_in_snapshot(snapshot) { + result.max(EvaluationResult::EvaluatedToOkModuloOpaqueTypes) + } else if infcx.region_constraints_added_in_snapshot(snapshot) { + result.max(EvaluationResult::EvaluatedToOkModuloRegions) + } else { + result + } + }); + + if result.must_apply_modulo_regions() { + WherePredicateEvaluation::Holds + } else { + WherePredicateEvaluation::NotProven + } +} + pub fn is_inherent_impl_coherent(db: &dyn HirDatabase, def_map: &DefMap, impl_id: ImplId) -> bool { let self_ty = db.impl_self_ty(impl_id).instantiate_identity().skip_norm_wip(); let self_ty = self_ty.kind(); diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index d187763151a2..0bc0fe08deb1 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -117,6 +117,38 @@ use triomphe::Arc; use crate::db::{DefDatabase, HirDatabase}; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PredicateEvaluationStatus { + Holds, + NotProven, + Invalid, + Unsupported, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PredicateEvaluationResult { + pub status: PredicateEvaluationStatus, + pub message: String, +} + +impl PredicateEvaluationResult { + pub fn holds(message: impl Into) -> Self { + Self { status: PredicateEvaluationStatus::Holds, message: message.into() } + } + + pub fn not_proven(message: impl Into) -> Self { + Self { status: PredicateEvaluationStatus::NotProven, message: message.into() } + } + + pub fn invalid(message: impl Into) -> Self { + Self { status: PredicateEvaluationStatus::Invalid, message: message.into() } + } + + pub fn unsupported(message: impl Into) -> Self { + Self { status: PredicateEvaluationStatus::Unsupported, message: message.into() } + } +} + pub use crate::{ attrs::{AttrsWithOwner, HasAttrs, resolve_doc_path_on}, diagnostics::*, diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index f633bb063fdd..dd4cc7b0df7f 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -32,7 +32,7 @@ use hir_expand::{ name::AsName, }; use hir_ty::{ - InferBodyId, InferenceResult, + InferBodyId, InferenceResult, LoweringMode, db::AnonConstId, diagnostics::unsafe_operations, infer_query_with_inspect, @@ -2564,6 +2564,20 @@ impl<'db> SemanticsImpl<'db> { Some(locals) } + pub fn evaluate_where_clause_at( + &self, + node: &SyntaxNode, + offset: TextSize, + where_clause: ast::WhereClause, + ) -> crate::PredicateEvaluationResult { + let Some(analyzer) = self.analyze_with_offset_no_infer(node, offset) else { + return crate::PredicateEvaluationResult::unsupported( + "predicate evaluation is only supported in files that belong to a crate", + ); + }; + analyzer.evaluate_where_clause(self.db, where_clause) + } + pub fn get_failed_obligations(&self, token: SyntaxToken) -> Option { let node = token.parent()?; let node = self.find_file(&node); @@ -2587,6 +2601,7 @@ impl<'db> SemanticsImpl<'db> { RESULT.with(|ctx| ctx.borrow_mut().push(data)); } }), + LoweringMode::Ide, ); let data: Vec = RESULT.with(|data| data.borrow_mut().drain(..).collect()); diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index 1f9520d780f0..17e65b68db1e 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -13,10 +13,10 @@ use std::{ use either::Either; use hir_def::{ AdtId, AssocItemId, CallableDefId, ConstId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, - FunctionId, GenericDefId, LocalFieldId, ModuleDefId, StructId, VariantId, + FunctionId, GenericDefId, HasModule, LocalFieldId, ModuleDefId, StructId, VariantId, expr_store::{ Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap, HygieneId, - lower::ExprCollector, + lower::{ExprCollector, lower_generic_params}, path::Path, scope::{ExprScopes, ScopeId}, }, @@ -44,7 +44,7 @@ use hir_ty::{ AliasTy, DbInterner, DefaultAny, EarlyBinder, ErrorGuaranteed, GenericArgs, ParamEnv, Region, Ty, TyKind, TypingMode, infer::DbInternerInferExt, }, - traits::structurally_normalize_ty, + traits::{WherePredicateEvaluation, structurally_normalize_ty, where_predicate_must_hold}, }; use intern::sym; use itertools::Itertools; @@ -63,7 +63,8 @@ use syntax::{ use crate::{ Adt, AnyFunctionId, AssocItem, BindingMode, BuiltinAttr, BuiltinType, Callable, Const, DeriveHelper, EnumVariant, Field, Function, GenericSubstitution, Local, Macro, ModuleDef, - SemanticsImpl, Static, Struct, ToolModule, Trait, TupleField, Type, TypeAlias, TypeOwnerId, + PredicateEvaluationResult, SemanticsImpl, Static, Struct, ToolModule, Trait, TupleField, Type, + TypeAlias, TypeOwnerId, db::HirDatabase, semantics::{PathResolution, PathResolutionPerNs}, }; @@ -364,6 +365,52 @@ impl<'db> SourceAnalyzer<'db> { )) } + pub(crate) fn evaluate_where_clause( + &self, + db: &'db dyn HirDatabase, + where_clause: ast::WhereClause, + ) -> PredicateEvaluationResult { + let Some(owner) = self.owner() else { + // FIXME + return PredicateEvaluationResult::unsupported( + "predicate evaluation is only supported inside an item", + ); + }; + let generic_def = owner.generic_def(db); + let module = generic_def.module(db); + let (store, params, _) = + lower_generic_params(db, module, generic_def, self.file_id, None, Some(where_clause)); + let predicates = params.where_predicates(); + if predicates.is_empty() { + return PredicateEvaluationResult::holds("predicate does not impose any obligations"); + } + + let env = self.trait_environment(db); + for predicate in predicates { + match where_predicate_must_hold( + db, + &self.resolver, + &store, + owner, + generic_def, + env, + predicate, + ) { + WherePredicateEvaluation::Holds | WherePredicateEvaluation::NoObligations => {} + WherePredicateEvaluation::HasErrors => { + return PredicateEvaluationResult::invalid( + "predicate contains unresolved names or invalid type syntax", + ); + } + WherePredicateEvaluation::NotProven => { + return PredicateEvaluationResult::not_proven("predicate is not known to hold"); + } + } + } + + PredicateEvaluationResult::holds("predicate holds") + } + pub(crate) fn expr_id(&self, expr: ast::Expr) -> Option { let src = InFile { file_id: self.file_id, value: expr }; self.store_sm()?.node_expr(src.as_ref()) diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index e131e7bdd17d..88cb570c6b0f 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -41,6 +41,7 @@ mod matching_brace; mod moniker; mod move_item; mod parent_module; +mod predicate_eval; mod references; mod rename; mod runnables; @@ -122,7 +123,7 @@ pub use crate::{ }, test_explorer::{TestItem, TestItemKind}, }; -pub use hir::Semantics; +pub use hir::{PredicateEvaluationResult, PredicateEvaluationStatus, Semantics}; pub use ide_assists::{ Assist, AssistConfig, AssistId, AssistKind, AssistResolveStrategy, SingleResolve, }; @@ -391,6 +392,14 @@ impl Analysis { self.with_db(|db| view_hir::view_hir(db, position)) } + pub fn evaluate_predicate( + &self, + text: String, + position: FilePosition, + ) -> Cancellable { + self.with_db(|db| predicate_eval::evaluate_predicate(db, text, position)) + } + pub fn view_mir(&self, position: FilePosition) -> Cancellable { self.with_db(|db| view_mir::view_mir(db, position)) } diff --git a/crates/ide/src/predicate_eval.rs b/crates/ide/src/predicate_eval.rs new file mode 100644 index 000000000000..8ae340bd954b --- /dev/null +++ b/crates/ide/src/predicate_eval.rs @@ -0,0 +1,163 @@ +use hir::{PredicateEvaluationResult, Semantics}; +use ide_db::{FilePosition, RootDatabase}; +use syntax::{AstNode, SourceFile, ast}; + +pub(crate) fn evaluate_predicate( + db: &RootDatabase, + text: String, + position: FilePosition, +) -> PredicateEvaluationResult { + let sema = Semantics::new(db); + let source_file = sema.parse_guess_edition(position.file_id); + let edition = sema.attach_first_edition(position.file_id).edition(db); + + let Some(where_clause) = parse_where_clause(&text, edition) else { + return PredicateEvaluationResult::invalid("expected a single where-clause predicate"); + }; + + let node = source_file + .syntax() + .token_at_offset(position.offset) + .next() + .and_then(|token| token.parent()) + .unwrap_or_else(|| source_file.syntax().clone()); + sema.evaluate_where_clause_at(&node, position.offset, where_clause) +} + +fn parse_where_clause(text: &str, edition: span::Edition) -> Option { + let text = text.trim().trim_end_matches(',').trim_end(); + let wrapped = format!("fn __ra_evaluate_predicate() where {text}, {{}}"); + let parse = SourceFile::parse(&wrapped, edition); + if !parse.errors().is_empty() { + return None; + } + + let where_clause = parse.tree().syntax().descendants().find_map(ast::WhereClause::cast)?; + if where_clause.predicates().count() == 1 { Some(where_clause) } else { None } +} + +#[cfg(test)] +mod tests { + use hir::PredicateEvaluationStatus; + + use crate::fixture; + + fn check(ra_fixture: &str, predicate: &str, status: PredicateEvaluationStatus) { + let (analysis, position) = fixture::position(ra_fixture); + let result = analysis.evaluate_predicate(predicate.to_owned(), position).unwrap(); + assert_eq!(result.status, status, "{}", result.message); + } + + #[test] + fn evaluates_concrete_trait_predicate() { + check( + r#" +trait Trait {} +struct S; +impl Trait for S {} +fn f() { $0 } +"#, + "S: Trait", + PredicateEvaluationStatus::Holds, + ); + } + + #[test] + fn evaluates_generic_bound_from_environment() { + check( + r#" +trait Trait {} +fn f() { $0 } +"#, + "T: Trait", + PredicateEvaluationStatus::Holds, + ); + } + + #[test] + fn reports_missing_generic_bound_as_not_proven() { + check( + r#" +trait Trait {} +fn f() { $0 } +"#, + "T: Trait", + PredicateEvaluationStatus::NotProven, + ); + } + + #[test] + fn evaluates_associated_type_binding() { + check( + r#" +trait Iterator { type Item; } +fn f>() { $0 } +"#, + "I: Iterator", + PredicateEvaluationStatus::Holds, + ); + } + + #[test] + fn reports_unresolved_type_as_invalid() { + check( + r#" +trait Trait {} +fn f() { $0 } +"#, + "Type: Trait", + PredicateEvaluationStatus::Invalid, + ); + } + + #[test] + fn reports_unresolved_trait_as_invalid() { + check( + r#" +struct Type; +fn f() { $0 } +"#, + "Type: Trait", + PredicateEvaluationStatus::Invalid, + ); + } + + #[test] + fn evaluates_lifetime_predicate() { + check( + r#" +fn f<'a, 'b>() +where + 'a: 'b, +{ + $0 +} +"#, + "'a: 'b", + PredicateEvaluationStatus::Holds, + ); + } + + #[test] + fn evaluates_type_outlives_predicate() { + check( + r#" +fn f() { $0 } +"#, + "T: 'static", + PredicateEvaluationStatus::Holds, + ); + } + + #[test] + fn rejects_invalid_predicate() { + check( + r#" +trait Trait {} +fn f() { $0 } +"#, + "u32 Trait", + PredicateEvaluationStatus::Invalid, + ); + } +} diff --git a/crates/rust-analyzer/src/handlers/request.rs b/crates/rust-analyzer/src/handlers/request.rs index 5bc0f5f0a72a..cf85db39f380 100644 --- a/crates/rust-analyzer/src/handlers/request.rs +++ b/crates/rust-analyzer/src/handlers/request.rs @@ -2598,6 +2598,28 @@ pub(crate) fn internal_testing_fetch_config( })) } +pub(crate) fn handle_evaluate_predicate( + snap: GlobalStateSnapshot, + params: lsp_ext::EvaluatePredicateParams, +) -> anyhow::Result { + let _p = tracing::info_span!("handle_evaluate_predicate").entered(); + let file_id = try_default!(from_proto::file_id(&snap, ¶ms.text_document.uri)?); + let line_index = snap.file_line_index(file_id)?; + let offset = from_proto::offset(&line_index, params.position)?; + + let result = snap.analysis.evaluate_predicate(params.text, FilePosition { file_id, offset })?; + let status = match result.status { + ide::PredicateEvaluationStatus::Holds => lsp_ext::PredicateEvaluationStatus::Holds, + ide::PredicateEvaluationStatus::NotProven => lsp_ext::PredicateEvaluationStatus::NotProven, + ide::PredicateEvaluationStatus::Invalid => lsp_ext::PredicateEvaluationStatus::Invalid, + ide::PredicateEvaluationStatus::Unsupported => { + lsp_ext::PredicateEvaluationStatus::Unsupported + } + }; + + Ok(lsp_ext::EvaluatePredicateResult { status, message: result.message }) +} + pub(crate) fn get_failed_obligations( snap: GlobalStateSnapshot, params: GetFailedObligationsParams, diff --git a/crates/rust-analyzer/src/lsp/ext.rs b/crates/rust-analyzer/src/lsp/ext.rs index 754d6e65fea9..444715891f4d 100644 --- a/crates/rust-analyzer/src/lsp/ext.rs +++ b/crates/rust-analyzer/src/lsp/ext.rs @@ -859,6 +859,39 @@ pub struct ClientCommandOptions { pub commands: Vec, } +pub enum EvaluatePredicate {} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct EvaluatePredicateParams { + pub text: String, + pub text_document: TextDocumentIdentifier, + pub position: Position, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +#[serde(rename_all = "camelCase")] +pub struct EvaluatePredicateResult { + pub status: PredicateEvaluationStatus, + pub message: String, +} + +#[derive(Deserialize, Serialize, Debug, Default, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub enum PredicateEvaluationStatus { + Holds, + #[default] + NotProven, + Invalid, + Unsupported, +} + +impl Request for EvaluatePredicate { + type Params = EvaluatePredicateParams; + type Result = EvaluatePredicateResult; + const METHOD: &'static str = "rust-analyzer/evaluatePredicate"; +} + pub enum GetFailedObligations {} #[derive(Deserialize, Serialize, Debug)] diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index 5ed522ceee4c..31728289e9bd 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -1375,6 +1375,7 @@ impl GlobalState { .on::(handlers::handle_move_item) // .on::(handlers::internal_testing_fetch_config) + .on::(handlers::handle_evaluate_predicate) .on::(handlers::get_failed_obligations) .finish(); } diff --git a/crates/rust-analyzer/tests/slow-tests/main.rs b/crates/rust-analyzer/tests/slow-tests/main.rs index a8632630784b..b91bde842806 100644 --- a/crates/rust-analyzer/tests/slow-tests/main.rs +++ b/crates/rust-analyzer/tests/slow-tests/main.rs @@ -1542,6 +1542,45 @@ version = "0.0.0" ); } +#[test] +fn test_evaluate_predicate() { + if skip_slow_tests() { + return; + } + + let server = Project::with_fixture( + r#" +//- /Cargo.toml +[package] +name = "foo" +version = "0.0.0" + +//- /src/lib.rs +trait Trait {} +struct S; +impl Trait for S {} + +fn test() { + let _ = 0;$0 +} +"#, + ) + .server() + .wait_until_workspace_is_loaded(); + + let res = server.send_request::( + rust_analyzer::lsp::ext::EvaluatePredicateParams { + text: "T: Trait".to_owned(), + text_document: server.doc_id("src/lib.rs"), + position: Position::new(5, 14), + }, + ); + + let res: rust_analyzer::lsp::ext::EvaluatePredicateResult = + serde_json::from_value(res).unwrap(); + assert_eq!(res.status, rust_analyzer::lsp::ext::PredicateEvaluationStatus::Holds); +} + #[test] fn test_get_failed_obligations() { use expect_test::expect; diff --git a/docs/book/src/contributing/lsp-extensions.md b/docs/book/src/contributing/lsp-extensions.md index b74c40c42246..a3189402a94e 100644 --- a/docs/book/src/contributing/lsp-extensions.md +++ b/docs/book/src/contributing/lsp-extensions.md @@ -1,5 +1,5 @@