diff --git a/Cargo.toml b/Cargo.toml index bfb938a..ed8a596 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ categories = ["api-bindings", "asynchronous", "development-tools"] [dependencies] anyhow = "1" base64 = "0.22" +lingua = { git = "https://github.com/braintrustdata/lingua.git" } bon = "3" arc-swap = "1" async-trait = "0.1" diff --git a/src/lib.rs b/src/lib.rs index 7a784e1..d91a11c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod json_merge; mod log_queue; mod logger; mod logs3; +mod prompt; mod span; mod span_components; mod stream; @@ -21,12 +22,15 @@ pub use experiments::{ GitMetadataSettings, MetricSummary, ProjectMetadata, RepoInfo, ScoreSummary, }; pub use extractors::{extract_anthropic_usage, extract_openai_usage}; +pub use lingua::universal::{AssistantContent, TokenBudget, UserContent}; +pub use lingua::{Message, UniversalParams, UniversalRequest}; pub use log_queue::LogQueueConfig; pub use logger::{ BraintrustClient, BraintrustClientBuilder, LoginState, OrgInfo, DEFAULT_API_URL, DEFAULT_APP_URL, }; pub use logs3::{Logs3BatchUploader, Logs3UploadResult}; +pub use prompt::{Prompt, PromptBuilder, PromptBuilderError}; pub use span::{SpanBuilder, SpanHandle, SpanLog, SpanLogBuilder, SpanLogBuilderError}; pub use span_components::SpanComponents; pub use stream::{ diff --git a/src/logger.rs b/src/logger.rs index 3db8902..3657315 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -15,6 +15,7 @@ use crate::experiments::api::{ }; use crate::experiments::{BaseExperimentInfo, ExperimentBuilder}; use crate::log_queue::{LogQueue, LogQueueConfig}; +use crate::prompt::{PromptBuilder, PromptFetchRequest, PromptFetchResponse, PromptFetcher}; use crate::span::SpanSubmitter; use crate::types::{ParentSpanInfo, SpanPayload}; @@ -505,6 +506,13 @@ impl BraintrustClient { crate::span::SpanBuilder::new(submitter, token, org_id) } + /// Create a prompt builder with an explicit API token. + /// + /// Use this if you already have the token and don't want to use the login state. + pub fn prompt_builder_with_credentials(&self, token: impl Into) -> PromptBuilder { + PromptBuilder::new(Arc::new(self.clone()), token) + } + /// Perform login synchronously. async fn perform_login(&self, api_key: &str, org_name: Option<&str>) -> Result<()> { let login_url = self @@ -910,6 +918,64 @@ impl ExperimentComparisonFetcher for BraintrustClient { } } +#[async_trait::async_trait] +impl PromptFetcher for BraintrustClient { + async fn fetch_prompt( + &self, + token: &str, + request: PromptFetchRequest, + ) -> Result { + let mut url = self + .inner + .api_url + .join("v1/prompt") + .map_err(|e| BraintrustError::InvalidConfig(e.to_string()))?; + + { + let mut query = url.query_pairs_mut(); + query.append_pair("slug", &request.slug); + if let Some(project_id) = &request.project_id { + query.append_pair("project_id", project_id); + } + if let Some(project_name) = &request.project_name { + query.append_pair("project_name", project_name); + } + if let Some(version) = &request.version { + query.append_pair("version", version); + } + if let Some(environment) = &request.environment { + query.append_pair("environment", environment); + } + } + + let response = self + .inner + .http_client + .get(url) + .bearer_auth(token) + .send() + .await + .map_err(|e| BraintrustError::Network(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(BraintrustError::Api { + status: status.as_u16(), + message: body, + }); + } + + response + .json::() + .await + .map_err(|e| BraintrustError::Api { + status: 200, + message: format!("Failed to parse prompt response: {}", e), + }) + } +} + #[allow(dead_code)] #[derive(Debug, Default)] struct MutableSpanEvent { diff --git a/src/prompt.rs b/src/prompt.rs new file mode 100644 index 0000000..5370f36 --- /dev/null +++ b/src/prompt.rs @@ -0,0 +1,400 @@ +//! Prompt support for loading versioned prompt templates. +//! +//! Prompts allow you to: +//! - Fetch versioned prompt templates from Braintrust +//! - Convert prompts to lingua `UniversalRequest` for use with any LLM provider +//! - Use environment-based versioning (e.g., "production") +//! +//! # Example +//! +//! ```ignore +//! let prompt = client +//! .prompt_builder_with_credentials(&api_key) +//! .project_name("my-project") +//! .slug("greeting-prompt") +//! .environment("production") +//! .build() +//! .await?; +//! +//! if let Some(request) = prompt.to_request() { +//! // Use request with your LLM client +//! } +//! ``` + +use std::fmt; +use std::sync::Arc; + +use async_trait::async_trait; +use lingua::universal::{AssistantContent, TokenBudget, UserContent}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::error::Result; + +// ============================================================================ +// PromptBuilder Error Types +// ============================================================================ + +/// Error type for PromptBuilder validation. +#[derive(Debug, Clone, PartialEq)] +#[non_exhaustive] +pub enum PromptBuilderError { + /// Slug is required but was not provided. + MissingSlug, + /// Neither project_name nor project_id was provided. + MissingProject, + /// Both version and environment were specified (mutually exclusive). + VersionAndEnvironmentConflict, + /// The API request to fetch the prompt failed. + FetchFailed(String), +} + +impl fmt::Display for PromptBuilderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MissingSlug => write!(f, "slug is required but was not provided"), + Self::MissingProject => { + write!(f, "either project_name or project_id is required") + } + Self::VersionAndEnvironmentConflict => { + write!(f, "version and environment are mutually exclusive") + } + Self::FetchFailed(msg) => write!(f, "failed to fetch prompt: {}", msg), + } + } +} + +impl std::error::Error for PromptBuilderError {} + +// ============================================================================ +// Prompt Fetch Types (Internal) +// ============================================================================ + +/// Request parameters for fetching a prompt. +#[derive(Debug, Clone, Serialize)] +pub(crate) struct PromptFetchRequest { + /// Project ID to fetch from. + #[serde(skip_serializing_if = "Option::is_none")] + pub project_id: Option, + /// Project name to fetch from. + #[serde(skip_serializing_if = "Option::is_none")] + pub project_name: Option, + /// Prompt slug identifier. + pub slug: String, + /// Specific version (transaction ID). + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + /// Environment name (e.g., "production"). + #[serde(skip_serializing_if = "Option::is_none")] + pub environment: Option, +} + +/// Response from fetching a prompt. +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct PromptFetchResponse { + /// Unique prompt ID. + pub id: String, + /// Prompt name. + pub name: String, + /// Prompt slug. + pub slug: String, + /// Current version (transaction ID). + #[serde(rename = "_xact_id")] + pub version: Option, + /// Project ID containing this prompt. + pub project_id: String, + /// Prompt template data (raw JSON). + pub prompt_data: Option, +} + +// ============================================================================ +// PromptFetcher Trait +// ============================================================================ + +/// Trait for fetching prompts from the API. +#[async_trait] +pub(crate) trait PromptFetcher: Send + Sync { + /// Fetch a prompt from the API. + async fn fetch_prompt( + &self, + token: &str, + request: PromptFetchRequest, + ) -> Result; +} + +// ============================================================================ +// Prompt - The main prompt handle +// ============================================================================ + +/// A loaded prompt template. +/// +/// Use [`PromptBuilder`] to load a prompt, then call [`Prompt::to_request`] to +/// convert it to a [`lingua::UniversalRequest`] for use with any LLM provider. +#[derive(Debug, Clone)] +pub struct Prompt { + id: String, + name: String, + slug: String, + version: String, + project_id: String, + prompt_data: Value, +} + +impl Prompt { + /// Get the prompt ID. + pub fn id(&self) -> &str { + &self.id + } + + /// Get the prompt name. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the prompt slug. + pub fn slug(&self) -> &str { + &self.slug + } + + /// Get the prompt version (transaction ID). + pub fn version(&self) -> &str { + &self.version + } + + /// Get the project ID containing this prompt. + pub fn project_id(&self) -> &str { + &self.project_id + } + + /// Get the raw prompt data from the Braintrust API. + pub fn data(&self) -> &Value { + &self.prompt_data + } + + /// Convert this prompt to a `lingua::UniversalRequest`. + /// + /// Returns `None` if the prompt is not a chat prompt (completion prompts are + /// not yet supported) or if the prompt data cannot be parsed. + pub fn to_request(&self) -> Option { + let prompt_obj = self.prompt_data.get("prompt")?; + let prompt_type = prompt_obj.get("type")?.as_str()?; + + if prompt_type != "chat" { + return None; + } + + let messages_arr = prompt_obj.get("messages")?.as_array()?; + let mut messages = Vec::new(); + + for msg in messages_arr { + let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or(""); + let content_val = msg.get("content").unwrap_or(&Value::Null); + + let content_str = match content_val { + Value::String(s) => s.clone(), + Value::Array(_) => { + // Multi-part content: concatenate text parts + content_val + .as_array() + .unwrap() + .iter() + .filter_map(|part| { + if part.get("type").and_then(|t| t.as_str()) == Some("text") { + part.get("text") + .and_then(|t| t.as_str()) + .map(|s| s.to_string()) + } else { + None + } + }) + .collect::>() + .join("") + } + _ => continue, + }; + + let message = match role { + "system" => lingua::Message::System { + content: UserContent::String(content_str), + }, + "developer" => lingua::Message::Developer { + content: UserContent::String(content_str), + }, + "user" => lingua::Message::User { + content: UserContent::String(content_str), + }, + "assistant" => lingua::Message::Assistant { + content: AssistantContent::String(content_str), + id: None, + }, + _ => continue, + }; + messages.push(message); + } + + let options = self.prompt_data.get("options"); + let model = options + .and_then(|o| o.get("model")) + .and_then(|m| m.as_str()) + .map(|s| s.to_string()); + + let mut params = lingua::UniversalParams::default(); + if let Some(params_val) = options.and_then(|o| o.get("params")) { + if let Some(t) = params_val.get("temperature").and_then(|v| v.as_f64()) { + params.temperature = Some(t); + } + if let Some(tp) = params_val.get("top_p").and_then(|v| v.as_f64()) { + params.top_p = Some(tp); + } + if let Some(mt) = params_val.get("max_tokens").and_then(|v| v.as_i64()) { + params.token_budget = Some(TokenBudget::OutputTokens(mt)); + } + if let Some(fp) = params_val.get("frequency_penalty").and_then(|v| v.as_f64()) { + params.frequency_penalty = Some(fp); + } + if let Some(pp) = params_val.get("presence_penalty").and_then(|v| v.as_f64()) { + params.presence_penalty = Some(pp); + } + if let Some(stop) = params_val.get("stop").and_then(|v| v.as_array()) { + let stop_strs: Vec = stop + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + if !stop_strs.is_empty() { + params.stop = Some(stop_strs); + } + } + } + + Some(lingua::UniversalRequest { + model, + messages, + params, + }) + } +} + +// ============================================================================ +// PromptBuilder - Builder for loading prompts +// ============================================================================ + +/// Builder for loading prompts from Braintrust. +/// +/// Use `client.prompt_builder_with_credentials()` to create a builder, +/// then configure it and call `build()` to fetch the prompt. +#[allow(private_bounds)] +pub struct PromptBuilder { + fetcher: Arc, + token: String, + project_name: Option, + project_id: Option, + slug: Option, + version: Option, + environment: Option, +} + +impl Clone for PromptBuilder { + fn clone(&self) -> Self { + Self { + fetcher: Arc::clone(&self.fetcher), + token: self.token.clone(), + project_name: self.project_name.clone(), + project_id: self.project_id.clone(), + slug: self.slug.clone(), + version: self.version.clone(), + environment: self.environment.clone(), + } + } +} + +#[allow(private_bounds)] +impl PromptBuilder { + /// Create a new PromptBuilder (internal use). + pub(crate) fn new(fetcher: Arc, token: impl Into) -> Self { + Self { + fetcher, + token: token.into(), + project_name: None, + project_id: None, + slug: None, + version: None, + environment: None, + } + } + + /// Set the project name. + pub fn project_name(mut self, name: impl Into) -> Self { + self.project_name = Some(name.into()); + self + } + + /// Set the project ID (takes precedence over project_name). + pub fn project_id(mut self, id: impl Into) -> Self { + self.project_id = Some(id.into()); + self + } + + /// Set the prompt slug (required). + pub fn slug(mut self, slug: impl Into) -> Self { + self.slug = Some(slug.into()); + self + } + + /// Set a specific version (transaction ID) to fetch. + /// + /// Mutually exclusive with `environment`. + pub fn version(mut self, version: impl Into) -> Self { + self.version = Some(version.into()); + self + } + + /// Set the environment to fetch (e.g., "production"). + /// + /// Mutually exclusive with `version`. + pub fn environment(mut self, env: impl Into) -> Self { + self.environment = Some(env.into()); + self + } + + /// Build and fetch the prompt. + /// + /// Returns an error if required fields are missing or if the fetch fails. + pub async fn build(self) -> std::result::Result { + let slug = self.slug.ok_or(PromptBuilderError::MissingSlug)?; + + if self.project_name.is_none() && self.project_id.is_none() { + return Err(PromptBuilderError::MissingProject); + } + + if self.version.is_some() && self.environment.is_some() { + return Err(PromptBuilderError::VersionAndEnvironmentConflict); + } + + let request = PromptFetchRequest { + project_name: if self.project_id.is_some() { + None + } else { + self.project_name.clone() + }, + project_id: self.project_id.clone(), + slug, + version: self.version, + environment: self.environment, + }; + + let response = self + .fetcher + .fetch_prompt(&self.token, request) + .await + .map_err(|e| PromptBuilderError::FetchFailed(e.to_string()))?; + + Ok(Prompt { + id: response.id, + name: response.name, + slug: response.slug, + version: response.version.unwrap_or_default(), + project_id: response.project_id, + prompt_data: response.prompt_data.unwrap_or(Value::Null), + }) + } +} diff --git a/tests/prompt_test.rs b/tests/prompt_test.rs new file mode 100644 index 0000000..a67cbce --- /dev/null +++ b/tests/prompt_test.rs @@ -0,0 +1,197 @@ +use braintrust_sdk_rust::{BraintrustClient, Message, PromptBuilderError, TokenBudget}; +use serde_json::json; +use wiremock::matchers::{method, path, query_param}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +#[tokio::test] +async fn prompt_builder_requires_slug() { + let client = BraintrustClient::builder() + .api_key("token") + .api_url("https://api.braintrust.dev") + .build() + .await + .expect("client"); + + let result = client + .prompt_builder_with_credentials("token") + .project_name("test-project") + // Missing slug + .build() + .await; + + assert!(matches!(result, Err(PromptBuilderError::MissingSlug))); +} + +#[tokio::test] +async fn prompt_builder_requires_project() { + let client = BraintrustClient::builder() + .api_key("token") + .api_url("https://api.braintrust.dev") + .build() + .await + .expect("client"); + + let result = client + .prompt_builder_with_credentials("token") + .slug("test-prompt") + // Missing project + .build() + .await; + + assert!(matches!(result, Err(PromptBuilderError::MissingProject))); +} + +#[tokio::test] +async fn prompt_builder_rejects_version_and_environment() { + let client = BraintrustClient::builder() + .api_key("token") + .api_url("https://api.braintrust.dev") + .build() + .await + .expect("client"); + + let result = client + .prompt_builder_with_credentials("token") + .project_name("test-project") + .slug("test-prompt") + .version("v1") + .environment("production") // Both version and environment + .build() + .await; + + assert!(matches!( + result, + Err(PromptBuilderError::VersionAndEnvironmentConflict) + )); +} + +#[tokio::test] +async fn prompt_fetches_and_builds_chat_prompt() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/v1/prompt")) + .and(query_param("slug", "greeting-prompt")) + .and(query_param("project_name", "test-project")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "prompt-id", + "name": "Greeting Prompt", + "slug": "greeting-prompt", + "_xact_id": "1000000001", + "project_id": "proj-id", + "prompt_data": { + "prompt": { + "type": "chat", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant for {{company}}." + }, + { + "role": "user", + "content": "Hello, my name is {{name}}." + } + ] + }, + "options": { + "model": "gpt-4", + "params": { + "temperature": 0.7, + "max_tokens": 1000 + } + }, + "template_format": "mustache" + } + }))) + .expect(1) + .mount(&server) + .await; + + let client = BraintrustClient::builder() + .api_key("token") + .app_url(server.uri()) + .api_url(server.uri()) + .build() + .await + .expect("client"); + + let prompt = client + .prompt_builder_with_credentials("token") + .project_name("test-project") + .slug("greeting-prompt") + .build() + .await + .expect("build prompt"); + + assert_eq!(prompt.id(), "prompt-id"); + assert_eq!(prompt.name(), "Greeting Prompt"); + assert_eq!(prompt.slug(), "greeting-prompt"); + assert_eq!(prompt.version(), "1000000001"); + + let request = prompt.to_request().expect("to_request should succeed"); + + assert_eq!(request.model, Some("gpt-4".to_string())); + assert_eq!(request.params.temperature, Some(0.7)); + assert_eq!( + request.params.token_budget, + Some(TokenBudget::OutputTokens(1000)) + ); + assert_eq!(request.messages.len(), 2); + + // Check message roles (content has unsubstituted template variables) + assert!(matches!(request.messages[0], Message::System { .. })); + assert!(matches!(request.messages[1], Message::User { .. })); +} + +#[tokio::test] +async fn prompt_fetches_completion_prompt() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/v1/prompt")) + .and(query_param("slug", "complete-prompt")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "id": "prompt-id-2", + "name": "Completion Prompt", + "slug": "complete-prompt", + "_xact_id": "1000000002", + "project_id": "proj-id", + "prompt_data": { + "prompt": { + "type": "completion", + "prompt": "Complete this sentence: {{prefix}}" + }, + "options": { + "model": "gpt-3.5-turbo-instruct" + }, + "template_format": "mustache" + } + }))) + .expect(1) + .mount(&server) + .await; + + let client = BraintrustClient::builder() + .api_key("token") + .app_url(server.uri()) + .api_url(server.uri()) + .build() + .await + .expect("client"); + + let prompt = client + .prompt_builder_with_credentials("token") + .project_name("test-project") + .slug("complete-prompt") + .build() + .await + .expect("build prompt"); + + assert_eq!(prompt.slug(), "complete-prompt"); + + // Raw prompt data is available + assert!(!prompt.data().is_null()); + + // Completion prompts are not yet supported via to_request() + assert!(prompt.to_request().is_none()); +}