Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,180 changes: 1,165 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions crates/pctx_callback_registry/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "pctx_callback_registry"
version = "0.1.0"
edition.workspace = true
rust-version.workspace = true
license.workspace = true
repository.workspace = true
description = "Shared callback registry for pctx runtimes"
keywords = ["pctx", "callback", "registry"]
categories = ["development-tools"]

[dependencies]
serde_json = { workspace = true }
tracing = { workspace = true }

[lints]
workspace = true
120 changes: 120 additions & 0 deletions crates/pctx_callback_registry/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, RwLock},
};

use tracing::instrument;

/// An async callback function that can be registered and invoked from sandboxed code.
///
/// Both the JavaScript (`pctx_code_execution_runtime`) and Python (`pctx_python_runtime`)
/// runtimes share this type so the same closures can be registered with either.
pub type CallbackFn = Arc<
dyn Fn(
Option<serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send>>
+ Send
+ Sync,
>;

/// Registry mapping callback names to their implementations.
///
/// Clone is cheap — the inner map is reference-counted.
#[derive(Clone, Default)]
pub struct CallbackRegistry {
callbacks: Arc<RwLock<HashMap<String, CallbackFn>>>,
}

impl std::fmt::Debug for CallbackRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CallbackRegistry")
.field("ids", &self.ids())
.finish()
}
}

impl CallbackRegistry {
/// Returns the ids registered in this [`CallbackRegistry`].
///
/// # Panics
///
/// Panics if the internal lock is poisoned.
pub fn ids(&self) -> Vec<String> {
self.callbacks
.read()
.unwrap()
.keys()
.map(String::from)
.collect()
}

/// Register a callback under `id`.
///
/// # Errors
///
/// Returns an error if a callback with the same `id` is already registered.
///
/// # Panics
///
/// Panics if the internal lock is poisoned.
pub fn add(&self, id: &str, callback: CallbackFn) -> Result<(), String> {
let mut callbacks = self
.callbacks
.write()
.map_err(|e| format!("Failed to acquire write lock: {e}"))?;

if callbacks.contains_key(id) {
return Err(format!("Callback \"{id}\" is already registered"));
}

callbacks.insert(id.to_owned(), callback);
Ok(())
}

/// Remove a callback from the registry.
///
/// # Panics
///
/// Panics if the internal lock is poisoned.
pub fn remove(&self, id: &str) -> Option<CallbackFn> {
self.callbacks.write().unwrap().remove(id)
}

/// Look up a callback by id.
///
/// # Panics
///
/// Panics if the internal lock is poisoned.
pub fn get(&self, id: &str) -> Option<CallbackFn> {
self.callbacks.read().unwrap().get(id).cloned()
}

/// Returns `true` if a callback with the given `id` is registered.
///
/// # Panics
///
/// Panics if the internal lock is poisoned.
pub fn has(&self, id: &str) -> bool {
self.callbacks.read().unwrap().contains_key(id)
}

/// Invoke a callback by id with the given JSON arguments.
///
/// # Errors
///
/// Returns an error string if the id is not registered or if the callback fails.
#[instrument(name = "invoke_callback", skip_all, fields(id = id))]
pub async fn invoke(
&self,
id: &str,
args: Option<serde_json::Value>,
) -> Result<serde_json::Value, String> {
let callback = self
.get(id)
.ok_or_else(|| format!("Callback \"{id}\" is not registered"))?;

callback(args).await
}
}
1 change: 1 addition & 0 deletions crates/pctx_code_execution_runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ path = "src/lib.rs"

[dependencies]
pctx_config = { version = "^0.1.3", path = "../pctx_config" }
pctx_callback_registry = { path = "../pctx_callback_registry" }
deno_core = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
5 changes: 4 additions & 1 deletion crates/pctx_code_execution_runtime/src/callback_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ pub(crate) async fn op_invoke_callback(
borrowed.borrow::<CallbackRegistry>().clone()
};

registry.invoke(&id, arguments).await
registry
.invoke(&id, arguments)
.await
.map_err(McpError::ExecutionError)
}
129 changes: 1 addition & 128 deletions crates/pctx_code_execution_runtime/src/callback_registry.rs
Original file line number Diff line number Diff line change
@@ -1,128 +1 @@
use serde_json::json;
use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, RwLock},
};
use tracing::instrument;

use crate::error::McpError;

pub type CallbackFn = Arc<
dyn Fn(
Option<serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send>>
+ Send
+ Sync,
>;

/// Singleton registry for callbacks
#[derive(Clone, Default)]
pub struct CallbackRegistry {
callbacks: Arc<RwLock<HashMap<String, CallbackFn>>>,
}

impl CallbackRegistry {
/// Returns the ids of this [`CallbackRegistry`].
///
/// # Panics
///
/// Panics if it fails acquiring the lock
pub fn ids(&self) -> Vec<String> {
self.callbacks
.read()
.unwrap()
.keys()
.map(String::from)
.collect()
}

/// Adds callback to registry
///
/// # Panics
///
/// Panics if cannot obtain lock
///
/// # Errors
///
/// This function will return an error if a callback already exists with the same ID
pub fn add(
&self,
id: &str, // namespace.name
callback: CallbackFn,
) -> Result<(), McpError> {
let mut callbacks = self.callbacks.write().map_err(|e| {
McpError::Config(format!(
"Failed obtaining write lock on callback registry: {e}"
))
})?;

if callbacks.contains_key(id) {
return Err(McpError::Config(format!(
"Callback with id \"{id}\" is already registered"
)));
}

callbacks.insert(id.into(), callback);

Ok(())
}

/// Remove a callback from the registry by id
///
/// # Panics
///
/// Panics if cannot obtain lock
pub fn remove(&self, id: &str) -> Option<CallbackFn> {
let mut callbacks = self.callbacks.write().unwrap();
callbacks.remove(id)
}

/// Get a Callback from the registry by id
///
/// # Panics
///
/// Panics if the internal lock is poisoned (i.e., a thread panicked while holding the lock)
pub fn get(&self, id: &str) -> Option<CallbackFn> {
let callbacks = self.callbacks.read().unwrap();
callbacks.get(id).cloned()
}

/// Confirms the callback registry contains a given id
///
/// # Panics
///
/// Panics if the internal lock is poisoned (i.e., a thread panicked while holding the lock)
pub fn has(&self, id: &str) -> bool {
let callbacks = self.callbacks.read().unwrap();
callbacks.contains_key(id)
}

/// invokes the callback with the provided args
///
/// # Errors
///
/// This function will return an error if a callback by the provided id doesn't exist
/// or if the callback itself fails
#[instrument(
name = "invoke_callback_tool",
skip_all,
fields(id=id, args = json!(args).to_string()),
ret(Display),
err
)]
pub async fn invoke(
&self,
id: &str,
args: Option<serde_json::Value>,
) -> Result<serde_json::Value, McpError> {
let callback = self.get(id).ok_or_else(|| {
McpError::ToolCall(format!("Callback with id \"{id}\" does not exist"))
})?;

callback(args).await.map_err(|e| {
McpError::ExecutionError(format!("Failed calling callback with id \"{id}\": {e}",))
})
}
}
pub use pctx_callback_registry::{CallbackFn, CallbackRegistry};
1 change: 1 addition & 0 deletions crates/pctx_code_mode/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pctx_config = { version = "^0.1.3", path = "../pctx_config" }
pctx_codegen = { version = "^0.2.0", path = "../pctx_codegen" }
pctx_executor = { version = "^0.1.3", path = "../pctx_executor" }
pctx_code_execution_runtime = { version = "^0.1.3", path = "../pctx_code_execution_runtime" }
pctx_python_runtime = { path = "../pctx_python_runtime" }

# general
thiserror = { workspace = true }
Expand Down
Loading
Loading