From cf36f93f67b17fc81d11295fd9eb3c212b96e93e Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 26 Apr 2026 03:48:19 +0800 Subject: [PATCH 1/2] feat: add bub-extism package Adds bub-extism adapter that bridges selected Bub hooks to Extism WebAssembly plugins. Extensions can be authored in any Extism PDK language; the package ships: - ExtismPlugin entry point exposing run_model / run_model_stream, provide_channels, provide_tape_store, system_prompt, etc. - ExtismBridge invoking guest hooks and proxying channels/tape stores. - bub-extism CLI for managing extism.json under Bub home. - Rust and Go example plugins with build-and-run integration tests. Amp-Thread-ID: https://ampcode.com/threads/T-019dc60d-af3b-71eb-8e4e-1f342014a2e1 Co-authored-by: Amp --- README.md | 1 + packages/bub-extism/README.md | 252 +++++++++++ packages/bub-extism/examples/.gitignore | 4 + packages/bub-extism/examples/README.md | 64 +++ .../bub-extism/examples/go-channel/go.mod | 5 + .../bub-extism/examples/go-channel/go.sum | 2 + .../bub-extism/examples/go-channel/main.go | 71 +++ .../examples/rust-model-stream/Cargo.lock | 380 ++++++++++++++++ .../examples/rust-model-stream/Cargo.toml | 12 + .../examples/rust-model-stream/src/lib.rs | 48 ++ packages/bub-extism/pyproject.toml | 28 ++ .../bub-extism/src/bub_extism/__init__.py | 7 + packages/bub-extism/src/bub_extism/bridge.py | 90 ++++ packages/bub-extism/src/bub_extism/channel.py | 108 +++++ packages/bub-extism/src/bub_extism/cli.py | 59 +++ packages/bub-extism/src/bub_extism/codec.py | 137 ++++++ packages/bub-extism/src/bub_extism/config.py | 85 ++++ packages/bub-extism/src/bub_extism/plugin.py | 276 ++++++++++++ packages/bub-extism/src/bub_extism/py.typed | 1 + packages/bub-extism/src/bub_extism/stream.py | 42 ++ .../bub-extism/src/bub_extism/tape_store.py | 84 ++++ packages/bub-extism/tests/test_bridge.py | 410 ++++++++++++++++++ packages/bub-extism/tests/test_examples.py | 118 +++++ pyproject.toml | 2 + uv.lock | 34 ++ 25 files changed, 2320 insertions(+) create mode 100644 packages/bub-extism/README.md create mode 100644 packages/bub-extism/examples/.gitignore create mode 100644 packages/bub-extism/examples/README.md create mode 100644 packages/bub-extism/examples/go-channel/go.mod create mode 100644 packages/bub-extism/examples/go-channel/go.sum create mode 100644 packages/bub-extism/examples/go-channel/main.go create mode 100644 packages/bub-extism/examples/rust-model-stream/Cargo.lock create mode 100644 packages/bub-extism/examples/rust-model-stream/Cargo.toml create mode 100644 packages/bub-extism/examples/rust-model-stream/src/lib.rs create mode 100644 packages/bub-extism/pyproject.toml create mode 100644 packages/bub-extism/src/bub_extism/__init__.py create mode 100644 packages/bub-extism/src/bub_extism/bridge.py create mode 100644 packages/bub-extism/src/bub_extism/channel.py create mode 100644 packages/bub-extism/src/bub_extism/cli.py create mode 100644 packages/bub-extism/src/bub_extism/codec.py create mode 100644 packages/bub-extism/src/bub_extism/config.py create mode 100644 packages/bub-extism/src/bub_extism/plugin.py create mode 100644 packages/bub-extism/src/bub_extism/py.typed create mode 100644 packages/bub-extism/src/bub_extism/stream.py create mode 100644 packages/bub-extism/src/bub_extism/tape_store.py create mode 100644 packages/bub-extism/tests/test_bridge.py create mode 100644 packages/bub-extism/tests/test_examples.py diff --git a/README.md b/README.md index bf5a9bb..20d0d5b 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Below is the list of packages currently included in this repository. | [`packages/bub-tapestore-sqlite`](./packages/bub-tapestore-sqlite/README.md) | `tapestore-sqlite` | Provides a SQLite-backed tape store for Bub conversation history. | | [`packages/bub-discord`](./packages/bub-discord/README.md) | `discord` | Provides a Discord channel adapter for Bub message IO. | | [`packages/bub-dingtalk`](./packages/bub-dingtalk/README.md) | `dingtalk` | Provides a DingTalk Stream Mode channel adapter for Bub message IO. | +| [`packages/bub-extism`](./packages/bub-extism/README.md) | `extism` | Bridges selected Bub hooks to Extism WebAssembly plugins so extensions can be written in any Extism PDK language. | | [`packages/bub-github-copilot`](./packages/bub-github-copilot/README.md) | `github-copilot` | Provides a `run_model` hook backed by the GitHub Copilot SDK, plus `bub login github` device-flow login commands. | | [`packages/bub-kimi`](./packages/bub-kimi/README.md) | `kimi` | Provides a `run_model` hook backed by the Kimi CLI, including persisted session resume support and temporary Bub skill wiring. | | [`packages/bub-mcp`](./packages/bub-mcp/README.md) | `mcp` | Exposes configured MCP servers as Bub tools, with `bub mcp` CLI commands to list, add, and remove server configs. | diff --git a/packages/bub-extism/README.md b/packages/bub-extism/README.md new file mode 100644 index 0000000..99120f5 --- /dev/null +++ b/packages/bub-extism/README.md @@ -0,0 +1,252 @@ +# bub-extism + +`bub-extism` lets a Bub workspace run selected Bub hooks through an +[Extism](https://extism.org/) WebAssembly plug-in. + +The package is intentionally a bridge, not a replacement for Bub's pluggy +extension model. Bub still discovers `bub-extism` as a normal Python plugin, +then `bub-extism` delegates configured hook calls to a `.wasm` module written +with any Extism PDK language. + +The Extism Python runtime is installed with this package. The dependency is +kept on the verified `extism` `1.1.x` and `extism-sys` `1.12.x` lines because +the Python package includes native runtime wheels. + +## Supported Hooks + +The bridge exposes the current Bub standard hook surface: + +- `resolve_session` +- `build_prompt` +- `run_model` +- `run_model_stream` +- `load_state` +- `save_state` +- `render_outbound` +- `dispatch_outbound` +- `register_cli_commands` +- `onboard_config` +- `on_error` +- `system_prompt` +- `provide_tape_store` +- `provide_channels` +- `build_tape_context` + +Pure value hooks map directly to WebAssembly calls. Hooks that return Python +runtime objects use Python-side proxies: + +- `run_model_stream` accepts a returned list of stream events and wraps it as + `AsyncStreamEvents`. +- `provide_channels` accepts channel descriptors and creates `ExtismChannel` + proxies. +- `provide_tape_store` accepts a tape store descriptor and creates an + `ExtismTapeStore` proxy. +- `register_cli_commands` accepts command descriptors and registers them under + `bub extism`. +- `build_tape_context` accepts a declarative context object; arbitrary Python + selector callbacks are not part of the WASM ABI. + +## Configuration + +Create `~/.bub/extism.json`: + +```json +{ + "defaultPlugin": "echo", + "plugins": { + "echo": { + "wasmPath": "/absolute/path/to/plugin.wasm", + "wasi": false, + "config": { + "model": "demo" + }, + "hooks": { + "resolve_session": "resolve_session", + "build_prompt": "build_prompt", + "run_model": "run_model", + "run_model_stream": "run_model_stream", + "load_state": "load_state", + "save_state": "save_state", + "render_outbound": "render_outbound", + "dispatch_outbound": "dispatch_outbound", + "register_cli_commands": "register_cli_commands", + "onboard_config": "onboard_config", + "on_error": "on_error", + "system_prompt": "system_prompt" + } + } + } +} +``` + +You can also load a URL or a full Extism manifest: + +```json +{ + "defaultPlugin": "remote", + "plugins": { + "remote": { + "wasmUrl": "https://example.com/plugin.wasm", + "hooks": { + "run_model": "run_model" + } + } + } +} +``` + +```json +{ + "defaultPlugin": "manifest", + "plugins": { + "manifest": { + "manifest": { + "wasm": [ + { + "url": "https://example.com/plugin.wasm", + "hash": "sha256..." + } + ], + "allowed_hosts": ["api.example.com"] + }, + "hooks": { + "run_model": "run_model" + } + } + } +} +``` + +Use `BUB_EXTISM_CONFIG_PATH=/path/to/extism.json` to override the config path. + +## Hook ABI + +Each exported hook function receives one UTF-8 JSON object. + +For `run_model`: + +```json +{ + "abi_version": "bub.extism.v1", + "hook": "run_model", + "args": { + "prompt": "hello", + "session_id": "cli:local", + "state": {} + } +} +``` + +For `system_prompt`: + +```json +{ + "abi_version": "bub.extism.v1", + "hook": "system_prompt", + "args": { + "prompt": "hello", + "state": {} + } +} +``` + +The bridge removes Bub runtime internals such as `_runtime_agent` and +non-JSON-serializable values from `state` before calling WebAssembly. + +The wasm function may return plain text: + +```text +hello from wasm +``` + +Or a JSON object: + +```json +{ + "value": "hello from wasm" +} +``` + +It can skip the hook: + +```json +{ + "skip": true +} +``` + +Or return an error: + +```json +{ + "error": { + "message": "missing api key" + } +} +``` + +For compatibility with early demos, `{"run_model": "..."}` and +`{"system_prompt": "..."}` are still accepted. + +## Proxy Descriptors + +`provide_channels` returns channel descriptors: + +```json +{ + "value": [ + { + "name": "wasm", + "pollIntervalSeconds": 1, + "functions": { + "start": "channel_start", + "poll": "channel_poll", + "send": "channel_send", + "stop": "channel_stop" + } + } + ] +} +``` + +`provide_tape_store` returns a tape store descriptor: + +```json +{ + "value": { + "functions": { + "list_tapes": "tape_list_tapes", + "fetch_all": "tape_fetch_all", + "append": "tape_append", + "reset": "tape_reset" + } + } +} +``` + +`register_cli_commands` returns command descriptors: + +```json +{ + "value": [ + { + "name": "hello", + "help": "Run the hello command.", + "function": "cli_hello" + } + ] +} +``` + +The command is exposed as `bub extism hello '{"name":"Bub"}'`. + +## Development + +From the repository root: + +```bash +uv run --directory packages/bub-extism --with ../bub --with pytest --with pytest-asyncio pytest +``` + +These tests use a fake Extism module and do not require a local WebAssembly +runtime. diff --git a/packages/bub-extism/examples/.gitignore b/packages/bub-extism/examples/.gitignore new file mode 100644 index 0000000..a4676e3 --- /dev/null +++ b/packages/bub-extism/examples/.gitignore @@ -0,0 +1,4 @@ +target/ +*.wasm +__pycache__/ +.pytest_cache/ diff --git a/packages/bub-extism/examples/README.md b/packages/bub-extism/examples/README.md new file mode 100644 index 0000000..6e89818 --- /dev/null +++ b/packages/bub-extism/examples/README.md @@ -0,0 +1,64 @@ +# bub-extism examples + +These examples show how to implement Bub extensions in languages other than +Python while still using Bub's pluggy-based extension surface through +`bub-extism`. + +## Rust model stream + +`rust-model-stream` mirrors model-provider plugins such as `bub-kimi` and +`bub-codex`. It implements `run_model_stream` and returns Republic stream +events. + +Build: + +```bash +cd packages/bub-extism/examples/rust-model-stream +cargo build --release --target wasm32-unknown-unknown +``` + +Configure: + +```json +{ + "defaultPlugin": "rust-model-stream", + "plugins": { + "rust-model-stream": { + "wasmPath": "packages/bub-extism/examples/rust-model-stream/target/wasm32-unknown-unknown/release/bub_extism_rust_model_stream.wasm", + "hooks": { + "run_model_stream": "run_model_stream" + } + } + } +} +``` + +## Go channel + +`go-channel` mirrors channel plugins such as `bub-discord`, `bub-feishu`, and +`bub-wecom`. It implements `provide_channels` and a `send` function used by the +Python `ExtismChannel` proxy. + +Build: + +```bash +cd packages/bub-extism/examples/go-channel +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o go-channel.wasm . +``` + +Configure: + +```json +{ + "defaultPlugin": "go-channel", + "plugins": { + "go-channel": { + "wasmPath": "packages/bub-extism/examples/go-channel/go-channel.wasm", + "wasi": true, + "hooks": { + "provide_channels": "provide_channels" + } + } + } +} +``` diff --git a/packages/bub-extism/examples/go-channel/go.mod b/packages/bub-extism/examples/go-channel/go.mod new file mode 100644 index 0000000..631c731 --- /dev/null +++ b/packages/bub-extism/examples/go-channel/go.mod @@ -0,0 +1,5 @@ +module github.com/bubbuild/bub-extism/examples/go-channel + +go 1.26 + +require github.com/extism/go-pdk v1.1.3 diff --git a/packages/bub-extism/examples/go-channel/go.sum b/packages/bub-extism/examples/go-channel/go.sum new file mode 100644 index 0000000..c15d382 --- /dev/null +++ b/packages/bub-extism/examples/go-channel/go.sum @@ -0,0 +1,2 @@ +github.com/extism/go-pdk v1.1.3 h1:hfViMPWrqjN6u67cIYRALZTZLk/enSPpNKa+rZ9X2SQ= +github.com/extism/go-pdk v1.1.3/go.mod h1:Gz+LIU/YCKnKXhgge8yo5Yu1F/lbv7KtKFkiCSzW/P4= diff --git a/packages/bub-extism/examples/go-channel/main.go b/packages/bub-extism/examples/go-channel/main.go new file mode 100644 index 0000000..3cd0d28 --- /dev/null +++ b/packages/bub-extism/examples/go-channel/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "encoding/json" + "fmt" + + "github.com/extism/go-pdk" +) + +type request struct { + Hook string `json:"hook"` + Args map[string]any `json:"args"` +} + +type response struct { + Value any `json:"value,omitempty"` + Skip bool `json:"skip,omitempty"` + Error any `json:"error,omitempty"` +} + +//go:wasmexport provide_channels +func provideChannels() int32 { + return outputJSON(response{ + Value: []map[string]any{ + { + "name": "go-echo", + "pollIntervalSeconds": 1, + "functions": map[string]string{ + "send": "channel_send", + }, + }, + }, + }) +} + +//go:wasmexport channel_send +func channelSend() int32 { + var req request + if err := pdk.InputJSON(&req); err != nil { + return outputError(err) + } + message, _ := req.Args["message"].(map[string]any) + content, _ := message["content"].(string) + return outputJSON(response{ + Value: map[string]any{ + "ok": true, + "channel": "go-echo", + "sent": content, + }, + }) +} + +func outputJSON(value any) int32 { + if err := pdk.OutputJSON(value); err != nil { + return outputError(err) + } + return 0 +} + +func outputError(err error) int32 { + pdk.SetErrorString(fmt.Sprintf("go-channel: %v", err)) + encoded, _ := json.Marshal(response{ + Error: map[string]string{ + "message": err.Error(), + }, + }) + pdk.Output(encoded) + return 1 +} + +func main() {} diff --git a/packages/bub-extism/examples/rust-model-stream/Cargo.lock b/packages/bub-extism/examples/rust-model-stream/Cargo.lock new file mode 100644 index 0000000..c15473b --- /dev/null +++ b/packages/bub-extism/examples/rust-model-stream/Cargo.lock @@ -0,0 +1,380 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bub-extism-rust-model-stream" +version = "0.1.0" +dependencies = [ + "extism-pdk", + "serde", + "serde_json", +] + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "extism-convert" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec1a8eac059a1730a21aa47f99a0c2075ba0ab88fd0c4e52e35027cf99cdf3e7" +dependencies = [ + "anyhow", + "base64", + "bytemuck", + "extism-convert-macros", + "prost", + "rmp-serde", + "serde", + "serde_json", +] + +[[package]] +name = "extism-convert-macros" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "848f105dd6e1af2ea4bb4a76447658e8587167df3c4e4658c4258e5b14a5b051" +dependencies = [ + "manyhow", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "extism-manifest" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953a22ad322939ae4567ec73a34913a3a43dcbdfa648b8307d38fe56bb3a0acd" +dependencies = [ + "base64", + "serde", + "serde_json", +] + +[[package]] +name = "extism-pdk" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "352fcb5a66eb74145a1c4a01f2bd15d59c62c85be73aac8471880c65b26b798f" +dependencies = [ + "anyhow", + "base64", + "extism-convert", + "extism-manifest", + "extism-pdk-derive", + "serde", + "serde_json", +] + +[[package]] +name = "extism-pdk-derive" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d086daea5fd844e3c5ac69ddfe36df4a9a43e7218cf7d1f888182b089b09806c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "manyhow" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b33efb3ca6d3b07393750d4030418d594ab1139cee518f0dc88db70fec873587" +dependencies = [ + "manyhow-macros", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "manyhow-macros" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46fce34d199b78b6e6073abf984c9cf5fd3e9330145a93ee0738a7443e371495" +dependencies = [ + "proc-macro-utils", + "proc-macro2", + "quote", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro-utils" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eeaf08a13de400bc215877b5bdc088f241b12eb42f0a548d3390dc1c56bb7071" +dependencies = [ + "proc-macro2", + "quote", + "smallvec", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.25.11+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "winnow" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +dependencies = [ + "memchr", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/packages/bub-extism/examples/rust-model-stream/Cargo.toml b/packages/bub-extism/examples/rust-model-stream/Cargo.toml new file mode 100644 index 0000000..60ce08e --- /dev/null +++ b/packages/bub-extism/examples/rust-model-stream/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "bub-extism-rust-model-stream" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +extism-pdk = "1.4.1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/packages/bub-extism/examples/rust-model-stream/src/lib.rs b/packages/bub-extism/examples/rust-model-stream/src/lib.rs new file mode 100644 index 0000000..2540dbe --- /dev/null +++ b/packages/bub-extism/examples/rust-model-stream/src/lib.rs @@ -0,0 +1,48 @@ +use extism_pdk::{plugin_fn, FnResult}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Deserialize)] +struct Request { + hook: String, + args: Args, +} + +#[derive(Deserialize)] +struct Args { + prompt: Value, + session_id: String, +} + +#[plugin_fn] +pub fn run_model_stream(input: String) -> FnResult { + let request: Request = serde_json::from_str(&input)?; + if request.hook != "run_model_stream" { + return Ok(json!({ "skip": true }).to_string()); + } + + let prompt = match request.args.prompt { + Value::String(value) => value, + other => other.to_string(), + }; + let text = format!("[rust-model-stream:{}] {}", request.args.session_id, prompt); + + Ok(json!({ + "value": { + "events": [ + { + "kind": "text", + "data": { "delta": text } + }, + { + "kind": "final", + "data": { "text": text } + } + ], + "usage": { + "output_tokens": text.split_whitespace().count() + } + } + }) + .to_string()) +} diff --git a/packages/bub-extism/pyproject.toml b/packages/bub-extism/pyproject.toml new file mode 100644 index 0000000..2349c0c --- /dev/null +++ b/packages/bub-extism/pyproject.toml @@ -0,0 +1,28 @@ +[project] +name = "bub-extism" +version = "0.1.0" +description = "Extism WebAssembly bridge plugin for Bub" +readme = "README.md" +authors = [ + { name = "Bub Build contributors" } +] +requires-python = ">=3.12" +dependencies = [ + "extism>=1.1.1,<1.2.0", + "extism-sys>=1.12.0,<1.13.0", + "pydantic>=2.10.0", + "pydantic-settings>=2.10.1", +] + +[project.entry-points.bub] +extism = "bub_extism.plugin:ExtismPlugin" + +[build-system] +requires = ["uv_build>=0.10.4,<0.11.0"] +build-backend = "uv_build" + +[dependency-groups] +dev = [ + "pytest>=9.0.3", + "pytest-asyncio>=1.3.0", +] diff --git a/packages/bub-extism/src/bub_extism/__init__.py b/packages/bub-extism/src/bub_extism/__init__.py new file mode 100644 index 0000000..53e5f08 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/__init__.py @@ -0,0 +1,7 @@ +"""Extism bridge plugin for Bub.""" + +from bub_extism.bridge import ExtismBridge +from bub_extism.config import ExtismPluginConfig, ExtismSettings +from bub_extism.plugin import ExtismPlugin + +__all__ = ["ExtismBridge", "ExtismPlugin", "ExtismPluginConfig", "ExtismSettings"] diff --git a/packages/bub-extism/src/bub_extism/bridge.py b/packages/bub-extism/src/bub_extism/bridge.py new file mode 100644 index 0000000..2c64c88 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/bridge.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any + +from bub_extism.codec import ExtismHookSkip, build_request, decode_response +from bub_extism.config import ExtismPluginConfig, ExtismSettings + + +class ExtismBridge: + def __init__(self, settings: ExtismSettings) -> None: + self.settings = settings + + def selected_config(self) -> ExtismPluginConfig | None: + return self.settings.read_config().selected_plugin() + + def function_name(self, hook_name: str) -> str | None: + config = self.selected_config() + if config is None: + return None + return getattr(config.hooks, hook_name) + + def call_hook_sync( + self, + hook_name: str, + args: dict[str, Any], + *, + config: ExtismPluginConfig | None = None, + function_name: str | None = None, + ) -> Any: + selected = config or self.selected_config() + if selected is None: + return None + + export_name = function_name or getattr(selected.hooks, hook_name) + if export_name is None: + return None + + try: + return self._call_export(selected, export_name, hook_name, args) + except ExtismHookSkip: + return None + + async def call_hook( + self, + hook_name: str, + args: dict[str, Any], + *, + config: ExtismPluginConfig | None = None, + function_name: str | None = None, + ) -> Any: + return await asyncio.to_thread( + self.call_hook_sync, + hook_name, + args, + config=config, + function_name=function_name, + ) + + def _call_export( + self, + config: ExtismPluginConfig, + function_name: str, + hook_name: str, + args: dict[str, Any], + ) -> Any: + extism = _import_extism() + request = build_request(hook_name, args) + with extism.Plugin( + config.plugin_input(), + wasi=config.wasi, + config=config.config or None, + ) as plugin: + if hasattr(plugin, "function_exists") and not plugin.function_exists(function_name): + raise ExtismHookSkip + + raw_result = plugin.call( + function_name, + json.dumps(request, ensure_ascii=False), + ) + return decode_response(raw_result, hook_name=hook_name) + + +def _import_extism() -> Any: + try: + import extism + except ImportError as exc: + raise RuntimeError("bub-extism requires the 'extism' runtime package") from exc + return extism diff --git a/packages/bub-extism/src/bub_extism/channel.py b/packages/bub-extism/src/bub_extism/channel.py new file mode 100644 index 0000000..858150f --- /dev/null +++ b/packages/bub-extism/src/bub_extism/channel.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable +from typing import Any + +from bub.channels import Channel +from bub.types import Envelope, MessageHandler +from republic import StreamEvent + +from bub_extism.bridge import ExtismBridge +from bub_extism.codec import to_json_value +from bub_extism.config import ExtismPluginConfig + + +class ExtismChannel(Channel): + def __init__( + self, + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptor: dict[str, Any], + message_handler: MessageHandler, + ) -> None: + self.bridge = bridge + self.config = config + self.descriptor = descriptor + self.name = str(descriptor["name"]) + self._message_handler = message_handler + self._functions = dict(descriptor.get("functions") or {}) + self._poll_interval_seconds = float(descriptor.get("pollIntervalSeconds", 1.0)) + + @property + def enabled(self) -> bool: + return bool(self.descriptor.get("enabled", True)) + + @property + def needs_debounce(self) -> bool: + return bool(self.descriptor.get("needsDebounce", False)) + + async def start(self, stop_event: asyncio.Event) -> None: + await self._call("start", {}) + if "poll" not in self._functions: + await stop_event.wait() + return + + while not stop_event.is_set(): + messages = await self._call("poll", {}) + for message in _messages_from_value(messages): + await self._message_handler(message) + try: + await asyncio.wait_for(stop_event.wait(), timeout=self._poll_interval_seconds) + except TimeoutError: + continue + + async def stop(self) -> None: + await self._call("stop", {}) + + async def send(self, message: Envelope) -> None: + await self._call("send", {"message": to_json_value(message)}) + + def stream_events( + self, + message: Envelope, + stream: AsyncIterable[StreamEvent], + ) -> AsyncIterable[StreamEvent]: + return stream + + async def _call(self, operation: str, args: dict[str, Any]) -> Any: + function_name = self._functions.get(operation) + if not isinstance(function_name, str) or not function_name: + return None + return await self.bridge.call_hook( + f"channel.{operation}", + {"channel": self.name, **args}, + config=self.config, + function_name=function_name, + ) + + +def channels_from_descriptors( + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptors: Any, + message_handler: MessageHandler, +) -> list[ExtismChannel]: + if descriptors is None: + return [] + if isinstance(descriptors, dict): + descriptors = descriptors.get("channels", []) + if not isinstance(descriptors, list): + raise RuntimeError("Extism provide_channels must return a list of channel descriptors") + + channels: list[ExtismChannel] = [] + for descriptor in descriptors: + if not isinstance(descriptor, dict) or not descriptor.get("name"): + raise RuntimeError("Extism channel descriptor must include a name") + channels.append(ExtismChannel(bridge, config, descriptor, message_handler)) + return channels + + +def _messages_from_value(value: Any) -> list[Envelope]: + if value is None: + return [] + if isinstance(value, dict): + value = value.get("messages", [value]) + if not isinstance(value, list): + raise RuntimeError("Extism channel poll must return a message or message list") + return value diff --git a/packages/bub-extism/src/bub_extism/cli.py b/packages/bub-extism/src/bub_extism/cli.py new file mode 100644 index 0000000..8612b2d --- /dev/null +++ b/packages/bub-extism/src/bub_extism/cli.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import json +from typing import Any + +import typer + +from bub_extism.bridge import ExtismBridge +from bub_extism.config import ExtismPluginConfig + + +def register_cli_commands( + app: typer.Typer, + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptors: Any, +) -> None: + if descriptors is None: + return + if isinstance(descriptors, dict): + descriptors = descriptors.get("commands", []) + if not isinstance(descriptors, list): + raise RuntimeError("Extism register_cli_commands must return a list") + + group = typer.Typer(help="Commands provided by Extism WebAssembly plugins.") + for descriptor in descriptors: + if not isinstance(descriptor, dict): + raise RuntimeError("Extism CLI command descriptor must be an object") + name = str(descriptor.get("name", "")).strip() + function_name = str(descriptor.get("function", "")).strip() + if not name or not function_name: + raise RuntimeError("Extism CLI command descriptor requires name and function") + help_text = str(descriptor.get("help", "Run an Extism command.")) + group.command(name, help=help_text)(_make_command(bridge, config, name, function_name)) + + app.add_typer(group, name="extism") + + +def _make_command( + bridge: ExtismBridge, + config: ExtismPluginConfig, + command_name: str, + function_name: str, +): + def command(payload: str = typer.Argument("{}", help="JSON payload for the command.")) -> None: + try: + args = json.loads(payload) + except json.JSONDecodeError as exc: + raise typer.BadParameter("payload must be valid JSON") from exc + result = bridge.call_hook_sync( + "cli_command", + {"command": command_name, "payload": args}, + config=config, + function_name=function_name, + ) + if result is not None: + typer.echo(json.dumps(result, ensure_ascii=False, indent=2)) + + return command diff --git a/packages/bub-extism/src/bub_extism/codec.py b/packages/bub-extism/src/bub_extism/codec.py new file mode 100644 index 0000000..f53cfc6 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/codec.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from typing import Any + +from bub.envelope import normalize_envelope +from republic import StreamEvent, TapeEntry +from republic.tape.entries import utc_now + +BUB_EXTISM_ABI_VERSION = "bub.extism.v1" + + +class ExtismHookError(RuntimeError): + pass + + +class ExtismHookSkip(Exception): + pass + + +def build_request(hook_name: str, args: dict[str, Any]) -> dict[str, Any]: + return { + "abi_version": BUB_EXTISM_ABI_VERSION, + "hook": hook_name, + "args": to_json_value(args), + } + + +def decode_response(raw_result: Any, *, hook_name: str) -> Any: + if raw_result is None: + raise ExtismHookSkip + + text = result_to_text(raw_result) + if not text: + raise ExtismHookSkip + + try: + parsed = json.loads(text) + except json.JSONDecodeError: + return text + + if parsed is None: + raise ExtismHookSkip + if not isinstance(parsed, dict): + return parsed + + if parsed.get("skip") is True: + raise ExtismHookSkip + if error := parsed.get("error"): + if isinstance(error, dict): + message = error.get("message", "Extism hook returned an error") + else: + message = str(error) + raise ExtismHookError(str(message)) + if "value" in parsed: + return parsed["value"] + if hook_name in parsed: + return parsed[hook_name] + if "text" in parsed: + return parsed["text"] + return parsed + + +def result_to_text(raw_result: Any) -> str: + if isinstance(raw_result, str): + return raw_result + if isinstance(raw_result, bytes): + return raw_result.decode("utf-8") + if isinstance(raw_result, bytearray): + return bytes(raw_result).decode("utf-8") + if isinstance(raw_result, memoryview): + return raw_result.tobytes().decode("utf-8") + return bytes(raw_result).decode("utf-8") + + +def to_json_value(value: Any) -> Any: + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, dict): + return { + str(key): to_json_value(item) + for key, item in value.items() + if is_json_safe(item) + } + if isinstance(value, list | tuple): + return [to_json_value(item) for item in value if is_json_safe(item)] + if isinstance(value, StreamEvent): + return {"kind": value.kind, "data": to_json_value(value.data)} + if isinstance(value, TapeEntry): + return tape_entry_to_dict(value) + if is_dataclass(value): + return to_json_value(asdict(value)) + if hasattr(value, "__dict__"): + return to_json_value(normalize_envelope(value)) + return str(value) + + +def is_json_safe(value: Any) -> bool: + try: + json.dumps(to_json_value(value)) + except (TypeError, ValueError, RecursionError): + return False + return True + + +def state_to_json(state: dict[str, Any]) -> dict[str, Any]: + safe_state: dict[str, Any] = {} + for key, value in state.items(): + if str(key).startswith("_runtime_"): + continue + try: + json.dumps(value) + except (TypeError, ValueError): + continue + safe_state[str(key)] = to_json_value(value) + return safe_state + + +def tape_entry_to_dict(entry: TapeEntry) -> dict[str, Any]: + return { + "id": entry.id, + "kind": entry.kind, + "payload": to_json_value(entry.payload), + "meta": to_json_value(entry.meta), + "date": entry.date, + } + + +def tape_entry_from_dict(value: dict[str, Any]) -> TapeEntry: + return TapeEntry( + id=int(value.get("id", 0)), + kind=str(value.get("kind", "event")), + payload=dict(value.get("payload") or {}), + meta=dict(value.get("meta") or {}), + date=str(value.get("date", "")) or utc_now(), + ) diff --git a/packages/bub-extism/src/bub_extism/config.py b/packages/bub-extism/src/bub_extism/config.py new file mode 100644 index 0000000..38af296 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/config.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +def default_config_path() -> Path: + from bub.builtin.settings import load_settings + + return load_settings().home / "extism.json" + + +class ExtismHookMap(BaseModel): + resolve_session: str | None = None + build_prompt: str | None = None + run_model: str | None = None + run_model_stream: str | None = None + load_state: str | None = None + save_state: str | None = None + render_outbound: str | None = None + dispatch_outbound: str | None = None + register_cli_commands: str | None = None + onboard_config: str | None = None + on_error: str | None = None + system_prompt: str | None = None + provide_tape_store: str | None = None + provide_channels: str | None = None + build_tape_context: str | None = None + + +class ExtismPluginConfig(BaseModel): + manifest: dict[str, Any] | None = None + wasm_path: Path | None = Field(default=None, alias="wasmPath") + wasm_url: str | None = Field(default=None, alias="wasmUrl") + hooks: ExtismHookMap = Field(default_factory=ExtismHookMap) + config: dict[str, str] = Field(default_factory=dict) + wasi: bool = False + + @model_validator(mode="after") + def validate_wasm_source(self) -> ExtismPluginConfig: + sources = [ + self.manifest is not None, + self.wasm_path is not None, + self.wasm_url is not None, + ] + if sum(sources) != 1: + raise ValueError("exactly one of manifest, wasmPath, or wasmUrl is required") + return self + + def plugin_input(self) -> dict[str, Any] | bytes: + if self.manifest is not None: + return self.manifest + if self.wasm_url is not None: + return {"wasm": [{"url": self.wasm_url}]} + if self.wasm_path is None: + raise RuntimeError("wasmPath is required") + return self.wasm_path.expanduser().read_bytes() + + +class ExtismConfig(BaseModel): + default_plugin: str | None = Field(default=None, alias="defaultPlugin") + plugins: dict[str, ExtismPluginConfig] = Field(default_factory=dict) + + def selected_plugin(self) -> ExtismPluginConfig | None: + if self.default_plugin is None: + return None + return self.plugins.get(self.default_plugin) + + +class ExtismSettings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="BUB_EXTISM_", extra="ignore") + + config_path: Path = Field(default_factory=default_config_path) + + def read_config(self) -> ExtismConfig: + if not self.config_path.exists(): + return ExtismConfig() + raw = json.loads(self.config_path.read_text(encoding="utf-8")) + if not isinstance(raw, dict): + raise RuntimeError("Extism config file must contain a top-level mapping") + return ExtismConfig.model_validate(raw) diff --git a/packages/bub-extism/src/bub_extism/plugin.py b/packages/bub-extism/src/bub_extism/plugin.py new file mode 100644 index 0000000..00c4882 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/plugin.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +from bub import hookimpl +from bub_extism.bridge import ExtismBridge +from bub_extism.channel import channels_from_descriptors +from bub_extism.cli import register_cli_commands +from bub_extism.codec import state_to_json, to_json_value +from bub_extism.config import ExtismSettings +from bub_extism.stream import stream_events_from_value +from bub_extism.tape_store import tape_store_from_descriptor +from republic import AsyncStreamEvents, TapeContext +from republic.tape.context import LAST_ANCHOR + +if TYPE_CHECKING: + import typer + from bub.channels import Channel + from bub.framework import BubFramework + from bub.types import Envelope, MessageHandler, State + from republic.tape import TapeStore + + +class ExtismPlugin: + def __init__(self, framework: BubFramework) -> None: + self.framework = framework + self.settings = ExtismSettings() + self.bridge = ExtismBridge(self.settings) + self._register_model_hook_adapter() + + def _register_model_hook_adapter(self) -> None: + config = self.bridge.selected_config() + if config is None: + return + + plugin_manager = getattr(self.framework, "_plugin_manager", None) + if plugin_manager is None: + return + + if config.hooks.run_model_stream is not None: + plugin_manager.register( + _ExtismRunModelStreamPlugin(self.bridge), + name="extism-run-model-stream", + ) + return + + if config.hooks.run_model is not None: + plugin_manager.register( + _ExtismRunModelPlugin(self.bridge), + name="extism-run-model", + ) + + @hookimpl + def resolve_session(self, message: Envelope) -> str | None: + value = self.bridge.call_hook_sync( + "resolve_session", + {"message": to_json_value(message)}, + ) + if value is None: + return None + return str(value) + + @hookimpl + async def build_prompt( + self, + message: Envelope, + session_id: str, + state: State, + ) -> str | list[dict[str, Any]] | None: + value = await self.bridge.call_hook( + "build_prompt", + { + "message": to_json_value(message), + "session_id": session_id, + "state": state_to_json(state), + }, + ) + if value is None: + return None + if isinstance(value, str | list): + return cast(str | list[dict[str, Any]], value) + raise RuntimeError("Extism build_prompt must return a string or content-part list") + + @hookimpl + async def load_state(self, message: Envelope, session_id: str) -> State | None: + value = await self.bridge.call_hook( + "load_state", + {"message": to_json_value(message), "session_id": session_id}, + ) + if value is None: + return None + if not isinstance(value, dict): + raise RuntimeError("Extism load_state must return an object") + return value + + @hookimpl + async def save_state( + self, + session_id: str, + state: State, + message: Envelope, + model_output: str, + ) -> None: + await self.bridge.call_hook( + "save_state", + { + "session_id": session_id, + "state": state_to_json(state), + "message": to_json_value(message), + "model_output": model_output, + }, + ) + + @hookimpl + def render_outbound( + self, + message: Envelope, + session_id: str, + state: State, + model_output: str, + ) -> list[Envelope]: + value = self.bridge.call_hook_sync( + "render_outbound", + { + "message": to_json_value(message), + "session_id": session_id, + "state": state_to_json(state), + "model_output": model_output, + }, + ) + if value is None: + return [] + if isinstance(value, dict): + return [value] + if isinstance(value, list): + return value + raise RuntimeError("Extism render_outbound must return an envelope or envelope list") + + @hookimpl + async def dispatch_outbound(self, message: Envelope) -> bool: + value = await self.bridge.call_hook( + "dispatch_outbound", + {"message": to_json_value(message)}, + ) + return bool(value) + + @hookimpl + def register_cli_commands(self, app: typer.Typer) -> None: + config = self.bridge.selected_config() + if config is None or config.hooks.register_cli_commands is None: + return + descriptors = self.bridge.call_hook_sync("register_cli_commands", {"commands": []}, config=config) + register_cli_commands(app, self.bridge, config, descriptors) + + @hookimpl + def onboard_config(self, current_config: dict[str, Any]) -> dict[str, Any] | None: + value = self.bridge.call_hook_sync( + "onboard_config", + {"current_config": to_json_value(current_config)}, + ) + if value is None: + return None + if not isinstance(value, dict): + raise RuntimeError("Extism onboard_config must return an object") + return value + + @hookimpl + async def on_error(self, stage: str, error: Exception, message: Envelope | None) -> None: + await self.bridge.call_hook( + "on_error", + { + "stage": stage, + "error": { + "type": type(error).__name__, + "message": str(error), + }, + "message": to_json_value(message), + }, + ) + + @hookimpl + def system_prompt(self, prompt: str | list[dict[str, Any]], state: State) -> str | None: + value = self.bridge.call_hook_sync( + "system_prompt", + {"prompt": prompt, "state": state_to_json(state)}, + ) + if value is None: + return None + if not isinstance(value, str): + raise RuntimeError("Extism system_prompt must return a string") + return value + + @hookimpl + def provide_tape_store(self) -> TapeStore | None: + config = self.bridge.selected_config() + if config is None or config.hooks.provide_tape_store is None: + return None + descriptor = self.bridge.call_hook_sync("provide_tape_store", {}, config=config) + return tape_store_from_descriptor(self.bridge, config, descriptor) + + @hookimpl + def provide_channels(self, message_handler: MessageHandler) -> list[Channel]: + config = self.bridge.selected_config() + if config is None or config.hooks.provide_channels is None: + return [] + descriptors = self.bridge.call_hook_sync("provide_channels", {}, config=config) + return channels_from_descriptors(self.bridge, config, descriptors, message_handler) + + @hookimpl + def build_tape_context(self) -> TapeContext | None: + value = self.bridge.call_hook_sync("build_tape_context", {}) + if value is None: + return None + if not isinstance(value, dict): + raise RuntimeError("Extism build_tape_context must return an object") + + anchor_value = value.get("anchor", "last") + if anchor_value is None: + anchor = None + elif str(anchor_value).lower() in {"last", "last_anchor"}: + anchor = LAST_ANCHOR + else: + anchor = str(anchor_value) + + state = value.get("state", {}) + if not isinstance(state, dict): + raise RuntimeError("Extism build_tape_context state must be an object") + return TapeContext(anchor=anchor, state=state) + + +class _ExtismRunModelPlugin: + def __init__(self, bridge: ExtismBridge) -> None: + self.bridge = bridge + + @hookimpl + async def run_model( + self, + prompt: str | list[dict[str, Any]], + session_id: str, + state: State, + ) -> str | None: + value = await self.bridge.call_hook( + "run_model", + { + "prompt": prompt, + "session_id": session_id, + "state": state_to_json(state), + }, + ) + if value is None: + return None + if not isinstance(value, str): + raise RuntimeError("Extism run_model must return a string") + return value + + +class _ExtismRunModelStreamPlugin: + def __init__(self, bridge: ExtismBridge) -> None: + self.bridge = bridge + + @hookimpl + async def run_model_stream( + self, + prompt: str | list[dict[str, Any]], + session_id: str, + state: State, + ) -> AsyncStreamEvents | None: + value = await self.bridge.call_hook( + "run_model_stream", + { + "prompt": prompt, + "session_id": session_id, + "state": state_to_json(state), + }, + ) + return stream_events_from_value(value) diff --git a/packages/bub-extism/src/bub_extism/py.typed b/packages/bub-extism/src/bub_extism/py.typed new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/py.typed @@ -0,0 +1 @@ + diff --git a/packages/bub-extism/src/bub_extism/stream.py b/packages/bub-extism/src/bub_extism/stream.py new file mode 100644 index 0000000..f39e52e --- /dev/null +++ b/packages/bub-extism/src/bub_extism/stream.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from republic import AsyncStreamEvents, StreamEvent, StreamState + + +def stream_events_from_value(value: Any) -> AsyncStreamEvents | None: + if value is None: + return None + + events_value = value + state = StreamState() + if isinstance(value, dict): + events_value = value.get("events", []) + usage = value.get("usage") + if isinstance(usage, dict): + state.usage = usage + + if not isinstance(events_value, list): + raise RuntimeError("Extism run_model_stream must return a list of stream events") + + events = [_stream_event_from_dict(item) for item in events_value] + + async def iterator() -> AsyncIterator[StreamEvent]: + for event in events: + yield event + + return AsyncStreamEvents(iterator(), state=state) + + +def _stream_event_from_dict(value: Any) -> StreamEvent: + if not isinstance(value, dict): + raise RuntimeError("Extism stream event must be a JSON object") + kind = value.get("kind") + data = value.get("data", {}) + if not isinstance(kind, str): + raise RuntimeError("Extism stream event must include a string kind") + if not isinstance(data, dict): + raise RuntimeError("Extism stream event data must be a JSON object") + return StreamEvent(kind, data) diff --git a/packages/bub-extism/src/bub_extism/tape_store.py b/packages/bub-extism/src/bub_extism/tape_store.py new file mode 100644 index 0000000..bc6d22d --- /dev/null +++ b/packages/bub-extism/src/bub_extism/tape_store.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +from republic import TapeEntry + +from bub_extism.bridge import ExtismBridge +from bub_extism.codec import tape_entry_from_dict, tape_entry_to_dict, to_json_value +from bub_extism.config import ExtismPluginConfig + +if TYPE_CHECKING: + from republic import TapeQuery + + +class ExtismTapeStore: + def __init__( + self, + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptor: dict[str, Any], + ) -> None: + self.bridge = bridge + self.config = config + self.descriptor = descriptor + self.functions = dict(descriptor.get("functions") or {}) + + def list_tapes(self) -> list[str]: + value = self._call("list_tapes", {}) + if value is None: + return [] + if not isinstance(value, list): + raise RuntimeError("Extism tape list_tapes must return a list") + return [str(item) for item in value] + + def reset(self, tape: str) -> None: + self._call("reset", {"tape": tape}) + + def fetch_all(self, query: TapeQuery) -> Iterable[TapeEntry]: + value = self._call("fetch_all", {"query": _query_to_dict(query)}) + if value is None: + return [] + if not isinstance(value, list): + raise RuntimeError("Extism tape fetch_all must return a list") + return [tape_entry_from_dict(item) for item in value if isinstance(item, dict)] + + def append(self, tape: str, entry: TapeEntry) -> None: + self._call("append", {"tape": tape, "entry": tape_entry_to_dict(entry)}) + + def _call(self, operation: str, args: dict[str, Any]) -> Any: + function_name = self.functions.get(operation) + if not isinstance(function_name, str) or not function_name: + raise RuntimeError(f"Extism tape store does not define '{operation}'") + return self.bridge.call_hook_sync( + f"tape_store.{operation}", + args, + config=self.config, + function_name=function_name, + ) + + +def tape_store_from_descriptor( + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptor: Any, +) -> ExtismTapeStore | None: + if descriptor is None: + return None + if not isinstance(descriptor, dict): + raise RuntimeError("Extism provide_tape_store must return a descriptor object") + return ExtismTapeStore(bridge, config, descriptor) + + +def _query_to_dict(query: TapeQuery) -> dict[str, Any]: + return { + "tape": query.tape, + "query": query._query, + "after_anchor": query._after_anchor, + "after_last": query._after_last, + "between_anchors": to_json_value(query._between_anchors), + "between_dates": to_json_value(query._between_dates), + "kinds": list(query._kinds), + "limit": query._limit, + } diff --git a/packages/bub-extism/tests/test_bridge.py b/packages/bub-extism/tests/test_bridge.py new file mode 100644 index 0000000..6eca477 --- /dev/null +++ b/packages/bub-extism/tests/test_bridge.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest +import pluggy +from republic import TapeEntry, TapeQuery + +from bub.hook_runtime import HookRuntime +from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs +from bub_extism.bridge import ExtismBridge +from bub_extism.config import ExtismSettings +from bub_extism.plugin import ExtismPlugin + + +class FakePlugin: + calls: list[dict[str, Any]] = [] + exports: dict[str, Any] = {} + + def __init__( + self, + plugin_input: dict[str, Any] | bytes, + *, + wasi: bool = False, + config: dict[str, str] | None = None, + ) -> None: + self.plugin_input = plugin_input + self.wasi = wasi + self.config = config + + def __enter__(self) -> FakePlugin: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def function_exists(self, name: str) -> bool: + return name in self.exports + + def call(self, function_name: str, data: str) -> Any: + payload = json.loads(data) + self.calls.append( + { + "function_name": function_name, + "payload": payload, + "plugin_input": self.plugin_input, + "wasi": self.wasi, + "config": self.config, + } + ) + result = self.exports[function_name] + if callable(result): + return result(payload) + return result + + +@pytest.fixture(autouse=True) +def fake_extism(monkeypatch): + FakePlugin.calls = [] + FakePlugin.exports = {} + monkeypatch.setitem(sys.modules, "extism", SimpleNamespace(Plugin=FakePlugin)) + + +def _write_config(tmp_path: Path, body: dict[str, Any]) -> Path: + config_path = tmp_path / "extism.json" + config_path.write_text(json.dumps(body), encoding="utf-8") + return config_path + + +def _bridge(config_path: Path) -> ExtismBridge: + return ExtismBridge(ExtismSettings(config_path=config_path)) + + +def _plugin(config_path: Path) -> ExtismPlugin: + plugin = ExtismPlugin(SimpleNamespace()) + plugin.bridge = _bridge(config_path) + return plugin + + +def _runtime(config_path: Path, monkeypatch: pytest.MonkeyPatch) -> HookRuntime: + monkeypatch.setenv("BUB_EXTISM_CONFIG_PATH", str(config_path)) + plugin_manager = pluggy.PluginManager(BUB_HOOK_NAMESPACE) + plugin_manager.add_hookspecs(BubHookSpecs) + framework = SimpleNamespace(_plugin_manager=plugin_manager) + plugin = ExtismPlugin(framework) + plugin_manager.register(plugin, name="extism") + return HookRuntime(plugin_manager) + + +def test_plugin_exposes_all_non_model_standard_bub_hooks() -> None: + expected_hooks = { + "resolve_session", + "build_prompt", + "load_state", + "save_state", + "render_outbound", + "dispatch_outbound", + "register_cli_commands", + "onboard_config", + "on_error", + "system_prompt", + "provide_tape_store", + "provide_channels", + "build_tape_context", + } + + assert expected_hooks <= set(dir(ExtismPlugin)) + + +def test_model_hook_adapter_registers_only_one_model_surface( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "model", + "plugins": { + "model": { + "wasmUrl": "https://example.com/model.wasm", + "hooks": { + "run_model": "run_model", + "run_model_stream": "run_model_stream", + }, + } + }, + }, + ) + + runtime = _runtime(config_path, monkeypatch) + + report = runtime.hook_report() + assert report["run_model_stream"] == ["extism-run-model-stream"] + assert "run_model" not in report + + +def test_call_hook_returns_none_without_selected_plugin(tmp_path: Path) -> None: + config_path = _write_config(tmp_path, {"plugins": {}}) + + result = _bridge(config_path).call_hook_sync("run_model", {"prompt": "hello"}) + + assert result is None + assert FakePlugin.calls == [] + + +def test_run_model_calls_configured_export_with_unified_request( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + wasm_path = tmp_path / "plugin.wasm" + wasm_path.write_bytes(b"\0asm") + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "echo", + "plugins": { + "echo": { + "wasmPath": str(wasm_path), + "wasi": True, + "config": {"model": "demo"}, + "hooks": {"run_model": "bub_run_model"}, + } + }, + }, + ) + FakePlugin.exports = { + "bub_run_model": lambda request: json.dumps( + { + "value": ( + f"echo:{request['args']['session_id']}:" + f"{request['args']['prompt']}:" + f"{sorted(request['args']['state'])}" + ) + } + ) + } + + result = asyncio.run( + _runtime(config_path, monkeypatch).run_model( + prompt="hello", + session_id="s1", + state={ + "visible": {"ok": True}, + "_runtime_agent": object(), + "not_json": object(), + }, + ) + ) + + assert result == "echo:s1:hello:['visible']" + assert FakePlugin.calls == [ + { + "function_name": "bub_run_model", + "payload": { + "abi_version": "bub.extism.v1", + "hook": "run_model", + "args": { + "prompt": "hello", + "session_id": "s1", + "state": {"visible": {"ok": True}}, + }, + }, + "plugin_input": b"\0asm", + "wasi": True, + "config": {"model": "demo"}, + } + ] + + +def test_system_prompt_accepts_plain_text_result(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "prompt", + "plugins": { + "prompt": { + "wasmUrl": "https://example.com/prompt.wasm", + "hooks": {"system_prompt": "system_prompt"}, + } + }, + }, + ) + FakePlugin.exports = {"system_prompt": b"from wasm"} + + result = _plugin(config_path).system_prompt("hello", {"session_id": "s1"}) + + assert result == "from wasm" + assert FakePlugin.calls[0]["plugin_input"] == { + "wasm": [{"url": "https://example.com/prompt.wasm"}] + } + + +def test_missing_export_skips_hook(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "missing", + "plugins": { + "missing": { + "manifest": {"wasm": [{"url": "https://example.com/plugin.wasm"}]}, + "hooks": {"run_model": "missing_run_model"}, + } + }, + }, + ) + + result = asyncio.run( + _runtime(config_path, monkeypatch).run_model( + prompt="hello", + session_id="s1", + state={}, + ) + ) + + assert result is None + assert FakePlugin.calls == [] + + +def test_run_model_stream_wraps_returned_events( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "stream", + "plugins": { + "stream": { + "wasmUrl": "https://example.com/stream.wasm", + "hooks": {"run_model_stream": "run_model_stream"}, + } + }, + }, + ) + FakePlugin.exports = { + "run_model_stream": json.dumps( + { + "value": { + "events": [ + {"kind": "text", "data": {"delta": "hello"}}, + {"kind": "final", "data": {"text": "hello"}}, + ], + "usage": {"output_tokens": 1}, + } + } + ) + } + + stream = asyncio.run( + _runtime(config_path, monkeypatch).run_model_stream( + prompt="hello", + session_id="s1", + state={}, + ) + ) + assert stream is not None + events = asyncio.run(_collect_stream(stream)) + assert [(event.kind, event.data) for event in events] == [ + ("text", {"delta": "hello"}), + ("final", {"text": "hello"}), + ] + assert stream.usage == {"output_tokens": 1} + + +async def _collect_stream(stream): + return [event async for event in stream] + + +def test_tape_store_proxy_forwards_operations(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "tape", + "plugins": { + "tape": { + "wasmUrl": "https://example.com/tape.wasm", + "hooks": {"provide_tape_store": "provide_tape_store"}, + } + }, + }, + ) + FakePlugin.exports = { + "provide_tape_store": json.dumps( + { + "value": { + "functions": { + "list_tapes": "list_tapes", + "fetch_all": "fetch_all", + "append": "append", + "reset": "reset", + } + } + } + ), + "list_tapes": json.dumps({"value": ["main"]}), + "fetch_all": json.dumps( + { + "value": [ + { + "id": 1, + "kind": "message", + "payload": {"role": "user", "content": "hello"}, + "meta": {}, + "date": "2026-04-26T00:00:00+00:00", + } + ] + } + ), + "append": json.dumps({"skip": True}), + "reset": json.dumps({"skip": True}), + } + + store = _plugin(config_path).provide_tape_store() + + assert store is not None + assert store.list_tapes() == ["main"] + entries = list(store.fetch_all(TapeQuery("main", store))) + assert entries == [ + TapeEntry( + id=1, + kind="message", + payload={"role": "user", "content": "hello"}, + meta={}, + date="2026-04-26T00:00:00+00:00", + ) + ] + store.append("main", TapeEntry.message({"role": "assistant", "content": "ok"})) + store.reset("main") + + +def test_channel_proxy_forwards_send(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "channel", + "plugins": { + "channel": { + "wasmUrl": "https://example.com/channel.wasm", + "hooks": {"provide_channels": "provide_channels"}, + } + }, + }, + ) + FakePlugin.exports = { + "provide_channels": json.dumps( + { + "value": [ + { + "name": "wasm", + "functions": { + "send": "channel_send", + }, + } + ] + } + ), + "channel_send": json.dumps({"value": True}), + } + + async def handler(message: dict[str, Any]) -> None: + del message + + channels = _plugin(config_path).provide_channels(handler) + + assert [channel.name for channel in channels] == ["wasm"] + asyncio.run(channels[0].send({"content": "hello"})) + assert FakePlugin.calls[-1]["function_name"] == "channel_send" + assert FakePlugin.calls[-1]["payload"]["hook"] == "channel.send" diff --git a/packages/bub-extism/tests/test_examples.py b/packages/bub-extism/tests/test_examples.py new file mode 100644 index 0000000..3f7f6eb --- /dev/null +++ b/packages/bub-extism/tests/test_examples.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import asyncio +import json +import os +import shutil +import subprocess +from pathlib import Path +from typing import Any + +import pytest + +from bub_extism.bridge import ExtismBridge +from bub_extism.config import ExtismSettings +from bub_extism.plugin import ExtismPlugin + +PACKAGE_ROOT = Path(__file__).resolve().parents[1] +RUST_EXAMPLE = PACKAGE_ROOT / "examples" / "rust-model-stream" +GO_EXAMPLE = PACKAGE_ROOT / "examples" / "go-channel" + + +def _write_config(tmp_path: Path, body: dict[str, Any]) -> Path: + config_path = tmp_path / "extism.json" + config_path.write_text(json.dumps(body), encoding="utf-8") + return config_path + + +def _plugin(config_path: Path) -> ExtismPlugin: + plugin = ExtismPlugin(type("Framework", (), {})()) + plugin.bridge = ExtismBridge(ExtismSettings(config_path=config_path)) + return plugin + + +@pytest.mark.skipif(shutil.which("cargo") is None, reason="cargo is not installed") +def test_rust_model_stream_example_builds_and_runs(tmp_path: Path) -> None: + subprocess.run( + ["cargo", "build", "--release", "--target", "wasm32-unknown-unknown"], + cwd=RUST_EXAMPLE, + check=True, + ) + wasm_path = ( + RUST_EXAMPLE + / "target" + / "wasm32-unknown-unknown" + / "release" + / "bub_extism_rust_model_stream.wasm" + ) + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "rust", + "plugins": { + "rust": { + "wasmPath": str(wasm_path), + "hooks": {"run_model_stream": "run_model_stream"}, + } + }, + }, + ) + + async def run_example() -> list[tuple[str, dict[str, Any]]]: + stream = await _plugin(config_path).bridge.call_hook( + "run_model_stream", + { + "prompt": "hello from bub", + "session_id": "example", + "state": {}, + }, + ) + from bub_extism.stream import stream_events_from_value + + events = stream_events_from_value(stream) + assert events is not None + return [(event.kind, event.data) async for event in events] + + assert asyncio.run(run_example()) == [ + ("text", {"delta": "[rust-model-stream:example] hello from bub"}), + ("final", {"text": "[rust-model-stream:example] hello from bub"}), + ] + + +@pytest.mark.skipif(shutil.which("go") is None, reason="go is not installed") +def test_go_channel_example_builds_and_runs(tmp_path: Path) -> None: + subprocess.run(["go", "mod", "tidy"], cwd=GO_EXAMPLE, check=True) + wasm_path = tmp_path / "go-channel.wasm" + subprocess.run( + [ + "go", + "build", + "-buildmode=c-shared", + "-o", + str(wasm_path), + ".", + ], + cwd=GO_EXAMPLE, + check=True, + env={**dict(os.environ), "GOOS": "wasip1", "GOARCH": "wasm"}, + ) + config_path = _write_config( + tmp_path, + { + "defaultPlugin": "go", + "plugins": { + "go": { + "wasmPath": str(wasm_path), + "wasi": True, + "hooks": {"provide_channels": "provide_channels"}, + } + }, + }, + ) + + async def handler(message: dict[str, Any]) -> None: + del message + + channels = _plugin(config_path).provide_channels(handler) + assert [channel.name for channel in channels] == ["go-echo"] + asyncio.run(channels[0].send({"content": "hello from bub"})) diff --git a/pyproject.toml b/pyproject.toml index 9bc1b3a..a781c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "bub-codex", "bub-discord", "bub-dingtalk", + "bub-extism", "bub-github-copilot", "bub-kimi", "bub-mcp", @@ -37,6 +38,7 @@ bub-tg-feed = { workspace = true } bub-codex = { workspace = true } bub-discord = { workspace = true } bub-dingtalk = { workspace = true } +bub-extism = { workspace = true } bub-github-copilot = { workspace = true } bub-kimi = { workspace = true } bub-mcp = { workspace = true } diff --git a/uv.lock b/uv.lock index fb5a682..59f3011 100644 --- a/uv.lock +++ b/uv.lock @@ -13,6 +13,7 @@ members = [ "bub-contrib", "bub-dingtalk", "bub-discord", + "bub-extism", "bub-feishu", "bub-github-copilot", "bub-kimi", @@ -389,6 +390,7 @@ dependencies = [ { name = "bub-codex" }, { name = "bub-dingtalk" }, { name = "bub-discord" }, + { name = "bub-extism" }, { name = "bub-feishu" }, { name = "bub-github-copilot" }, { name = "bub-kimi" }, @@ -421,6 +423,7 @@ requires-dist = [ { name = "bub-codex", editable = "packages/bub-codex" }, { name = "bub-dingtalk", editable = "packages/bub-dingtalk" }, { name = "bub-discord", editable = "packages/bub-discord" }, + { name = "bub-extism", editable = "packages/bub-extism" }, { name = "bub-feishu", editable = "packages/bub-feishu" }, { name = "bub-github-copilot", editable = "packages/bub-github-copilot" }, { name = "bub-kimi", editable = "packages/bub-kimi" }, @@ -473,6 +476,37 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "discord-py", specifier = ">=2.7.1" }] +[[package]] +name = "bub-extism" +version = "0.1.0" +source = { editable = "packages/bub-extism" } +dependencies = [ + { name = "extism" }, + { name = "extism-sys" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + +[package.metadata] +requires-dist = [ + { name = "extism", specifier = ">=1.1.1,<1.2.0" }, + { name = "extism-sys", specifier = ">=1.12.0,<1.13.0" }, + { name = "pydantic", specifier = ">=2.10.0" }, + { name = "pydantic-settings", specifier = ">=2.10.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=9.0.3" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, +] + [[package]] name = "bub-feishu" version = "0.1.0" From b3d5f548b67b66b8f485c3e431977c9b9cd20bc4 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Wed, 20 May 2026 19:23:28 +0000 Subject: [PATCH 2/2] refactor: almost rewrite Signed-off-by: Chojan Shang --- packages/bub-extism/README.md | 270 ++++++----- packages/bub-extism/examples/README.md | 198 ++++++-- .../examples/go-build-prompt/go.mod | 5 + .../{go-channel => go-build-prompt}/go.sum | 0 .../examples/go-build-prompt/main.go | 59 +++ .../bub-extism/examples/go-channel/go.mod | 5 - .../bub-extism/examples/go-channel/main.go | 71 --- .../examples/rust-model-stream/src/lib.rs | 48 -- .../Cargo.lock | 2 +- .../Cargo.toml | 2 +- .../examples/rust-run-model/src/lib.rs | 31 ++ packages/bub-extism/src/bub_extism/bridge.py | 40 +- packages/bub-extism/src/bub_extism/channel.py | 74 ++- packages/bub-extism/src/bub_extism/cli.py | 235 +++++++++- packages/bub-extism/src/bub_extism/codec.py | 119 ++--- packages/bub-extism/src/bub_extism/config.py | 93 ++-- .../bub-extism/src/bub_extism/descriptors.py | 42 ++ packages/bub-extism/src/bub_extism/plugin.py | 422 ++++++++++-------- packages/bub-extism/src/bub_extism/stream.py | 25 +- .../bub-extism/src/bub_extism/tape_store.py | 46 +- packages/bub-extism/tests/test_bridge.py | 344 ++++++++++---- packages/bub-extism/tests/test_cli.py | 216 +++++++++ packages/bub-extism/tests/test_examples.py | 175 +++++--- 23 files changed, 1721 insertions(+), 801 deletions(-) create mode 100644 packages/bub-extism/examples/go-build-prompt/go.mod rename packages/bub-extism/examples/{go-channel => go-build-prompt}/go.sum (100%) create mode 100644 packages/bub-extism/examples/go-build-prompt/main.go delete mode 100644 packages/bub-extism/examples/go-channel/go.mod delete mode 100644 packages/bub-extism/examples/go-channel/main.go delete mode 100644 packages/bub-extism/examples/rust-model-stream/src/lib.rs rename packages/bub-extism/examples/{rust-model-stream => rust-run-model}/Cargo.lock (99%) rename packages/bub-extism/examples/{rust-model-stream => rust-run-model}/Cargo.toml (82%) create mode 100644 packages/bub-extism/examples/rust-run-model/src/lib.rs create mode 100644 packages/bub-extism/src/bub_extism/descriptors.py create mode 100644 packages/bub-extism/tests/test_cli.py diff --git a/packages/bub-extism/README.md b/packages/bub-extism/README.md index 99120f5..9fdf488 100644 --- a/packages/bub-extism/README.md +++ b/packages/bub-extism/README.md @@ -1,114 +1,89 @@ # bub-extism -`bub-extism` lets a Bub workspace run selected Bub hooks through an -[Extism](https://extism.org/) WebAssembly plug-in. +Extism WebAssembly bridge plugin for `bub`. -The package is intentionally a bridge, not a replacement for Bub's pluggy -extension model. Bub still discovers `bub-extism` as a normal Python plugin, -then `bub-extism` delegates configured hook calls to a `.wasm` module written -with any Extism PDK language. +## What It Provides -The Extism Python runtime is installed with this package. The dependency is -kept on the verified `extism` `1.1.x` and `extism-sys` `1.12.x` lines because -the Python package includes native runtime wheels. +- Bub plugin entry point: `extism` +- One Bub hook adapter per configured Extism plug-in +- Standard Extism manifest support +- `bub extism` management commands: + - `list` + - `show` + - `add` + - `remove` +- Python-side proxies for hook surfaces that need Bub runtime objects: + - `run_model_stream` + - `provide_channels` + - `provide_tape_store` + - `register_cli_commands` -## Supported Hooks +`bub-extism` does not replace Bub's pluggy model. It loads as a normal Bub +plugin, then registers one hook adapter for each configured wasm plug-in. -The bridge exposes the current Bub standard hook surface: +## Installation -- `resolve_session` -- `build_prompt` -- `run_model` -- `run_model_stream` -- `load_state` -- `save_state` -- `render_outbound` -- `dispatch_outbound` -- `register_cli_commands` -- `onboard_config` -- `on_error` -- `system_prompt` -- `provide_tape_store` -- `provide_channels` -- `build_tape_context` +```bash +uv pip install "git+https://github.com/bubbuild/bub-contrib.git#subdirectory=packages/bub-extism" +``` + +You can also install it with Bub: + +```bash +bub install bub-extism@main +``` + +## Prerequisites + +- Python 3.12+ +- The `extism` Python runtime package is installed with this package +- WASI-enabled modules must set `"wasi": true` in Bub config -Pure value hooks map directly to WebAssembly calls. Hooks that return Python -runtime objects use Python-side proxies: +For example builds: -- `run_model_stream` accepts a returned list of stream events and wraps it as - `AsyncStreamEvents`. -- `provide_channels` accepts channel descriptors and creates `ExtismChannel` - proxies. -- `provide_tape_store` accepts a tape store descriptor and creates an - `ExtismTapeStore` proxy. -- `register_cli_commands` accepts command descriptors and registers them under - `bub extism`. -- `build_tape_context` accepts a declarative context object; arbitrary Python - selector callbacks are not part of the WASM ABI. +- Rust example: + - `cargo` + - `rustup` + - `wasm32-unknown-unknown` target +- Go example: + - Go with `GOOS=wasip1 GOARCH=wasm` support ## Configuration -Create `~/.bub/extism.json`: +By default, `bub-extism` reads `~/.bub/extism.json`. -```json -{ - "defaultPlugin": "echo", - "plugins": { - "echo": { - "wasmPath": "/absolute/path/to/plugin.wasm", - "wasi": false, - "config": { - "model": "demo" - }, - "hooks": { - "resolve_session": "resolve_session", - "build_prompt": "build_prompt", - "run_model": "run_model", - "run_model_stream": "run_model_stream", - "load_state": "load_state", - "save_state": "save_state", - "render_outbound": "render_outbound", - "dispatch_outbound": "dispatch_outbound", - "register_cli_commands": "register_cli_commands", - "onboard_config": "onboard_config", - "on_error": "on_error", - "system_prompt": "system_prompt" - } - } - } -} -``` +Use `BUB_EXTISM_CONFIG_PATH=/path/to/extism.json` to override the config path. -You can also load a URL or a full Extism manifest: +Example: ```json { - "defaultPlugin": "remote", "plugins": { - "remote": { - "wasmUrl": "https://example.com/plugin.wasm", + "prompt": { + "manifest": { + "wasm": [ + { + "path": "/absolute/path/to/prompt.wasm" + } + ] + }, "hooks": { - "run_model": "run_model" + "build_prompt": "build_prompt" } - } - } -} -``` - -```json -{ - "defaultPlugin": "manifest", - "plugins": { - "manifest": { + }, + "model": { "manifest": { "wasm": [ { - "url": "https://example.com/plugin.wasm", - "hash": "sha256..." + "path": "/absolute/path/to/model.wasm" } ], - "allowed_hosts": ["api.example.com"] + "allowed_hosts": ["api.example.com"], + "config": { + "provider": "demo" + } }, + "wasi": true, "hooks": { "run_model": "run_model" } @@ -117,13 +92,67 @@ You can also load a URL or a full Extism manifest: } ``` -Use `BUB_EXTISM_CONFIG_PATH=/path/to/extism.json` to override the config path. +Configuration rules: + +- Each entry under `plugins` is one Bub hook adapter backed by one Extism plug-in. +- `manifest` is a standard Extism manifest object. +- `wasi` stays on the Bub side because WASI enablement is a host/runtime decision. +- `hooks` maps Bub hook names to exported wasm functions. + +## Runtime Model + +- Bub still owns hook dispatch and precedence. +- `bub-extism` registers one Python adapter per configured entry. +- You can split hooks across multiple plug-ins or keep them in one module. + +Typical layouts: -## Hook ABI +- one plug-in for `build_prompt` +- one plug-in for `run_model` +- one combined plug-in exporting both + +## Supported Hooks + +- `resolve_session` +- `build_prompt` +- `run_model` +- `run_model_stream` +- `load_state` +- `save_state` +- `render_outbound` +- `dispatch_outbound` +- `register_cli_commands` +- `onboard_config` +- `on_error` +- `system_prompt` +- `provide_tape_store` +- `provide_channels` +- `build_tape_context` + +## CLI + +`bub-extism` adds a management group similar to `bub mcp`: + +```bash +bub extism list +bub extism show prompt +bub extism add prompt ./prompt.manifest.json --hook build_prompt=build_prompt +bub extism remove prompt +``` + +`bub extism add` expects: + +- one standard Extism manifest JSON file +- one or more `--hook HOOK=EXPORT` bindings + +If a wasm plug-in exposes `register_cli_commands`, its commands are registered +into the same `bub extism` group. + +## Hook ABI Reference Each exported hook function receives one UTF-8 JSON object. -For `run_model`: +`run_model` request: ```json { @@ -137,29 +166,36 @@ For `run_model`: } ``` -For `system_prompt`: +`build_prompt` request: ```json { "abi_version": "bub.extism.v1", - "hook": "system_prompt", + "hook": "build_prompt", "args": { - "prompt": "hello", + "message": { + "content": "hello" + }, + "session_id": "cli:local", "state": {} } } ``` -The bridge removes Bub runtime internals such as `_runtime_agent` and -non-JSON-serializable values from `state` before calling WebAssembly. +Bridge behavior: + +- Bub runtime internals such as `_runtime_*` fields are removed from `state` +- Non-JSON-serializable values are skipped before the wasm call -The wasm function may return plain text: +Valid return shapes: + +Plain text: ```text hello from wasm ``` -Or a JSON object: +Wrapped value: ```json { @@ -167,7 +203,7 @@ Or a JSON object: } ``` -It can skip the hook: +Skip current hook: ```json { @@ -175,7 +211,7 @@ It can skip the hook: } ``` -Or return an error: +Return an error: ```json { @@ -185,10 +221,7 @@ Or return an error: } ``` -For compatibility with early demos, `{"run_model": "..."}` and -`{"system_prompt": "..."}` are still accepted. - -## Proxy Descriptors +## Descriptor Reference `provide_channels` returns channel descriptors: @@ -238,15 +271,40 @@ For compatibility with early demos, `{"run_model": "..."}` and } ``` -The command is exposed as `bub extism hello '{"name":"Bub"}'`. +That command is exposed as: + +```bash +bub extism hello '{"name":"Bub"}' +``` + +## Examples -## Development +See [examples/README.md](./examples/README.md) for three verified paths: + +- Rust `run_model` on its own +- Go `build_prompt` on its own +- Go `build_prompt` plus Rust `run_model` together + +## Verification From the repository root: ```bash -uv run --directory packages/bub-extism --with ../bub --with pytest --with pytest-asyncio pytest +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests -q ``` -These tests use a fake Extism module and do not require a local WebAssembly -runtime. +To verify example builds and composition only: + +```bash +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests/test_examples.py -q +``` diff --git a/packages/bub-extism/examples/README.md b/packages/bub-extism/examples/README.md index 6e89818..3291d37 100644 --- a/packages/bub-extism/examples/README.md +++ b/packages/bub-extism/examples/README.md @@ -1,64 +1,210 @@ # bub-extism examples -These examples show how to implement Bub extensions in languages other than -Python while still using Bub's pluggy-based extension surface through -`bub-extism`. +This directory contains two verified example modules: -## Rust model stream +- `go-build-prompt` +- `rust-run-model` -`rust-model-stream` mirrors model-provider plugins such as `bub-kimi` and -`bub-codex`. It implements `run_model_stream` and returns Republic stream -events. +They are intended to demonstrate three cases: -Build: +1. a single `build_prompt` wasm adapter +2. a single `run_model` wasm adapter +3. separate prompt and model adapters composed in one `extism.json` + +## Prerequisites + +Run these commands from the repository root. + +Build prerequisites: + +- Rust example: + - `cargo` + - `rustup` + - `wasm32-unknown-unknown` +- Go example: + - `go` + - `GOOS=wasip1 GOARCH=wasm` support + +The examples themselves do not require a real model backend. + +## Build Artifacts + +Build the Rust example: ```bash -cd packages/bub-extism/examples/rust-model-stream +cd bub-contrib/packages/bub-extism/examples/rust-run-model cargo build --release --target wasm32-unknown-unknown ``` -Configure: +Expected artifact: + +```text +bub-contrib/packages/bub-extism/examples/rust-run-model/target/wasm32-unknown-unknown/release/bub_extism_rust_run_model.wasm +``` + +Build the Go example: + +```bash +cd bub-contrib/packages/bub-extism/examples/go-build-prompt +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o go-build-prompt.wasm . +``` + +Expected artifact: + +```text +bub-contrib/packages/bub-extism/examples/go-build-prompt/go-build-prompt.wasm +``` + +## Run the Rust `run_model` Example + +This example exports `run_model` and returns: + +```text +[rust-run-model:] +``` + +Example config: ```json { - "defaultPlugin": "rust-model-stream", "plugins": { - "rust-model-stream": { - "wasmPath": "packages/bub-extism/examples/rust-model-stream/target/wasm32-unknown-unknown/release/bub_extism_rust_model_stream.wasm", + "model": { + "manifest": { + "wasm": [ + { + "path": "bub-contrib/packages/bub-extism/examples/rust-run-model/target/wasm32-unknown-unknown/release/bub_extism_rust_run_model.wasm" + } + ] + }, "hooks": { - "run_model_stream": "run_model_stream" + "run_model": "run_model" } } } } ``` -## Go channel +Verification: -`go-channel` mirrors channel plugins such as `bub-discord`, `bub-feishu`, and -`bub-wecom`. It implements `provide_channels` and a `send` function used by the -Python `ExtismChannel` proxy. +```bash +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests/test_examples.py \ + -k rust_run_model_example_builds_and_runs -q +``` -Build: +## Run the Go `build_prompt` Example + +This example exports `build_prompt` and returns: + +```text +[go-build-prompt:] +``` + +Example config: + +```json +{ + "plugins": { + "prompt": { + "manifest": { + "wasm": [ + { + "path": "bub-contrib/packages/bub-extism/examples/go-build-prompt/go-build-prompt.wasm" + } + ] + }, + "wasi": true, + "hooks": { + "build_prompt": "build_prompt" + } + } + } +} +``` + +Verification: ```bash -cd packages/bub-extism/examples/go-channel -GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o go-channel.wasm . +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests/test_examples.py \ + -k go_build_prompt_example_builds_and_runs -q ``` -Configure: +## Run Both Examples Together + +This is the composition case: + +- Go handles `build_prompt` +- Rust handles `run_model` + +Combined config: ```json { - "defaultPlugin": "go-channel", "plugins": { - "go-channel": { - "wasmPath": "packages/bub-extism/examples/go-channel/go-channel.wasm", + "prompt": { + "manifest": { + "wasm": [ + { + "path": "bub-contrib/packages/bub-extism/examples/go-build-prompt/go-build-prompt.wasm" + } + ] + }, "wasi": true, "hooks": { - "provide_channels": "provide_channels" + "build_prompt": "build_prompt" + } + }, + "model": { + "manifest": { + "wasm": [ + { + "path": "bub-contrib/packages/bub-extism/examples/rust-run-model/target/wasm32-unknown-unknown/release/bub_extism_rust_run_model.wasm" + } + ] + }, + "hooks": { + "run_model": "run_model" } } } } ``` + +Expected flow: + +1. `build_prompt` returns `[go-build-prompt:example] hello from bub` +2. `run_model` receives that prompt and returns `[rust-run-model:example] [go-build-prompt:example] hello from bub` + +Verification: + +```bash +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests/test_examples.py \ + -k go_and_rust_examples_can_be_combined -q +``` + +## Full Example Verification + +Run all three verified paths: + +```bash +uv run --python 3.12 --no-project \ + --with-editable ./bub \ + --with-editable ./bub-contrib/packages/bub-extism \ + --with pytest \ + --with pytest-asyncio \ + -m pytest bub-contrib/packages/bub-extism/tests/test_examples.py -q +``` diff --git a/packages/bub-extism/examples/go-build-prompt/go.mod b/packages/bub-extism/examples/go-build-prompt/go.mod new file mode 100644 index 0000000..3078124 --- /dev/null +++ b/packages/bub-extism/examples/go-build-prompt/go.mod @@ -0,0 +1,5 @@ +module github.com/bubbuild/bub-extism/examples/go-build-prompt + +go 1.26 + +require github.com/extism/go-pdk v1.1.3 diff --git a/packages/bub-extism/examples/go-channel/go.sum b/packages/bub-extism/examples/go-build-prompt/go.sum similarity index 100% rename from packages/bub-extism/examples/go-channel/go.sum rename to packages/bub-extism/examples/go-build-prompt/go.sum diff --git a/packages/bub-extism/examples/go-build-prompt/main.go b/packages/bub-extism/examples/go-build-prompt/main.go new file mode 100644 index 0000000..c9e89d7 --- /dev/null +++ b/packages/bub-extism/examples/go-build-prompt/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "encoding/json" + "fmt" + + "github.com/extism/go-pdk" +) + +type request struct { + Hook string `json:"hook"` + Args requestArgs `json:"args"` +} + +type requestArgs struct { + Message map[string]any `json:"message"` + SessionID string `json:"session_id"` +} + +type response struct { + Value any `json:"value,omitempty"` + Skip bool `json:"skip,omitempty"` +} + +//go:wasmexport build_prompt +func buildPrompt() int32 { + var req request + if err := pdk.InputJSON(&req); err != nil { + return outputError(err) + } + if req.Hook != "build_prompt" { + return outputJSON(response{Skip: true}) + } + + content, _ := req.Args.Message["content"].(string) + prompt := fmt.Sprintf("[go-build-prompt:%s] %s", req.Args.SessionID, content) + return outputJSON(response{Value: prompt}) +} + +func outputJSON(value any) int32 { + if err := pdk.OutputJSON(value); err != nil { + return outputError(err) + } + return 0 +} + +func outputError(err error) int32 { + encoded, _ := json.Marshal( + map[string]any{ + "error": map[string]string{ + "message": fmt.Sprintf("go-build-prompt: %v", err), + }, + }, + ) + pdk.Output(encoded) + return 1 +} + +func main() {} diff --git a/packages/bub-extism/examples/go-channel/go.mod b/packages/bub-extism/examples/go-channel/go.mod deleted file mode 100644 index 631c731..0000000 --- a/packages/bub-extism/examples/go-channel/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/bubbuild/bub-extism/examples/go-channel - -go 1.26 - -require github.com/extism/go-pdk v1.1.3 diff --git a/packages/bub-extism/examples/go-channel/main.go b/packages/bub-extism/examples/go-channel/main.go deleted file mode 100644 index 3cd0d28..0000000 --- a/packages/bub-extism/examples/go-channel/main.go +++ /dev/null @@ -1,71 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - - "github.com/extism/go-pdk" -) - -type request struct { - Hook string `json:"hook"` - Args map[string]any `json:"args"` -} - -type response struct { - Value any `json:"value,omitempty"` - Skip bool `json:"skip,omitempty"` - Error any `json:"error,omitempty"` -} - -//go:wasmexport provide_channels -func provideChannels() int32 { - return outputJSON(response{ - Value: []map[string]any{ - { - "name": "go-echo", - "pollIntervalSeconds": 1, - "functions": map[string]string{ - "send": "channel_send", - }, - }, - }, - }) -} - -//go:wasmexport channel_send -func channelSend() int32 { - var req request - if err := pdk.InputJSON(&req); err != nil { - return outputError(err) - } - message, _ := req.Args["message"].(map[string]any) - content, _ := message["content"].(string) - return outputJSON(response{ - Value: map[string]any{ - "ok": true, - "channel": "go-echo", - "sent": content, - }, - }) -} - -func outputJSON(value any) int32 { - if err := pdk.OutputJSON(value); err != nil { - return outputError(err) - } - return 0 -} - -func outputError(err error) int32 { - pdk.SetErrorString(fmt.Sprintf("go-channel: %v", err)) - encoded, _ := json.Marshal(response{ - Error: map[string]string{ - "message": err.Error(), - }, - }) - pdk.Output(encoded) - return 1 -} - -func main() {} diff --git a/packages/bub-extism/examples/rust-model-stream/src/lib.rs b/packages/bub-extism/examples/rust-model-stream/src/lib.rs deleted file mode 100644 index 2540dbe..0000000 --- a/packages/bub-extism/examples/rust-model-stream/src/lib.rs +++ /dev/null @@ -1,48 +0,0 @@ -use extism_pdk::{plugin_fn, FnResult}; -use serde::Deserialize; -use serde_json::{json, Value}; - -#[derive(Deserialize)] -struct Request { - hook: String, - args: Args, -} - -#[derive(Deserialize)] -struct Args { - prompt: Value, - session_id: String, -} - -#[plugin_fn] -pub fn run_model_stream(input: String) -> FnResult { - let request: Request = serde_json::from_str(&input)?; - if request.hook != "run_model_stream" { - return Ok(json!({ "skip": true }).to_string()); - } - - let prompt = match request.args.prompt { - Value::String(value) => value, - other => other.to_string(), - }; - let text = format!("[rust-model-stream:{}] {}", request.args.session_id, prompt); - - Ok(json!({ - "value": { - "events": [ - { - "kind": "text", - "data": { "delta": text } - }, - { - "kind": "final", - "data": { "text": text } - } - ], - "usage": { - "output_tokens": text.split_whitespace().count() - } - } - }) - .to_string()) -} diff --git a/packages/bub-extism/examples/rust-model-stream/Cargo.lock b/packages/bub-extism/examples/rust-run-model/Cargo.lock similarity index 99% rename from packages/bub-extism/examples/rust-model-stream/Cargo.lock rename to packages/bub-extism/examples/rust-run-model/Cargo.lock index c15473b..2e92681 100644 --- a/packages/bub-extism/examples/rust-model-stream/Cargo.lock +++ b/packages/bub-extism/examples/rust-run-model/Cargo.lock @@ -21,7 +21,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "bub-extism-rust-model-stream" +name = "bub-extism-rust-run-model" version = "0.1.0" dependencies = [ "extism-pdk", diff --git a/packages/bub-extism/examples/rust-model-stream/Cargo.toml b/packages/bub-extism/examples/rust-run-model/Cargo.toml similarity index 82% rename from packages/bub-extism/examples/rust-model-stream/Cargo.toml rename to packages/bub-extism/examples/rust-run-model/Cargo.toml index 60ce08e..1e35b2a 100644 --- a/packages/bub-extism/examples/rust-model-stream/Cargo.toml +++ b/packages/bub-extism/examples/rust-run-model/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "bub-extism-rust-model-stream" +name = "bub-extism-rust-run-model" version = "0.1.0" edition = "2021" diff --git a/packages/bub-extism/examples/rust-run-model/src/lib.rs b/packages/bub-extism/examples/rust-run-model/src/lib.rs new file mode 100644 index 0000000..fce1c10 --- /dev/null +++ b/packages/bub-extism/examples/rust-run-model/src/lib.rs @@ -0,0 +1,31 @@ +use extism_pdk::{plugin_fn, FnResult}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Deserialize)] +struct Request { + hook: String, + args: Args, +} + +#[derive(Deserialize)] +struct Args { + prompt: Value, + session_id: String, +} + +#[plugin_fn] +pub fn run_model(input: String) -> FnResult { + let request: Request = serde_json::from_str(&input)?; + if request.hook != "run_model" { + return Ok(json!({ "skip": true }).to_string()); + } + + let prompt = match request.args.prompt { + Value::String(value) => value, + other => other.to_string(), + }; + let text = format!("[rust-run-model:{}] {}", request.args.session_id, prompt); + + Ok(json!({ "value": text }).to_string()) +} diff --git a/packages/bub-extism/src/bub_extism/bridge.py b/packages/bub-extism/src/bub_extism/bridge.py index 2c64c88..c67c532 100644 --- a/packages/bub-extism/src/bub_extism/bridge.py +++ b/packages/bub-extism/src/bub_extism/bridge.py @@ -5,40 +5,24 @@ from typing import Any from bub_extism.codec import ExtismHookSkip, build_request, decode_response -from bub_extism.config import ExtismPluginConfig, ExtismSettings +from bub_extism.config import ExtismPluginConfig class ExtismBridge: - def __init__(self, settings: ExtismSettings) -> None: - self.settings = settings - - def selected_config(self) -> ExtismPluginConfig | None: - return self.settings.read_config().selected_plugin() - - def function_name(self, hook_name: str) -> str | None: - config = self.selected_config() - if config is None: - return None - return getattr(config.hooks, hook_name) - def call_hook_sync( self, hook_name: str, args: dict[str, Any], *, - config: ExtismPluginConfig | None = None, + config: ExtismPluginConfig, function_name: str | None = None, ) -> Any: - selected = config or self.selected_config() - if selected is None: - return None - - export_name = function_name or getattr(selected.hooks, hook_name) + export_name = function_name or config.hooks.get(hook_name) if export_name is None: return None try: - return self._call_export(selected, export_name, hook_name, args) + return self._call_export(config, export_name, hook_name, args) except ExtismHookSkip: return None @@ -47,7 +31,7 @@ async def call_hook( hook_name: str, args: dict[str, Any], *, - config: ExtismPluginConfig | None = None, + config: ExtismPluginConfig, function_name: str | None = None, ) -> Any: return await asyncio.to_thread( @@ -67,19 +51,11 @@ def _call_export( ) -> Any: extism = _import_extism() request = build_request(hook_name, args) - with extism.Plugin( - config.plugin_input(), - wasi=config.wasi, - config=config.config or None, - ) as plugin: + with extism.Plugin(config.manifest, wasi=config.wasi, config=None) as plugin: if hasattr(plugin, "function_exists") and not plugin.function_exists(function_name): raise ExtismHookSkip - - raw_result = plugin.call( - function_name, - json.dumps(request, ensure_ascii=False), - ) - return decode_response(raw_result, hook_name=hook_name) + raw_result = plugin.call(function_name, json.dumps(request, ensure_ascii=False)) + return decode_response(raw_result) def _import_extism() -> Any: diff --git a/packages/bub-extism/src/bub_extism/channel.py b/packages/bub-extism/src/bub_extism/channel.py index 858150f..d4d4d1e 100644 --- a/packages/bub-extism/src/bub_extism/channel.py +++ b/packages/bub-extism/src/bub_extism/channel.py @@ -9,8 +9,9 @@ from republic import StreamEvent from bub_extism.bridge import ExtismBridge -from bub_extism.codec import to_json_value +from bub_extism.codec import message_to_json from bub_extism.config import ExtismPluginConfig +from bub_extism.descriptors import normalize_function_bindings, require_mapping, required_text class ExtismChannel(Channel): @@ -18,24 +19,51 @@ def __init__( self, bridge: ExtismBridge, config: ExtismPluginConfig, - descriptor: dict[str, Any], + *, + name: str, + enabled: bool, + needs_debounce: bool, + poll_interval_seconds: float, + functions: dict[str, str], message_handler: MessageHandler, ) -> None: self.bridge = bridge self.config = config - self.descriptor = descriptor - self.name = str(descriptor["name"]) + self.name = name + self._enabled = enabled + self._needs_debounce = needs_debounce self._message_handler = message_handler - self._functions = dict(descriptor.get("functions") or {}) - self._poll_interval_seconds = float(descriptor.get("pollIntervalSeconds", 1.0)) + self._functions = functions + self._poll_interval_seconds = poll_interval_seconds + + @classmethod + def from_descriptor( + cls, + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptor: Any, + message_handler: MessageHandler, + ) -> ExtismChannel: + data = require_mapping(descriptor, message="Extism channel descriptor must be an object") + name = required_text(data.get("name"), message="Extism channel descriptor must include a name") + return cls( + bridge, + config, + name=name, + enabled=bool(data.get("enabled", True)), + needs_debounce=bool(data.get("needsDebounce", False)), + poll_interval_seconds=float(data.get("pollIntervalSeconds", 1.0)), + functions=_functions_from_descriptor(data), + message_handler=message_handler, + ) @property def enabled(self) -> bool: - return bool(self.descriptor.get("enabled", True)) + return self._enabled @property def needs_debounce(self) -> bool: - return bool(self.descriptor.get("needsDebounce", False)) + return self._needs_debounce async def start(self, stop_event: asyncio.Event) -> None: await self._call("start", {}) @@ -56,7 +84,7 @@ async def stop(self) -> None: await self._call("stop", {}) async def send(self, message: Envelope) -> None: - await self._call("send", {"message": to_json_value(message)}) + await self._call("send", {"message": message_to_json(message)}) def stream_events( self, @@ -77,25 +105,19 @@ async def _call(self, operation: str, args: dict[str, Any]) -> Any: ) -def channels_from_descriptors( +def channels_from_value( bridge: ExtismBridge, config: ExtismPluginConfig, - descriptors: Any, + value: Any, message_handler: MessageHandler, ) -> list[ExtismChannel]: - if descriptors is None: + if value is None: return [] - if isinstance(descriptors, dict): - descriptors = descriptors.get("channels", []) - if not isinstance(descriptors, list): + if isinstance(value, dict): + value = value.get("channels", []) + if not isinstance(value, list): raise RuntimeError("Extism provide_channels must return a list of channel descriptors") - - channels: list[ExtismChannel] = [] - for descriptor in descriptors: - if not isinstance(descriptor, dict) or not descriptor.get("name"): - raise RuntimeError("Extism channel descriptor must include a name") - channels.append(ExtismChannel(bridge, config, descriptor, message_handler)) - return channels + return [ExtismChannel.from_descriptor(bridge, config, descriptor, message_handler) for descriptor in value] def _messages_from_value(value: Any) -> list[Envelope]: @@ -106,3 +128,11 @@ def _messages_from_value(value: Any) -> list[Envelope]: if not isinstance(value, list): raise RuntimeError("Extism channel poll must return a message or message list") return value + + +def _functions_from_descriptor(descriptor: dict[str, Any]) -> dict[str, str]: + return normalize_function_bindings( + descriptor.get("functions"), + message="Extism channel functions must be an object", + missing_ok=True, + ) diff --git a/packages/bub-extism/src/bub_extism/cli.py b/packages/bub-extism/src/bub_extism/cli.py index 8612b2d..0b2a555 100644 --- a/packages/bub-extism/src/bub_extism/cli.py +++ b/packages/bub-extism/src/bub_extism/cli.py @@ -1,43 +1,172 @@ from __future__ import annotations import json -from typing import Any +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any import typer from bub_extism.bridge import ExtismBridge -from bub_extism.config import ExtismPluginConfig +from bub_extism.config import ( + CLI_HOOK_NAME, + ExtismConfig, + ExtismPluginConfig, + ExtismSettings, + normalize_hook_bindings, +) +from bub_extism.descriptors import require_mapping, required_text + +if TYPE_CHECKING: + from collections.abc import Iterable + + +RESERVED_COMMAND_NAMES = {"add", "list", "remove", "show"} + + +@dataclass(frozen=True) +class CommandDescriptor: + name: str + function: str + help_text: str | None = None + + @classmethod + def from_descriptor(cls, descriptor: Any) -> CommandDescriptor: + data = require_mapping(descriptor, message="Extism CLI command descriptor must be an object") + name = required_text(data.get("name"), message="Extism CLI command descriptor requires name and function") + function = required_text( + data.get("function"), + message="Extism CLI command descriptor requires name and function", + ) + help_value = data.get("help") + help_text = None if help_value is None else str(help_value) + return cls(name=name, function=function, help_text=help_text) def register_cli_commands( app: typer.Typer, + settings: ExtismSettings, bridge: ExtismBridge, - config: ExtismPluginConfig, - descriptors: Any, ) -> None: - if descriptors is None: - return - if isinstance(descriptors, dict): - descriptors = descriptors.get("commands", []) - if not isinstance(descriptors, list): - raise RuntimeError("Extism register_cli_commands must return a list") + app.add_typer(make_extism_command(settings, bridge), name="extism") + + +def make_extism_command(settings: ExtismSettings, bridge: ExtismBridge) -> typer.Typer: + app = typer.Typer(help="Manage and inspect Extism-backed Bub adapters.") + + @app.command("list") + def list_plugins() -> None: + """List configured Extism adapters.""" + config = settings.read_config() + if not config.plugins: + typer.echo("No Extism plugins configured.") + typer.echo(f"Config: {settings.config_path}") + return + typer.echo(_format_plugin_list(config)) + typer.echo(f"Config: {settings.config_path}") + + @app.command("show") + def show_plugin(name: str = typer.Argument(..., help="Configured adapter name.")) -> None: + """Show one adapter configuration.""" + plugin = settings.read_config().plugins.get(name) + if plugin is None: + typer.echo(f"Extism plugin '{name}' does not exist.", err=True) + raise typer.Exit(code=1) + typer.echo(json.dumps(plugin.model_dump(mode="json"), ensure_ascii=False, indent=2)) + + @app.command("add") + def add_plugin( + name: str = typer.Argument(..., help="Adapter name."), + manifest_path: Path = typer.Argument( + ..., + exists=True, + dir_okay=False, + help="Path to an Extism manifest JSON file.", + ), + hook: list[str] | None = typer.Option( + None, + "--hook", + help="Hook binding in HOOK=EXPORT format. Repeat to bind multiple Bub hooks.", + ), + wasi: bool = typer.Option(False, "--wasi", help="Enable WASI for this adapter."), + replace: bool = typer.Option(False, "--replace", help="Replace an existing adapter with the same name."), + ) -> None: + """Add one Extism adapter.""" + config = settings.read_config() + if name in config.plugins and not replace: + raise typer.BadParameter(f"Extism plugin '{name}' already exists. Use --replace to overwrite it.") + + config.plugins[name] = ExtismPluginConfig( + manifest=_load_manifest(manifest_path), + hooks=_parse_hook_bindings(hook or []), + wasi=wasi, + ) + settings.write_config(config) + + typer.echo(f"Added Extism plugin '{name}'.") + typer.echo(f"Config: {settings.config_path}") + typer.echo(_format_single_plugin(name, config.plugins[name])) + + @app.command("remove") + def remove_plugin(name: str = typer.Argument(..., help="Adapter name.")) -> None: + """Remove one Extism adapter.""" + config = settings.read_config() + if name not in config.plugins: + typer.echo(f"Extism plugin '{name}' does not exist.", err=True) + raise typer.Exit(code=1) - group = typer.Typer(help="Commands provided by Extism WebAssembly plugins.") - for descriptor in descriptors: - if not isinstance(descriptor, dict): - raise RuntimeError("Extism CLI command descriptor must be an object") - name = str(descriptor.get("name", "")).strip() - function_name = str(descriptor.get("function", "")).strip() - if not name or not function_name: - raise RuntimeError("Extism CLI command descriptor requires name and function") - help_text = str(descriptor.get("help", "Run an Extism command.")) - group.command(name, help=help_text)(_make_command(bridge, config, name, function_name)) + del config.plugins[name] + settings.write_config(config) + typer.echo(f"Removed Extism plugin '{name}'.") + typer.echo(f"Config: {settings.config_path}") - app.add_typer(group, name="extism") + _register_plugin_commands(app, settings.read_config(), bridge) + return app -def _make_command( +def _register_plugin_commands(app: typer.Typer, config: ExtismConfig, bridge: ExtismBridge) -> None: + registered_names = set(RESERVED_COMMAND_NAMES) + for plugin_name, plugin_config in config.plugins.items(): + if CLI_HOOK_NAME not in plugin_config.hooks: + continue + for descriptor in commands_from_value( + bridge.call_hook_sync( + "register_cli_commands", + {"commands": []}, + config=plugin_config, + ) + ): + if descriptor.name in registered_names: + raise RuntimeError( + f"Extism CLI command '{descriptor.name}' conflicts with an existing command" + ) + registered_names.add(descriptor.name) + + help_text = descriptor.help_text or f"Run the '{descriptor.name}' command from Extism plugin '{plugin_name}'." + app.command(descriptor.name, help=help_text)( + _make_plugin_command( + bridge, + plugin_name, + plugin_config, + descriptor.name, + descriptor.function, + ) + ) + + +def commands_from_value(value: Any) -> list[CommandDescriptor]: + if value is None: + return [] + if isinstance(value, dict): + value = value.get("commands", []) + if not isinstance(value, list): + raise RuntimeError("Extism register_cli_commands must return a list") + return [CommandDescriptor.from_descriptor(item) for item in value] + + +def _make_plugin_command( bridge: ExtismBridge, + plugin_name: str, config: ExtismPluginConfig, command_name: str, function_name: str, @@ -49,7 +178,11 @@ def command(payload: str = typer.Argument("{}", help="JSON payload for the comma raise typer.BadParameter("payload must be valid JSON") from exc result = bridge.call_hook_sync( "cli_command", - {"command": command_name, "payload": args}, + { + "plugin": plugin_name, + "command": command_name, + "payload": args, + }, config=config, function_name=function_name, ) @@ -57,3 +190,59 @@ def command(payload: str = typer.Argument("{}", help="JSON payload for the comma typer.echo(json.dumps(result, ensure_ascii=False, indent=2)) return command + + +def _format_plugin_list(config: ExtismConfig) -> str: + lines = [typer.style("Extism Plugins", bold=True)] + for name, plugin in config.plugins.items(): + lines.append(_format_single_plugin(name, plugin)) + return "\n".join(lines) + + +def _format_single_plugin(name: str, plugin: ExtismPluginConfig) -> str: + hooks = plugin.hooks + hook_text = ", ".join(f"{hook}->{export}" for hook, export in hooks.items()) if hooks else "No hooks" + wasi_text = "enabled" if plugin.wasi else "disabled" + return f"- {name}\n WASI: {wasi_text}\n Source: {_manifest_source(plugin.manifest)}\n Hooks: {hook_text}" + + +def _manifest_source(manifest: dict[str, Any]) -> str: + wasm_entries = manifest.get("wasm") + if not isinstance(wasm_entries, list) or not wasm_entries: + return "manifest" + + first_entry = wasm_entries[0] + if not isinstance(first_entry, dict): + return "manifest" + for key in ("path", "url", "name"): + value = first_entry.get(key) + if isinstance(value, str) and value.strip(): + return value + return "manifest" + + +def _load_manifest(manifest_path: Path) -> dict[str, Any]: + try: + raw = json.loads(manifest_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise typer.BadParameter("manifest file must contain valid JSON") from exc + if not isinstance(raw, dict): + raise typer.BadParameter("manifest file must contain a top-level object") + return raw + + +def _parse_hook_bindings(bindings: Iterable[str]) -> dict[str, str]: + payload: dict[str, str] = {} + for item in bindings: + if "=" not in item: + raise typer.BadParameter("--hook must be in HOOK=EXPORT format") + hook_name, export_name = item.split("=", 1) + hook_name = hook_name.strip() + export_name = export_name.strip() + if not hook_name or not export_name: + raise typer.BadParameter("--hook requires both hook and export names") + payload[hook_name] = export_name + try: + return normalize_hook_bindings(payload) + except ValueError as exc: + raise typer.BadParameter(str(exc)) from exc diff --git a/packages/bub-extism/src/bub_extism/codec.py b/packages/bub-extism/src/bub_extism/codec.py index f53cfc6..897cb1c 100644 --- a/packages/bub-extism/src/bub_extism/codec.py +++ b/packages/bub-extism/src/bub_extism/codec.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from collections.abc import Mapping, Sequence from dataclasses import asdict, is_dataclass from typing import Any @@ -9,6 +10,7 @@ from republic.tape.entries import utc_now BUB_EXTISM_ABI_VERSION = "bub.extism.v1" +_SKIP_JSON_VALUE = object() class ExtismHookError(RuntimeError): @@ -23,11 +25,11 @@ def build_request(hook_name: str, args: dict[str, Any]) -> dict[str, Any]: return { "abi_version": BUB_EXTISM_ABI_VERSION, "hook": hook_name, - "args": to_json_value(args), + "args": mapping_to_json(args), } -def decode_response(raw_result: Any, *, hook_name: str) -> Any: +def decode_response(raw_result: Any) -> Any: if raw_result is None: raise ExtismHookSkip @@ -44,19 +46,12 @@ def decode_response(raw_result: Any, *, hook_name: str) -> Any: raise ExtismHookSkip if not isinstance(parsed, dict): return parsed - if parsed.get("skip") is True: raise ExtismHookSkip if error := parsed.get("error"): - if isinstance(error, dict): - message = error.get("message", "Extism hook returned an error") - else: - message = str(error) - raise ExtismHookError(str(message)) + raise ExtismHookError(_error_message(error)) if "value" in parsed: return parsed["value"] - if hook_name in parsed: - return parsed[hook_name] if "text" in parsed: return parsed["text"] return parsed @@ -65,64 +60,45 @@ def decode_response(raw_result: Any, *, hook_name: str) -> Any: def result_to_text(raw_result: Any) -> str: if isinstance(raw_result, str): return raw_result - if isinstance(raw_result, bytes): - return raw_result.decode("utf-8") - if isinstance(raw_result, bytearray): + if isinstance(raw_result, bytes | bytearray | memoryview): return bytes(raw_result).decode("utf-8") - if isinstance(raw_result, memoryview): - return raw_result.tobytes().decode("utf-8") return bytes(raw_result).decode("utf-8") -def to_json_value(value: Any) -> Any: - if value is None or isinstance(value, str | int | float | bool): - return value - if isinstance(value, dict): - return { - str(key): to_json_value(item) - for key, item in value.items() - if is_json_safe(item) - } - if isinstance(value, list | tuple): - return [to_json_value(item) for item in value if is_json_safe(item)] - if isinstance(value, StreamEvent): - return {"kind": value.kind, "data": to_json_value(value.data)} - if isinstance(value, TapeEntry): - return tape_entry_to_dict(value) - if is_dataclass(value): - return to_json_value(asdict(value)) - if hasattr(value, "__dict__"): - return to_json_value(normalize_envelope(value)) - return str(value) +def message_to_json(message: Any) -> dict[str, Any]: + return mapping_to_json(normalize_envelope(message)) -def is_json_safe(value: Any) -> bool: - try: - json.dumps(to_json_value(value)) - except (TypeError, ValueError, RecursionError): - return False - return True +def error_to_json(error: Exception) -> dict[str, str]: + return { + "type": type(error).__name__, + "message": str(error), + } + + +def mapping_to_json(mapping: Mapping[str, Any]) -> dict[str, Any]: + return { + str(key): encoded + for key, value in mapping.items() + if (encoded := _encode_or_skip(value)) is not _SKIP_JSON_VALUE + } def state_to_json(state: dict[str, Any]) -> dict[str, Any]: - safe_state: dict[str, Any] = {} - for key, value in state.items(): - if str(key).startswith("_runtime_"): - continue - try: - json.dumps(value) - except (TypeError, ValueError): - continue - safe_state[str(key)] = to_json_value(value) - return safe_state + return { + str(key): encoded + for key, value in state.items() + if not str(key).startswith("_runtime_") + and (encoded := _encode_or_skip(value)) is not _SKIP_JSON_VALUE + } def tape_entry_to_dict(entry: TapeEntry) -> dict[str, Any]: return { "id": entry.id, "kind": entry.kind, - "payload": to_json_value(entry.payload), - "meta": to_json_value(entry.meta), + "payload": mapping_to_json(entry.payload), + "meta": mapping_to_json(entry.meta), "date": entry.date, } @@ -135,3 +111,40 @@ def tape_entry_from_dict(value: dict[str, Any]) -> TapeEntry: meta=dict(value.get("meta") or {}), date=str(value.get("date", "")) or utc_now(), ) + + +def _error_message(error: Any) -> str: + if isinstance(error, dict): + return str(error.get("message", "Extism hook returned an error")) + return str(error) + + +def _encode_or_skip(value: Any) -> Any: + try: + return _encode_json_value(value) + except (TypeError, ValueError, RecursionError): + return _SKIP_JSON_VALUE + + +def _encode_json_value(value: Any) -> Any: + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, Mapping): + return mapping_to_json({str(key): item for key, item in value.items()}) + if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray | memoryview): + return [encoded for item in value if (encoded := _encode_or_skip(item)) is not _SKIP_JSON_VALUE] + if isinstance(value, StreamEvent): + return { + "kind": value.kind, + "data": mapping_to_json(value.data), + } + if isinstance(value, TapeEntry): + return tape_entry_to_dict(value) + if is_dataclass(value): + dataclass_value = asdict(value) + if not isinstance(dataclass_value, Mapping): + raise TypeError("Dataclass value must encode to a mapping") + return mapping_to_json(dataclass_value) + if hasattr(value, "__dict__"): + return message_to_json(value) + raise TypeError(f"Unsupported Extism JSON value: {type(value).__name__}") diff --git a/packages/bub-extism/src/bub_extism/config.py b/packages/bub-extism/src/bub_extism/config.py index 38af296..a1022eb 100644 --- a/packages/bub-extism/src/bub_extism/config.py +++ b/packages/bub-extism/src/bub_extism/config.py @@ -4,9 +4,28 @@ from pathlib import Path from typing import Any -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict +PLUGIN_HOOK_NAMES = ( + "resolve_session", + "build_prompt", + "run_model", + "run_model_stream", + "load_state", + "save_state", + "render_outbound", + "dispatch_outbound", + "onboard_config", + "on_error", + "system_prompt", + "provide_tape_store", + "provide_channels", + "build_tape_context", +) +CLI_HOOK_NAME = "register_cli_commands" +ALLOWED_HOOK_NAMES = frozenset((*PLUGIN_HOOK_NAMES, CLI_HOOK_NAME)) + def default_config_path() -> Path: from bub.builtin.settings import load_settings @@ -14,62 +33,34 @@ def default_config_path() -> Path: return load_settings().home / "extism.json" -class ExtismHookMap(BaseModel): - resolve_session: str | None = None - build_prompt: str | None = None - run_model: str | None = None - run_model_stream: str | None = None - load_state: str | None = None - save_state: str | None = None - render_outbound: str | None = None - dispatch_outbound: str | None = None - register_cli_commands: str | None = None - onboard_config: str | None = None - on_error: str | None = None - system_prompt: str | None = None - provide_tape_store: str | None = None - provide_channels: str | None = None - build_tape_context: str | None = None +def normalize_hook_bindings(hooks: dict[str, str]) -> dict[str, str]: + normalized: dict[str, str] = {} + for hook_name, export_name in hooks.items(): + hook_text = str(hook_name).strip() + export_text = str(export_name).strip() + if hook_text not in ALLOWED_HOOK_NAMES: + supported = ", ".join(sorted(ALLOWED_HOOK_NAMES)) + raise ValueError(f"unsupported hook '{hook_text}'; expected one of: {supported}") + if not export_text: + raise ValueError(f"hook '{hook_text}' requires a non-empty export name") + normalized[hook_text] = export_text + return normalized class ExtismPluginConfig(BaseModel): - manifest: dict[str, Any] | None = None - wasm_path: Path | None = Field(default=None, alias="wasmPath") - wasm_url: str | None = Field(default=None, alias="wasmUrl") - hooks: ExtismHookMap = Field(default_factory=ExtismHookMap) - config: dict[str, str] = Field(default_factory=dict) + manifest: dict[str, Any] + hooks: dict[str, str] = Field(default_factory=dict) wasi: bool = False - @model_validator(mode="after") - def validate_wasm_source(self) -> ExtismPluginConfig: - sources = [ - self.manifest is not None, - self.wasm_path is not None, - self.wasm_url is not None, - ] - if sum(sources) != 1: - raise ValueError("exactly one of manifest, wasmPath, or wasmUrl is required") - return self - - def plugin_input(self) -> dict[str, Any] | bytes: - if self.manifest is not None: - return self.manifest - if self.wasm_url is not None: - return {"wasm": [{"url": self.wasm_url}]} - if self.wasm_path is None: - raise RuntimeError("wasmPath is required") - return self.wasm_path.expanduser().read_bytes() + @field_validator("hooks") + @classmethod + def validate_hooks(cls, hooks: dict[str, str]) -> dict[str, str]: + return normalize_hook_bindings(hooks) class ExtismConfig(BaseModel): - default_plugin: str | None = Field(default=None, alias="defaultPlugin") plugins: dict[str, ExtismPluginConfig] = Field(default_factory=dict) - def selected_plugin(self) -> ExtismPluginConfig | None: - if self.default_plugin is None: - return None - return self.plugins.get(self.default_plugin) - class ExtismSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="BUB_EXTISM_", extra="ignore") @@ -83,3 +74,11 @@ def read_config(self) -> ExtismConfig: if not isinstance(raw, dict): raise RuntimeError("Extism config file must contain a top-level mapping") return ExtismConfig.model_validate(raw) + + def write_config(self, config: ExtismConfig) -> None: + self.config_path.parent.mkdir(parents=True, exist_ok=True) + payload = config.model_dump(mode="json") + self.config_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) diff --git a/packages/bub-extism/src/bub_extism/descriptors.py b/packages/bub-extism/src/bub_extism/descriptors.py new file mode 100644 index 0000000..4e93a69 --- /dev/null +++ b/packages/bub-extism/src/bub_extism/descriptors.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import Any + + +def require_mapping(value: Any, *, message: str) -> dict[str, Any]: + if not isinstance(value, dict): + raise RuntimeError(message) + return value + + +def required_text(value: Any, *, message: str) -> str: + text = str(value or "").strip() + if not text: + raise RuntimeError(message) + return text + + +def normalize_function_bindings( + value: Any, + *, + message: str, + missing_ok: bool, +) -> dict[str, str]: + if value is None: + if missing_ok: + return {} + raise RuntimeError(message) + + data = require_mapping(value, message=message) + bindings: dict[str, str] = {} + for operation, function_name in data.items(): + operation_text = required_text( + operation, + message="Extism functions must map operation names to export names", + ) + function_text = required_text( + function_name, + message="Extism functions must map operation names to export names", + ) + bindings[operation_text] = function_text + return bindings diff --git a/packages/bub-extism/src/bub_extism/plugin.py b/packages/bub-extism/src/bub_extism/plugin.py index 00c4882..85fcb83 100644 --- a/packages/bub-extism/src/bub_extism/plugin.py +++ b/packages/bub-extism/src/bub_extism/plugin.py @@ -1,15 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from bub import hookimpl from bub_extism.bridge import ExtismBridge -from bub_extism.channel import channels_from_descriptors +from bub_extism.channel import channels_from_value from bub_extism.cli import register_cli_commands -from bub_extism.codec import state_to_json, to_json_value -from bub_extism.config import ExtismSettings +from bub_extism.codec import error_to_json, mapping_to_json, message_to_json, state_to_json +from bub_extism.config import ExtismPluginConfig, ExtismSettings, PLUGIN_HOOK_NAMES from bub_extism.stream import stream_events_from_value -from bub_extism.tape_store import tape_store_from_descriptor +from bub_extism.tape_store import tape_store_from_value from republic import AsyncStreamEvents, TapeContext from republic.tape.context import LAST_ANCHOR @@ -21,256 +21,286 @@ from republic.tape import TapeStore +def _message_args(message: Envelope) -> dict[str, Any]: + return {"message": message_to_json(message)} + + +def _message_session_args(message: Envelope, session_id: str) -> dict[str, Any]: + return { + **_message_args(message), + "session_id": session_id, + } + + +def _state_args(state: State) -> dict[str, Any]: + return {"state": state_to_json(state)} + + +def _message_session_state_args(message: Envelope, session_id: str, state: State) -> dict[str, Any]: + return { + **_message_session_args(message, session_id), + **_state_args(state), + } + + +def _prompt_session_state_args( + prompt: str | list[dict[str, Any]], + session_id: str, + state: State, +) -> dict[str, Any]: + return { + "prompt": prompt, + "session_id": session_id, + **_state_args(state), + } + + +def _require_mapping(value: Any, *, hook_name: str) -> dict[str, Any]: + if not isinstance(value, dict): + raise RuntimeError(f"Extism {hook_name} must return an object") + return value + + +def _require_string(value: Any, *, hook_name: str) -> str: + if not isinstance(value, str): + raise RuntimeError(f"Extism {hook_name} must return a string") + return value + + +def _optional_mapping(value: Any, *, hook_name: str) -> dict[str, Any] | None: + if value is None: + return None + return _require_mapping(value, hook_name=hook_name) + + +def _optional_string(value: Any, *, hook_name: str) -> str | None: + if value is None: + return None + return _require_string(value, hook_name=hook_name) + + +def _prompt_value(value: Any) -> str | list[dict[str, Any]] | None: + if value is None or isinstance(value, str): + return value + if isinstance(value, list): + return value + raise RuntimeError("Extism build_prompt must return a string or content-part list") + + +def _outbound_messages(value: Any) -> list[Envelope]: + if value is None: + return [] + if isinstance(value, dict): + return [value] + if isinstance(value, list): + return value + raise RuntimeError("Extism render_outbound must return an envelope or envelope list") + + +def _tape_context(value: Any) -> TapeContext | None: + if value is None: + return None + + data = _require_mapping(value, hook_name="build_tape_context") + anchor_value = data.get("anchor", "last") + if anchor_value is None: + anchor = None + elif str(anchor_value).lower() in {"last", "last_anchor"}: + anchor = LAST_ANCHOR + else: + anchor = str(anchor_value) + + state = _require_mapping(data.get("state", {}), hook_name="build_tape_context state") + return TapeContext(anchor=anchor, state=state) + + class ExtismPlugin: - def __init__(self, framework: BubFramework) -> None: + def __init__( + self, + framework: BubFramework, + *, + settings: ExtismSettings | None = None, + ) -> None: self.framework = framework - self.settings = ExtismSettings() - self.bridge = ExtismBridge(self.settings) - self._register_model_hook_adapter() - - def _register_model_hook_adapter(self) -> None: - config = self.bridge.selected_config() - if config is None: - return + self.settings = settings or ExtismSettings() + self.bridge = ExtismBridge() + self._register_hook_adapters() + def _register_hook_adapters(self) -> None: plugin_manager = getattr(self.framework, "_plugin_manager", None) if plugin_manager is None: return - if config.hooks.run_model_stream is not None: - plugin_manager.register( - _ExtismRunModelStreamPlugin(self.bridge), - name="extism-run-model-stream", - ) - return - - if config.hooks.run_model is not None: - plugin_manager.register( - _ExtismRunModelPlugin(self.bridge), - name="extism-run-model", - ) + for plugin_name, config in self.settings.read_config().plugins.items(): + adapter = build_hook_adapter(plugin_name, self.bridge, config) + if adapter is not None: + plugin_manager.register(adapter, name=f"extism:{plugin_name}") @hookimpl - def resolve_session(self, message: Envelope) -> str | None: - value = self.bridge.call_hook_sync( - "resolve_session", - {"message": to_json_value(message)}, - ) - if value is None: - return None - return str(value) + def register_cli_commands(self, app: typer.Typer) -> None: + register_cli_commands(app, self.settings, self.bridge) - @hookimpl - async def build_prompt( + +class ExtismHookAdapter: + def __init__(self, bridge: ExtismBridge, config: ExtismPluginConfig) -> None: + self.bridge = bridge + self.config = config + + def _call_sync(self, hook_name: str, **args: Any) -> Any: + return self.bridge.call_hook_sync(hook_name, args, config=self.config) + + async def _call(self, hook_name: str, **args: Any) -> Any: + return await self.bridge.call_hook(hook_name, args, config=self.config) + + def hook_resolve_session(self, message: Envelope) -> str | None: + value = self._call_sync("resolve_session", **_message_args(message)) + return None if value is None else str(value) + + async def hook_build_prompt( self, message: Envelope, session_id: str, state: State, ) -> str | list[dict[str, Any]] | None: - value = await self.bridge.call_hook( - "build_prompt", - { - "message": to_json_value(message), - "session_id": session_id, - "state": state_to_json(state), - }, + return _prompt_value( + await self._call( + "build_prompt", + **_message_session_state_args(message, session_id, state), + ) ) - if value is None: - return None - if isinstance(value, str | list): - return cast(str | list[dict[str, Any]], value) - raise RuntimeError("Extism build_prompt must return a string or content-part list") - @hookimpl - async def load_state(self, message: Envelope, session_id: str) -> State | None: - value = await self.bridge.call_hook( - "load_state", - {"message": to_json_value(message), "session_id": session_id}, + async def hook_load_state(self, message: Envelope, session_id: str) -> State | None: + return _optional_mapping( + await self._call("load_state", **_message_session_args(message, session_id)), + hook_name="load_state", ) - if value is None: - return None - if not isinstance(value, dict): - raise RuntimeError("Extism load_state must return an object") - return value - @hookimpl - async def save_state( + async def hook_save_state( self, session_id: str, state: State, message: Envelope, model_output: str, ) -> None: - await self.bridge.call_hook( + await self._call( "save_state", - { - "session_id": session_id, - "state": state_to_json(state), - "message": to_json_value(message), - "model_output": model_output, - }, + **_message_session_state_args(message, session_id, state), + model_output=model_output, ) - @hookimpl - def render_outbound( + def hook_render_outbound( self, message: Envelope, session_id: str, state: State, model_output: str, ) -> list[Envelope]: - value = self.bridge.call_hook_sync( - "render_outbound", - { - "message": to_json_value(message), - "session_id": session_id, - "state": state_to_json(state), - "model_output": model_output, - }, - ) - if value is None: - return [] - if isinstance(value, dict): - return [value] - if isinstance(value, list): - return value - raise RuntimeError("Extism render_outbound must return an envelope or envelope list") - - @hookimpl - async def dispatch_outbound(self, message: Envelope) -> bool: - value = await self.bridge.call_hook( - "dispatch_outbound", - {"message": to_json_value(message)}, + return _outbound_messages( + self._call_sync( + "render_outbound", + **_message_session_state_args(message, session_id, state), + model_output=model_output, + ) ) - return bool(value) - @hookimpl - def register_cli_commands(self, app: typer.Typer) -> None: - config = self.bridge.selected_config() - if config is None or config.hooks.register_cli_commands is None: - return - descriptors = self.bridge.call_hook_sync("register_cli_commands", {"commands": []}, config=config) - register_cli_commands(app, self.bridge, config, descriptors) + async def hook_dispatch_outbound(self, message: Envelope) -> bool: + return bool(await self._call("dispatch_outbound", **_message_args(message))) - @hookimpl - def onboard_config(self, current_config: dict[str, Any]) -> dict[str, Any] | None: - value = self.bridge.call_hook_sync( - "onboard_config", - {"current_config": to_json_value(current_config)}, + def hook_onboard_config(self, current_config: dict[str, Any]) -> dict[str, Any] | None: + return _optional_mapping( + self._call_sync("onboard_config", current_config=mapping_to_json(current_config)), + hook_name="onboard_config", ) - if value is None: - return None - if not isinstance(value, dict): - raise RuntimeError("Extism onboard_config must return an object") - return value - @hookimpl - async def on_error(self, stage: str, error: Exception, message: Envelope | None) -> None: - await self.bridge.call_hook( + async def hook_on_error( + self, + stage: str, + error: Exception, + message: Envelope | None, + ) -> None: + await self._call( "on_error", - { - "stage": stage, - "error": { - "type": type(error).__name__, - "message": str(error), - }, - "message": to_json_value(message), - }, + stage=stage, + error=error_to_json(error), + message=None if message is None else message_to_json(message), ) - @hookimpl - def system_prompt(self, prompt: str | list[dict[str, Any]], state: State) -> str | None: - value = self.bridge.call_hook_sync( - "system_prompt", - {"prompt": prompt, "state": state_to_json(state)}, + def hook_system_prompt(self, prompt: str | list[dict[str, Any]], state: State) -> str | None: + return _optional_string( + self._call_sync("system_prompt", prompt=prompt, **_state_args(state)), + hook_name="system_prompt", ) - if value is None: - return None - if not isinstance(value, str): - raise RuntimeError("Extism system_prompt must return a string") - return value - @hookimpl - def provide_tape_store(self) -> TapeStore | None: - config = self.bridge.selected_config() - if config is None or config.hooks.provide_tape_store is None: - return None - descriptor = self.bridge.call_hook_sync("provide_tape_store", {}, config=config) - return tape_store_from_descriptor(self.bridge, config, descriptor) + def hook_provide_tape_store(self) -> TapeStore | None: + return tape_store_from_value( + self.bridge, + self.config, + self._call_sync("provide_tape_store"), + ) - @hookimpl - def provide_channels(self, message_handler: MessageHandler) -> list[Channel]: - config = self.bridge.selected_config() - if config is None or config.hooks.provide_channels is None: - return [] - descriptors = self.bridge.call_hook_sync("provide_channels", {}, config=config) - return channels_from_descriptors(self.bridge, config, descriptors, message_handler) + def hook_provide_channels(self, message_handler: MessageHandler) -> list[Channel]: + return channels_from_value( + self.bridge, + self.config, + self._call_sync("provide_channels"), + message_handler, + ) - @hookimpl - def build_tape_context(self) -> TapeContext | None: - value = self.bridge.call_hook_sync("build_tape_context", {}) - if value is None: - return None - if not isinstance(value, dict): - raise RuntimeError("Extism build_tape_context must return an object") - - anchor_value = value.get("anchor", "last") - if anchor_value is None: - anchor = None - elif str(anchor_value).lower() in {"last", "last_anchor"}: - anchor = LAST_ANCHOR - else: - anchor = str(anchor_value) - - state = value.get("state", {}) - if not isinstance(state, dict): - raise RuntimeError("Extism build_tape_context state must be an object") - return TapeContext(anchor=anchor, state=state) - - -class _ExtismRunModelPlugin: - def __init__(self, bridge: ExtismBridge) -> None: - self.bridge = bridge + def hook_build_tape_context(self) -> TapeContext | None: + return _tape_context(self._call_sync("build_tape_context")) - @hookimpl - async def run_model( + async def hook_run_model( self, prompt: str | list[dict[str, Any]], session_id: str, state: State, ) -> str | None: - value = await self.bridge.call_hook( - "run_model", - { - "prompt": prompt, - "session_id": session_id, - "state": state_to_json(state), - }, + return _optional_string( + await self._call( + "run_model", + **_prompt_session_state_args(prompt, session_id, state), + ), + hook_name="run_model", ) - if value is None: - return None - if not isinstance(value, str): - raise RuntimeError("Extism run_model must return a string") - return value - -class _ExtismRunModelStreamPlugin: - def __init__(self, bridge: ExtismBridge) -> None: - self.bridge = bridge - - @hookimpl - async def run_model_stream( + async def hook_run_model_stream( self, prompt: str | list[dict[str, Any]], session_id: str, state: State, ) -> AsyncStreamEvents | None: - value = await self.bridge.call_hook( + value = await self._call( "run_model_stream", - { - "prompt": prompt, - "session_id": session_id, - "state": state_to_json(state), - }, + **_prompt_session_state_args(prompt, session_id, state), ) return stream_events_from_value(value) + + +def build_hook_adapter( + plugin_name: str, + bridge: ExtismBridge, + config: ExtismPluginConfig, +) -> ExtismHookAdapter | None: + enabled_hook_names = tuple( + hook_name + for hook_name in PLUGIN_HOOK_NAMES + if hook_name in config.hooks + ) + if not enabled_hook_names: + return None + + attrs = { + hook_name: hookimpl(getattr(ExtismHookAdapter, f"hook_{hook_name}")) + for hook_name in enabled_hook_names + } + adapter_type = type( + f"ExtismHookAdapter_{_class_name_fragment(plugin_name)}", + (ExtismHookAdapter,), + attrs, + ) + return adapter_type(bridge, config) + + +def _class_name_fragment(plugin_name: str) -> str: + text = "".join(char if char.isalnum() else "_" for char in plugin_name.strip()) + return text or "Plugin" diff --git a/packages/bub-extism/src/bub_extism/stream.py b/packages/bub-extism/src/bub_extism/stream.py index f39e52e..b7113b7 100644 --- a/packages/bub-extism/src/bub_extism/stream.py +++ b/packages/bub-extism/src/bub_extism/stream.py @@ -10,24 +10,29 @@ def stream_events_from_value(value: Any) -> AsyncStreamEvents | None: if value is None: return None - events_value = value + events, state = _stream_payload(value) + + async def iterator() -> AsyncIterator[StreamEvent]: + for event in events: + yield event + + return AsyncStreamEvents(iterator(), state=state) + + +def _stream_payload(value: Any) -> tuple[list[StreamEvent], StreamState]: state = StreamState() + events_value = value if isinstance(value, dict): events_value = value.get("events", []) usage = value.get("usage") - if isinstance(usage, dict): + if usage is not None: + if not isinstance(usage, dict): + raise RuntimeError("Extism run_model_stream usage must be a JSON object") state.usage = usage if not isinstance(events_value, list): raise RuntimeError("Extism run_model_stream must return a list of stream events") - - events = [_stream_event_from_dict(item) for item in events_value] - - async def iterator() -> AsyncIterator[StreamEvent]: - for event in events: - yield event - - return AsyncStreamEvents(iterator(), state=state) + return ([_stream_event_from_dict(item) for item in events_value], state) def _stream_event_from_dict(value: Any) -> StreamEvent: diff --git a/packages/bub-extism/src/bub_extism/tape_store.py b/packages/bub-extism/src/bub_extism/tape_store.py index bc6d22d..5f18b70 100644 --- a/packages/bub-extism/src/bub_extism/tape_store.py +++ b/packages/bub-extism/src/bub_extism/tape_store.py @@ -6,8 +6,9 @@ from republic import TapeEntry from bub_extism.bridge import ExtismBridge -from bub_extism.codec import tape_entry_from_dict, tape_entry_to_dict, to_json_value +from bub_extism.codec import tape_entry_from_dict, tape_entry_to_dict from bub_extism.config import ExtismPluginConfig +from bub_extism.descriptors import normalize_function_bindings, require_mapping if TYPE_CHECKING: from republic import TapeQuery @@ -18,12 +19,25 @@ def __init__( self, bridge: ExtismBridge, config: ExtismPluginConfig, - descriptor: dict[str, Any], + *, + functions: dict[str, str], ) -> None: self.bridge = bridge self.config = config - self.descriptor = descriptor - self.functions = dict(descriptor.get("functions") or {}) + self.functions = functions + + @classmethod + def from_descriptor( + cls, + bridge: ExtismBridge, + config: ExtismPluginConfig, + descriptor: Any, + ) -> ExtismTapeStore: + data = require_mapping( + descriptor, + message="Extism provide_tape_store must return a descriptor object", + ) + return cls(bridge, config, functions=_functions_from_descriptor(data)) def list_tapes(self) -> list[str]: value = self._call("list_tapes", {}) @@ -42,7 +56,7 @@ def fetch_all(self, query: TapeQuery) -> Iterable[TapeEntry]: return [] if not isinstance(value, list): raise RuntimeError("Extism tape fetch_all must return a list") - return [tape_entry_from_dict(item) for item in value if isinstance(item, dict)] + return [tape_entry_from_dict(require_mapping(item, message="Extism tape entry must be an object")) for item in value] def append(self, tape: str, entry: TapeEntry) -> None: self._call("append", {"tape": tape, "entry": tape_entry_to_dict(entry)}) @@ -59,16 +73,22 @@ def _call(self, operation: str, args: dict[str, Any]) -> Any: ) -def tape_store_from_descriptor( +def tape_store_from_value( bridge: ExtismBridge, config: ExtismPluginConfig, - descriptor: Any, + value: Any, ) -> ExtismTapeStore | None: - if descriptor is None: + if value is None: return None - if not isinstance(descriptor, dict): - raise RuntimeError("Extism provide_tape_store must return a descriptor object") - return ExtismTapeStore(bridge, config, descriptor) + return ExtismTapeStore.from_descriptor(bridge, config, value) + + +def _functions_from_descriptor(descriptor: dict[str, Any]) -> dict[str, str]: + return normalize_function_bindings( + descriptor.get("functions"), + message="Extism tape store descriptor must include a functions object", + missing_ok=False, + ) def _query_to_dict(query: TapeQuery) -> dict[str, Any]: @@ -77,8 +97,8 @@ def _query_to_dict(query: TapeQuery) -> dict[str, Any]: "query": query._query, "after_anchor": query._after_anchor, "after_last": query._after_last, - "between_anchors": to_json_value(query._between_anchors), - "between_dates": to_json_value(query._between_dates), + "between_anchors": list(query._between_anchors) if query._between_anchors is not None else None, + "between_dates": list(query._between_dates) if query._between_dates is not None else None, "kinds": list(query._kinds), "limit": query._limit, } diff --git a/packages/bub-extism/tests/test_bridge.py b/packages/bub-extism/tests/test_bridge.py index 6eca477..079cd5d 100644 --- a/packages/bub-extism/tests/test_bridge.py +++ b/packages/bub-extism/tests/test_bridge.py @@ -7,13 +7,12 @@ from types import SimpleNamespace from typing import Any -import pytest import pluggy +import pytest from republic import TapeEntry, TapeQuery from bub.hook_runtime import HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs -from bub_extism.bridge import ExtismBridge from bub_extism.config import ExtismSettings from bub_extism.plugin import ExtismPlugin @@ -60,7 +59,7 @@ def call(self, function_name: str, data: str) -> Any: @pytest.fixture(autouse=True) -def fake_extism(monkeypatch): +def fake_extism(monkeypatch: pytest.MonkeyPatch) -> None: FakePlugin.calls = [] FakePlugin.exports = {} monkeypatch.setitem(sys.modules, "extism", SimpleNamespace(Plugin=FakePlugin)) @@ -72,98 +71,61 @@ def _write_config(tmp_path: Path, body: dict[str, Any]) -> Path: return config_path -def _bridge(config_path: Path) -> ExtismBridge: - return ExtismBridge(ExtismSettings(config_path=config_path)) - - -def _plugin(config_path: Path) -> ExtismPlugin: - plugin = ExtismPlugin(SimpleNamespace()) - plugin.bridge = _bridge(config_path) - return plugin - - -def _runtime(config_path: Path, monkeypatch: pytest.MonkeyPatch) -> HookRuntime: - monkeypatch.setenv("BUB_EXTISM_CONFIG_PATH", str(config_path)) +def _runtime(config_path: Path) -> HookRuntime: + settings = ExtismSettings(config_path=config_path) plugin_manager = pluggy.PluginManager(BUB_HOOK_NAMESPACE) plugin_manager.add_hookspecs(BubHookSpecs) framework = SimpleNamespace(_plugin_manager=plugin_manager) - plugin = ExtismPlugin(framework) + plugin = ExtismPlugin(framework, settings=settings) plugin_manager.register(plugin, name="extism") return HookRuntime(plugin_manager) -def test_plugin_exposes_all_non_model_standard_bub_hooks() -> None: - expected_hooks = { - "resolve_session", - "build_prompt", - "load_state", - "save_state", - "render_outbound", - "dispatch_outbound", - "register_cli_commands", - "onboard_config", - "on_error", - "system_prompt", - "provide_tape_store", - "provide_channels", - "build_tape_context", - } - - assert expected_hooks <= set(dir(ExtismPlugin)) +def _flatten_channel_results(results: list[list[Any]]) -> list[Any]: + channels: list[Any] = [] + for batch in results: + channels.extend(batch) + return channels -def test_model_hook_adapter_registers_only_one_model_surface( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch -) -> None: +def test_runtime_registers_configured_hook_adapters(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "model", "plugins": { + "prompt": { + "manifest": {"wasm": [{"path": "./prompt.wasm"}]}, + "hooks": {"build_prompt": "build_prompt"}, + }, "model": { - "wasmUrl": "https://example.com/model.wasm", - "hooks": { - "run_model": "run_model", - "run_model_stream": "run_model_stream", - }, - } - }, + "manifest": {"wasm": [{"path": "./model.wasm"}]}, + "hooks": {"run_model": "run_model"}, + }, + } }, ) - runtime = _runtime(config_path, monkeypatch) - - report = runtime.hook_report() - assert report["run_model_stream"] == ["extism-run-model-stream"] - assert "run_model" not in report - - -def test_call_hook_returns_none_without_selected_plugin(tmp_path: Path) -> None: - config_path = _write_config(tmp_path, {"plugins": {}}) - - result = _bridge(config_path).call_hook_sync("run_model", {"prompt": "hello"}) - - assert result is None - assert FakePlugin.calls == [] + report = _runtime(config_path).hook_report() + assert report["build_prompt"] == ["extism:prompt"] + assert report["run_model"] == ["extism:model"] def test_run_model_calls_configured_export_with_unified_request( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - wasm_path = tmp_path / "plugin.wasm" - wasm_path.write_bytes(b"\0asm") config_path = _write_config( tmp_path, { - "defaultPlugin": "echo", "plugins": { "echo": { - "wasmPath": str(wasm_path), + "manifest": { + "wasm": [{"path": "./plugin.wasm", "hash": "demo"}], + "config": {"model": "demo"}, + }, "wasi": True, - "config": {"model": "demo"}, "hooks": {"run_model": "bub_run_model"}, } - }, + } }, ) FakePlugin.exports = { @@ -179,7 +141,7 @@ def test_run_model_calls_configured_export_with_unified_request( } result = asyncio.run( - _runtime(config_path, monkeypatch).run_model( + _runtime(config_path).run_model( prompt="hello", session_id="s1", state={ @@ -203,9 +165,12 @@ def test_run_model_calls_configured_export_with_unified_request( "state": {"visible": {"ok": True}}, }, }, - "plugin_input": b"\0asm", + "plugin_input": { + "wasm": [{"path": "./plugin.wasm", "hash": "demo"}], + "config": {"model": "demo"}, + }, "wasi": True, - "config": {"model": "demo"}, + "config": None, } ] @@ -214,18 +179,21 @@ def test_system_prompt_accepts_plain_text_result(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "prompt", "plugins": { "prompt": { - "wasmUrl": "https://example.com/prompt.wasm", + "manifest": {"wasm": [{"url": "https://example.com/prompt.wasm"}]}, "hooks": {"system_prompt": "system_prompt"}, } - }, + } }, ) FakePlugin.exports = {"system_prompt": b"from wasm"} - result = _plugin(config_path).system_prompt("hello", {"session_id": "s1"}) + result = _runtime(config_path).call_first_sync( + "system_prompt", + prompt="hello", + state={"session_id": "s1"}, + ) assert result == "from wasm" assert FakePlugin.calls[0]["plugin_input"] == { @@ -233,22 +201,21 @@ def test_system_prompt_accepts_plain_text_result(tmp_path: Path) -> None: } -def test_missing_export_skips_hook(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: +def test_missing_export_skips_hook(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "missing", "plugins": { "missing": { "manifest": {"wasm": [{"url": "https://example.com/plugin.wasm"}]}, "hooks": {"run_model": "missing_run_model"}, } - }, + } }, ) result = asyncio.run( - _runtime(config_path, monkeypatch).run_model( + _runtime(config_path).run_model( prompt="hello", session_id="s1", state={}, @@ -259,19 +226,16 @@ def test_missing_export_skips_hook(tmp_path: Path, monkeypatch: pytest.MonkeyPat assert FakePlugin.calls == [] -def test_run_model_stream_wraps_returned_events( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch -) -> None: +def test_run_model_stream_wraps_returned_events(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "stream", "plugins": { "stream": { - "wasmUrl": "https://example.com/stream.wasm", + "manifest": {"wasm": [{"url": "https://example.com/stream.wasm"}]}, "hooks": {"run_model_stream": "run_model_stream"}, } - }, + } }, ) FakePlugin.exports = { @@ -289,7 +253,7 @@ def test_run_model_stream_wraps_returned_events( } stream = asyncio.run( - _runtime(config_path, monkeypatch).run_model_stream( + _runtime(config_path).run_model_stream( prompt="hello", session_id="s1", state={}, @@ -304,6 +268,99 @@ def test_run_model_stream_wraps_returned_events( assert stream.usage == {"output_tokens": 1} +def test_run_model_stream_rejects_invalid_usage_shape(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "stream": { + "manifest": {"wasm": [{"url": "https://example.com/stream.wasm"}]}, + "hooks": {"run_model_stream": "run_model_stream"}, + } + } + }, + ) + FakePlugin.exports = { + "run_model_stream": json.dumps( + { + "value": { + "events": [], + "usage": "invalid", + } + } + ) + } + + with pytest.raises(RuntimeError, match="usage must be a JSON object"): + asyncio.run( + _runtime(config_path).run_model_stream( + prompt="hello", + session_id="s1", + state={}, + ) + ) + + +def test_build_prompt_and_run_model_can_be_split_across_plugins(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "prompt": { + "manifest": {"wasm": [{"path": "./prompt.wasm"}]}, + "hooks": {"build_prompt": "build_prompt"}, + }, + "model": { + "manifest": {"wasm": [{"path": "./model.wasm"}]}, + "hooks": {"run_model": "run_model"}, + }, + } + }, + ) + FakePlugin.exports = { + "build_prompt": lambda request: json.dumps( + { + "value": ( + f"[prompt:{request['args']['session_id']}] " + f"{request['args']['message']['content']}" + ) + } + ), + "run_model": lambda request: json.dumps( + { + "value": ( + f"[model:{request['args']['session_id']}] " + f"{request['args']['prompt']}" + ) + } + ), + } + + runtime = _runtime(config_path) + prompt = asyncio.run( + runtime.call_first( + "build_prompt", + message={"content": "hello from bub"}, + session_id="example", + state={}, + ) + ) + output = asyncio.run( + runtime.run_model( + prompt=prompt, + session_id="example", + state={}, + ) + ) + + assert prompt == "[prompt:example] hello from bub" + assert output == "[model:example] [prompt:example] hello from bub" + assert [call["function_name"] for call in FakePlugin.calls] == [ + "build_prompt", + "run_model", + ] + + async def _collect_stream(stream): return [event async for event in stream] @@ -312,13 +369,12 @@ def test_tape_store_proxy_forwards_operations(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "tape", "plugins": { "tape": { - "wasmUrl": "https://example.com/tape.wasm", + "manifest": {"wasm": [{"url": "https://example.com/tape.wasm"}]}, "hooks": {"provide_tape_store": "provide_tape_store"}, } - }, + } }, ) FakePlugin.exports = { @@ -352,7 +408,7 @@ def test_tape_store_proxy_forwards_operations(tmp_path: Path) -> None: "reset": json.dumps({"skip": True}), } - store = _plugin(config_path).provide_tape_store() + store = _runtime(config_path).call_first_sync("provide_tape_store") assert store is not None assert store.list_tapes() == ["main"] @@ -370,17 +426,94 @@ def test_tape_store_proxy_forwards_operations(tmp_path: Path) -> None: store.reset("main") +def test_tape_store_rejects_invalid_entry_shape(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "tape": { + "manifest": {"wasm": [{"url": "https://example.com/tape.wasm"}]}, + "hooks": {"provide_tape_store": "provide_tape_store"}, + } + } + }, + ) + FakePlugin.exports = { + "provide_tape_store": json.dumps( + { + "value": { + "functions": { + "fetch_all": "fetch_all", + } + } + } + ), + "fetch_all": json.dumps({"value": ["bad-entry"]}), + } + + store = _runtime(config_path).call_first_sync("provide_tape_store") + + assert store is not None + with pytest.raises(RuntimeError, match="tape entry must be an object"): + list(store.fetch_all(TapeQuery("main", store))) + + +def test_tape_store_rejects_invalid_functions_shape(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "tape": { + "manifest": {"wasm": [{"url": "https://example.com/tape.wasm"}]}, + "hooks": {"provide_tape_store": "provide_tape_store"}, + } + } + }, + ) + FakePlugin.exports = { + "provide_tape_store": json.dumps( + { + "value": { + "functions": ["fetch_all"], + } + } + ) + } + + with pytest.raises(RuntimeError, match="functions object"): + _runtime(config_path).call_first_sync("provide_tape_store") + + +def test_tape_store_requires_functions_object(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "tape": { + "manifest": {"wasm": [{"url": "https://example.com/tape.wasm"}]}, + "hooks": {"provide_tape_store": "provide_tape_store"}, + } + } + }, + ) + FakePlugin.exports = { + "provide_tape_store": json.dumps({"value": {}}), + } + + with pytest.raises(RuntimeError, match="functions object"): + _runtime(config_path).call_first_sync("provide_tape_store") + + def test_channel_proxy_forwards_send(tmp_path: Path) -> None: config_path = _write_config( tmp_path, { - "defaultPlugin": "channel", "plugins": { "channel": { - "wasmUrl": "https://example.com/channel.wasm", + "manifest": {"wasm": [{"url": "https://example.com/channel.wasm"}]}, "hooks": {"provide_channels": "provide_channels"}, } - }, + } }, ) FakePlugin.exports = { @@ -402,9 +535,36 @@ def test_channel_proxy_forwards_send(tmp_path: Path) -> None: async def handler(message: dict[str, Any]) -> None: del message - channels = _plugin(config_path).provide_channels(handler) + channel_batches = _runtime(config_path).call_many_sync( + "provide_channels", + message_handler=handler, + ) + channels = _flatten_channel_results(channel_batches) assert [channel.name for channel in channels] == ["wasm"] asyncio.run(channels[0].send({"content": "hello"})) assert FakePlugin.calls[-1]["function_name"] == "channel_send" assert FakePlugin.calls[-1]["payload"]["hook"] == "channel.send" + + +def test_channel_proxy_rejects_invalid_wrapper_shape(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "plugins": { + "channel": { + "manifest": {"wasm": [{"url": "https://example.com/channel.wasm"}]}, + "hooks": {"provide_channels": "provide_channels"}, + } + } + }, + ) + FakePlugin.exports = { + "provide_channels": json.dumps({"value": {"channels": "bad"}}), + } + + async def handler(message: dict[str, Any]) -> None: + del message + + with pytest.raises(RuntimeError, match="list of channel descriptors"): + _runtime(config_path).call_many_sync("provide_channels", message_handler=handler) diff --git a/packages/bub-extism/tests/test_cli.py b/packages/bub-extism/tests/test_cli.py new file mode 100644 index 0000000..8515dd2 --- /dev/null +++ b/packages/bub-extism/tests/test_cli.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest +import typer +from typer.testing import CliRunner + +from bub_extism.bridge import ExtismBridge +from bub_extism.cli import register_cli_commands +from bub_extism.config import ExtismSettings + +runner = CliRunner() + + +class FakePlugin: + exports: dict[str, Any] = {} + + def __init__( + self, + plugin_input: dict[str, Any] | bytes, + *, + wasi: bool = False, + config: dict[str, str] | None = None, + ) -> None: + del plugin_input, wasi, config + + def __enter__(self) -> FakePlugin: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def function_exists(self, name: str) -> bool: + return name in self.exports + + def call(self, function_name: str, data: str) -> Any: + payload = json.loads(data) + result = self.exports[function_name] + if callable(result): + return result(payload) + return result + + +@pytest.fixture(autouse=True) +def fake_extism(monkeypatch: pytest.MonkeyPatch) -> None: + FakePlugin.exports = {} + monkeypatch.setitem(sys.modules, "extism", SimpleNamespace(Plugin=FakePlugin)) + + +def _make_app(tmp_path: Path) -> tuple[typer.Typer, ExtismSettings]: + settings = ExtismSettings(config_path=tmp_path / "extism.json") + app = typer.Typer() + register_cli_commands(app, settings, ExtismBridge()) + return app, settings + + +def _write_manifest(path: Path, *, wasm_path: str) -> None: + path.write_text( + json.dumps({"wasm": [{"path": wasm_path}], "allowed_hosts": ["example.com"]}), + encoding="utf-8", + ) + + +def test_add_list_show_and_remove_plugin(tmp_path: Path) -> None: + app, settings = _make_app(tmp_path) + manifest_path = tmp_path / "plugin.manifest.json" + _write_manifest(manifest_path, wasm_path="./demo.wasm") + + add_result = runner.invoke( + app, + [ + "extism", + "add", + "demo", + str(manifest_path), + "--hook", + "build_prompt=build_prompt", + "--hook", + "run_model=run_model", + "--wasi", + ], + ) + + assert add_result.exit_code == 0 + assert "Added Extism plugin 'demo'." in add_result.stdout + assert settings.read_config().model_dump(mode="json") == { + "plugins": { + "demo": { + "manifest": { + "wasm": [{"path": "./demo.wasm"}], + "allowed_hosts": ["example.com"], + }, + "hooks": { + "build_prompt": "build_prompt", + "run_model": "run_model", + }, + "wasi": True, + } + } + } + + list_result = runner.invoke(app, ["extism", "list"]) + assert list_result.exit_code == 0 + assert "demo" in list_result.stdout + assert "build_prompt->build_prompt" in list_result.stdout + assert "run_model->run_model" in list_result.stdout + + show_result = runner.invoke(app, ["extism", "show", "demo"]) + assert show_result.exit_code == 0 + assert '"path": "./demo.wasm"' in show_result.stdout + + remove_result = runner.invoke(app, ["extism", "remove", "demo"]) + assert remove_result.exit_code == 0 + assert "Removed Extism plugin 'demo'." in remove_result.stdout + assert settings.read_config().model_dump(mode="json") == {"plugins": {}} + + +def test_plugin_defined_commands_share_extism_group(tmp_path: Path) -> None: + app, _settings = _make_app(tmp_path) + manifest_path = tmp_path / "plugin.manifest.json" + _write_manifest(manifest_path, wasm_path="./cli.wasm") + tmp_path.joinpath("extism.json").write_text( + json.dumps( + { + "plugins": { + "cli": { + "manifest": {"wasm": [{"path": "./cli.wasm"}]}, + "hooks": {"register_cli_commands": "register_cli_commands"}, + } + } + } + ), + encoding="utf-8", + ) + FakePlugin.exports = { + "register_cli_commands": json.dumps( + { + "value": [ + { + "name": "hello", + "function": "cli_hello", + } + ] + } + ), + "cli_hello": lambda request: json.dumps( + { + "value": { + "plugin": request["args"]["plugin"], + "command": request["args"]["command"], + "payload": request["args"]["payload"], + } + } + ), + } + + app, _settings = _make_app(tmp_path) + result = runner.invoke(app, ["extism", "hello", '{"name":"Bub"}']) + + assert result.exit_code == 0 + assert '"plugin": "cli"' in result.stdout + assert '"command": "hello"' in result.stdout + assert '"name": "Bub"' in result.stdout + + +def test_plugin_defined_commands_reject_invalid_wrapper_shape(tmp_path: Path) -> None: + manifest_path = tmp_path / "plugin.manifest.json" + _write_manifest(manifest_path, wasm_path="./cli.wasm") + tmp_path.joinpath("extism.json").write_text( + json.dumps( + { + "plugins": { + "cli": { + "manifest": {"wasm": [{"path": "./cli.wasm"}]}, + "hooks": {"register_cli_commands": "register_cli_commands"}, + } + } + } + ), + encoding="utf-8", + ) + FakePlugin.exports = { + "register_cli_commands": json.dumps({"value": {"commands": "bad"}}), + } + + with pytest.raises(RuntimeError, match="must return a list"): + _make_app(tmp_path) + + +def test_plugin_defined_commands_reject_invalid_descriptor_shape(tmp_path: Path) -> None: + manifest_path = tmp_path / "plugin.manifest.json" + _write_manifest(manifest_path, wasm_path="./cli.wasm") + tmp_path.joinpath("extism.json").write_text( + json.dumps( + { + "plugins": { + "cli": { + "manifest": {"wasm": [{"path": "./cli.wasm"}]}, + "hooks": {"register_cli_commands": "register_cli_commands"}, + } + } + } + ), + encoding="utf-8", + ) + FakePlugin.exports = { + "register_cli_commands": json.dumps({"value": [{"name": "hello"}]}), + } + + with pytest.raises(RuntimeError, match="requires name and function"): + _make_app(tmp_path) diff --git a/packages/bub-extism/tests/test_examples.py b/packages/bub-extism/tests/test_examples.py index 3f7f6eb..ea0ce87 100644 --- a/packages/bub-extism/tests/test_examples.py +++ b/packages/bub-extism/tests/test_examples.py @@ -6,17 +6,32 @@ import shutil import subprocess from pathlib import Path +from types import SimpleNamespace from typing import Any +import pluggy import pytest -from bub_extism.bridge import ExtismBridge +from bub.hook_runtime import HookRuntime +from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs from bub_extism.config import ExtismSettings from bub_extism.plugin import ExtismPlugin PACKAGE_ROOT = Path(__file__).resolve().parents[1] -RUST_EXAMPLE = PACKAGE_ROOT / "examples" / "rust-model-stream" -GO_EXAMPLE = PACKAGE_ROOT / "examples" / "go-channel" +RUST_EXAMPLE = PACKAGE_ROOT / "examples" / "rust-run-model" +GO_EXAMPLE = PACKAGE_ROOT / "examples" / "go-build-prompt" + + +def _has_rust_wasm_target() -> bool: + if shutil.which("cargo") is None or shutil.which("rustup") is None: + return False + result = subprocess.run( + ["rustup", "target", "list", "--installed"], + check=False, + capture_output=True, + text=True, + ) + return "wasm32-unknown-unknown" in result.stdout.split() def _write_config(tmp_path: Path, body: dict[str, Any]) -> Path: @@ -25,64 +40,34 @@ def _write_config(tmp_path: Path, body: dict[str, Any]) -> Path: return config_path -def _plugin(config_path: Path) -> ExtismPlugin: - plugin = ExtismPlugin(type("Framework", (), {})()) - plugin.bridge = ExtismBridge(ExtismSettings(config_path=config_path)) - return plugin +def _runtime(config_path: Path) -> HookRuntime: + settings = ExtismSettings(config_path=config_path) + plugin_manager = pluggy.PluginManager(BUB_HOOK_NAMESPACE) + plugin_manager.add_hookspecs(BubHookSpecs) + framework = SimpleNamespace(_plugin_manager=plugin_manager) + plugin = ExtismPlugin(framework, settings=settings) + plugin_manager.register(plugin, name="extism") + return HookRuntime(plugin_manager) -@pytest.mark.skipif(shutil.which("cargo") is None, reason="cargo is not installed") -def test_rust_model_stream_example_builds_and_runs(tmp_path: Path) -> None: +def _build_rust_example() -> Path: subprocess.run( ["cargo", "build", "--release", "--target", "wasm32-unknown-unknown"], cwd=RUST_EXAMPLE, check=True, ) - wasm_path = ( + return ( RUST_EXAMPLE / "target" / "wasm32-unknown-unknown" / "release" - / "bub_extism_rust_model_stream.wasm" - ) - config_path = _write_config( - tmp_path, - { - "defaultPlugin": "rust", - "plugins": { - "rust": { - "wasmPath": str(wasm_path), - "hooks": {"run_model_stream": "run_model_stream"}, - } - }, - }, + / "bub_extism_rust_run_model.wasm" ) - async def run_example() -> list[tuple[str, dict[str, Any]]]: - stream = await _plugin(config_path).bridge.call_hook( - "run_model_stream", - { - "prompt": "hello from bub", - "session_id": "example", - "state": {}, - }, - ) - from bub_extism.stream import stream_events_from_value - - events = stream_events_from_value(stream) - assert events is not None - return [(event.kind, event.data) async for event in events] - - assert asyncio.run(run_example()) == [ - ("text", {"delta": "[rust-model-stream:example] hello from bub"}), - ("final", {"text": "[rust-model-stream:example] hello from bub"}), - ] - -@pytest.mark.skipif(shutil.which("go") is None, reason="go is not installed") -def test_go_channel_example_builds_and_runs(tmp_path: Path) -> None: +def _build_go_example(tmp_path: Path) -> Path: subprocess.run(["go", "mod", "tidy"], cwd=GO_EXAMPLE, check=True) - wasm_path = tmp_path / "go-channel.wasm" + wasm_path = tmp_path / "go-build-prompt.wasm" subprocess.run( [ "go", @@ -96,23 +81,103 @@ def test_go_channel_example_builds_and_runs(tmp_path: Path) -> None: check=True, env={**dict(os.environ), "GOOS": "wasip1", "GOARCH": "wasm"}, ) + return wasm_path + + +@pytest.mark.skipif(not _has_rust_wasm_target(), reason="cargo or wasm32-unknown-unknown target is not installed") +def test_rust_run_model_example_builds_and_runs(tmp_path: Path) -> None: + wasm_path = _build_rust_example() + config_path = _write_config( + tmp_path, + { + "plugins": { + "rust": { + "manifest": {"wasm": [{"path": str(wasm_path)}]}, + "hooks": {"run_model": "run_model"}, + } + } + }, + ) + + result = asyncio.run( + _runtime(config_path).run_model( + prompt="hello from bub", + session_id="example", + state={}, + ) + ) + + assert result == "[rust-run-model:example] hello from bub" + + +@pytest.mark.skipif(shutil.which("go") is None, reason="go is not installed") +def test_go_build_prompt_example_builds_and_runs(tmp_path: Path) -> None: + wasm_path = _build_go_example(tmp_path) config_path = _write_config( tmp_path, { - "defaultPlugin": "go", "plugins": { "go": { - "wasmPath": str(wasm_path), + "manifest": {"wasm": [{"path": str(wasm_path)}]}, "wasi": True, - "hooks": {"provide_channels": "provide_channels"}, + "hooks": {"build_prompt": "build_prompt"}, } - }, + } + }, + ) + + prompt = asyncio.run( + _runtime(config_path).call_first( + "build_prompt", + message={"content": "hello from bub"}, + session_id="example", + state={}, + ) + ) + + assert prompt == "[go-build-prompt:example] hello from bub" + + +@pytest.mark.skipif( + not _has_rust_wasm_target() or shutil.which("go") is None, + reason="cargo target or go is not installed", +) +def test_go_and_rust_examples_can_be_combined(tmp_path: Path) -> None: + rust_wasm_path = _build_rust_example() + go_wasm_path = _build_go_example(tmp_path) + config_path = _write_config( + tmp_path, + { + "plugins": { + "prompt": { + "manifest": {"wasm": [{"path": str(go_wasm_path)}]}, + "wasi": True, + "hooks": {"build_prompt": "build_prompt"}, + }, + "model": { + "manifest": {"wasm": [{"path": str(rust_wasm_path)}]}, + "hooks": {"run_model": "run_model"}, + }, + } }, ) - async def handler(message: dict[str, Any]) -> None: - del message + runtime = _runtime(config_path) + prompt = asyncio.run( + runtime.call_first( + "build_prompt", + message={"content": "hello from bub"}, + session_id="example", + state={}, + ) + ) + result = asyncio.run( + runtime.run_model( + prompt=prompt, + session_id="example", + state={}, + ) + ) - channels = _plugin(config_path).provide_channels(handler) - assert [channel.name for channel in channels] == ["go-echo"] - asyncio.run(channels[0].send({"content": "hello from bub"})) + assert prompt == "[go-build-prompt:example] hello from bub" + assert result == "[rust-run-model:example] [go-build-prompt:example] hello from bub"