diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 06929a8f5..4f9edbc0f 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -122,7 +122,7 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { request: CompleteRequestParam, context: RequestContext, ) -> impl Future> + Send + '_ { - std::future::ready(Err(McpError::method_not_found::())) + std::future::ready(Ok(CompleteResult::default())) } fn set_level( &self, diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index d2cd43db1..89762bd63 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -1091,17 +1091,66 @@ pub struct ModelHint { // COMPLETION AND AUTOCOMPLETE // ============================================================================= +/// Context for completion requests providing previously resolved arguments. +/// +/// This enables context-aware completion where subsequent argument completions +/// can take into account the values of previously resolved arguments. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CompletionContext { + /// Previously resolved argument values that can inform completion suggestions + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +impl CompletionContext { + /// Create a new empty completion context + pub fn new() -> Self { + Self::default() + } + + /// Create a completion context with the given arguments + pub fn with_arguments(arguments: std::collections::HashMap) -> Self { + Self { + arguments: Some(arguments), + } + } + + /// Get a specific argument value by name + pub fn get_argument(&self, name: &str) -> Option<&String> { + self.arguments.as_ref()?.get(name) + } + + /// Check if the context has any arguments + pub fn has_arguments(&self) -> bool { + self.arguments.as_ref().is_some_and(|args| !args.is_empty()) + } + + /// Get all argument names + pub fn argument_names(&self) -> impl Iterator { + self.arguments + .as_ref() + .into_iter() + .flat_map(|args| args.keys()) + .map(|k| k.as_str()) + } +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct CompleteRequestParam { pub r#ref: Reference, pub argument: ArgumentInfo, + /// Optional context containing previously resolved argument values + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, } pub type CompleteRequest = Request; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct CompletionInfo { @@ -1112,7 +1161,74 @@ pub struct CompletionInfo { pub has_more: Option, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +impl CompletionInfo { + /// Maximum number of completion values allowed per response according to MCP specification + pub const MAX_VALUES: usize = 100; + + /// Create a new CompletionInfo with validation for maximum values + pub fn new(values: Vec) -> Result { + if values.len() > Self::MAX_VALUES { + return Err(format!( + "Too many completion values: {} (max: {})", + values.len(), + Self::MAX_VALUES + )); + } + Ok(Self { + values, + total: None, + has_more: None, + }) + } + + /// Create CompletionInfo with all values and no pagination + pub fn with_all_values(values: Vec) -> Result { + let completion = Self::new(values)?; + Ok(Self { + total: Some(completion.values.len() as u32), + has_more: Some(false), + ..completion + }) + } + + /// Create CompletionInfo with pagination information + pub fn with_pagination( + values: Vec, + total: Option, + has_more: bool, + ) -> Result { + let completion = Self::new(values)?; + Ok(Self { + total, + has_more: Some(has_more), + ..completion + }) + } + + /// Check if this completion response indicates more results are available + pub fn has_more_results(&self) -> bool { + self.has_more.unwrap_or(false) + } + + /// Get the total number of available completions, if known + pub fn total_available(&self) -> Option { + self.total + } + + /// Validate that the completion info complies with MCP specification + pub fn validate(&self) -> Result<(), String> { + if self.values.len() > Self::MAX_VALUES { + return Err(format!( + "Too many completion values: {} (max: {})", + self.values.len(), + Self::MAX_VALUES + )); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct CompleteResult { @@ -1129,6 +1245,42 @@ pub enum Reference { Prompt(PromptReference), } +impl Reference { + /// Create a prompt reference + pub fn for_prompt(name: impl Into) -> Self { + Self::Prompt(PromptReference { name: name.into() }) + } + + /// Create a resource reference + pub fn for_resource(uri: impl Into) -> Self { + Self::Resource(ResourceReference { uri: uri.into() }) + } + + /// Get the reference type as a string + pub fn reference_type(&self) -> &'static str { + match self { + Self::Prompt(_) => "ref/prompt", + Self::Resource(_) => "ref/resource", + } + } + + /// Extract prompt name if this is a prompt reference + pub fn as_prompt_name(&self) -> Option<&str> { + match self { + Self::Prompt(prompt_ref) => Some(&prompt_ref.name), + _ => None, + } + } + + /// Extract resource URI if this is a resource reference + pub fn as_resource_uri(&self) -> Option<&str> { + match self { + Self::Resource(resource_ref) => Some(&resource_ref.uri), + _ => None, + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct ResourceReference { diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index dba7d048e..062bd1b38 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -5,18 +5,18 @@ use thiserror::Error; use super::*; use crate::{ model::{ - CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, + ArgumentInfo, CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, - GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest, - InitializedNotification, JsonRpcResponse, ListPromptsRequest, ListPromptsResult, - ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest, - ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam, - ProgressNotification, ProgressNotificationParam, ReadResourceRequest, - ReadResourceRequestParam, ReadResourceResult, RequestId, RootsListChangedNotification, - ServerInfo, ServerJsonRpcMessage, ServerNotification, ServerRequest, ServerResult, - SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam, - UnsubscribeRequest, UnsubscribeRequestParam, + CompletionContext, CompletionInfo, GetPromptRequest, GetPromptRequestParam, + GetPromptResult, InitializeRequest, InitializedNotification, JsonRpcResponse, + ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, + ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, + ListToolsResult, PaginatedRequestParam, ProgressNotification, ProgressNotificationParam, + ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, Reference, RequestId, + RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, + ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest, + SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, }, transport::DynamicTransportError, }; @@ -390,4 +390,96 @@ impl Peer { } Ok(resource_templates) } + + /// Convenient method to get completion suggestions for a prompt argument + /// + /// # Arguments + /// * `prompt_name` - Name of the prompt being completed + /// * `argument_name` - Name of the argument being completed + /// * `current_value` - Current partial value of the argument + /// * `context` - Optional context with previously resolved arguments + /// + /// # Returns + /// CompletionInfo with suggestions for the specified prompt argument + pub async fn complete_prompt_argument( + &self, + prompt_name: impl Into, + argument_name: impl Into, + current_value: impl Into, + context: Option, + ) -> Result { + let request = CompleteRequestParam { + r#ref: Reference::for_prompt(prompt_name), + argument: ArgumentInfo { + name: argument_name.into(), + value: current_value.into(), + }, + context, + }; + + let result = self.complete(request).await?; + Ok(result.completion) + } + + /// Convenient method to get completion suggestions for a resource URI argument + /// + /// # Arguments + /// * `uri_template` - URI template pattern being completed + /// * `argument_name` - Name of the URI parameter being completed + /// * `current_value` - Current partial value of the parameter + /// * `context` - Optional context with previously resolved arguments + /// + /// # Returns + /// CompletionInfo with suggestions for the specified resource URI argument + pub async fn complete_resource_argument( + &self, + uri_template: impl Into, + argument_name: impl Into, + current_value: impl Into, + context: Option, + ) -> Result { + let request = CompleteRequestParam { + r#ref: Reference::for_resource(uri_template), + argument: ArgumentInfo { + name: argument_name.into(), + value: current_value.into(), + }, + context, + }; + + let result = self.complete(request).await?; + Ok(result.completion) + } + + /// Simple completion for a prompt argument without context + /// + /// This is a convenience wrapper around `complete_prompt_argument` for + /// simple completion scenarios that don't require context awareness. + pub async fn complete_prompt_simple( + &self, + prompt_name: impl Into, + argument_name: impl Into, + current_value: impl Into, + ) -> Result, ServiceError> { + let completion = self + .complete_prompt_argument(prompt_name, argument_name, current_value, None) + .await?; + Ok(completion.values) + } + + /// Simple completion for a resource URI argument without context + /// + /// This is a convenience wrapper around `complete_resource_argument` for + /// simple completion scenarios that don't require context awareness. + pub async fn complete_resource_simple( + &self, + uri_template: impl Into, + argument_name: impl Into, + current_value: impl Into, + ) -> Result, ServiceError> { + let completion = self + .complete_resource_argument(uri_template, argument_name, current_value, None) + .await?; + Ok(completion.values) + } } diff --git a/crates/rmcp/tests/test_completion.rs b/crates/rmcp/tests/test_completion.rs new file mode 100644 index 000000000..ea9f632fe --- /dev/null +++ b/crates/rmcp/tests/test_completion.rs @@ -0,0 +1,218 @@ +use std::collections::HashMap; + +use rmcp::model::*; +use serde_json::json; + +#[test] +fn test_completion_context_serialization() { + let mut args = HashMap::new(); + args.insert("key1".to_string(), "value1".to_string()); + args.insert("key2".to_string(), "value2".to_string()); + + let context = CompletionContext::with_arguments(args); + + // Test serialization + let json = serde_json::to_value(&context).unwrap(); + let expected = json!({ + "arguments": { + "key1": "value1", + "key2": "value2" + } + }); + assert_eq!(json, expected); + + // Test deserialization + let deserialized: CompletionContext = serde_json::from_value(expected).unwrap(); + assert_eq!(deserialized, context); +} + +#[test] +fn test_completion_context_methods() { + let mut args = HashMap::new(); + args.insert("city".to_string(), "San Francisco".to_string()); + args.insert("country".to_string(), "USA".to_string()); + + let context = CompletionContext::with_arguments(args); + + assert!(context.has_arguments()); + assert_eq!( + context.get_argument("city"), + Some(&"San Francisco".to_string()) + ); + assert_eq!(context.get_argument("missing"), None); + + let names: Vec<&str> = context.argument_names().collect(); + assert!(names.contains(&"city")); + assert!(names.contains(&"country")); + assert_eq!(names.len(), 2); +} + +#[test] +fn test_complete_request_param_serialization() { + let mut args = HashMap::new(); + args.insert("previous_input".to_string(), "test".to_string()); + + let request = CompleteRequestParam { + r#ref: Reference::for_prompt("weather_prompt"), + argument: ArgumentInfo { + name: "location".to_string(), + value: "San".to_string(), + }, + context: Some(CompletionContext::with_arguments(args)), + }; + + let json = serde_json::to_value(&request).unwrap(); + assert!(json["ref"]["name"].as_str().unwrap() == "weather_prompt"); + assert!(json["argument"]["name"].as_str().unwrap() == "location"); + assert!(json["argument"]["value"].as_str().unwrap() == "San"); + assert!( + json["context"]["arguments"]["previous_input"] + .as_str() + .unwrap() + == "test" + ); +} + +#[test] +fn test_completion_info_validation() { + // Valid completion with less than max values + let values = vec!["option1".to_string(), "option2".to_string()]; + let completion = CompletionInfo::new(values.clone()).unwrap(); + assert_eq!(completion.values, values); + assert!(completion.validate().is_ok()); + + // Test max values limit + let many_values: Vec = (0..=CompletionInfo::MAX_VALUES) + .map(|i| format!("option_{}", i)) + .collect(); + let result = CompletionInfo::new(many_values); + assert!(result.is_err()); +} + +#[test] +fn test_completion_info_helper_methods() { + let values = vec!["test1".to_string(), "test2".to_string()]; + + // Test with_all_values + let completion = CompletionInfo::with_all_values(values.clone()).unwrap(); + assert_eq!(completion.values, values); + assert_eq!(completion.total, Some(2)); + assert_eq!(completion.has_more, Some(false)); + assert!(!completion.has_more_results()); + assert_eq!(completion.total_available(), Some(2)); + + // Test with_pagination + let paginated = CompletionInfo::with_pagination(values.clone(), Some(10), true).unwrap(); + assert_eq!(paginated.values, values); + assert_eq!(paginated.total, Some(10)); + assert_eq!(paginated.has_more, Some(true)); + assert!(paginated.has_more_results()); + assert_eq!(paginated.total_available(), Some(10)); +} + +#[test] +fn test_completion_info_bounds() { + // Test exactly at the limit + let max_values: Vec = (0..CompletionInfo::MAX_VALUES) + .map(|i| format!("value_{}", i)) + .collect(); + assert!(CompletionInfo::new(max_values).is_ok()); + + // Test over the limit + let over_limit: Vec = (0..=CompletionInfo::MAX_VALUES) + .map(|i| format!("value_{}", i)) + .collect(); + assert!(CompletionInfo::new(over_limit).is_err()); +} + +#[test] +fn test_reference_convenience_methods() { + let prompt_ref = Reference::for_prompt("test_prompt"); + assert_eq!(prompt_ref.reference_type(), "ref/prompt"); + assert_eq!(prompt_ref.as_prompt_name(), Some("test_prompt")); + assert_eq!(prompt_ref.as_resource_uri(), None); + + let resource_ref = Reference::for_resource("file://path/to/resource"); + assert_eq!(resource_ref.reference_type(), "ref/resource"); + assert_eq!( + resource_ref.as_resource_uri(), + Some("file://path/to/resource") + ); + assert_eq!(resource_ref.as_prompt_name(), None); +} + +#[test] +fn test_completion_serialization_format() { + // Test that completion follows MCP 2025-06-18 specification format + let completion = CompletionInfo { + values: vec!["value1".to_string(), "value2".to_string()], + total: Some(2), + has_more: Some(false), + }; + + let json = serde_json::to_value(&completion).unwrap(); + + // Verify JSON structure matches specification + assert!(json.is_object()); + assert!(json["values"].is_array()); + assert_eq!(json["values"].as_array().unwrap().len(), 2); + assert_eq!(json["total"].as_u64().unwrap(), 2); + assert!(!json["hasMore"].as_bool().unwrap()); +} + +#[test] +fn test_resource_reference() { + // Test that ResourceReference works correctly + let resource_ref = ResourceReference { + uri: "test://uri".to_string(), + }; + + // Test that ResourceReference works correctly + let another_ref = ResourceReference { + uri: "test://uri".to_string(), + }; + + // They should be equivalent + assert_eq!(resource_ref.uri, another_ref.uri); +} + +#[test] +fn test_complete_result_default() { + let result = CompleteResult::default(); + assert!(result.completion.values.is_empty()); + assert_eq!(result.completion.total, None); + assert_eq!(result.completion.has_more, None); +} + +#[test] +fn test_completion_context_empty() { + let context = CompletionContext::new(); + assert!(!context.has_arguments()); + assert_eq!(context.get_argument("any"), None); + assert!(context.argument_names().count() == 0); +} + +#[test] +fn test_mcp_schema_compliance() { + // Test that our types serialize correctly according to MCP specification + let request = CompleteRequestParam { + r#ref: Reference::for_resource("file://{path}"), + argument: ArgumentInfo { + name: "path".to_string(), + value: "src/".to_string(), + }, + context: None, + }; + + let json_str = serde_json::to_string(&request).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + + // Verify key structure matches MCP spec + assert!(parsed["ref"].is_object()); + assert!(parsed["argument"].is_object()); + assert!(parsed["argument"]["name"].is_string()); + assert!(parsed["argument"]["value"].is_string()); + + // Verify type tag is correct + assert_eq!(parsed["ref"]["type"].as_str().unwrap(), "ref/resource"); +} diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json index d11c105cd..595cc122c 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json @@ -324,6 +324,17 @@ "argument": { "$ref": "#/definitions/ArgumentInfo" }, + "context": { + "description": "Optional context containing previously resolved argument values", + "anyOf": [ + { + "$ref": "#/definitions/CompletionContext" + }, + { + "type": "null" + } + ] + }, "ref": { "$ref": "#/definitions/Reference" } @@ -333,6 +344,22 @@ "argument" ] }, + "CompletionContext": { + "description": "Context for completion requests providing previously resolved arguments.\n\nThis enables context-aware completion where subsequent argument completions\ncan take into account the values of previously resolved arguments.", + "type": "object", + "properties": { + "arguments": { + "description": "Previously resolved argument values that can inform completion suggestions", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "string" + } + } + } + }, "CreateElicitationResult": { "description": "The result returned by a client in response to an elicitation request.\n\nContains the user's decision (accept/decline/cancel) and optionally their input data\nif they chose to accept the request.", "type": "object", diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 90b0cbd47..09e045005 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -98,3 +98,7 @@ path = "src/structured_output.rs" [[example]] name = "servers_elicitation_stdio" path = "src/elicitation_stdio.rs" + +[[example]] +name = "servers_completion_stdio" +path = "src/completion_stdio.rs" diff --git a/examples/servers/src/completion_stdio.rs b/examples/servers/src/completion_stdio.rs new file mode 100644 index 000000000..7beb1e1e3 --- /dev/null +++ b/examples/servers/src/completion_stdio.rs @@ -0,0 +1,455 @@ +//! MCP Server demonstrating code review completion functionality +//! +//! This example shows how to implement completion support for MCP prompts +//! with intelligent fuzzy matching for code review parameters. +//! +//! Run with MCP Inspector: +//! ```bash +//! npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example servers_completion_stdio +//! ``` + +use anyhow::Result; +use rmcp::{ + ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, + handler::server::{router::prompt::PromptRouter, wrapper::Parameters}, + model::*, + prompt, prompt_handler, prompt_router, + schemars::JsonSchema, + service::RequestContext, + transport::stdio, +}; +use serde::{Deserialize, Serialize}; +use tracing_subscriber::{self, EnvFilter}; + +/// Arguments for the SQL query builder prompt +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[schemars(description = "SQL query builder with progressive completion")] +pub struct SqlQueryArgs { + #[schemars(description = "SQL operation type (SELECT, INSERT, UPDATE, DELETE)")] + pub operation: String, + #[schemars(description = "Database table name")] + pub table: String, + #[schemars(description = "Columns to select/update (only for SELECT/UPDATE)")] + pub columns: Option, + #[schemars(description = "WHERE clause condition (optional for all operations)")] + pub where_clause: Option, + #[schemars(description = "Values to insert (only for INSERT)")] + pub values: Option, +} + +/// SQL query builder server with progressive completion +#[derive(Clone)] +pub struct SqlQueryServer { + prompt_router: PromptRouter, +} + +impl SqlQueryServer { + pub fn new() -> Self { + Self { + prompt_router: Self::prompt_router(), + } + } +} + +impl Default for SqlQueryServer { + fn default() -> Self { + Self::new() + } +} + +impl SqlQueryServer { + /// Fuzzy matching with scoring for completion suggestions + fn fuzzy_match(&self, query: &str, candidates: &[&str]) -> Vec { + if query.is_empty() { + return candidates.iter().take(10).map(|s| s.to_string()).collect(); + } + + let query_lower = query.to_lowercase(); + let mut scored_matches = Vec::new(); + + for candidate in candidates { + let candidate_lower = candidate.to_lowercase(); + + let score = if candidate_lower == query_lower { + 1000 // Exact match + } else if candidate_lower.starts_with(&query_lower) { + 900 // Prefix match + } else if candidate_lower.contains(&query_lower) { + 800 // Contains substring + } else if self.is_acronym_match(&query_lower, candidate) { + 700 // Acronym match (e.g., "js" → "JavaScript") + } else if self.is_subsequence_match(&query_lower, &candidate_lower) { + 680 // Subsequence match (e.g., "rs" → "Rust") + } else if self.is_single_letter_match(&query_lower, candidate) { + 650 // Single letter match (e.g., "j" → "Java") + } else { + continue; // No match + }; + + scored_matches.push((candidate.to_string(), score)); + } + + // Sort by score (desc) then alphabetically + scored_matches.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0))); + scored_matches + .into_iter() + .take(10) + .map(|(name, _)| name) + .collect() + } + + /// Check if query matches as acronym (first letters of words or camelCase) + fn is_acronym_match(&self, query: &str, candidate: &str) -> bool { + let query_chars: Vec = query.chars().collect(); + + // Extract first letters from words (split by whitespace) or uppercase letters (camelCase) + let mut first_chars: Vec; + + // Split by whitespace first + let words: Vec<&str> = candidate.split_whitespace().collect(); + if words.len() > 1 { + // Multi-word case (e.g., "Memory Safety" -> "MS") + first_chars = words + .into_iter() + .filter_map(|word| word.chars().next()) + .map(|c| c.to_lowercase().next().unwrap_or('\0')) + .collect(); + } else { + // Single word case - extract uppercase letters for camelCase (e.g., "JavaScript" -> "JS") + first_chars = candidate + .chars() + .filter(|c| c.is_uppercase()) + .map(|c| c.to_lowercase().next().unwrap_or('\0')) + .collect(); + + // If no uppercase letters found, just use first letter + if first_chars.is_empty() && !candidate.is_empty() { + if let Some(first) = candidate.chars().next() { + first_chars.push(first.to_lowercase().next().unwrap_or('\0')); + } + } + } + + // Special case: if query is 2 chars and we only got 1 char, try matching first 2 letters + if query_chars.len() == 2 && first_chars.len() == 1 { + if let Some(first) = candidate.chars().nth(0) { + if let Some(second) = candidate.chars().nth(1) { + first_chars = vec![ + first.to_lowercase().next().unwrap_or('\0'), + second.to_lowercase().next().unwrap_or('\0'), + ]; + } + } + } + + if query_chars.len() != first_chars.len() { + return false; + } + + query_chars + .iter() + .zip(first_chars.iter()) + .all(|(q, c)| q.to_lowercase().next().unwrap_or('\0') == *c) + } + + /// Check if query is a subsequence of candidate (e.g., "rs" in "rust") + fn is_subsequence_match(&self, query: &str, candidate_lower: &str) -> bool { + let query_chars: Vec = query.chars().collect(); + let candidate_chars: Vec = candidate_lower.chars().collect(); + + let mut query_idx = 0; + + for &candidate_char in &candidate_chars { + if query_idx < query_chars.len() && query_chars[query_idx] == candidate_char { + query_idx += 1; + } + } + + query_idx == query_chars.len() + } + + /// Check if query matches first letter of single word + fn is_single_letter_match(&self, query: &str, candidate: &str) -> bool { + if query.len() != 1 { + return false; + } + + let query_char = query + .chars() + .next() + .unwrap() + .to_lowercase() + .next() + .unwrap_or('\0'); + let first_char = candidate + .chars() + .next() + .unwrap_or('\0') + .to_lowercase() + .next() + .unwrap_or('\0'); + + query_char == first_char + } +} + +#[prompt_router] +impl SqlQueryServer { + #[prompt(name = "sql_query", description = "Smart SQL query builder")] + async fn sql_query( + &self, + Parameters(args): Parameters, + ) -> Result { + let messages = if args.operation.is_empty() { + vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "I need help building a SQL query. Where should I start?", + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll help you build a SQL query step by step. First, what type of operation do you want to perform? \ + Choose from: SELECT (to read data), INSERT (to add data), UPDATE (to modify data), or DELETE (to remove data).", + ), + ] + } else if args.table.is_empty() { + vec![ + PromptMessage::new_text( + PromptMessageRole::User, + format!("I want to {} data. What's next?", args.operation), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "Great! For a {} operation, I need to know which table you want to work with. \ + What's the name of your database table?", + args.operation + ), + ), + ] + } else { + // Build the SQL query based on filled arguments + let query = match args.operation.to_uppercase().as_str() { + "SELECT" => { + let cols = args + .columns + .as_ref() + .filter(|c| !c.is_empty()) + .map(|c| c.as_str()) + .unwrap_or("*"); + let where_part = args + .where_clause + .as_ref() + .map(|w| format!(" WHERE {}", w)) + .unwrap_or_default(); + format!("SELECT {} FROM {}{}", cols, args.table, where_part) + } + "INSERT" => match &args.values { + Some(vals) if !vals.is_empty() => { + format!("INSERT INTO {} VALUES ({})", args.table, vals) + } + _ => format!("INSERT INTO {} (...) VALUES (...)", args.table), + }, + "UPDATE" => { + let set_part = args + .columns + .as_ref() + .filter(|c| !c.is_empty()) + .map(|c| c.as_str()) + .unwrap_or("..."); + let where_part = args + .where_clause + .as_ref() + .map(|w| format!(" WHERE {}", w)) + .unwrap_or_default(); + format!("UPDATE {} SET {}{}", args.table, set_part, where_part) + } + "DELETE" => { + let where_part = args + .where_clause + .as_ref() + .map(|w| format!(" WHERE {}", w)) + .unwrap_or_default(); + format!("DELETE FROM {}{}", args.table, where_part) + } + _ => format!("{} FROM {}", args.operation, args.table), + }; + + vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "Generate the SQL query based on my parameters and explain what it does.", + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "Here's your SQL query:\n\n```sql\n{}\n```\n\nThis query will {} the {} table.", + query, + args.operation.to_lowercase(), + args.table + ), + ), + ] + }; + + Ok(GetPromptResult { + description: Some(format!( + "SQL Query: {} on {}", + if args.operation.is_empty() { + "Unknown" + } else { + &args.operation + }, + if args.table.is_empty() { + "table" + } else { + &args.table + } + )), + messages, + }) + } +} + +#[prompt_handler] +impl ServerHandler for SqlQueryServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_completions() + .enable_prompts() + .build(), + server_info: Implementation::from_build_env(), + instructions: Some( + "Smart SQL query builder with progressive completion that adapts based on your choices:\n\n\ + Step 1: Choose operation type ('sel' → SELECT, 'ins' → INSERT, 'upd' → UPDATE, 'del' → DELETE)\n\ + Step 2: Specify table name ('users', 'orders', 'products')\n\ + Step 3: Add relevant fields based on operation type:\n\ + • SELECT/UPDATE: columns ('name', 'email', 'id')\n\ + • INSERT: values to insert\n\ + • All: optional WHERE clause\n\n\ + The completion adapts - only relevant fields appear based on your SQL operation!" + .to_string(), + ), + ..Default::default() + } + } + + async fn complete( + &self, + request: CompleteRequestParam, + _context: RequestContext, + ) -> Result { + let candidates = match &request.r#ref { + Reference::Prompt(prompt_ref) if prompt_ref.name == "sql_query" => { + let filled_fields: Vec<&str> = request + .context + .as_ref() + .map(|ctx| ctx.argument_names().collect()) + .unwrap_or_default(); + + tracing::debug!( + "SQL completion - filled fields: {:?}, completing: {}", + filled_fields, + request.argument.name + ); + + match request.argument.name.as_str() { + "operation" => vec!["SELECT", "INSERT", "UPDATE", "DELETE"], + "table" => vec!["users", "orders", "products", "categories", "reviews"], + "columns" => { + // Only show columns completion if operation is SELECT or UPDATE + if let Some(context) = &request.context { + if let Some(operation) = context.get_argument("operation") { + match operation.to_uppercase().as_str() { + "SELECT" | "UPDATE" => { + vec!["id", "name", "email", "created_at", "updated_at", "*"] + } + _ => vec!["Not applicable for this operation"], + } + } else { + vec!["Choose operation first"] + } + } else { + vec!["Choose operation first"] + } + } + "values" => { + // Only show values completion for INSERT + if let Some(context) = &request.context { + if let Some(operation) = context.get_argument("operation") { + match operation.to_uppercase().as_str() { + "INSERT" => { + vec!["'John Doe'", "'jane@example.com'", "123", "NOW()"] + } + _ => vec!["Not applicable for this operation"], + } + } else { + vec!["Choose operation first"] + } + } else { + vec!["Choose operation first"] + } + } + "where_clause" => { + // WHERE clause suggestions based on filled fields count + match filled_fields.len() { + 0..=1 => vec!["Complete operation and table first"], + _ => vec![ + "id = 1", + "name = 'example'", + "created_at > '2023-01-01'", + "status = 'active'", + ], + } + } + _ => vec![], + } + } + _ => vec![], + }; + + let suggestions = self.fuzzy_match(&request.argument.value, &candidates); + + let completion = CompletionInfo { + values: suggestions, + total: None, + has_more: Some(false), + }; + + Ok(CompleteResult { completion }) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .init(); + + println!("MCP SQL Query Builder with Smart Completion"); + println!("=========================================="); + println!(); + println!("This server demonstrates argument_names() value with progressive completion:"); + println!("1. Start with operation type (SELECT, INSERT, UPDATE, DELETE)"); + println!("2. Choose table name (users, orders, products)"); + println!("3. Only relevant fields appear based on your operation!"); + println!(" • SELECT/UPDATE: shows columns field"); + println!(" • INSERT: shows values field"); + println!(" • All operations: optional WHERE clause after step 2"); + println!(); + println!("To test with MCP Inspector:"); + println!( + "npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example servers_completion_stdio" + ); + println!(); + + let server = SqlQueryServer::new(); + let service = server.serve(stdio()).await.inspect_err(|e| { + tracing::error!("Server error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +}