diff --git a/src/log_queue/queue.rs b/src/log_queue/queue.rs index 23e874d..4a94d83 100644 --- a/src/log_queue/queue.rs +++ b/src/log_queue/queue.rs @@ -401,7 +401,7 @@ impl LogQueueCore { merge_paths: None, span_id, root_span_id, - span_parents: None, + span_parents: span_components.span_parents.clone(), destination, org_id, org_name, @@ -470,6 +470,7 @@ impl LogQueueCore { compute_object_metadata_args, span_id: parent_span_id, root_span_id: parent_root_span_id, + span_parents: _, propagated_event: _, }) => { let span_parents = Some(vec![parent_span_id]); @@ -804,7 +805,9 @@ impl LogQueueCore { is_merge: if payload.is_merge { Some(true) } else { None }, merge_paths: None, root_span_id, - span_parents: None, + span_parents: span_components + .as_ref() + .and_then(|components| components.span_parents.clone()), destination, org_id: payload.org_id, org_name: payload.org_name, diff --git a/src/span.rs b/src/span.rs index b4fd51d..84e6e9f 100644 --- a/src/span.rs +++ b/src/span.rs @@ -477,6 +477,10 @@ impl SpanHandle { Some(ParentSpanInfo::FullSpan { root_span_id, .. }) => Some(root_span_id.clone()), _ => Some(inner.span_id.clone()), }; + let span_parents = match &self.parent_info { + Some(ParentSpanInfo::FullSpan { span_id, .. }) => Some(vec![span_id.clone()]), + _ => None, + }; let compute_object_metadata_args = if object_id.is_none() { inherited_compute_object_metadata_args.or_else(|| { @@ -500,6 +504,7 @@ impl SpanHandle { row_id: Some(inner.row_id.clone()), span_id: Some(inner.span_id.clone()), root_span_id, + span_parents, propagated_event: inner.propagated_event.clone(), }) } @@ -876,6 +881,7 @@ mod tests { span_id: "parent-span-id".to_string(), root_span_id: "root-span-id".to_string(), compute_object_metadata_args: None, + span_parents: None, propagated_event: Some(parent_propagated), }; @@ -924,6 +930,7 @@ mod tests { span_id: "span-456".to_string(), root_span_id: "root-789".to_string(), compute_object_metadata_args: None, + span_parents: None, propagated_event: Some(propagated), }; @@ -940,6 +947,7 @@ mod tests { event.get("test_key").and_then(|v| v.as_str()), Some("test_value") ); + assert_eq!(exported.span_parents, Some(vec!["span-456".to_string()])); } #[tokio::test] diff --git a/src/span_components.rs b/src/span_components.rs index 6b764ec..c1609d5 100644 --- a/src/span_components.rs +++ b/src/span_components.rs @@ -70,6 +70,10 @@ pub struct SpanComponents { #[serde(skip_serializing_if = "Option::is_none")] pub root_span_id: Option, + /// Direct parent span IDs for this span. + #[serde(skip_serializing_if = "Option::is_none")] + pub span_parents: Option>, + /// Event data to propagate to child spans (e.g., prompt versions, metadata) #[serde(skip_serializing_if = "Option::is_none")] pub propagated_event: Option>, @@ -85,6 +89,7 @@ impl SpanComponents { row_id: None, span_id: None, root_span_id: None, + span_parents: None, propagated_event: None, } } @@ -194,6 +199,12 @@ impl SpanComponents { if let Some(ref event) = self.propagated_event { json_obj.insert("propagated_event".to_string(), Value::Object(event.clone())); } + if let Some(ref span_parents) = self.span_parents { + json_obj.insert( + "span_parents".to_string(), + Value::Array(span_parents.iter().cloned().map(Value::String).collect()), + ); + } if !json_obj.is_empty() { let json_str = serde_json::to_string(&json_obj).unwrap(); @@ -365,6 +376,25 @@ impl SpanComponents { root_span_id: json_obj .remove("root_span_id") .and_then(|v| v.as_str().map(String::from)), + span_parents: match json_obj.remove("span_parents") { + None => None, + Some(Value::Array(values)) => Some( + values + .into_iter() + .map(|value| match value { + Value::String(value) => Ok(value), + _ => Err(BraintrustError::InvalidConfig( + "span_parents must be an array of strings".to_string(), + )), + }) + .collect::>>()?, + ), + Some(_) => { + return Err(BraintrustError::InvalidConfig( + "span_parents must be an array of strings".to_string(), + )) + } + }, propagated_event: json_obj .remove("propagated_event") .and_then(|v| v.as_object().cloned()), @@ -404,6 +434,7 @@ impl SpanComponents { compute_object_metadata_args: self.compute_object_metadata_args.clone(), span_id, root_span_id, + span_parents: self.span_parents.clone(), propagated_event: self.propagated_event.clone(), }) } @@ -417,6 +448,7 @@ impl SpanComponents { compute_object_metadata_args, span_id, root_span_id, + span_parents, propagated_event, } => Some(Self { object_type: *object_type, @@ -425,6 +457,7 @@ impl SpanComponents { row_id: None, span_id: Some(span_id.clone()), root_span_id: Some(root_span_id.clone()), + span_parents: span_parents.clone(), propagated_event: propagated_event.clone(), }), _ => None, @@ -529,6 +562,7 @@ mod tests { components.object_id = Some("550e8400-e29b-41d4-a716-446655440000".to_string()); components.span_id = Some("0123456789abcdef".to_string()); components.root_span_id = Some("0123456789abcdef0123456789abcdef".to_string()); + components.span_parents = Some(vec!["parent-a".to_string()]); let encoded = components.to_str(); let decoded = SpanComponents::parse(&encoded).unwrap(); @@ -537,6 +571,7 @@ mod tests { assert_eq!(decoded.object_id, components.object_id); assert_eq!(decoded.span_id, components.span_id); assert_eq!(decoded.root_span_id, components.root_span_id); + assert_eq!(decoded.span_parents, components.span_parents); } #[test] @@ -584,6 +619,7 @@ mod tests { components.object_id = Some("project-123".to_string()); components.span_id = Some("span-456".to_string()); components.root_span_id = Some("root-789".to_string()); + components.span_parents = Some(vec!["parent-a".to_string()]); let mut propagated = Map::new(); propagated.insert( @@ -600,6 +636,7 @@ mod tests { object_id, span_id, root_span_id, + span_parents, propagated_event, .. } => { @@ -607,6 +644,7 @@ mod tests { assert_eq!(object_id, Some("project-123".to_string())); assert_eq!(span_id, "span-456"); assert_eq!(root_span_id, "root-789"); + assert_eq!(span_parents, Some(vec!["parent-a".to_string()])); assert!(propagated_event.is_some()); let event = propagated_event.unwrap(); assert_eq!( @@ -629,6 +667,7 @@ mod tests { span_id: "span-456".to_string(), root_span_id: "root-789".to_string(), compute_object_metadata_args: None, + span_parents: Some(vec!["parent-a".to_string()]), propagated_event: Some(propagated), }; @@ -638,6 +677,48 @@ mod tests { assert_eq!(components.object_id, Some("exp-123".to_string())); assert_eq!(components.span_id, Some("span-456".to_string())); assert_eq!(components.root_span_id, Some("root-789".to_string())); + assert_eq!(components.span_parents, Some(vec!["parent-a".to_string()])); assert!(components.propagated_event.is_some()); } + + #[test] + fn test_parse_v3_json_remainder_preserves_span_parents() { + let payload = serde_json::json!({ + "object_type": SpanObjectType::ProjectLogs as u8, + "object_id": "project-123", + "row_id": "row-123", + "span_id": "span-123", + "root_span_id": "root-123", + "span_parents": ["parent-a"], + }); + let encoded = BASE64.encode( + [ + vec![ENCODING_VERSION_V3, SpanObjectType::ProjectLogs as u8, 0], + serde_json::to_vec(&payload).unwrap(), + ] + .concat(), + ); + + let decoded = SpanComponents::parse(&encoded).unwrap(); + + assert_eq!(decoded.span_parents, Some(vec!["parent-a".to_string()])); + } + + #[test] + fn test_parse_rejects_invalid_span_parents() { + let payload = serde_json::json!({ + "object_type": SpanObjectType::ProjectLogs as u8, + "span_parents": [123], + }); + let encoded = BASE64.encode( + [ + vec![ENCODING_VERSION_V4, SpanObjectType::ProjectLogs as u8, 0], + serde_json::to_vec(&payload).unwrap(), + ] + .concat(), + ); + + let err = SpanComponents::parse(&encoded).unwrap_err(); + assert!(matches!(err, BraintrustError::InvalidConfig(_))); + } } diff --git a/src/types.rs b/src/types.rs index 956a873..44b5bac 100644 --- a/src/types.rs +++ b/src/types.rs @@ -364,6 +364,8 @@ pub enum ParentSpanInfo { span_id: String, root_span_id: String, #[serde(skip_serializing_if = "Option::is_none")] + span_parents: Option>, + #[serde(skip_serializing_if = "Option::is_none")] propagated_event: Option>, }, } @@ -876,6 +878,7 @@ mod tests { span_id: "span-1".to_string(), root_span_id: "root-1".to_string(), compute_object_metadata_args: None, + span_parents: Some(vec!["parent-1".to_string()]), propagated_event: None, }; @@ -884,6 +887,7 @@ mod tests { // SpanObjectType serializes as u8 for wire compatibility assert_eq!(obj.get("object_type").unwrap(), 1); + assert_eq!(obj.get("span_parents").unwrap(), &json!(["parent-1"])); } #[test] @@ -894,15 +898,21 @@ mod tests { "object_type": 1, "object_id": "exp-123", "span_id": "span-1", - "root_span_id": "root-1" + "root_span_id": "root-1", + "span_parents": ["parent-1"] } }); let parent: ParentSpanInfo = serde_json::from_value(json).unwrap(); match parent { - ParentSpanInfo::FullSpan { object_type, .. } => { + ParentSpanInfo::FullSpan { + object_type, + span_parents, + .. + } => { assert_eq!(object_type, SpanObjectType::Experiment); + assert_eq!(span_parents, Some(vec!["parent-1".to_string()])); } _ => panic!("Expected FullSpan variant"), } @@ -916,6 +926,7 @@ mod tests { span_id: "span-1".to_string(), root_span_id: "root-1".to_string(), compute_object_metadata_args: None, + span_parents: Some(vec!["parent-1".to_string()]), propagated_event: Some(Map::from_iter([( "metrics".to_string(), json!({ "foo": 0.1 }), @@ -928,6 +939,7 @@ mod tests { obj.get("propagated_event").unwrap(), &json!({ "metrics": { "foo": 0.1 } }) ); + assert_eq!(obj.get("span_parents").unwrap(), &json!(["parent-1"])); } #[test] diff --git a/tests/span_lifecycle.rs b/tests/span_lifecycle.rs index 3a52bef..2f86e77 100644 --- a/tests/span_lifecycle.rs +++ b/tests/span_lifecycle.rs @@ -94,6 +94,7 @@ async fn client_update_span_uses_exported_ids_for_project_logs() { row_id: Some("row-id".to_string()), span_id: Some("span-id".to_string()), root_span_id: Some("root-id".to_string()), + span_parents: None, propagated_event: None, } .to_str(); @@ -157,6 +158,7 @@ async fn client_update_span_with_credentials_works_without_priming_login_state() row_id: Some("row-id".to_string()), span_id: Some("span-id".to_string()), root_span_id: Some("root-id".to_string()), + span_parents: None, propagated_event: None, } .to_str(); @@ -192,6 +194,127 @@ async fn client_update_span_with_credentials_works_without_priming_login_state() assert_eq!(row["project_id"], "proj-id"); assert_eq!(row["span_id"], "span-id"); assert_eq!(row["root_span_id"], "root-id"); + assert!(row.get("span_parents").is_none()); +} + +#[tokio::test] +async fn client_update_span_includes_exported_span_parents() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/logs3")) + .respond_with(ResponseTemplate::new(200).set_body_string("{}")) + .mount(&server) + .await; + + let client = BraintrustClient::builder() + .skip_login(true) + .api_url(server.uri()) + .app_url(server.uri()) + .build() + .await + .expect("client"); + let _ = client.span_builder_with_credentials("token", "org-id"); + + let exported = SpanComponents { + object_type: SpanObjectType::ProjectLogs, + object_id: Some("proj-id".to_string()), + compute_object_metadata_args: None, + row_id: Some("row-id".to_string()), + span_id: Some("span-id".to_string()), + root_span_id: Some("root-id".to_string()), + span_parents: Some(vec!["parent-id".to_string()]), + propagated_event: None, + } + .to_str(); + + client + .update_span( + &exported, + SpanLog::builder() + .output(json!({"status": "updated"})) + .build() + .expect("build"), + ) + .await + .expect("update"); + client.flush().await.expect("flush"); + + let logs_requests: Vec<_> = server + .received_requests() + .await + .unwrap() + .into_iter() + .filter(|request| request.url.path() == "/logs3") + .collect(); + assert_eq!(logs_requests.len(), 1); + + let body: Value = serde_json::from_slice(&logs_requests[0].body).expect("json body"); + let row = body["rows"] + .as_array() + .and_then(|rows| rows.first()) + .expect("row"); + assert_eq!(row["span_parents"], json!(["parent-id"])); +} + +#[tokio::test] +async fn client_update_span_with_credentials_includes_exported_span_parents() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/logs3")) + .respond_with(ResponseTemplate::new(200).set_body_string("{}")) + .mount(&server) + .await; + + let client = BraintrustClient::builder() + .skip_login(true) + .api_url(server.uri()) + .app_url(server.uri()) + .build() + .await + .expect("client"); + + let exported = SpanComponents { + object_type: SpanObjectType::ProjectLogs, + object_id: Some("proj-id".to_string()), + compute_object_metadata_args: None, + row_id: Some("row-id".to_string()), + span_id: Some("span-id".to_string()), + root_span_id: Some("root-id".to_string()), + span_parents: Some(vec!["parent-id".to_string()]), + propagated_event: None, + } + .to_str(); + + client + .update_span_with_credentials( + "token", + "org-id", + &exported, + SpanLog::builder() + .output(json!({"status": "updated"})) + .build() + .expect("build"), + ) + .expect("update"); + client.flush().await.expect("flush"); + + let logs_requests: Vec<_> = server + .received_requests() + .await + .unwrap() + .into_iter() + .filter(|request| request.url.path() == "/logs3") + .collect(); + assert_eq!(logs_requests.len(), 1); + + let body: Value = serde_json::from_slice(&logs_requests[0].body).expect("json body"); + let row = body["rows"] + .as_array() + .and_then(|rows| rows.first()) + .expect("row"); + assert_eq!(row["span_parents"], json!(["parent-id"])); } #[tokio::test] @@ -232,6 +355,7 @@ async fn client_update_span_resolves_project_name_from_exported_compute_metadata row_id: Some("row-id".to_string()), span_id: Some("span-id".to_string()), root_span_id: Some("root-id".to_string()), + span_parents: None, propagated_event: None, } .to_str();