diff --git a/.claude/.gitignore b/.claude/.gitignore index 0726a34d..588d6863 100644 --- a/.claude/.gitignore +++ b/.claude/.gitignore @@ -1,2 +1,3 @@ *.lock +/plans/ /worktrees/ diff --git a/.claude/rules/paddler-agent.md b/.claude/rules/paddler-agent.md new file mode 100644 index 00000000..e87cf016 --- /dev/null +++ b/.claude/rules/paddler-agent.md @@ -0,0 +1,11 @@ +--- +paths: + - "paddler_agent/**" +--- + +# Paddler Agent Context + +- `paddler_agent` is the only crate that can rely on `llama-cpp-bindings` +- agent is the only crate responsible for instantiating llama.cpp back-end, and communicating with it +- no crate can depend directly on `paddler_agent` (besides `paddler_bootstrap`, and other test related crates), they need to use `paddler_messaging` instead + diff --git a/.claude/rules/paddler-balancer.md b/.claude/rules/paddler-balancer.md new file mode 100644 index 00000000..1e73b023 --- /dev/null +++ b/.claude/rules/paddler-balancer.md @@ -0,0 +1,10 @@ +--- +paths: + - "paddler_balancer/**" +--- + +# Paddler Balancer Context + +- `paddler_balancer` crate is responsible for starting inference, and management servers +- Paddler Agents connect to the balancer in order to handle the requests that the balancer dispatches +- `paddler_balancer` provides compatibility services that expose vendor-compatible APIs (for example OpenAI compatibility server) diff --git a/.claude/rules/paddler-bootstrap.md b/.claude/rules/paddler-bootstrap.md new file mode 100644 index 00000000..9c2088ad --- /dev/null +++ b/.claude/rules/paddler-bootstrap.md @@ -0,0 +1,11 @@ +--- +paths: + - "paddler_bootstrap/**" +--- + +# Paddler Bootstrap Context + +- `paddler_bootstrap` is used both by `paddler_cli`, `paddler_tests`, and `paddler_gui` +- `paddler_bootstrap` combines both `paddler_agent`, and `paddler_balancer`, and provides a unified entry point +- `paddler_bootstrap` is the canonical way to start both core paddler services (balancer, and agent) +- `paddler_bootstrap` is the source of truth on how to start Paddler services diff --git a/.claude/rules/paddler-cache-dir.md b/.claude/rules/paddler-cache-dir.md new file mode 100644 index 00000000..780ff411 --- /dev/null +++ b/.claude/rules/paddler-cache-dir.md @@ -0,0 +1,11 @@ +--- +paths: + - "paddler_cache_dir/**" +--- + +# Paddler Cache Directory Context + +- `paddler_cache_dir` is a root Paddler crate, it must not depend on any other Paddler crate +- `paddler_cache_dir` manages Paddler's global cache directory, and all its nuances +- `paddler_cache_dir` resolves OS-related differences internally +- `paddler_cache_dir` must use cache directory patterns idiomatic to the specific operating system diff --git a/.claude/rules/paddler-cli.md b/.claude/rules/paddler-cli.md new file mode 100644 index 00000000..f44aae2c --- /dev/null +++ b/.claude/rules/paddler-cli.md @@ -0,0 +1,10 @@ +--- +paths: + - "paddler_cli/**" +--- + +# Paddler CLI Context + +- `paddler_cli` is intended to combine both `paddler_agent`, and `paddler_balancer` through `paddler_bootstrap` +- `paddler_cli` is intended to provide a command line Paddler binary, to be used in server infrastructure +- `paddler_cli` is intended to encapsulate both balancer, and agent in a single binary, for the sake of ease of deployment diff --git a/.claude/rules/paddler-client-javascript.md b/.claude/rules/paddler-client-javascript.md new file mode 100644 index 00000000..28b7316a --- /dev/null +++ b/.claude/rules/paddler-client-javascript.md @@ -0,0 +1,12 @@ +--- +paths: + - "paddler_client_javascript/**" +--- + +# Paddler Client Context + +- `paddler_client_javascript` provides JavaScript client that connects to `paddler_balancer` +- It must provide a way to connect to only, specifically balancer's inference address (without the need to connect to management service at the same time) +- It must provide a way to connect to only, specifically balancer's management address (without the need to connect to inference service at the same time) +- Paddler client must support all Paddler's native endpoints +- It must not implement OpenAI compatibility client diff --git a/.claude/rules/paddler-client-python.md b/.claude/rules/paddler-client-python.md new file mode 100644 index 00000000..076b143d --- /dev/null +++ b/.claude/rules/paddler-client-python.md @@ -0,0 +1,12 @@ +--- +paths: + - "paddler_client_python/**" +--- + +# Paddler Client Context + +- `paddler_client_python` provides JavaScript client that connects to `paddler_balancer` +- It must provide a way to connect to only, specifically balancer's inference address (without the need to connect to management service at the same time) +- It must provide a way to connect to only, specifically balancer's management address (without the need to connect to inference service at the same time) +- Paddler client must support all Paddler's native endpoints +- It must not implement OpenAI compatibility client diff --git a/.claude/rules/paddler-client.md b/.claude/rules/paddler-client.md new file mode 100644 index 00000000..1f93341b --- /dev/null +++ b/.claude/rules/paddler-client.md @@ -0,0 +1,12 @@ +--- +paths: + - "paddler_client/**" +--- + +# Paddler Client Context + +- `paddler_client` provides Rust client that connects to `paddler_balancer` +- It must provide a way to connect to only, specifically balancer's inference address (without the need to connect to management service at the same time) +- It must provide a way to connect to only, specifically balancer's management address (without the need to connect to inference service at the same time) +- Paddler client must support all Paddler's native endpoints +- It must not implement OpenAI compatibility client diff --git a/.claude/rules/paddler-download-manager.md b/.claude/rules/paddler-download-manager.md new file mode 100644 index 00000000..7eb1bd2c --- /dev/null +++ b/.claude/rules/paddler-download-manager.md @@ -0,0 +1,12 @@ +--- +paths: + - "paddler_download_manager/**" +--- + +# Paddler Download Manager Context + +- `paddler_download_manager` is a root Paddler crate, it must not depend on any other Paddler crate +- `paddler_download_manager` is responsible for downloading GGUF models from HTTP URLs +- `paddler_download_manager` must be resilient, it must support resumes, handle cache corruptions +- `paddler_download_manager` must not do retries, because it is intended to be used by `paddler_agent`, and `paddler_agent` already has a built-in retry mechanism +- `paddler_download_manager` must focus only on Paddler internal use-cases related to downloading models diff --git a/.claude/rules/paddler-messaging.md b/.claude/rules/paddler-messaging.md new file mode 100644 index 00000000..29195130 --- /dev/null +++ b/.claude/rules/paddler-messaging.md @@ -0,0 +1,10 @@ +--- +paths: + - "paddler_messaging/**" +--- + +# Paddler Messaging Context + +- `paddler_messaging` must contain only messaing protocol between `paddler_agent`, and `paddler_balancer` +- `paddler_messaging` must only be a thin messaging, and validation layer between `paddler_agent`, and `paddler_balancer` +- `paddler_messaging` is intended to be used in `paddler_client`, and must not pull heavy dependencies like `llama-cpp-bindings` diff --git a/.claude/rules/paddler-openai-response-format-validator.md b/.claude/rules/paddler-openai-response-format-validator.md new file mode 100644 index 00000000..443fed51 --- /dev/null +++ b/.claude/rules/paddler-openai-response-format-validator.md @@ -0,0 +1,13 @@ +--- +paths: + - "paddler_openai_response_format_validator/**" +--- + +# Paddler OpenAI Response Format Validator Context + +- `paddler_openai_response_format_validator` is intended to ONLY be used in test crates, unit tests, and such +- `paddler_openai_response_format_validator` is only intended to be used to validate Paddler's OpenAI compatibility endpoints +- `paddler_openai_response_format_validator` must NOT be used on runtime; it must ONLY be used in tests, unit tests, integration tests +- `paddler_openai_response_format_validator` must directly use vendored, official OpenAI schema to build its validation setup +- `paddler_openai_response_format_validator` must make the official OpenAI schema stricture, to make sure Paddler does not introduce extra fields to the requests +- `paddler_openai_response_format_validator` must make the official OpenAI schema stricture, to make sure Paddler does not accept unsupported fields diff --git a/.claude/rules/paddler-state-conversion.md b/.claude/rules/paddler-state-conversion.md new file mode 100644 index 00000000..c60583c5 --- /dev/null +++ b/.claude/rules/paddler-state-conversion.md @@ -0,0 +1,9 @@ +--- +paths: + - "paddler_state_conversion/**" +--- + +# Paddler State Conversion Context + +- `paddler_state_conversion` provides traits and utilities responsible for converting `paddler_agent`'s desired state to applicable state +- `paddler_state_conversion` provides traits and utilities responsible for converting `paddler_agent`'s applicable state to desired state diff --git a/.claude/rules/paddler-test-cluster-harness.md b/.claude/rules/paddler-test-cluster-harness.md new file mode 100644 index 00000000..4158f1c0 --- /dev/null +++ b/.claude/rules/paddler-test-cluster-harness.md @@ -0,0 +1,12 @@ +--- +paths: + - "paddler_test_cluster_harness/**" +--- + +# Paddler Test Cluster Harness Context + +- `paddler_test_cluster_harness` provides common test harness to be used with `paddler_cli_tests`, and `paddler_tests` + +# OpenAI Compatibility Testing + +- To stay objective, we must not implement our own OpenAI client, instead we need to use a vetted 3rd party (preferably official OpenAI API client) diff --git a/.claude/rules/paddler-tests.md b/.claude/rules/paddler-tests.md new file mode 100644 index 00000000..3bca9060 --- /dev/null +++ b/.claude/rules/paddler-tests.md @@ -0,0 +1,9 @@ +--- +paths: + - "paddler_tests/**" +--- + +# Paddler Tests Context + +- `paddler_tests` contains Paddler integration tests +- `paddler_tests` contain utilities related to the integration tests diff --git a/.claude/rules/python-on-nixos.md b/.claude/rules/python-on-nixos.md deleted file mode 100644 index 04329068..00000000 --- a/.claude/rules/python-on-nixos.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -paths: - - "paddler_client_python/**/*" ---- - -# Running Python tooling on NixOS - -To run any Python tool that may have ELF / dynamic-linker issues on NixOS — `ruff`, `mypy`, `pyright`, `pytest`, anything installed from a pip wheel with native -bits — first enter `paddler_client_python/shell.nix`, then drive everything through `poetry` from inside that shell. - -**Why:** -pip wheels like `ruff` ship a generic-linux binary, which NixOS does not provide. -Running them directly fails with `Could not start dynamically linked executable: ... NixOS cannot run dynamically linked executables intended for generic linux environments`. -`shell.nix` provides the Nix-built loader / replacement tools that make those binaries (or their Nix equivalents) actually launch. -`poetry` is just the dispatcher you use *inside* that prepared shell — never the entry point on its own. - -**How to apply:** -- Never invoke `ruff`, `poetry run ...`, `python`, `pytest`, etc. from outside `nix-shell`. If a command starts with one of those, it must be inside `nix-shell --run "..."`. -- If `paddler_client_python/shell.nix` is missing, stop and ask. Adding a `shell.nix` is the fix; running tooling unwrapped is not. diff --git a/.claude/skills/running-all-tests/SKILL.md b/.claude/skills/running-all-tests/SKILL.md index 7af693e7..32bc6812 100644 --- a/.claude/skills/running-all-tests/SKILL.md +++ b/.claude/skills/running-all-tests/SKILL.md @@ -24,7 +24,7 @@ echo "Device: $DEVICE" `$DEVICE` selects the Rust integration suite variant in Step 2. The other four suites don't take a device feature. -## Step 2: run the five suites +## Step 2: run the suites Copy this checklist and tick each item as the suite completes: @@ -33,16 +33,14 @@ Copy this checklist and tick each item as the suite completes: - [ ] Python client - [ ] Rust unit - [ ] Rust integration -- [ ] paddler_gui ``` | # | Suite | Inner command | Working dir | |---|------------------|-----------------------------------------------------------------------------------------------------------------------------------|--------------------------| | 1 | JS client | `make test.client.js` | repo root | | 2 | Python client | NixOS: `poetry run pytest`, `ruff`, `poetry run mypy"`. Every other OS: `poetry run pytest`, `poetry run ruff`, `poetry run mypy` | `paddler_client_python/` | -| 3 | Rust unit | `make test.unit` | repo root | -| 4 | Rust integration | `make test.integration` (cpu) / `make test.integration.cuda` / `make test.integration.metal` — pick by `$DEVICE` | repo root | -| 5 | paddler_gui | `cargo test -p paddler_gui --features web_admin_panel` | `paddler_gui/` | +| 3 | Rust unit | `TEST_DEVICE=$DEVICE make test.unit` | repo root | +| 4 | Rust integration | `TEST_DEVICE=$DEVICE make test.integration` | repo root | Run them in this order. Cheap suites (1, 3, 4) surface bugs quickly; the heavy GPU-bound suites (2, 5) come last. diff --git a/.claude/skills/running-coverage/SKILL.md b/.claude/skills/running-coverage/SKILL.md new file mode 100644 index 00000000..b1cb8e50 --- /dev/null +++ b/.claude/skills/running-coverage/SKILL.md @@ -0,0 +1,40 @@ +--- +name: running-coverage +description: Runs every test suite in the paddler workspace on the fastest available device, and produces code coverage report. Use when the user asks to run the code coverage, or to check the coverage. +--- + +# Running the code coverage + +Run every test suite in the workspace, picking the fastest compiled device backend for the host, then report the workspace code coverage. + +## Step 1: detect the device + +Run this once at the start and echo the chosen device: + +```bash +if [[ "$OSTYPE" == "darwin"* ]]; then + DEVICE=metal +elif command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi >/dev/null 2>&1; then + DEVICE=cuda +else + DEVICE=cpu +fi +echo "Device: $DEVICE" +``` + +`$DEVICE` selects the Rust integration suite variant in Step 2. The other four suites don't take a device feature. + +## Step 2: run the code coverage + +Copy this checklist and tick each item as the suite completes: + +`TEST_DEVICE=$DEVICE make test.coverage` + +## Step 3: rules during the run + +- **Serialize GPU suites.** When `$DEVICE` is `cuda` or `metal`, run test suites sequentially to avoid device contention. +- **Per-test 30 s budget.** Flag any individual test that exceeds 30 s wall-clock. That is a real bug — production or test — not flakiness. + +## Step 4: report + +After the coverage suite finishes, sum up the results in an actionable report. diff --git a/.github/workflows/clippy.yml b/.github/workflows/clippy.yml index 23e8411a..be038dde 100644 --- a/.github/workflows/clippy.yml +++ b/.github/workflows/clippy.yml @@ -24,6 +24,8 @@ jobs: steps: - name: checkout code uses: actions/checkout@v4 + with: + submodules: true - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0776fb4b..dc0eabca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,6 +24,8 @@ jobs: steps: - name: checkout code uses: actions/checkout@v4 + with: + submodules: true - uses: Swatinem/rust-cache@v2 diff --git a/.gitmodules b/.gitmodules index e69de29b..794b4954 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "vendor/openai/openai-openapi"] + path = vendor/openai/openai-openapi + url = https://github.com/openai/openai-openapi.git diff --git a/Cargo.lock b/Cargo.lock index 724be68c..ed2c0a81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,31 +18,6 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "366ffbaa4442f4684d91e2cd7c5ea7c4ed8add41959a31447066e279e432b618" -[[package]] -name = "actix" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de7fa236829ba0841304542f7614c42b80fca007455315c45c785ccfa873a85b" -dependencies = [ - "actix-macros", - "actix-rt", - "actix_derive", - "bitflags 2.11.0", - "bytes 1.11.1", - "crossbeam-channel", - "futures-core", - "futures-sink", - "futures-task", - "futures-util", - "log", - "once_cell", - "parking_lot", - "pin-project-lite", - "smallvec", - "tokio", - "tokio-util", -] - [[package]] name = "actix-codec" version = "0.5.2" @@ -121,7 +96,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -239,7 +214,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -287,7 +262,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ac4d6e04e97fe707286509b4f338e99c5fb7249c770e1da074af5e27faa96b3" dependencies = [ "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -304,17 +279,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "actix_derive" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6ac1e58cded18cb28ddc17143c4dea5345b3ad575e14f32f66e4054a56eb271" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "adler2" version = "2.0.1" @@ -505,9 +469,15 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] +[[package]] +name = "arraydeque" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" + [[package]] name = "arrayref" version = "0.3.9" @@ -571,7 +541,7 @@ dependencies = [ "rustc-hash 2.1.2", "serde", "serde_derive", - "syn 2.0.117", + "syn", ] [[package]] @@ -653,6 +623,45 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-openai" +version = "0.40.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "278af84d3d19995d440ea7b401a4b121c780c8d37f5eb6e883d987c17b523aa4" +dependencies = [ + "async-openai-macros", + "base64", + "bytes 1.11.1", + "derive_builder", + "eventsource-stream", + "futures 0.3.32", + "getrandom 0.3.4", + "rand 0.9.4", + "reqwest 0.13.4", + "secrecy", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-util", + "tower", + "tracing", + "url", +] + +[[package]] +name = "async-openai-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "492a944774207eed3acf425214eadbd6ce84a2b89331164ff1c11bae92b26302" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-process" version = "2.5.0" @@ -679,7 +688,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -719,7 +728,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -736,16 +745,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", -] - -[[package]] -name = "atomic" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a89cbf775b137e9b968e67227ef7f775587cde3fd31b0d8599dbd0f598a48340" -dependencies = [ - "bytemuck", + "syn", ] [[package]] @@ -803,6 +803,28 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "aws-lc-rs" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ec2f1fc3ec205783a5da9a7e6c1509cc69dedf09a1949e412c1e18469326d00" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a2f9779ce85b93ab6170dd940ad0169b5766ff848247aff13bb788b832fe3f4" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "base64" version = "0.22.1" @@ -835,16 +857,7 @@ dependencies = [ "regex", "rustc-hash 2.1.2", "shlex", - "syn 2.0.117", -] - -[[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec 0.6.3", + "syn", ] [[package]] @@ -853,15 +866,9 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ - "bit-vec 0.8.0", + "bit-vec", ] -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - [[package]] name = "bit-vec" version = "0.8.0" @@ -1012,7 +1019,7 @@ checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1112,15 +1119,6 @@ dependencies = [ "wayland-client", ] -[[package]] -name = "castaway" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" -dependencies = [ - "rustversion", -] - [[package]] name = "cc" version = "1.2.58" @@ -1213,7 +1211,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1303,20 +1301,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "compact_str" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "rustversion", - "ryu", - "static_assertions", -] - [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1529,34 +1513,6 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" -[[package]] -name = "crossterm" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" -dependencies = [ - "bitflags 2.11.0", - "crossterm_winapi", - "derive_more", - "document-features", - "futures-core", - "mio", - "parking_lot", - "rustix 1.1.4", - "signal-hook", - "signal-hook-mio", - "winapi", -] - -[[package]] -name = "crossterm_winapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" -dependencies = [ - "winapi", -] - [[package]] name = "crunchy" version = "0.2.4" @@ -1595,16 +1551,6 @@ dependencies = [ "hybrid-array", ] -[[package]] -name = "csscolorparser" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb2a7d3066da2de787b7f032c736763eb7ae5d355f81a68bab2675a96008b0bf" -dependencies = [ - "lab", - "phf", -] - [[package]] name = "csv" version = "1.4.0" @@ -1640,9 +1586,9 @@ checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" [[package]] name = "darling" -version = "0.23.0" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1650,26 +1596,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.23.0" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ + "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.117", + "syn", ] [[package]] name = "darling_macro" -version = "0.23.0" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1698,12 +1645,6 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be1e0bca6c3637f992fc1cc7cbc52a78c1ef6db076dbf1059c4323d6a2048376" -[[package]] -name = "deltae" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5729f5117e208430e437df2f4843f5e5952997175992d1414f94c57d61e270b4" - [[package]] name = "deranged" version = "0.5.8" @@ -1714,6 +1655,37 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + [[package]] name = "derive_more" version = "2.1.1" @@ -1733,7 +1705,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.117", + "syn", "unicode-xid", ] @@ -1824,7 +1796,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1857,6 +1829,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b14ccef22fc6f5a8f4d7d768562a182c04ce9a3b3157b91390b52ddfdf1a76" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "either" version = "1.15.0" @@ -1912,7 +1890,7 @@ checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1955,7 +1933,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2037,6 +2015,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom 7.1.3", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.74.0" @@ -2052,23 +2041,13 @@ dependencies = [ "zune-inflate", ] -[[package]] -name = "fancy-regex" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" -dependencies = [ - "bit-set 0.5.3", - "regex", -] - [[package]] name = "fancy-regex" version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "998b056554fbe42e03ae0e152895cd1a7e1002aec800fdc6635d20270260c46f" dependencies = [ - "bit-set 0.8.0", + "bit-set", "regex-automata", "regex-syntax", ] @@ -2096,7 +2075,7 @@ checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2108,17 +2087,6 @@ dependencies = [ "simd-adler32", ] -[[package]] -name = "filedescriptor" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d" -dependencies = [ - "libc", - "thiserror 1.0.69", - "winapi", -] - [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -2134,18 +2102,6 @@ dependencies = [ "glob", ] -[[package]] -name = "finl_unicode" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9844ddc3a6e533d62bba727eb6c28b5d360921d5175e9ff0f1e621a5c590a4d5" - -[[package]] -name = "fixedbitset" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" - [[package]] name = "flate2" version = "1.1.9" @@ -2270,7 +2226,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2304,6 +2260,12 @@ dependencies = [ "num", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fslock" version = "0.2.1" @@ -2389,7 +2351,7 @@ checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2449,8 +2411,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -2695,6 +2659,15 @@ dependencies = [ "foldhash 0.2.0", ] +[[package]] +name = "hashlink" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" +dependencies = [ + "hashbrown 0.16.1", +] + [[package]] name = "headers" version = "0.4.1" @@ -2758,7 +2731,7 @@ dependencies = [ "native-tls", "num_cpus", "rand 0.9.4", - "reqwest", + "reqwest 0.12.28", "serde", "serde_json", "thiserror 2.0.18", @@ -3323,19 +3296,6 @@ dependencies = [ "rustversion", ] -[[package]] -name = "instability" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eb2d60ef19920a3a9193c3e371f726ec1dafc045dac788d0fb3704272458971" -dependencies = [ - "darling", - "indoc", - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "interpolate_name" version = "0.2.4" @@ -3344,7 +3304,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3442,7 +3402,7 @@ checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3488,7 +3448,7 @@ dependencies = [ "quote", "rustc_version", "simd_cesu8", - "syn 2.0.117", + "syn", ] [[package]] @@ -3516,7 +3476,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" dependencies = [ "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3551,7 +3511,7 @@ dependencies = [ "bytecount", "data-encoding", "email_address", - "fancy-regex 0.16.2", + "fancy-regex", "fraction", "getrandom 0.3.4", "idna", @@ -3577,17 +3537,6 @@ dependencies = [ "mutate_once", ] -[[package]] -name = "kasuari" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde5057d6143cc94e861d90f591b9303d6716c6b9602309150bd068853c10899" -dependencies = [ - "hashbrown 0.16.1", - "portable-atomic", - "thiserror 2.0.18", -] - [[package]] name = "khronos-egl" version = "6.0.0" @@ -3637,12 +3586,6 @@ dependencies = [ "smallvec", ] -[[package]] -name = "lab" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf36173d4167ed999940f804952e6b08197cae5ad5d572eb4db150ce8ad5d58f" - [[package]] name = "language-tags" version = "0.3.2" @@ -3720,15 +3663,6 @@ dependencies = [ "web-time", ] -[[package]] -name = "line-clipping" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f50e8f47623268b5407192d26876c4d7f89d686ca130fdc53bced4814cd29f8" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "linebender_resource_handle" version = "0.1.1" @@ -3752,7 +3686,7 @@ checksum = "e5cec0ec4228b4853bb129c84dbf093a27e6c7a20526da046defc334a1b017f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3902,19 +3836,12 @@ name = "lru" version = "0.16.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f66e8d5d03f609abc3a39e6f08e4164ebf1447a732906d39eb9b99b7919ef39" -dependencies = [ - "hashbrown 0.16.1", -] [[package]] -name = "mac_address" -version = "1.1.8" +name = "lru-slab" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" -dependencies = [ - "nix 0.29.0", - "winapi", -] +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "macro_registry" @@ -3923,7 +3850,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28c03fc749d06e1000766283015673e91aea121f30c60b7445681f2248e4994c" dependencies = [ "module_path_extractor", - "syn 2.0.117", + "syn", ] [[package]] @@ -3960,12 +3887,6 @@ dependencies = [ "libc", ] -[[package]] -name = "memmem" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" - [[package]] name = "memo-map" version = "0.3.3" @@ -4124,7 +4045,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "066cf25f0e8b11ee0df221219010f213ad429855f57c494f995590c861a9a7d8" dependencies = [ "arrayvec", - "bit-set 0.8.0", + "bit-set", "bitflags 2.11.0", "cfg-if", "cfg_aliases", @@ -4205,19 +4126,6 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" -[[package]] -name = "nix" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" -dependencies = [ - "bitflags 2.11.0", - "cfg-if", - "cfg_aliases", - "libc", - "memoffset", -] - [[package]] name = "nix" version = "0.30.1" @@ -4308,7 +4216,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -4381,16 +4289,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.117", -] - -[[package]] -name = "num_threads" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" -dependencies = [ - "libc", + "syn", ] [[package]] @@ -4804,7 +4703,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -4841,15 +4740,6 @@ dependencies = [ "libredox", ] -[[package]] -name = "ordered-float" -version = "4.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" -dependencies = [ - "num-traits", -] - [[package]] name = "ordered-float" version = "5.3.0" @@ -4885,27 +4775,16 @@ dependencies = [ ] [[package]] -name = "paddler" +name = "paddler_agent" version = "4.0.0" dependencies = [ - "actix", - "actix-cors", - "actix-rt", - "actix-web", - "actix-web-lab", - "actix-ws", "anyhow", - "askama", "async-stream", "async-trait", "base64", "bytes 1.11.1", - "cadence", - "clap", "dashmap", "encoding_rs", - "env_logger", - "esbuild-metafile", "futures 0.3.32", "futures-util", "hf-hub", @@ -4914,18 +4793,19 @@ dependencies = [ "jsonschema", "llama-cpp-bindings", "llama-cpp-bindings-sys", + "llama-cpp-bindings-types", "log", - "mime_guess", "minijinja", "minijinja-contrib", "nanoid", "paddler_cache_dir", "paddler_download_manager", - "paddler_types", + "paddler_messaging", + "paddler_state_conversion", + "parking_lot", "rand 0.9.4", - "reqwest", + "reqwest 0.12.28", "resvg 0.46.0", - "rust-embed", "serde", "serde_json", "shellexpand", @@ -4940,6 +4820,47 @@ dependencies = [ "url", ] +[[package]] +name = "paddler_balancer" +version = "4.0.0" +dependencies = [ + "actix-cors", + "actix-web", + "actix-web-lab", + "actix-ws", + "anyhow", + "askama", + "async-stream", + "async-trait", + "bytes 1.11.1", + "cadence", + "dashmap", + "esbuild-metafile", + "futures 0.3.32", + "futures-util", + "indoc", + "llama-cpp-bindings-types", + "log", + "mime_guess", + "nanoid", + "paddler_messaging", + "paddler_openai_response_format_validator", + "paddler_state_conversion", + "parking_lot", + "rust-embed", + "serde", + "serde_json", + "shellexpand", + "tempfile", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-test", + "tokio-util", + "trzcina", + "url", +] + [[package]] name = "paddler_bootstrap" version = "4.0.0" @@ -4949,8 +4870,11 @@ dependencies = [ "async-trait", "log", "nanoid", - "paddler", - "paddler_types", + "nix", + "paddler_agent", + "paddler_balancer", + "paddler_messaging", + "reqwest 0.12.28", "tempfile", "tokio", "tokio-util", @@ -4982,54 +4906,52 @@ dependencies = [ "esbuild-metafile", "log", "nanoid", - "paddler", + "paddler_balancer", "paddler_bootstrap", - "paddler_types", "tokio", "tokio-util", "trzcina", ] [[package]] -name = "paddler_client" +name = "paddler_cli_tests" version = "4.0.0" dependencies = [ "anyhow", - "dashmap", + "async-trait", + "base64", "futures-util", "log", - "nanoid", - "paddler_types", - "reqwest", - "serde", + "nix", + "paddler_cli", + "paddler_client", + "paddler_messaging", + "paddler_test_cluster_harness", + "reqwest 0.12.28", "serde_json", - "thiserror 2.0.18", + "serial_test", "tokio", - "tokio-stream", - "tokio-tungstenite", "url", ] [[package]] -name = "paddler_client_cli" +name = "paddler_client" version = "4.0.0" dependencies = [ "anyhow", - "async-trait", - "clap", - "crossterm", - "env_logger", + "dashmap", "futures-util", - "llama-cpp-bindings-types", + "http 1.4.0", "log", - "paddler_bootstrap", - "paddler_client", - "paddler_types", - "ratatui", - "reqwest", + "nanoid", + "paddler_messaging", + "reqwest 0.12.28", + "serde", "serde_json", + "thiserror 2.0.18", "tokio", - "tokio-util", + "tokio-stream", + "tokio-tungstenite", "url", ] @@ -5041,7 +4963,7 @@ dependencies = [ "bytes 1.11.1", "futures-util", "headers", - "reqwest", + "reqwest 0.12.28", "tempfile", "thiserror 2.0.18", "tokio", @@ -5062,51 +4984,104 @@ dependencies = [ "if-addrs", "log", "open", - "paddler", + "paddler_balancer", "paddler_bootstrap", - "paddler_types", + "paddler_messaging", + "parking_lot", "statum", "tokio", "tokio-util", + "trzcina", ] [[package]] -name = "paddler_tests" +name = "paddler_messaging" +version = "4.0.0" +dependencies = [ + "anyhow", + "base64", + "encoding_rs", + "llama-cpp-bindings-types", + "log", + "nanoid", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "url", +] + +[[package]] +name = "paddler_openai_response_format_validator" +version = "4.0.0" +dependencies = [ + "anyhow", + "jsonschema", + "serde_json", + "thiserror 2.0.18", + "yaml-rust2", +] + +[[package]] +name = "paddler_state_conversion" +version = "4.0.0" +dependencies = [ + "anyhow", + "async-trait", +] + +[[package]] +name = "paddler_test_cluster_harness" version = "4.0.0" dependencies = [ "anyhow", + "async-openai", "async-stream", + "async-trait", "base64", "futures-util", - "hf-hub", - "llama-cpp-bindings", - "log", - "nix 0.30.1", - "paddler", - "paddler_bootstrap", + "http 1.4.0", "paddler_client", - "paddler_types", - "reqwest", + "paddler_messaging", + "reqwest 0.12.28", "serde", "serde_json", - "serial_test", "tempfile", "tokio", - "tokio-tungstenite", - "tokio-util", "url", ] [[package]] -name = "paddler_types" +name = "paddler_tests" version = "4.0.0" dependencies = [ "anyhow", - "jsonschema", - "llama-cpp-bindings-types", + "async-stream", + "async-trait", + "base64", + "futures-util", + "hf-hub", + "llama-cpp-bindings", + "log", + "nix", + "paddler_agent", + "paddler_balancer", + "paddler_bootstrap", + "paddler_client", + "paddler_messaging", + "paddler_openai_response_format_validator", + "paddler_test_cluster_harness", + "parking_lot", + "reqwest 0.12.28", "serde", "serde_json", - "thiserror 2.0.18", + "serial_test", + "tempfile", + "tokio", + "tokio-tungstenite", + "tokio-util", + "trzcina", + "url", ] [[package]] @@ -5162,101 +5137,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "pest" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" -dependencies = [ - "memchr", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" -dependencies = [ - "pest", - "pest_generator", -] - -[[package]] -name = "pest_generator" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] -name = "pest_meta" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" -dependencies = [ - "pest", - "sha2", -] - -[[package]] -name = "phf" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" -dependencies = [ - "phf_macros", - "phf_shared", -] - -[[package]] -name = "phf_codegen" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" -dependencies = [ - "phf_generator", - "phf_shared", -] - -[[package]] -name = "phf_generator" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" -dependencies = [ - "phf_shared", - "rand 0.8.6", -] - -[[package]] -name = "phf_macros" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" -dependencies = [ - "phf_generator", - "phf_shared", - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] -name = "phf_shared" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" -dependencies = [ - "siphasher", -] - [[package]] name = "pico-args" version = "0.5.0" @@ -5280,7 +5160,7 @@ checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -5420,7 +5300,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.117", + "syn", ] [[package]] @@ -5457,7 +5337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" dependencies = [ "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -5490,6 +5370,62 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes 1.11.1", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.2", + "rustls", + "socket2 0.6.3", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "aws-lc-rs", + "bytes 1.11.1", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.4", + "ring", + "rustc-hash 2.1.2", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2 0.6.3", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.45" @@ -5588,101 +5524,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" [[package]] -name = "range-alloc" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" - -[[package]] -name = "rangemap" -version = "1.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" - -[[package]] -name = "ratatui" -version = "0.30.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce67fb8ba4446454d1c8dbaeda0557ff5e94d39d5e5ed7f10a65eb4c8266bc" -dependencies = [ - "instability", - "ratatui-core", - "ratatui-crossterm", - "ratatui-macros", - "ratatui-termwiz", - "ratatui-widgets", -] - -[[package]] -name = "ratatui-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef8dea09a92caaf73bff7adb70b76162e5937524058a7e5bff37869cbbec293" -dependencies = [ - "bitflags 2.11.0", - "compact_str", - "hashbrown 0.16.1", - "indoc", - "itertools 0.14.0", - "kasuari", - "lru", - "strum", - "thiserror 2.0.18", - "unicode-segmentation", - "unicode-truncate", - "unicode-width", -] - -[[package]] -name = "ratatui-crossterm" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "577c9b9f652b4c121fb25c6a391dd06406d3b092ba68827e6d2f09550edc54b3" -dependencies = [ - "cfg-if", - "crossterm", - "instability", - "ratatui-core", -] - -[[package]] -name = "ratatui-macros" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7f1342a13e83e4bb9d0b793d0ea762be633f9582048c892ae9041ef39c936f4" -dependencies = [ - "ratatui-core", - "ratatui-widgets", -] - -[[package]] -name = "ratatui-termwiz" -version = "0.1.0" +name = "range-alloc" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f76fe0bd0ed4295f0321b1676732e2454024c15a35d01904ddb315afd3d545c" -dependencies = [ - "ratatui-core", - "termwiz", -] +checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" [[package]] -name = "ratatui-widgets" -version = "0.3.0" +name = "rangemap" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7dbfa023cd4e604c2553483820c5fe8aa9d71a42eea5aa77c6e7f35756612db" -dependencies = [ - "bitflags 2.11.0", - "hashbrown 0.16.1", - "indoc", - "instability", - "itertools 0.14.0", - "line-clipping", - "ratatui-core", - "strum", - "time", - "unicode-segmentation", - "unicode-width", -] +checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" [[package]] name = "rav1e" @@ -5836,7 +5687,7 @@ checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -5948,7 +5799,49 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", + "wasm-streams 0.4.2", + "web-sys", +] + +[[package]] +name = "reqwest" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" +dependencies = [ + "base64", + "bytes 1.11.1", + "futures-core", + "futures-util", + "http 1.4.0", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "mime_guess", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "rustls-platform-verifier", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams 0.5.0", "web-sys", ] @@ -6045,7 +5938,7 @@ dependencies = [ "quote", "rust-embed-utils", "shellexpand", - "syn 2.0.117", + "syn", "walkdir", ] @@ -6112,6 +6005,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -6121,21 +6015,62 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dab5152771c58876a2146916e53e35057e1a4dfa2b9df0f0305b07f611fdea4d" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d1e2536ce4f35f4846aa13bff16bd0ff40157cdb14cc056c7b14ba41233ba0" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "jni 0.22.4", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -6229,6 +6164,16 @@ version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -6291,7 +6236,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6340,7 +6285,7 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6379,7 +6324,7 @@ checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6430,27 +6375,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" -dependencies = [ - "libc", - "signal-hook-registry", -] - -[[package]] -name = "signal-hook-mio" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" -dependencies = [ - "libc", - "mio", - "signal-hook", -] - [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -6739,7 +6663,7 @@ dependencies = [ "moddef", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6775,7 +6699,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6821,17 +6745,6 @@ dependencies = [ "zeno", ] -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.117" @@ -6860,7 +6773,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -6915,69 +6828,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "terminfo" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4ea810f0692f9f51b382fff5893887bb4580f5fa246fde546e0b13e7fcee662" -dependencies = [ - "fnv", - "nom 7.1.3", - "phf", - "phf_codegen", -] - -[[package]] -name = "termios" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "411c5bf740737c7918b8b1fe232dca4dc9f8e754b8ad5e20966814001ed0ac6b" -dependencies = [ - "libc", -] - -[[package]] -name = "termwiz" -version = "0.23.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4676b37242ccbd1aabf56edb093a4827dc49086c0ffd764a5705899e0f35f8f7" -dependencies = [ - "anyhow", - "base64", - "bitflags 2.11.0", - "fancy-regex 0.11.0", - "filedescriptor", - "finl_unicode", - "fixedbitset", - "hex", - "lazy_static", - "libc", - "log", - "memmem", - "nix 0.29.0", - "num-derive", - "num-traits", - "ordered-float 4.6.0", - "pest", - "pest_derive", - "phf", - "sha2", - "signal-hook", - "siphasher", - "terminfo", - "termios", - "thiserror 1.0.69", - "ucd-trie", - "unicode-segmentation", - "vtparse", - "wezterm-bidi", - "wezterm-blob-leases", - "wezterm-color-types", - "wezterm-dynamic", - "wezterm-input-types", - "winapi", -] - [[package]] name = "textwrap" version = "0.16.2" @@ -7013,7 +6863,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -7024,7 +6874,7 @@ checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -7049,9 +6899,7 @@ checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", - "libc", "num-conv", - "num_threads", "powerfmt", "serde_core", "time-core", @@ -7174,7 +7022,7 @@ checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -7299,8 +7147,10 @@ dependencies = [ "pin-project-lite", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -7353,7 +7203,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -7417,12 +7267,6 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - [[package]] name = "uds_windows" version = "1.2.1" @@ -7494,17 +7338,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" -[[package]] -name = "unicode-truncate" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b380a1238663e5f8a691f9039c73e1cdae598a30e9855f541d29b08b53e9a5" -dependencies = [ - "itertools 0.14.0", - "unicode-segmentation", - "unicode-width", -] - [[package]] name = "unicode-vo" version = "0.1.0" @@ -7646,8 +7479,6 @@ version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ - "atomic", - "getrandom 0.4.2", "js-sys", "serde_core", "wasm-bindgen", @@ -7692,15 +7523,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" -[[package]] -name = "vtparse" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d9b2acfb050df409c972a37d3b8e08cdea3bddb0c09db9d53137e504cfabed0" -dependencies = [ - "utf8parse", -] - [[package]] name = "walkdir" version = "2.5.0" @@ -7786,7 +7608,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wasm-bindgen-shared", ] @@ -7834,6 +7656,19 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasm-streams" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1ec4f6517c9e11ae630e200b2b65d193279042e28edd4a2cda233e46670bbb" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" @@ -8015,6 +7850,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31141ce3fc3e300ae89b78c0dd67f9708061d1d2eda54b8209346fd6be9a92c" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -8039,78 +7883,6 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" -[[package]] -name = "wezterm-bidi" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c0a6e355560527dd2d1cf7890652f4f09bb3433b6aadade4c9b5ed76de5f3ec" -dependencies = [ - "log", - "wezterm-dynamic", -] - -[[package]] -name = "wezterm-blob-leases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "692daff6d93d94e29e4114544ef6d5c942a7ed998b37abdc19b17136ea428eb7" -dependencies = [ - "getrandom 0.3.4", - "mac_address", - "sha2", - "thiserror 1.0.69", - "uuid", -] - -[[package]] -name = "wezterm-color-types" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de81ef35c9010270d63772bebef2f2d6d1f2d20a983d27505ac850b8c4b4296" -dependencies = [ - "csscolorparser", - "deltae", - "lazy_static", - "wezterm-dynamic", -] - -[[package]] -name = "wezterm-dynamic" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f2ab60e120fd6eaa68d9567f3226e876684639d22a4219b313ff69ec0ccd5ac" -dependencies = [ - "log", - "ordered-float 4.6.0", - "strsim", - "thiserror 1.0.69", - "wezterm-dynamic-derive", -] - -[[package]] -name = "wezterm-dynamic-derive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c0cf2d539c645b448eaffec9ec494b8b19bd5077d9e58cb1ae7efece8d575b" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "wezterm-input-types" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7012add459f951456ec9d6c7e6fc340b1ce15d6fc9629f8c42853412c029e57e" -dependencies = [ - "bitflags 1.3.2", - "euclid", - "lazy_static", - "serde", - "wezterm-dynamic", -] - [[package]] name = "wgpu" version = "27.0.1" @@ -8147,8 +7919,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27a75de515543b1897b26119f93731b385a19aea165a1ec5f0e3acecc229cae7" dependencies = [ "arrayvec", - "bit-set 0.8.0", - "bit-vec 0.8.0", + "bit-set", + "bit-vec", "bitflags 2.11.0", "bytemuck", "cfg_aliases", @@ -8208,7 +7980,7 @@ dependencies = [ "android_system_properties", "arrayvec", "ash", - "bit-set 0.8.0", + "bit-set", "bitflags 2.11.0", "block", "bytemuck", @@ -8231,7 +8003,7 @@ dependencies = [ "ndk-sys", "objc", "once_cell", - "ordered-float 5.3.0", + "ordered-float", "parking_lot", "portable-atomic", "portable-atomic-util", @@ -8383,7 +8155,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -8394,7 +8166,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -8405,7 +8177,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -8416,7 +8188,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -8823,7 +8595,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn 2.0.117", + "syn", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -8839,7 +8611,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -8962,6 +8734,17 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448" +[[package]] +name = "yaml-rust2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "631a50d867fafb7093e709d75aaee9e0e0d5deb934021fcea25ac2fe09edc51e" +dependencies = [ + "arraydeque", + "encoding_rs", + "hashlink", +] + [[package]] name = "yansi" version = "1.0.1" @@ -8993,7 +8776,7 @@ checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", "synstructure", ] @@ -9041,7 +8824,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "zbus_names", "zvariant", "zvariant_utils", @@ -9081,7 +8864,7 @@ checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -9101,7 +8884,7 @@ checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", "synstructure", ] @@ -9141,7 +8924,7 @@ checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -9240,7 +9023,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "zvariant_utils", ] @@ -9253,6 +9036,6 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.117", + "syn", "winnow 0.7.15", ] diff --git a/Cargo.toml b/Cargo.toml index 6f73ced9..668adba1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,20 @@ [workspace] -members = ["paddler", "paddler_bootstrap", "paddler_cache_dir", "paddler_cli", "paddler_client", "paddler_client_cli", "paddler_download_manager", "paddler_gui", "paddler_tests", "paddler_types"] +members = [ + "paddler_agent", + "paddler_balancer", + "paddler_bootstrap", + "paddler_cache_dir", + "paddler_cli", + "paddler_cli_tests", + "paddler_client", + "paddler_download_manager", + "paddler_gui", + "paddler_messaging", + "paddler_openai_response_format_validator", + "paddler_state_conversion", + "paddler_test_cluster_harness", + "paddler_tests", +] resolver = "2" [workspace.package] @@ -20,6 +35,7 @@ actix-web-lab = "0.26" actix-ws = "0.3" anyhow = { version = "1", features = ["backtrace"] } askama = "0.14" +async-openai = { version = "=0.40.3", features = ["byot", "chat-completion", "responses"] } async-stream = "0.3" async-trait = "0.1" bytes = "1.11" @@ -35,6 +51,7 @@ futures = "0.3" futures-util = { version = "0.3", features = ["tokio-io"] } headers = "=0.4.1" hf-hub = { version = "0.4", features = ["tokio"] } +http = "1" image = "0.25" indoc = "2" jsonschema = { version = "0.37", default-features = false } @@ -49,6 +66,7 @@ minijinja-contrib = { version = "2.12", features = ["datetime", "pycompat", "wor nanoid = "0.4" nix = { version = "0.30", features = ["signal"] } open = "5.3.4" +parking_lot = "=0.12.5" pastey = "0.2" rand = "0.9" ratatui = "=0.30.0" @@ -72,13 +90,19 @@ tokio-util = "0.7" thiserror = "2" trzcina = "=0.3.0" url = { version = "2.5", features = ["serde"] } -paddler = { version = "4.0.0", path = "paddler" } +yaml-rust2 = "0.11" +paddler_agent = { version = "4.0.0", path = "paddler_agent" } +paddler_balancer = { version = "4.0.0", path = "paddler_balancer" } paddler_bootstrap = { version = "4.0.0", path = "paddler_bootstrap" } paddler_cache_dir = { version = "4.0.0", path = "paddler_cache_dir" } +paddler_cli = { version = "4.0.0", path = "paddler_cli" } paddler_client = { version = "4.0.0", path = "paddler_client" } paddler_download_manager = { version = "4.0.0", path = "paddler_download_manager" } +paddler_messaging = { version = "4.0.0", path = "paddler_messaging" } +paddler_openai_response_format_validator = { version = "4.0.0", path = "paddler_openai_response_format_validator" } +paddler_state_conversion = { version = "4.0.0", path = "paddler_state_conversion" } +paddler_test_cluster_harness = { version = "4.0.0", path = "paddler_test_cluster_harness" } paddler_tests = { version = "4.0.0", path = "paddler_tests" } -paddler_types = { version = "4.0.0", path = "paddler_types" } [profile.release] lto = true diff --git a/Makefile b/Makefile index 29782f21..e0ed7ab7 100644 --- a/Makefile +++ b/Makefile @@ -2,10 +2,19 @@ RUST_LOG ?= debug -COVERAGE_PACKAGES := -p paddler_cache_dir -p paddler_download_manager -PADDLER_SOURCES := $(shell find paddler/src paddler_bootstrap/src paddler_cache_dir/src paddler_cli/src paddler_client/src paddler_download_manager/src paddler_gui/src paddler_types/src -name '*.rs') +PADDLER_SOURCES := $(shell find paddler_agent/src paddler_balancer/src paddler_bootstrap/src paddler_cache_dir/src paddler_cli/src paddler_client/src paddler_download_manager/src paddler_gui/src paddler_messaging/src paddler_state_conversion/src -name '*.rs') FRONTEND_SOURCES := $(shell find resources -type f) $(wildcard jarmuz/*.mjs) +TEST_DEVICE ?= cpu + +ifeq ($(TEST_DEVICE),cpu) +TEST_DEVICE_FEATURE_SUFFIX := +TEST_DEVICE_TARGET_DIR := +else +TEST_DEVICE_FEATURE_SUFFIX := ,$(TEST_DEVICE) +TEST_DEVICE_TARGET_DIR := --target-dir target/$(TEST_DEVICE) +endif + # ----------------------------------------------------------------------------- # Real targets # ----------------------------------------------------------------------------- @@ -67,29 +76,7 @@ clean: .PHONY: clippy clippy: esbuild-meta.json - cargo clippy --workspace --all-targets --features web_admin_panel,tests_that_use_llms,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster - -.PHONY: coverage -coverage: node_modules - cargo llvm-cov clean --workspace - cargo llvm-cov $(COVERAGE_PACKAGES) --no-report - cargo llvm-cov report --json --output-path target/llvm-cov.json - cargo llvm-cov report --lcov --output-path target/lcov.info - cargo llvm-cov report - npx rust-coverage-check target/llvm-cov.json \ - --workspace-root $(CURDIR) \ - --gated paddler_cache_dir=100 \ - --gated paddler_download_manager=99 - -.PHONY: coverage-clean -coverage-clean: - cargo llvm-cov clean --workspace - rm -rf target/llvm-cov-target - rm -f target/llvm-cov.json target/lcov.info - -.PHONY: coverage-report -coverage-report: - cargo llvm-cov $(COVERAGE_PACKAGES) --html + cargo clippy --workspace --all-targets --features web_admin_panel,tests_that_use_llms .PHONY: fmt fmt: node_modules @@ -102,21 +89,42 @@ test: test.client.js test.unit test.integration test.client.js: node_modules npm --workspace @intentee/paddler-client test -.PHONY: test.integration -test.integration: target/debug/paddler - cargo test -p paddler_tests --features tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms - -.PHONY: test.integration.cuda -test.integration.cuda: target/cuda/debug/paddler - PADDLER_BINARY_PATH=../target/cuda/debug/paddler PADDLER_TEST_DEVICE=cuda cargo test --target-dir target/cuda -p paddler_tests --features cuda,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms +.PHONY: test.coverage +test.coverage: esbuild-meta.json node_modules + cargo llvm-cov clean --profraw-only + cargo llvm-cov --features tests_that_use_llms,web_admin_panel$(TEST_DEVICE_FEATURE_SUFFIX) --no-report --workspace + cargo llvm-cov report --json --output-path target/llvm-cov.json + cargo llvm-cov report --lcov --output-path target/lcov.info + cargo llvm-cov report + npx rust-coverage-check target/llvm-cov.json \ + --workspace-root $(CURDIR) \ + --gated paddler_agent=96 \ + --gated paddler_balancer=84 \ + --gated paddler_bootstrap=100 \ + --gated paddler_cache_dir=100 \ + --gated paddler_cli=83 \ + --gated paddler_cli_tests=87 \ + --gated paddler_client=41 \ + --gated paddler_download_manager=99 \ + --gated paddler_gui=13 \ + --gated paddler_messaging=100 \ + --gated paddler_openai_response_format_validator=99 \ + --gated paddler_test_cluster_harness=67 \ + --gated paddler_tests=80 + +.PHONY: test.coverage-clean +test.coverage-clean: + cargo llvm-cov clean --workspace + rm -rf target/llvm-cov-target + rm -f target/llvm-cov.json target/lcov.info -.PHONY: test.integration.metal -test.integration.metal: target/metal/debug/paddler - PADDLER_BINARY_PATH=../target/metal/debug/paddler PADDLER_TEST_DEVICE=metal cargo test --target-dir target/metal -p paddler_tests --features metal,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms +.PHONY: test.integration +test.integration: + cargo test -p paddler_tests -p paddler_cli_tests --features tests_that_use_llms$(TEST_DEVICE_FEATURE_SUFFIX) $(TEST_DEVICE_TARGET_DIR) .PHONY: test.unit test.unit: esbuild-meta.json - cargo test --features web_admin_panel + cargo test --features web_admin_panel$(TEST_DEVICE_FEATURE_SUFFIX) $(TEST_DEVICE_TARGET_DIR) .PHONY: watch watch: node_modules diff --git a/clippy.toml b/clippy.toml index 59cb72d1..c64d65f8 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1,2 +1,3 @@ allow-expect-in-tests = true +allow-panic-in-tests = true allow-unwrap-in-tests = true diff --git a/jarmuz-fmt.mjs b/jarmuz-fmt.mjs index 510bddfe..c8202f47 100755 --- a/jarmuz-fmt.mjs +++ b/jarmuz-fmt.mjs @@ -7,12 +7,21 @@ jarmuz({ pipeline: ["cargo-fmt", "prettier"], watch: [ "jarmuz", - "paddler", + "paddler_agent", + "paddler_balancer", + "paddler_bootstrap", + "paddler_cache_dir", + "paddler_cli", + "paddler_cli_tests", "paddler_client", - "paddler_types", + "paddler_download_manager", + "paddler_gui", + "paddler_messaging", + "paddler_openai_response_format_validator", + "paddler_state_conversion", + "paddler_test_cluster_harness", + "paddler_tests", "resources", - "templates", - "*.mjs", ], }).decide(function ({ matches, schedule }) { switch (true) { diff --git a/jarmuz/run-website.mjs b/jarmuz/run-website.mjs index ee8453e5..c32909f1 100644 --- a/jarmuz/run-website.mjs +++ b/jarmuz/run-website.mjs @@ -7,10 +7,16 @@ export function run({ development, once = false, rustJobs }) { once, pipeline: ["stylelint", "tcm", "tsc", "eslint", esbuildJob, ...rustJobs], watch: [ - "paddler", + "paddler_agent", + "paddler_balancer", + "paddler_bootstrap", + "paddler_cache_dir", + "paddler_cli", "paddler_client", "paddler_client_javascript", - "paddler_types", + "paddler_download_manager", + "paddler_messaging", + "paddler_state_conversion", "resources", ], }).decide(function ({ matches, schedule }) { @@ -28,7 +34,7 @@ export function run({ development, once = false, rustJobs }) { schedule("tcm"); schedule(esbuildJob); return; - case matches("paddler/templates/**/*.html"): + case matches("paddler_balancer/templates/**/*.html"): case matches("**/*.rs"): for (const job of rustJobs) { schedule(job); diff --git a/paddler/src/agent/continuous_batch_active_request.rs b/paddler/src/agent/continuous_batch_active_request.rs deleted file mode 100644 index 13d239ef..00000000 --- a/paddler/src/agent/continuous_batch_active_request.rs +++ /dev/null @@ -1,60 +0,0 @@ -use llama_cpp_bindings::SampledToken; -use llama_cpp_bindings::SampledTokenClassifier; -use llama_cpp_bindings::sampling::LlamaSampler; -use llama_cpp_bindings::token::LlamaToken; -use log::warn; -use paddler_types::generated_token_result::GeneratedTokenResult; -use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TryRecvError; - -use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; -use crate::agent::slot_guard::SlotGuard; -use crate::tool_call_pipeline::ToolCallPipeline; - -pub struct ContinuousBatchActiveRequest { - pub chain: LlamaSampler, - pub token_classifier: SampledTokenClassifier<'static>, - pub current_token_position: i32, - pub grammar_sampler: Option, - pub generated_tokens_tx: mpsc::UnboundedSender, - pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, - pub i_batch: Option, - pub max_tokens: i32, - pub pending_sampled_token: Option, - pub phase: ContinuousBatchRequestPhase, - pub prompt_tokens: Vec, - pub prompt_tokens_ingested: usize, - pub sequence_id: i32, - pub slot_guard: SlotGuard, - pub tool_call_pipeline: Option, -} - -impl ContinuousBatchActiveRequest { - pub fn complete_with_outcome( - &mut self, - agent_name: &Option, - outcome: GeneratedTokenResult, - ) { - if self.generated_tokens_tx.send(outcome).is_err() { - warn!( - "{agent_name:?}: sequence {} failed to send result to client (receiver dropped)", - self.sequence_id - ); - } - - self.i_batch = None; - self.phase = ContinuousBatchRequestPhase::Completed; - } - - pub fn is_stop_requested(&mut self) -> bool { - match self.generate_tokens_stop_rx.try_recv() { - Ok(()) | Err(TryRecvError::Disconnected) => true, - Err(TryRecvError::Empty) => false, - } - } - - #[must_use] - pub fn remaining_prompt_tokens(&self) -> &[LlamaToken] { - &self.prompt_tokens[self.prompt_tokens_ingested..] - } -} diff --git a/paddler/src/agent/continuous_batch_arbiter_handle.rs b/paddler/src/agent/continuous_batch_arbiter_handle.rs deleted file mode 100644 index b89df927..00000000 --- a/paddler/src/agent/continuous_batch_arbiter_handle.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::sync::mpsc::SendError; -use std::sync::mpsc::Sender; -use std::thread; - -use anyhow::Result; -use anyhow::anyhow; - -use crate::agent::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; - -pub struct ContinuousBatchArbiterHandle { - pub command_tx: Sender, - pub scheduler_thread_handle: thread::JoinHandle>, -} - -impl ContinuousBatchArbiterHandle { - pub fn shutdown(self) -> Result<()> { - if let Err(SendError(_unsent_command)) = self - .command_tx - .send(ContinuousBatchSchedulerCommand::Shutdown) - { - // Scheduler thread already dropped its receiver; join below is authoritative. - } - - self.scheduler_thread_handle - .join() - .map_err(|err| anyhow!("Failed to join scheduler thread: {err:?}"))? - } -} diff --git a/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs b/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs deleted file mode 100644 index 174d4b7e..00000000 --- a/paddler/src/agent/continuous_batch_scheduler/advance_outcome.rs +++ /dev/null @@ -1,34 +0,0 @@ -use llama_cpp_bindings::SampledToken; -use paddler_types::generated_token_result::GeneratedTokenResult; - -pub enum AdvanceOutcome { - SampledAndStored(SampledToken), - Completed(GeneratedTokenResult), - ChannelDropped, -} - -#[cfg(test)] -mod tests { - use paddler_types::generated_token_result::GeneratedTokenResult; - use paddler_types::generation_summary::GenerationSummary; - - use super::AdvanceOutcome; - - #[test] - fn completed_carries_event_through_into_inner() { - let outcome = - AdvanceOutcome::Completed(GeneratedTokenResult::Done(GenerationSummary::default())); - - assert!(matches!( - outcome, - AdvanceOutcome::Completed(GeneratedTokenResult::Done(_)) - )); - } - - #[test] - fn channel_dropped_is_distinct_variant() { - let outcome = AdvanceOutcome::ChannelDropped; - - assert!(matches!(outcome, AdvanceOutcome::ChannelDropped)); - } -} diff --git a/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs b/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs deleted file mode 100644 index 9e58ca7b..00000000 --- a/paddler/src/agent/continuous_batch_scheduler/commit_phase.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; -use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; - -#[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "chunk sizes fit in i32 for llama.cpp position arithmetic" -)] -pub fn run(pass: BatchPass, requests: &mut [ContinuousBatchActiveRequest]) { - for contribution in pass.contributions.generating { - let request = &mut requests[contribution.request_index]; - - request.pending_sampled_token = None; - request.i_batch = Some(contribution.batch_position); - request.current_token_position += 1; - } - - for contribution in pass.contributions.ingesting { - let request = &mut requests[contribution.request_index]; - - request.prompt_tokens_ingested += contribution.chunk_size; - request.current_token_position += contribution.chunk_size as i32; - - if contribution.is_last_chunk { - request.i_batch = Some(contribution.last_batch_position); - request.phase = ContinuousBatchRequestPhase::Generating; - } - } -} diff --git a/paddler/src/agent/continuous_batch_scheduler_command.rs b/paddler/src/agent/continuous_batch_scheduler_command.rs deleted file mode 100644 index a8517936..00000000 --- a/paddler/src/agent/continuous_batch_scheduler_command.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; - -pub enum ContinuousBatchSchedulerCommand { - ContinueFromConversationHistory(ContinueFromConversationHistoryRequest), - ContinueFromRawPrompt(ContinueFromRawPromptRequest), - GenerateEmbeddingBatch(GenerateEmbeddingBatchRequest), - Shutdown, -} diff --git a/paddler/src/agent/jsonrpc/mod.rs b/paddler/src/agent/jsonrpc/mod.rs deleted file mode 100644 index 627d6035..00000000 --- a/paddler/src/agent/jsonrpc/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod message; -mod notification; -pub mod notification_params; -mod request; -pub mod response; - -pub use self::message::Message; -pub use self::notification::Notification; -pub use self::request::Request; -pub use self::response::Response; diff --git a/paddler/src/agent/jsonrpc/notification_params/mod.rs b/paddler/src/agent/jsonrpc/notification_params/mod.rs deleted file mode 100644 index d0e18301..00000000 --- a/paddler/src/agent/jsonrpc/notification_params/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod set_state_params; -mod version_params; - -pub use self::set_state_params::SetStateParams; -pub use self::version_params::VersionParams; diff --git a/paddler/src/agent/jsonrpc/response.rs b/paddler/src/agent/jsonrpc/response.rs deleted file mode 100644 index c15ab81b..00000000 --- a/paddler/src/agent/jsonrpc/response.rs +++ /dev/null @@ -1,39 +0,0 @@ -use paddler_types::chat_template::ChatTemplate; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::model_metadata::ModelMetadata; -use serde::Deserialize; -use serde::Serialize; - -#[derive(Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -pub enum Response { - ChatTemplateOverride(Option), - Embedding(EmbeddingResult), - GeneratedToken(GeneratedTokenResult), - ModelMetadata(Option), -} - -impl From> for Response { - fn from(chat_template: Option) -> Self { - Self::ChatTemplateOverride(chat_template) - } -} - -impl From for Response { - fn from(embedding_result: EmbeddingResult) -> Self { - Self::Embedding(embedding_result) - } -} - -impl From for Response { - fn from(generated_token_result: GeneratedTokenResult) -> Self { - Self::GeneratedToken(generated_token_result) - } -} - -impl From> for Response { - fn from(model_metadata: Option) -> Self { - Self::ModelMetadata(model_metadata) - } -} diff --git a/paddler/src/agent/management_socket_client_service.rs b/paddler/src/agent/management_socket_client_service.rs deleted file mode 100644 index 771f5836..00000000 --- a/paddler/src/agent/management_socket_client_service.rs +++ /dev/null @@ -1,521 +0,0 @@ -use std::sync::Arc; - -use actix_web::web::Bytes; -use anyhow::Context; -use anyhow::Result; -use async_trait::async_trait; -use futures_util::SinkExt as _; -use futures_util::StreamExt; -use log::debug; -use log::error; -use log::info; -use log::warn; -use tokio::sync::mpsc; -use tokio::time::Duration; -use tokio::time::MissedTickBehavior; -use tokio::time::interval; -use tokio_tungstenite::connect_async; -use tokio_tungstenite::tungstenite::protocol::Message; -use tokio_util::sync::CancellationToken; -use trzcina::Service; - -use paddler_types::agent_desired_state::AgentDesiredState; -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::jsonrpc::ResponseEnvelope; - -use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use crate::agent::from_request_params::FromRequestParams; -use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use crate::agent::jsonrpc::Message as JsonRpcMessage; -use crate::agent::jsonrpc::Notification as JsonRpcNotification; -use crate::agent::jsonrpc::Request as JsonRpcRequest; -use crate::agent::jsonrpc::Response as JsonRpcResponse; -use crate::agent::jsonrpc::notification_params::VersionParams; -use crate::agent::model_metadata_holder::ModelMetadataHolder; -use crate::agent::receive_stream_stopper_collection::ReceiveStreamStopperCollection; -use crate::agent_applicable_state_holder::AgentApplicableStateHolder; -use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Message as ManagementJsonRpcMessage; -use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Notification as ManagementJsonRpcNotification; -use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::RegisterAgentParams; -use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::UpdateAgentStatusParams; -use crate::produces_snapshot::ProducesSnapshot; -use crate::slot_aggregated_status::SlotAggregatedStatus; -use crate::subscribes_to_updates::SubscribesToUpdates as _; - -struct IncomingMessageContext { - agent_applicable_state_holder: Arc, - agent_desired_state_tx: mpsc::UnboundedSender, - connection_close: CancellationToken, - continue_from_conversation_history_request_tx: - mpsc::UnboundedSender, - continue_from_raw_prompt_request_tx: mpsc::UnboundedSender, - generate_embedding_batch_request_tx: mpsc::UnboundedSender, - model_metadata_holder: Arc, - receive_stream_stopper_collection: Arc, - message_tx: mpsc::UnboundedSender, - slot_aggregated_status: Arc, -} - -pub struct ManagementSocketClientService { - pub agent_applicable_state_holder: Arc, - pub agent_desired_state_tx: mpsc::UnboundedSender, - pub continue_from_conversation_history_request_tx: - mpsc::UnboundedSender, - pub continue_from_raw_prompt_request_tx: mpsc::UnboundedSender, - pub generate_embedding_batch_request_tx: mpsc::UnboundedSender, - pub model_metadata_holder: Arc, - pub name: Option, - pub receive_stream_stopper_collection: Arc, - pub slot_aggregated_status: Arc, - pub socket_url: String, -} - -impl ManagementSocketClientService { - async fn generate_responses( - connection_close: CancellationToken, - id: String, - message_tx: mpsc::UnboundedSender, - request_params: TRequest::RequestParams, - receive_stream_stopper_collection: Arc, - request_tx: mpsc::UnboundedSender, - slot_aggregated_status: Arc, - ) -> Result<()> { - let (response_tx, mut response_rx) = mpsc::unbounded_channel::(); - let (stop_tx, stop_rx) = mpsc::unbounded_channel::<()>(); - - let _guard = receive_stream_stopper_collection - .register_stopper_with_guard(id.clone(), stop_tx) - .context(format!("Failed to register stopper for request: {id}"))?; - - request_tx.send(TRequest::from_request_params( - request_params, - response_tx, - stop_rx, - slot_aggregated_status, - ))?; - - loop { - tokio::select! { - () = connection_close.cancelled() => break, - response = response_rx.recv() => { - match response { - Some(response) => { - message_tx.send( - ManagementJsonRpcMessage::Response( - ResponseEnvelope { - generated_by: None, - request_id: id.clone(), - response: response.into(), - } - ), - )?; - } - None => break, - } - } - } - } - - Ok(()) - } - - async fn handle_deserialized_message( - IncomingMessageContext { - agent_applicable_state_holder, - agent_desired_state_tx, - connection_close, - continue_from_conversation_history_request_tx, - continue_from_raw_prompt_request_tx, - generate_embedding_batch_request_tx, - message_tx, - model_metadata_holder, - receive_stream_stopper_collection, - slot_aggregated_status, - }: IncomingMessageContext, - deserialized_message: JsonRpcMessage, - ) -> Result<()> { - match deserialized_message { - JsonRpcMessage::Error(ErrorEnvelope { - request_id, - error: JsonRpcError { code, description }, - }) => { - error!( - "Received error from server: code: {code}, description: {description:?}, request_id: {request_id:?}" - ); - - Ok(()) - } - JsonRpcMessage::Notification(JsonRpcNotification::SetState(set_state_params)) => { - agent_desired_state_tx.send(set_state_params.desired_state)?; - - Ok(()) - } - JsonRpcMessage::Notification(JsonRpcNotification::StopRespondingTo(request_id)) => { - debug!("Received StopGeneratingTokens notification for request ID: {request_id:?}"); - receive_stream_stopper_collection - .stop(&request_id) - .context(format!( - "Failed to stop generating tokens for request ID: {request_id}" - ))?; - - Ok(()) - } - JsonRpcMessage::Notification(JsonRpcNotification::Version(VersionParams { - version, - })) => { - if version != env!("CARGO_PKG_VERSION") { - warn!( - "Version mismatch: server version is {version}, client version is {}", - env!("CARGO_PKG_VERSION") - ); - } - - Ok(()) - } - JsonRpcMessage::Request(RequestEnvelope { - id, - request: - JsonRpcRequest::ContinueFromConversationHistory( - continue_from_conversation_history_params, - ), - }) => { - Self::generate_responses( - connection_close, - id, - message_tx, - continue_from_conversation_history_params, - receive_stream_stopper_collection, - continue_from_conversation_history_request_tx, - slot_aggregated_status, - ) - .await - } - JsonRpcMessage::Request(RequestEnvelope { - id, - request: JsonRpcRequest::ContinueFromRawPrompt(generate_tokens_params), - }) => { - Self::generate_responses( - connection_close, - id, - message_tx, - generate_tokens_params, - receive_stream_stopper_collection, - continue_from_raw_prompt_request_tx, - slot_aggregated_status, - ) - .await - } - JsonRpcMessage::Request(RequestEnvelope { - id, - request: JsonRpcRequest::GenerateEmbeddingBatch(generate_embedding_batch_params), - }) => { - Self::generate_responses( - connection_close, - id, - message_tx, - generate_embedding_batch_params, - receive_stream_stopper_collection, - generate_embedding_batch_request_tx, - slot_aggregated_status, - ) - .await - } - JsonRpcMessage::Request(RequestEnvelope { - id, - request: JsonRpcRequest::GetChatTemplateOverride, - }) => Ok( - message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { - generated_by: None, - request_id: id, - response: JsonRpcResponse::ChatTemplateOverride( - if let Some(agent_applicable_state) = - agent_applicable_state_holder.get_agent_applicable_state() - { - agent_applicable_state.chat_template_override - } else { - None - }, - ), - }))?, - ), - JsonRpcMessage::Request(RequestEnvelope { - id, - request: JsonRpcRequest::GetModelMetadata, - }) => Ok( - message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { - generated_by: None, - request_id: id, - response: JsonRpcResponse::ModelMetadata( - model_metadata_holder.get_model_metadata(), - ), - }))?, - ), - } - } - - fn handle_incoming_message( - incoming_message_context: IncomingMessageContext, - msg: Message, - pong_tx: &mpsc::UnboundedSender, - ) -> Result<()> { - match msg { - Message::Text(text) => { - let connection_close = incoming_message_context.connection_close.clone(); - - tokio::spawn(async move { - tokio::select! { - () = connection_close.cancelled() => { - info!("Connection close signal received, shutting down"); - } - result = Self::handle_deserialized_message( - incoming_message_context, - match serde_json::from_str::(&text).context(format!("Failed to parse JSON-RPC message: {text}")) { - Ok(message) => message, - Err(err) => { - error!("Failed to deserialize message: {err}"); - - return; - } - }, - ) => if let Err(err) = result { - error!("Error handling incoming message: {err}"); - } - } - }); - - Ok(()) - } - Message::Binary(_) => { - error!("Received binary message, which is not expected"); - - Ok(()) - } - Message::Close(_) => { - info!("Connection closed by server"); - - Ok(()) - } - Message::Frame(_) => { - error!("Received a frame message, which is not expected"); - - Ok(()) - } - Message::Ping(payload) => Ok(pong_tx.send(payload)?), - Message::Pong(_) => { - // Pong received, no action needed - Ok(()) - } - } - } - - async fn keep_connection_alive(&self, shutdown: CancellationToken) -> Result<()> { - info!("Connecting to management server at {}", self.socket_url); - - let (ws_stream, _response) = connect_async(self.socket_url.clone()).await?; - - info!("Connected to management server"); - - let connection_close = CancellationToken::new(); - let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); - let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); - let (mut write, mut read) = ws_stream.split(); - - let forward_connection_close = connection_close.clone(); - let forward_shutdown = shutdown.clone(); - - let message_forward_handle = tokio::spawn(async move { - loop { - tokio::select! { - () = forward_connection_close.cancelled() => { - break; - } - () = forward_shutdown.cancelled() => { - info!("Shutdown signal received, deregistering agent"); - - write.send(Message::Text(match serde_json::to_string( - &ManagementJsonRpcMessage::Notification( - ManagementJsonRpcNotification::DeregisterAgent, - ) - ) { - Ok(serialized_message) => serialized_message.into(), - Err(err) => { - error!("Failed to serialize deregister agent notification: {err}"); - return; - } - })).await.unwrap_or_else(|err| { - error!("Failed to send deregister agent notification: {err}"); - }); - - break; - } - message = message_rx.recv() => { - match message { - Some(msg) => { - match serde_json::to_string(&msg) { - Ok(serialized_message) => { - let message = Message::Text(serialized_message.into()); - - if let Err(err) = write.send(message).await { - error!("Failed to send message: {err}"); - break; - } - }, - Err(err) => { - error!("Failed to serialize message: {err}"); - } - } - } - None => break, - } - } - payload = pong_rx.recv() => { - match payload { - Some(payload) => { - write.send(Message::Pong(payload)).await.unwrap_or_else(|err| { - error!("Failed to send pong message: {err}"); - }); - } - None => break, - } - } - } - } - }); - - match self.slot_aggregated_status.make_snapshot() { - Ok(slot_aggregated_status_snapshot) => { - message_tx - .send(ManagementJsonRpcMessage::Notification( - ManagementJsonRpcNotification::RegisterAgent(RegisterAgentParams { - name: self.name.clone(), - slot_aggregated_status_snapshot, - }), - )) - .unwrap_or_else(|err| { - error!("Failed to send register agent notification: {err}"); - }); - } - Err(err) => { - error!("Failed to create slot aggregated status snapshot: {err}"); - - return Err(err); - } - } - - let do_send_status_update = || match self.slot_aggregated_status.make_snapshot() { - Ok(slot_aggregated_status_snapshot) => { - message_tx - .send(ManagementJsonRpcMessage::Notification( - ManagementJsonRpcNotification::UpdateAgentStatus(UpdateAgentStatusParams { - slot_aggregated_status_snapshot, - }), - )) - .unwrap_or_else(|err| { - error!("Failed to send status update notification: {err}"); - }); - } - Err(err) => error!("Failed to create slot aggregated status snapshot: {err}"), - }; - - let mut ticker = interval(Duration::from_secs(1)); - let mut update_rx = self.slot_aggregated_status.subscribe_to_updates(); - - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - () = connection_close.cancelled() => { - info!("Connection close signal received, shutting down"); - - break; - } - () = shutdown.cancelled() => break, - changed = update_rx.changed() => { - if changed.is_err() { - break; - } - do_send_status_update(); - } - _ = ticker.tick() => do_send_status_update(), - msg = read.next() => { - let should_close = match msg { - Some(Ok(msg)) => { - if let Err(err) = Self::handle_incoming_message( - IncomingMessageContext { - agent_applicable_state_holder: self.agent_applicable_state_holder.clone(), - agent_desired_state_tx: self.agent_desired_state_tx.clone(), - connection_close: connection_close.clone(), - continue_from_conversation_history_request_tx: self.continue_from_conversation_history_request_tx.clone(), - continue_from_raw_prompt_request_tx: self.continue_from_raw_prompt_request_tx.clone(), - generate_embedding_batch_request_tx: self.generate_embedding_batch_request_tx.clone(), - model_metadata_holder: self.model_metadata_holder.clone(), - receive_stream_stopper_collection: self.receive_stream_stopper_collection.clone(), - message_tx: message_tx.clone(), - slot_aggregated_status: self.slot_aggregated_status.clone(), - }, - msg, - &pong_tx, - ) - .context("Failed to handle incoming message") - { - error!("Error handling incoming message: {err}"); - } - - false - } - Some(Err(err)) => { - error!("Error reading message: {err}"); - - true - } - None => true, - }; - - if should_close { - connection_close.cancel(); - - break; - } - } - } - } - - message_forward_handle - .await - .context("Failed to join message forwarding task")?; - - Ok(()) - } -} - -#[async_trait] -impl Service for ManagementSocketClientService { - fn name(&self) -> &'static str { - "agent::management_socket_client_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - let mut ticker = interval(Duration::from_secs(1)); - - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - () = shutdown.cancelled() => break Ok(()), - _ = ticker.tick() => { - match self.keep_connection_alive(shutdown.clone()).await { - Err(err) => { - error!("Failed to keep the connection alive: {err:?}"); - } - Ok(()) => { - info!("Gracefully closed connection to management server"); - } - } - } - } - } - } -} diff --git a/paddler/src/agent/model_metadata_holder.rs b/paddler/src/agent/model_metadata_holder.rs deleted file mode 100644 index e3bac8d2..00000000 --- a/paddler/src/agent/model_metadata_holder.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::RwLock; - -use paddler_types::model_metadata::ModelMetadata; - -pub struct ModelMetadataHolder { - model_metadata: RwLock>, -} - -impl ModelMetadataHolder { - #[must_use] - pub fn new() -> Self { - Self::default() - } - - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] - pub fn set_model_metadata(&self, metadata: ModelMetadata) { - let mut lock = self - .model_metadata - .write() - .expect("Failed to acquire write lock on model metadata"); - - *lock = Some(metadata); - } - - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] - pub fn get_model_metadata(&self) -> Option { - let lock = self - .model_metadata - .read() - .expect("Failed to acquire read lock on model metadata"); - - lock.clone() - } -} - -impl Default for ModelMetadataHolder { - fn default() -> Self { - Self { - model_metadata: RwLock::new(None), - } - } -} diff --git a/paddler/src/agent/receive_stream_stopper_drop_guard.rs b/paddler/src/agent/receive_stream_stopper_drop_guard.rs deleted file mode 100644 index 99e25752..00000000 --- a/paddler/src/agent/receive_stream_stopper_drop_guard.rs +++ /dev/null @@ -1,24 +0,0 @@ -use std::sync::Arc; - -use log::error; - -use crate::agent::receive_stream_stopper_collection::ReceiveStreamStopperCollection; - -pub struct ReceiveStreamStopperDropGuard { - pub receive_stream_stopper_collection: Arc, - pub request_id: String, -} - -impl Drop for ReceiveStreamStopperDropGuard { - fn drop(&mut self) { - if let Err(err) = self - .receive_stream_stopper_collection - .deregister_stopper(&self.request_id) - { - error!( - "Failed to deregister stopper for request_id {}: {}", - self.request_id, err - ); - } - } -} diff --git a/paddler/src/agent/resolve_grammar.rs b/paddler/src/agent/resolve_grammar.rs deleted file mode 100644 index aa057baf..00000000 --- a/paddler/src/agent/resolve_grammar.rs +++ /dev/null @@ -1,42 +0,0 @@ -use anyhow::Result; -use anyhow::anyhow; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::grammar_constraint::GrammarConstraint; -use tokio::sync::mpsc; - -use crate::agent::grammar_sampler::GrammarSampler; - -pub fn resolve_grammar( - grammar: Option<&GrammarConstraint>, - enable_thinking: bool, - generated_tokens_tx: &mpsc::UnboundedSender, -) -> Result> { - let Some(grammar_constraint) = grammar else { - return Ok(None); - }; - - if enable_thinking { - let message = "Grammar constraints are incompatible with thinking mode".to_owned(); - - generated_tokens_tx - .send(GeneratedTokenResult::GrammarIncompatibleWithThinking( - message.clone(), - )) - .map_err(|err| anyhow!("Failed to send grammar incompatibility error: {err}"))?; - - return Err(anyhow!(message)); - } - - match GrammarSampler::new(grammar_constraint) { - Ok(sampler) => Ok(Some(sampler)), - Err(err) => { - let message = format!("Failed to create grammar sampler: {err}"); - - generated_tokens_tx - .send(GeneratedTokenResult::GrammarSyntaxError(message.clone())) - .map_err(|send_err| anyhow!("Failed to send grammar syntax error: {send_err}"))?; - - Err(anyhow!(message)) - } - } -} diff --git a/paddler/src/balancer/compatibility/openai_service/app_data.rs b/paddler/src/balancer/compatibility/openai_service/app_data.rs deleted file mode 100644 index 9f989cb1..00000000 --- a/paddler/src/balancer/compatibility/openai_service/app_data.rs +++ /dev/null @@ -1,9 +0,0 @@ -use std::sync::Arc; - -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::inference_service::configuration::Configuration; - -pub struct AppData { - pub buffered_request_manager: Arc, - pub inference_service_configuration: Configuration, -} diff --git a/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs b/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs deleted file mode 100644 index 1abdca71..00000000 --- a/paddler/src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs +++ /dev/null @@ -1,1613 +0,0 @@ -use std::sync::Arc; -use std::sync::Mutex; -use std::time::SystemTime; -use std::time::UNIX_EPOCH; - -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::post; -use actix_web::web; -use anyhow::Context as _; -use anyhow::Result; -use anyhow::anyhow; -use async_trait::async_trait; -use nanoid::nanoid; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::generation_summary::GenerationSummary; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::inference_client::Response as OutgoingResponse; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::ResponseEnvelope; -use llama_cpp_bindings::ParsedToolCall; -use llama_cpp_bindings::TokenUsage; -use llama_cpp_bindings::ToolCallArguments; -use paddler_types::oversized_image_details::OversizedImageDetails; -use paddler_types::raw_tool_call_tokens::RawToolCallTokens; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; -use paddler_types::validates::Validates; -use serde::Deserialize; -use serde_json::json; -use tokio_stream::StreamExt as _; - -use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; -use crate::balancer::compatibility::openai_service::app_data::AppData; -use crate::balancer::http_stream_from_agent::http_stream_from_agent; -use crate::balancer::unbounded_stream_from_agent::unbounded_stream_from_agent; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -fn openai_error_json(error_type: &str, message: &str) -> serde_json::Value { - json!({ - "error": { - "message": message, - "type": error_type, - "param": null, - "code": null - } - }) -} - -fn openai_usage_json(usage: &TokenUsage) -> serde_json::Value { - json!({ - "prompt_tokens": usage.prompt_tokens, - "completion_tokens": usage.completion_tokens(), - "total_tokens": usage.total_tokens(), - "prompt_tokens_details": { - "cached_tokens": usage.cached_prompt_tokens, - "audio_tokens": usage.input_audio_tokens, - "image_tokens": usage.input_image_tokens, - }, - "completion_tokens_details": { - "reasoning_tokens": usage.reasoning_tokens, - } - }) -} - -#[expect( - clippy::expect_used, - reason = "system time before UNIX_EPOCH means we are moving back in time" -)] -fn current_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time went backwards") - .as_secs() -} - -fn validation_failure_message(errors: &[String]) -> String { - errors - .first() - .cloned() - .unwrap_or_else(|| "tool call failed validation".to_owned()) -} - -fn unrecognized_tool_call_format_message(raw: &RawToolCallTokens) -> String { - format!( - "model produced output the parser did not recognise as any registered tool-call format; \ - FFI error: {}; raw text: {}", - raw.ffi_error_message, raw.text, - ) -} - -fn image_exceeds_batch_size_message(details: &OversizedImageDetails) -> String { - format!( - "image required {} tokens but agent n_batch is {}; rerun with a larger n_batch", - details.image_tokens, details.n_batch, - ) -} - -fn arguments_to_openai_string(arguments: &ToolCallArguments) -> Result { - match arguments { - ToolCallArguments::ValidJson(value) => { - serde_json::to_string(value).context("serializing tool-call arguments to OpenAI string") - } - ToolCallArguments::InvalidJson(raw) => Ok(raw.clone()), - } -} - -fn server_error_chunk(description: &str) -> TransformResult { - TransformResult::Error(openai_error_json("server_error", description).to_string()) -} - -fn timeout_response_chunk() -> TransformResult { - TransformResult::Error(openai_error_json("timeout", "request timed out").to_string()) -} - -fn rate_limit_response_chunk() -> TransformResult { - TransformResult::Error( - openai_error_json("rate_limit_error", "too many buffered requests").to_string(), - ) -} - -fn unexpected_embedding_response_chunk() -> TransformResult { - TransformResult::Error( - openai_error_json( - "invalid_request_error", - "unexpected embedding response in chat completions", - ) - .to_string(), - ) -} - -fn description_from_error_token(token: &GeneratedTokenResult) -> Option<&str> { - match token { - GeneratedTokenResult::ChatTemplateError(description) - | GeneratedTokenResult::GrammarIncompatibleWithThinking(description) - | GeneratedTokenResult::GrammarRejectedModelOutput(description) - | GeneratedTokenResult::GrammarInitializationFailed(description) - | GeneratedTokenResult::GrammarSyntaxError(description) - | GeneratedTokenResult::ImageDecodingFailed(description) - | GeneratedTokenResult::MultimodalNotSupported(description) - | GeneratedTokenResult::SamplerError(description) - | GeneratedTokenResult::ToolCallParseFailed(description) - | GeneratedTokenResult::ToolSchemaInvalid(description) => Some(description), - _ => None, - } -} - -fn try_universal_error_chunk(message: &OutgoingMessage) -> Option { - match message { - OutgoingMessage::Error(ErrorEnvelope { - error: paddler_types::jsonrpc::Error { description, .. }, - .. - }) => Some(server_error_chunk(description)), - OutgoingMessage::Response(ResponseEnvelope { response, .. }) => match response { - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ImageExceedsBatchSize( - details, - )) => Some(server_error_chunk(&image_exceeds_batch_size_message( - details, - ))), - OutgoingResponse::GeneratedToken(token) => { - description_from_error_token(token).map(server_error_chunk) - } - OutgoingResponse::Timeout => Some(timeout_response_chunk()), - OutgoingResponse::TooManyBufferedRequests => Some(rate_limit_response_chunk()), - OutgoingResponse::Embedding(_) => Some(unexpected_embedding_response_chunk()), - }, - } -} - -#[derive(Deserialize)] -struct OpenAIMessage { - content: ConversationMessageContent, - role: String, -} - -impl From<&OpenAIMessage> for ConversationMessage { - fn from(openai_message: &OpenAIMessage) -> Self { - Self { - content: openai_message.content.clone(), - role: openai_message.role.clone(), - } - } -} - -#[derive(Default, Deserialize)] -#[serde(deny_unknown_fields)] -struct StreamOptions { - #[serde(default)] - include_usage: bool, -} - -#[derive(Deserialize)] -struct OpenAICompletionRequestParams { - max_completion_tokens: Option, - messages: Vec, - /// This parameter is ignored here, but is required by the `OpenAI` API. - model: String, - stream: Option, - stream_options: Option, - #[serde(default)] - tools: Vec>, -} - -#[derive(Default)] -struct OpenAIStreamingState { - saw_tool_call: bool, -} - -#[derive(Clone)] -struct OpenAIStreamingResponseTransformer { - include_usage: bool, - model: String, - state: Arc>, - system_fingerprint: String, -} - -impl OpenAIStreamingResponseTransformer { - fn content_chunk(&self, request_id: &str, text: &str) -> Result { - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "content": text, - }, - "logprobs": null, - "finish_reason": null - } - ] - }))?) - } - - fn reasoning_chunk(&self, request_id: &str, text: &str) -> Result { - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "reasoning_content": text, - }, - "logprobs": null, - "finish_reason": null - } - ] - }))?) - } - - fn tool_calls_chunk( - &self, - request_id: &str, - parsed_calls: &[ParsedToolCall], - ) -> Result { - let tool_calls = parsed_calls - .iter() - .enumerate() - .map(|(index, call)| -> Result { - let arguments = arguments_to_openai_string(&call.arguments)?; - Ok(json!({ - "index": index, - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": arguments, - } - })) - }) - .collect::>>()?; - - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "tool_calls": tool_calls, - }, - "logprobs": null, - "finish_reason": null - } - ] - }))?) - } - - fn finish_chunk(&self, request_id: &str, finish_reason: &str) -> Result { - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": null, - "finish_reason": finish_reason - } - ] - }))?) - } - - fn usage_chunk(&self, request_id: &str, usage: &TokenUsage) -> Result { - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion.chunk", - "created": current_timestamp(), - "model": self.model, - "system_fingerprint": self.system_fingerprint, - "choices": [], - "usage": openai_usage_json(usage), - }))?) - } - - fn handle_content(&self, request_id: &str, text: &str) -> Result> { - Ok(vec![TransformResult::Chunk( - self.content_chunk(request_id, text)?, - )]) - } - - fn handle_reasoning(&self, request_id: &str, text: &str) -> Result> { - Ok(vec![TransformResult::Chunk( - self.reasoning_chunk(request_id, text)?, - )]) - } - - fn handle_tool_call_parsed( - &self, - request_id: &str, - parsed_calls: &[ParsedToolCall], - ) -> Result> { - if parsed_calls.is_empty() { - return Ok(vec![]); - } - - self.state - .lock() - .map_err(|err| anyhow!("streaming state mutex poisoned: {err}"))? - .saw_tool_call = true; - - Ok(vec![TransformResult::Chunk( - self.tool_calls_chunk(request_id, parsed_calls)?, - )]) - } - - fn handle_done( - &self, - request_id: &str, - summary: &GenerationSummary, - ) -> Result> { - let saw_tool_call = self - .state - .lock() - .map_err(|err| anyhow!("streaming state mutex poisoned: {err}"))? - .saw_tool_call; - - let finish_reason = if saw_tool_call { "tool_calls" } else { "stop" }; - let finish = TransformResult::Chunk(self.finish_chunk(request_id, finish_reason)?); - - if self.include_usage { - let usage = TransformResult::Chunk(self.usage_chunk(request_id, &summary.usage)?); - Ok(vec![finish, usage]) - } else { - Ok(vec![finish]) - } - } -} - -#[async_trait] -impl TransformsOutgoingMessage for OpenAIStreamingResponseTransformer { - async fn transform(&self, message: OutgoingMessage) -> Result> { - if let Some(error_chunk) = try_universal_error_chunk(&message) { - return Ok(vec![error_chunk]); - } - - match message { - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: - OutgoingResponse::GeneratedToken( - GeneratedTokenResult::ContentToken(text) - | GeneratedTokenResult::UndeterminableToken(text), - ), - .. - }) => self.handle_content(&request_id, &text), - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ReasoningToken(text)), - .. - }) => self.handle_reasoning(&request_id, &text), - OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallToken(_)), - .. - }) => Ok(vec![]), - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), - .. - }) => self.handle_tool_call_parsed(&request_id, &parsed_calls), - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallValidationFailed( - errors, - )), - .. - }) => Ok(vec![server_error_chunk(&validation_failure_message( - &errors, - ))]), - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::UnrecognizedToolCallFormat( - raw, - )), - .. - }) => Ok(vec![server_error_chunk( - &unrecognized_tool_call_format_message(&raw), - )]), - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), - .. - }) => self.handle_done(&request_id, &summary), - other => Err(anyhow!( - "OpenAIStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" - )), - } - } -} - -#[derive(Clone, Default)] -struct OpenAINonStreamingState { - content: String, - reasoning: String, - tool_calls: Vec, -} - -#[derive(Clone)] -struct OpenAINonStreamingResponseTransformer { - model: String, - state: Arc>, -} - -impl OpenAINonStreamingResponseTransformer { - fn append_content(&self, text: &str) -> Result<()> { - self.state - .lock() - .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? - .content - .push_str(text); - Ok(()) - } - - fn append_reasoning(&self, text: &str) -> Result<()> { - self.state - .lock() - .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? - .reasoning - .push_str(text); - Ok(()) - } - - fn append_tool_calls(&self, parsed_calls: Vec) -> Result<()> { - self.state - .lock() - .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))? - .tool_calls - .extend(parsed_calls); - Ok(()) - } - - fn build_done_chunk(&self, request_id: &str, summary: &GenerationSummary) -> Result { - let snapshot = self.snapshot_state()?; - - let has_tool_calls = !snapshot.tool_calls.is_empty(); - let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" }; - - let mut message_obj = json!({ - "role": "assistant", - "content": if snapshot.content.is_empty() && has_tool_calls { - serde_json::Value::Null - } else { - json!(snapshot.content) - }, - "refusal": null, - "annotations": [] - }); - - if !snapshot.reasoning.is_empty() - && let Some(map) = message_obj.as_object_mut() - { - map.insert("reasoning_content".to_owned(), json!(snapshot.reasoning)); - } - - if has_tool_calls && let Some(map) = message_obj.as_object_mut() { - let tool_calls_json = snapshot - .tool_calls - .iter() - .map(|call| -> Result { - let arguments = arguments_to_openai_string(&call.arguments)?; - Ok(json!({ - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": arguments, - } - })) - }) - .collect::>>()?; - map.insert("tool_calls".to_owned(), json!(tool_calls_json)); - } - - Ok(serde_json::to_string(&json!({ - "id": request_id, - "object": "chat.completion", - "created": current_timestamp(), - "model": self.model, - "choices": [ - { - "index": 0, - "message": message_obj, - "logprobs": null, - "finish_reason": finish_reason - } - ], - "usage": openai_usage_json(&summary.usage), - "service_tier": "default" - }))?) - } - - fn snapshot_state(&self) -> Result { - let state = self - .state - .lock() - .map_err(|err| anyhow!("non-streaming state mutex poisoned: {err}"))?; - Ok(state.clone()) - } -} - -#[async_trait] -impl TransformsOutgoingMessage for OpenAINonStreamingResponseTransformer { - async fn transform(&self, message: OutgoingMessage) -> Result> { - if let Some(error_chunk) = try_universal_error_chunk(&message) { - return Ok(vec![error_chunk]); - } - - match message { - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken( - GeneratedTokenResult::ContentToken(text) - | GeneratedTokenResult::UndeterminableToken(text), - ), - .. - }) => { - self.append_content(&text)?; - Ok(vec![]) - } - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ReasoningToken(text)), - .. - }) => { - self.append_reasoning(&text)?; - Ok(vec![]) - } - OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallToken(_)), - .. - }) => Ok(vec![]), - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), - .. - }) => { - self.append_tool_calls(parsed_calls)?; - Ok(vec![]) - } - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallValidationFailed( - errors, - )), - .. - }) => Ok(vec![server_error_chunk(&validation_failure_message( - &errors, - ))]), - OutgoingMessage::Response(ResponseEnvelope { - response: - OutgoingResponse::GeneratedToken(GeneratedTokenResult::UnrecognizedToolCallFormat( - raw, - )), - .. - }) => Ok(vec![server_error_chunk( - &unrecognized_tool_call_format_message(&raw), - )]), - OutgoingMessage::Response(ResponseEnvelope { - request_id, - response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), - .. - }) => Ok(vec![TransformResult::Chunk( - self.build_done_chunk(&request_id, &summary)?, - )]), - other => Err(anyhow!( - "OpenAINonStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" - )), - } - } -} - -#[post("/v1/chat/completions")] -async fn respond( - app_data: web::Data, - openai_params: web::Json, -) -> Result { - let openai_params = openai_params.into_inner(); - - let validated_tools = match openai_params - .tools - .into_iter() - .map(Validates::validate) - .collect::, _>>() - { - Ok(tools) => tools, - Err(err) => { - return Ok(HttpResponse::BadRequest() - .content_type("application/json") - .body(openai_error_json("invalid_request_error", &err.to_string()).to_string())); - } - }; - - let parse_tool_calls = !validated_tools.is_empty(); - let paddler_params = ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new( - openai_params - .messages - .iter() - .map(ConversationMessage::from) - .collect(), - ), - enable_thinking: true, - grammar: None, - max_tokens: openai_params.max_completion_tokens.unwrap_or(2000), - parse_tool_calls, - tools: validated_tools, - }; - - if openai_params.stream.unwrap_or(false) { - let include_usage = openai_params - .stream_options - .as_ref() - .is_some_and(|options| options.include_usage); - - Ok(http_stream_from_agent( - app_data.buffered_request_manager.clone(), - app_data.inference_service_configuration.clone(), - paddler_params, - OpenAIStreamingResponseTransformer { - include_usage, - model: openai_params.model.clone(), - state: Arc::new(Mutex::new(OpenAIStreamingState::default())), - system_fingerprint: nanoid!(), - }, - )) - } else { - let results: Vec = unbounded_stream_from_agent( - app_data.buffered_request_manager.clone(), - app_data.inference_service_configuration.clone(), - paddler_params, - OpenAINonStreamingResponseTransformer { - model: openai_params.model.clone(), - state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), - }, - ) - .collect() - .await; - - if let Some(TransformResult::Error(error_json)) = results - .iter() - .find(|result| matches!(result, TransformResult::Error(_))) - { - return Ok(HttpResponse::InternalServerError() - .content_type("application/json") - .body(error_json.clone())); - } - - let body = results.into_iter().find_map(|result| match result { - TransformResult::Chunk(content) => Some(content), - TransformResult::Discard | TransformResult::Error(_) => None, - }); - - Ok(body.map_or_else( - || { - HttpResponse::InternalServerError() - .content_type("application/json") - .body(openai_error_json("server_error", "no completion produced").to_string()) - }, - |json_body| { - HttpResponse::Ok() - .content_type("application/json") - .body(json_body) - }, - )) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - use std::sync::Mutex; - - use anyhow::Result; - use llama_cpp_bindings::ParsedToolCall; - use llama_cpp_bindings::TokenUsage; - use llama_cpp_bindings::ToolCallArguments; - use paddler_types::generated_token_result::GeneratedTokenResult; - use paddler_types::generation_summary::GenerationSummary; - use paddler_types::inference_client::Message as OutgoingMessage; - use paddler_types::inference_client::Response as OutgoingResponse; - use paddler_types::jsonrpc::ErrorEnvelope; - use paddler_types::jsonrpc::ResponseEnvelope; - - use super::OpenAINonStreamingResponseTransformer; - use super::OpenAINonStreamingState; - use super::OpenAIStreamingResponseTransformer; - use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; - use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; - - fn make_token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { - OutgoingMessage::Response(ResponseEnvelope { - generated_by: None, - request_id: "test-request".to_owned(), - response: OutgoingResponse::GeneratedToken(token_result), - }) - } - - fn make_error_message(code: i32, description: &str) -> OutgoingMessage { - OutgoingMessage::Error(ErrorEnvelope { - request_id: "test-request".to_owned(), - error: paddler_types::jsonrpc::Error { - code, - description: description.to_owned(), - }, - }) - } - - fn make_response_message(response: OutgoingResponse) -> OutgoingMessage { - OutgoingMessage::Response(ResponseEnvelope { - generated_by: None, - request_id: "test-request".to_owned(), - response, - }) - } - - fn streaming_transformer(include_usage: bool) -> OpenAIStreamingResponseTransformer { - OpenAIStreamingResponseTransformer { - include_usage, - model: "test-model".to_owned(), - state: Arc::new(Mutex::new(super::OpenAIStreamingState::default())), - system_fingerprint: "test-fingerprint".to_owned(), - } - } - - fn non_streaming_transformer() -> OpenAINonStreamingResponseTransformer { - OpenAINonStreamingResponseTransformer { - model: "test-model".to_owned(), - state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), - } - } - - fn assert_chunk_contains(result: &TransformResult, expected: &str) -> Result<()> { - let TransformResult::Chunk(content) = result else { - anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); - }; - - assert!( - content.contains(expected), - "chunk does not contain '{expected}': {content}" - ); - - Ok(()) - } - - fn assert_chunk_does_not_contain(result: &TransformResult, expected: &str) -> Result<()> { - let TransformResult::Chunk(content) = result else { - anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); - }; - - assert!( - !content.contains(expected), - "chunk unexpectedly contains '{expected}': {content}" - ); - - Ok(()) - } - - fn assert_error_contains(result: &TransformResult, expected: &str) -> Result<()> { - let TransformResult::Error(content) = result else { - anyhow::bail!("expected TransformResult::Error, got TransformResult::Chunk"); - }; - - assert!( - content.contains(expected), - "error does not contain '{expected}': {content}" - ); - - Ok(()) - } - - fn summary_with_counts( - prompt_tokens: u64, - content_tokens: u64, - reasoning_tokens: u64, - ) -> GenerationSummary { - GenerationSummary { - usage: TokenUsage { - prompt_tokens, - content_tokens, - reasoning_tokens, - ..TokenUsage::default() - }, - } - } - - fn weather_call() -> ParsedToolCall { - ParsedToolCall::new( - "call_x".to_owned(), - "get_weather".to_owned(), - ToolCallArguments::ValidJson(serde_json::json!({"location": "Paris"})), - ) - } - - #[actix_web::test] - async fn streaming_content_token_emits_content_delta() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::ContentToken("hello".to_owned())); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"content\":\"hello\"")?; - assert_chunk_contains(&chunks[0], "\"role\":\"assistant\"")?; - assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_reasoning_token_emits_reasoning_content_delta() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = - make_token_message(GeneratedTokenResult::ReasoningToken("thought".to_owned())); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"reasoning_content\":\"thought\"")?; - assert_chunk_contains(&chunks[0], "\"role\":\"assistant\"")?; - assert_chunk_does_not_contain(&chunks[0], "\"content\":")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_undeterminable_token_emits_content_delta() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::UndeterminableToken( - "ambig".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"content\":\"ambig\"")?; - assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_tool_call_token_is_silently_dropped() -> Result<()> { - let transformer = streaming_transformer(false); - - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::ToolCallToken( - "{".to_owned(), - ))) - .await?; - - assert_eq!(chunks.len(), 0); - - Ok(()) - } - - #[actix_web::test] - async fn streaming_tool_call_parsed_emits_structured_tool_calls_chunk() -> Result<()> { - let transformer = streaming_transformer(false); - - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( - vec![weather_call()], - ))) - .await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"tool_calls\"")?; - assert_chunk_contains(&chunks[0], "\"id\":\"call_x\"")?; - assert_chunk_contains(&chunks[0], "\"name\":\"get_weather\"")?; - assert_chunk_contains( - &chunks[0], - "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", - )?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_done_after_tool_call_uses_tool_calls_finish_reason() -> Result<()> { - let transformer = streaming_transformer(false); - - transformer - .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( - vec![weather_call()], - ))) - .await?; - - let summary = summary_with_counts(2, 0, 0); - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"finish_reason\":\"tool_calls\"")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_done_without_tool_call_uses_stop_finish_reason() -> Result<()> { - let transformer = streaming_transformer(false); - - transformer - .transform(make_token_message(GeneratedTokenResult::ContentToken( - "hi".to_owned(), - ))) - .await?; - - let summary = summary_with_counts(2, 1, 0); - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_done_with_include_usage_emits_finish_then_usage_chunk() -> Result<()> { - let transformer = streaming_transformer(true); - let summary = summary_with_counts(7, 4, 1); - - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(chunks.len(), 2); - assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; - assert_chunk_does_not_contain(&chunks[0], "usage")?; - assert_chunk_contains(&chunks[1], "\"prompt_tokens\":7")?; - assert_chunk_contains(&chunks[1], "\"completion_tokens\":5")?; - assert_chunk_contains(&chunks[1], "\"total_tokens\":12")?; - assert_chunk_contains(&chunks[1], "\"choices\":[]")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_done_without_include_usage_emits_only_finish_chunk() -> Result<()> { - let transformer = streaming_transformer(false); - let summary = summary_with_counts(5, 3, 2); - - let chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(chunks.len(), 1); - assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; - assert_chunk_does_not_contain(&chunks[0], "usage")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_tool_call_parse_failed_emits_server_error() -> Result<()> { - let transformer = streaming_transformer(false); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::ToolCallParseFailed("bad payload".to_owned()), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "bad payload")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_tool_call_validation_failed_emits_server_error() -> Result<()> { - let transformer = streaming_transformer(false); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::ToolCallValidationFailed(vec!["missing field x".to_owned()]), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "missing field x")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { - let transformer = streaming_transformer(false); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::UnrecognizedToolCallFormat( - paddler_types::raw_tool_call_tokens::RawToolCallTokens { - text: "blah".to_owned(), - ffi_error_message: "common_chat_parse failed: no parser".to_owned(), - }, - ), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; - assert_error_contains(&chunks[0], "blah")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_error_message_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_error_message(500, "internal server error"); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "internal server error")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_chat_template_error_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::ChatTemplateError( - "bad template".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "bad template")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_timeout_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_response_message(OutgoingResponse::Timeout); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "request timed out")?; - assert_error_contains(&chunks[0], "timeout")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "too many buffered requests")?; - assert_error_contains(&chunks[0], "rate_limit_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_image_decoding_failed_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::ImageDecodingFailed( - "unsupported format".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "unsupported format")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::MultimodalNotSupported( - "model does not support images".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "model does not support images")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { - let transformer = streaming_transformer(false); - - let message = make_token_message(GeneratedTokenResult::ImageExceedsBatchSize( - paddler_types::oversized_image_details::OversizedImageDetails { - image_tokens: 368, - n_batch: 100, - }, - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "368")?; - assert_error_contains(&chunks[0], "100")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_aggregates_content_only_when_no_reasoning() -> Result<()> { - let transformer = non_streaming_transformer(); - - transformer - .transform(make_token_message(GeneratedTokenResult::ContentToken( - "hel".to_owned(), - ))) - .await?; - transformer - .transform(make_token_message(GeneratedTokenResult::ContentToken( - "lo".to_owned(), - ))) - .await?; - - let summary = summary_with_counts(4, 2, 0); - let final_chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(final_chunks.len(), 1); - assert_chunk_contains(&final_chunks[0], "\"content\":\"hello\"")?; - assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; - assert_chunk_contains(&final_chunks[0], "\"prompt_tokens\":4")?; - assert_chunk_contains(&final_chunks[0], "\"completion_tokens\":2")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_separates_reasoning_from_content() -> Result<()> { - let transformer = non_streaming_transformer(); - - transformer - .transform(make_token_message(GeneratedTokenResult::ReasoningToken( - "think".to_owned(), - ))) - .await?; - transformer - .transform(make_token_message(GeneratedTokenResult::ContentToken( - "answer".to_owned(), - ))) - .await?; - - let summary = summary_with_counts(3, 1, 1); - let final_chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(final_chunks.len(), 1); - assert_chunk_contains(&final_chunks[0], "\"content\":\"answer\"")?; - assert_chunk_contains(&final_chunks[0], "\"reasoning_content\":\"think\"")?; - assert_chunk_contains(&final_chunks[0], "\"reasoning_tokens\":1")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_undeterminable_routes_to_content() -> Result<()> { - let transformer = non_streaming_transformer(); - - transformer - .transform(make_token_message( - GeneratedTokenResult::UndeterminableToken("amb".to_owned()), - )) - .await?; - - let summary = summary_with_counts(2, 0, 0); - let final_chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(final_chunks.len(), 1); - assert_chunk_contains(&final_chunks[0], "\"content\":\"amb\"")?; - assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_tool_call_parsed_populates_message_tool_calls() -> Result<()> { - let transformer = non_streaming_transformer(); - - transformer - .transform(make_token_message(GeneratedTokenResult::ToolCallParsed( - vec![weather_call()], - ))) - .await?; - - let summary = summary_with_counts(4, 0, 0); - let final_chunks = transformer - .transform(make_token_message(GeneratedTokenResult::Done(summary))) - .await?; - - assert_eq!(final_chunks.len(), 1); - assert_chunk_contains(&final_chunks[0], "\"tool_calls\":")?; - assert_chunk_contains(&final_chunks[0], "\"name\":\"get_weather\"")?; - assert_chunk_contains( - &final_chunks[0], - "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", - )?; - assert_chunk_contains(&final_chunks[0], "\"finish_reason\":\"tool_calls\"")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_tool_call_parse_failed_emits_error() -> Result<()> { - let transformer = non_streaming_transformer(); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::ToolCallParseFailed("bad payload".to_owned()), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "bad payload")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_tool_call_validation_failed_emits_error() -> Result<()> { - let transformer = non_streaming_transformer(); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::ToolCallValidationFailed(vec!["bad shape".to_owned()]), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "bad shape")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { - let transformer = non_streaming_transformer(); - - let chunks = transformer - .transform(make_token_message( - GeneratedTokenResult::UnrecognizedToolCallFormat( - paddler_types::raw_tool_call_tokens::RawToolCallTokens { - text: "blah".to_owned(), - ffi_error_message: "common_chat_parse failed: no parser".to_owned(), - }, - ), - )) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; - assert_error_contains(&chunks[0], "blah")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_error_message_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let chunks = transformer - .transform(make_error_message(500, "internal server error")) - .await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "internal server error")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_chat_template_error_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_token_message(GeneratedTokenResult::ChatTemplateError( - "bad template".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "bad template")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_image_decoding_failed_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_token_message(GeneratedTokenResult::ImageDecodingFailed( - "unsupported format".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "unsupported format")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_token_message(GeneratedTokenResult::MultimodalNotSupported( - "model does not support images".to_owned(), - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "model does not support images")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_token_message(GeneratedTokenResult::ImageExceedsBatchSize( - paddler_types::oversized_image_details::OversizedImageDetails { - image_tokens: 368, - n_batch: 100, - }, - )); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "368")?; - assert_error_contains(&chunks[0], "100")?; - assert_error_contains(&chunks[0], "server_error")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_timeout_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_response_message(OutgoingResponse::Timeout); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "request timed out")?; - assert_error_contains(&chunks[0], "timeout")?; - - Ok(()) - } - - #[actix_web::test] - async fn non_streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { - let transformer = non_streaming_transformer(); - - let message = make_response_message(OutgoingResponse::TooManyBufferedRequests); - let chunks = transformer.transform(message).await?; - - assert_eq!(chunks.len(), 1); - assert_error_contains(&chunks[0], "too many buffered requests")?; - assert_error_contains(&chunks[0], "rate_limit_error")?; - - Ok(()) - } - - #[test] - fn deserialize_text_only_request() -> Result<()> { - let input = serde_json::json!({ - "model": "test-model", - "messages": [ - {"role": "user", "content": "hello"} - ] - }); - - let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; - - assert_eq!(params.model, "test-model"); - assert_eq!(params.messages.len(), 1); - assert_eq!(params.messages[0].role, "user"); - assert_eq!(params.messages[0].content.text_content(), "hello"); - - Ok(()) - } - - #[test] - fn deserialize_request_with_stream_options_include_usage_true() -> Result<()> { - let input = serde_json::json!({ - "model": "test-model", - "messages": [{"role": "user", "content": "hi"}], - "stream": true, - "stream_options": {"include_usage": true} - }); - - let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; - - let stream_options = params - .stream_options - .ok_or_else(|| anyhow::anyhow!("expected stream_options"))?; - - assert!(stream_options.include_usage); - - Ok(()) - } - - #[test] - fn deserialize_request_without_stream_options_defaults_to_none() -> Result<()> { - let input = serde_json::json!({ - "model": "test-model", - "messages": [{"role": "user", "content": "hi"}], - "stream": true - }); - - let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; - - assert!(params.stream_options.is_none()); - - Ok(()) - } - - #[test] - fn deserialize_multimodal_request_with_image() -> Result<()> { - let input = serde_json::json!({ - "model": "vision-model", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "describe this image"}, - {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ"}} - ] - } - ] - }); - - let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; - - assert_eq!(params.messages.len(), 1); - assert_eq!( - params.messages[0].content.text_content(), - "describe this image" - ); - - let image_urls = params.messages[0].content.image_urls(); - - assert_eq!(image_urls.len(), 1); - assert_eq!(image_urls[0].url, "data:image/jpeg;base64,/9j/4AAQ"); - - Ok(()) - } - - #[test] - fn deserialize_multi_turn_conversation() -> Result<()> { - let input = serde_json::json!({ - "model": "test-model", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - {"role": "user", "content": "And 3+3?"} - ] - }); - - let params: super::OpenAICompletionRequestParams = serde_json::from_value(input)?; - - assert_eq!(params.messages.len(), 4); - - Ok(()) - } - - #[test] - fn openai_message_converts_to_conversation_message() -> Result<()> { - use paddler_types::conversation_message::ConversationMessage; - - let input = serde_json::json!({ - "role": "user", - "content": [ - {"type": "text", "text": "OCR this"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}} - ] - }); - - let openai_message: super::OpenAIMessage = serde_json::from_value(input)?; - let conversation_message = ConversationMessage::from(&openai_message); - - assert_eq!(conversation_message.role, "user"); - assert_eq!(conversation_message.content.text_content(), "OCR this"); - assert_eq!(conversation_message.content.image_urls().len(), 1); - - Ok(()) - } - - #[test] - fn openai_error_json_has_correct_structure() { - let error = super::openai_error_json("server_error", "something went wrong"); - - assert_eq!(error["error"]["type"], "server_error"); - assert_eq!(error["error"]["message"], "something went wrong"); - assert!(error["error"]["param"].is_null()); - assert!(error["error"]["code"].is_null()); - } - - #[test] - fn validation_failure_message_returns_first_error() { - let message = super::validation_failure_message(&[ - "first issue".to_owned(), - "second issue".to_owned(), - ]); - - assert_eq!(message, "first issue"); - } - - #[test] - fn validation_failure_message_falls_back_when_no_errors() { - let message = super::validation_failure_message(&[]); - - assert!(message.contains("validation")); - } -} diff --git a/paddler/src/balancer/compatibility/openai_service/mod.rs b/paddler/src/balancer/compatibility/openai_service/mod.rs deleted file mode 100644 index 55b8fd9f..00000000 --- a/paddler/src/balancer/compatibility/openai_service/mod.rs +++ /dev/null @@ -1,68 +0,0 @@ -pub mod app_data; -pub mod configuration; -pub mod http_route; - -use std::sync::Arc; - -use actix_web::App; -use actix_web::HttpServer; -use actix_web::web::Data; -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; -use trzcina::Service; -use trzcina::ServiceShutdownOptions; - -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::compatibility::openai_service::app_data::AppData; -use crate::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use crate::balancer::http_route as common_http_route; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use crate::create_cors_middleware::create_cors_middleware; - -pub struct OpenAIService { - pub buffered_request_manager: Arc, - pub inference_service_configuration: InferenceServiceConfiguration, - pub openai_service_configuration: OpenAIServiceConfiguration, - pub shutdown_options: ServiceShutdownOptions, -} - -#[async_trait] -impl Service for OpenAIService { - fn name(&self) -> &'static str { - "balancer::compatibility::openai_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - let cors_allowed_hosts = self - .inference_service_configuration - .cors_allowed_hosts - .clone(); - let cors_allowed_hosts_arc = Arc::new(cors_allowed_hosts); - - let app_data = Data::new(AppData { - buffered_request_manager: self.buffered_request_manager.clone(), - inference_service_configuration: self.inference_service_configuration.clone(), - }); - - #[expect(clippy::expect_used, reason = "server bind failure is unrecoverable")] - HttpServer::new(move || { - App::new() - .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) - .app_data(app_data.clone()) - .configure(common_http_route::get_health::register) - .configure(http_route::post_chat_completions::register) - }) - .shutdown_signal(async move { - shutdown.cancelled().await; - }) - .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) - .disable_signals() - .bind(self.openai_service_configuration.addr) - .expect("Unable to bind server to address") - .run() - .await?; - - Ok(()) - } -} diff --git a/paddler/src/balancer/controls_manages_senders_endpoint.rs b/paddler/src/balancer/controls_manages_senders_endpoint.rs deleted file mode 100644 index b3c83a22..00000000 --- a/paddler/src/balancer/controls_manages_senders_endpoint.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::sync::Arc; - -use actix_web::Error; -use actix_web::HttpResponse; -use async_trait::async_trait; -use tokio::time::Duration; -use tokio::time::sleep; - -use crate::balancer::agent_controller::AgentController; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::manages_senders_controller::ManagesSendersController; - -const TIMEOUT: Duration = Duration::from_secs(3); - -#[async_trait] -pub trait ControlsManagesSendersEndpoint { - type SenderCollection: ManagesSenders + Send + Sync + 'static; - - fn get_agent_controller_pool(&self) -> Arc; - - fn get_agent_id(&self) -> String; - - async fn get_manages_senders_controller( - &self, - agent_controller: Arc, - ) -> anyhow::Result>; - - async fn respond(&self) -> Result { - let agent_controller_pool = self.get_agent_controller_pool(); - let agent_id = self.get_agent_id(); - let Some(agent_controller) = agent_controller_pool.get_agent_controller(&agent_id) else { - return Ok(HttpResponse::NotFound().finish()); - }; - - let connection_close = agent_controller.connection_close.clone(); - - match self.get_manages_senders_controller(agent_controller).await { - Ok(mut receive_response_controller) => { - tokio::select! { - () = connection_close.cancelled() => Ok(HttpResponse::BadGateway().finish()), - () = sleep(TIMEOUT) => Ok(HttpResponse::GatewayTimeout().finish()), - response = receive_response_controller.response_rx.recv() => response.map_or_else( - || Ok(HttpResponse::NotFound().finish()), - |existing_response| Ok(HttpResponse::Ok().json(existing_response)), - ), - } - } - Err(err) => Ok(HttpResponse::InternalServerError().body(format!("{err}"))), - } - } -} diff --git a/paddler/src/balancer/http_stream_from_agent.rs b/paddler/src/balancer/http_stream_from_agent.rs deleted file mode 100644 index 7fad8d2e..00000000 --- a/paddler/src/balancer/http_stream_from_agent.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::fmt::Debug; -use std::sync::Arc; - -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::http::header; -use bytes::Bytes; -use futures::stream::StreamExt; -use paddler_types::inference_client::Response as OutgoingResponse; -use paddler_types::streamable_result::StreamableResult; - -use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; -use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::unbounded_stream_from_agent::unbounded_stream_from_agent; - -pub fn http_stream_from_agent( - buffered_request_manager: Arc, - inference_service_configuration: InferenceServiceConfiguration, - params: TParams, - transformer: TTransformsOutgoingMessage, -) -> HttpResponse -where - TParams: Debug + Into + Send + 'static, - AgentController: HandlesAgentStreamingResponse, - <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, - TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, -{ - let stream = unbounded_stream_from_agent( - buffered_request_manager, - inference_service_configuration, - params, - transformer, - ) - .filter_map(|transform_result| async move { - match transform_result { - TransformResult::Chunk(chunk) => { - Some(Ok::<_, Error>(Bytes::from(format!("{chunk}\n")))) - } - TransformResult::Error(error) => { - Some(Ok::<_, Error>(Bytes::from(format!("{error}\n")))) - } - TransformResult::Discard => None, - } - }); - - HttpResponse::Ok() - .insert_header(header::ContentType::json()) - .insert_header((header::CACHE_CONTROL, "no-cache")) - .streaming(stream) -} diff --git a/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs b/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs deleted file mode 100644 index efd59ecc..00000000 --- a/paddler/src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs +++ /dev/null @@ -1,180 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::error::ErrorInternalServerError; -use actix_web::error::ErrorNotImplemented; -use actix_web::error::ErrorServiceUnavailable; -use actix_web::http::header; -use actix_web::post; -use actix_web::rt; -use actix_web::web; -use anyhow::Result; -use async_trait::async_trait; -use bytes::Bytes; -use futures::stream::StreamExt; -use log::error; -use nanoid::nanoid; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::inference_client::Response as OutgoingResponse; -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::ResponseEnvelope; -use paddler_types::request_params::ChunkEvenlyWithCapError; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use tokio::sync::mpsc; -use tokio::task::JoinSet; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_util::sync::CancellationToken; - -use crate::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; -use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; -use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; -use crate::balancer::inference_service::app_data::AppData; -use crate::balancer::request_from_agent::request_from_agent; -use crate::cancellation_token_stream_guard::CancellationTokenStreamGuard; -use crate::controls_session::ControlsSession as _; - -#[derive(Clone)] -struct EmbeddingChunkBodyTransformer; - -#[async_trait] -impl TransformsOutgoingMessage for EmbeddingChunkBodyTransformer { - async fn transform(&self, message: OutgoingMessage) -> Result> { - if let OutgoingMessage::Response(ResponseEnvelope { - response: OutgoingResponse::Embedding(EmbeddingResult::Done), - .. - }) = &message - { - return Ok(vec![TransformResult::Discard]); - } - - let serialized = serde_json::to_string(&message)?; - - Ok(vec![TransformResult::Chunk(serialized)]) - } -} - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[post("/api/v1/generate_embedding_batch")] -async fn respond( - app_data: web::Data, - params: web::Json, -) -> Result { - let balancer_applicable_state_holder = app_data.balancer_applicable_state_holder.clone(); - let Some(agent_desired_state) = balancer_applicable_state_holder.get_agent_desired_state() - else { - return Err(ErrorServiceUnavailable( - "Balancer applicable state is not yet set", - )); - }; - - if !agent_desired_state.inference_parameters.enable_embeddings { - return Err(ErrorNotImplemented( - "Embedding generation is not enabled in the inference parameters", - )); - } - - let agent_count = app_data.agent_controller_pool.agents.len(); - let embedding_batch_size = agent_desired_state - .inference_parameters - .embedding_batch_size; - - let connection_close = CancellationToken::new(); - let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); - - let mut chunk_tasks: JoinSet<()> = JoinSet::new(); - - let batches = match params - .into_inner() - .chunk_evenly_with_cap(agent_count, embedding_batch_size) - { - Ok(batches) => batches, - Err(ChunkEvenlyWithCapError::ZeroAgentCount) => { - return Err(ErrorServiceUnavailable("No agents are currently connected")); - } - Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) => { - return Err(ErrorInternalServerError( - "embedding_batch_size is zero despite validation", - )); - } - }; - - for batch in batches { - let buffered_request_manager_clone = app_data.buffered_request_manager.clone(); - let chunk_tx_clone = chunk_tx.clone(); - let connection_close_clone = connection_close.clone(); - let inference_service_configuration_clone = - app_data.inference_service_configuration.clone(); - - chunk_tasks.spawn(async move { - let request_id: String = nanoid!(); - let mut session_controller = ChunkForwardingSessionController::new( - chunk_tx_clone, - EmbeddingChunkBodyTransformer, - ); - - if let Err(err) = request_from_agent( - buffered_request_manager_clone, - connection_close_clone, - inference_service_configuration_clone, - batch, - request_id.clone(), - session_controller.clone(), - ) - .await - { - error!("Failed to handle request: {err}"); - session_controller - .send_response_safe(OutgoingMessage::Error(ErrorEnvelope { - request_id: request_id.clone(), - error: JsonRpcError { - code: 500, - description: format!("Request {request_id} failed: {err}"), - }, - })) - .await; - } - }); - } - - let final_done_chunk_tx = chunk_tx.clone(); - - rt::spawn(async move { - while chunk_tasks.join_next().await.is_some() {} - - let final_request_id: String = nanoid!(); - let mut final_session = - ChunkForwardingSessionController::new(final_done_chunk_tx, IdentityTransformer::new()); - - final_session - .send_response_safe(OutgoingMessage::Response(ResponseEnvelope { - generated_by: None, - request_id: final_request_id, - response: OutgoingResponse::Embedding(EmbeddingResult::Done), - })) - .await; - }); - - drop(chunk_tx); - - let stream = - CancellationTokenStreamGuard::new(UnboundedReceiverStream::new(chunk_rx), connection_close) - .filter_map(|transform_result| async move { - match transform_result { - TransformResult::Chunk(content) | TransformResult::Error(content) => { - Some(Ok::<_, Error>(Bytes::from(format!("{content}\n")))) - } - TransformResult::Discard => None, - } - }); - - Ok(HttpResponse::Ok() - .insert_header(header::ContentType::json()) - .insert_header((header::CACHE_CONTROL, "no-cache")) - .streaming(stream)) -} diff --git a/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs b/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs deleted file mode 100644 index 70ae91c9..00000000 --- a/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs +++ /dev/null @@ -1,9 +0,0 @@ -use std::sync::Arc; - -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; - -pub struct InferenceSocketControllerContext { - pub buffered_request_manager: Arc, - pub inference_service_configuration: InferenceServiceConfiguration, -} diff --git a/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/mod.rs b/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/mod.rs deleted file mode 100644 index 3ee10550..00000000 --- a/paddler/src/balancer/inference_service/http_route/api/ws_inference_socket/mod.rs +++ /dev/null @@ -1,144 +0,0 @@ -mod inference_socket_controller_context; - -use std::sync::Arc; - -use actix_web::rt; -use actix_web::Error; -use actix_web::HttpRequest; -use actix_web::HttpResponse; -use actix_web::get; -use actix_web::web::Data; -use actix_web::web::Payload; -use actix_web::web::ServiceConfig; -use anyhow::Result; -use async_trait::async_trait; -use log::error; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::inference_server::Message as InferenceServerMessage; -use paddler_types::inference_server::Request as InferenceServerRequest; -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; -use paddler_types::validates::Validates as _; -use tokio_util::sync::CancellationToken; - -use self::inference_socket_controller_context::InferenceSocketControllerContext; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::inference_service::app_data::AppData; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use crate::balancer::request_from_agent::request_from_agent; -use crate::continuation_decision::ContinuationDecision; -use crate::controls_websocket_endpoint::ControlsWebSocketEndpoint; -use crate::websocket_session_controller::WebSocketSessionController; - -type InferenceJsonRpcMessage = InferenceServerMessage; -type InferenceJsonRpcRequest = InferenceServerRequest; - -struct InferenceSocketController { - buffered_request_manager: Arc, - inference_service_configuration: InferenceServiceConfiguration, -} - -#[async_trait] -impl ControlsWebSocketEndpoint for InferenceSocketController { - type Context = InferenceSocketControllerContext; - type IncomingMessage = InferenceJsonRpcMessage; - type OutgoingMessage = OutgoingMessage; - - fn create_context(&self) -> Self::Context { - InferenceSocketControllerContext { - buffered_request_manager: self.buffered_request_manager.clone(), - inference_service_configuration: self.inference_service_configuration.clone(), - } - } - - async fn handle_deserialized_message( - connection_close: CancellationToken, - context: Arc, - deserialized_message: Self::IncomingMessage, - websocket_session_controller: WebSocketSessionController, - ) -> Result { - match deserialized_message { - InferenceJsonRpcMessage::Error(ErrorEnvelope { - request_id, - error: JsonRpcError { code, description }, - }) => { - error!( - "Received error from client: code: {code}, description: {description:?}, request_id: {request_id:?}" - ); - - return Ok(ContinuationDecision::Continue); - } - InferenceJsonRpcMessage::Request(RequestEnvelope { - id: request_id, - request: - InferenceJsonRpcRequest::ContinueFromConversationHistory( - conversation_history_params, - ), - }) => { - let validated_params = conversation_history_params.validate()?; - - rt::spawn(async move { - if let Err(err) = request_from_agent( - context.buffered_request_manager.clone(), - connection_close, - context.inference_service_configuration.clone(), - validated_params, - request_id.clone(), - websocket_session_controller, - ) - .await - { - error!("Request {request_id:?} failed: {err}"); - } - }); - - Ok(ContinuationDecision::Continue) - } - InferenceJsonRpcMessage::Request(RequestEnvelope { - id: request_id, - request: InferenceJsonRpcRequest::ContinueFromRawPrompt(raw_prompt_params), - }) => { - rt::spawn(async move { - if let Err(err) = request_from_agent( - context.buffered_request_manager.clone(), - connection_close, - context.inference_service_configuration.clone(), - raw_prompt_params, - request_id.clone(), - websocket_session_controller, - ) - .await - { - error!("Request {request_id:?} failed: {err}"); - } - }); - - Ok(ContinuationDecision::Continue) - } - } - } -} - -#[get("/api/v1/inference_socket")] -#[expect( - clippy::future_not_send, - reason = "actix-web handlers run on a single-threaded runtime" -)] -async fn respond( - app_data: Data, - payload: Payload, - http_request: HttpRequest, -) -> Result { - let inference_socket_controller = InferenceSocketController { - buffered_request_manager: app_data.buffered_request_manager.clone(), - inference_service_configuration: app_data.inference_service_configuration.clone(), - }; - - inference_socket_controller.respond(payload, http_request, app_data.shutdown.clone()) -} - -pub fn register(service_config: &mut ServiceConfig) { - service_config.service(respond); -} diff --git a/paddler/src/balancer/inference_service/mod.rs b/paddler/src/balancer/inference_service/mod.rs deleted file mode 100644 index 3deaccf4..00000000 --- a/paddler/src/balancer/inference_service/mod.rs +++ /dev/null @@ -1,84 +0,0 @@ -pub mod app_data; -pub mod configuration; -pub mod http_route; - -use std::sync::Arc; - -use actix_web::App; -use actix_web::HttpServer; -use actix_web::web::Data; -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; -use trzcina::Service; -use trzcina::ServiceShutdownOptions; - -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::http_route as common_http_route; -use crate::balancer::inference_service::app_data::AppData; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -#[cfg(feature = "web_admin_panel")] -use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use crate::create_cors_middleware::create_cors_middleware; - -pub struct InferenceService { - pub agent_controller_pool: Arc, - pub balancer_applicable_state_holder: Arc, - pub buffered_request_manager: Arc, - pub configuration: InferenceServiceConfiguration, - pub shutdown_options: ServiceShutdownOptions, - #[cfg(feature = "web_admin_panel")] - pub web_admin_panel_service_configuration: Option, -} - -#[async_trait] -impl Service for InferenceService { - fn name(&self) -> &'static str { - "balancer::inference_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - #[cfg_attr(not(feature = "web_admin_panel"), expect(unused_mut))] - let mut cors_allowed_hosts = self.configuration.cors_allowed_hosts.clone(); - - #[cfg(feature = "web_admin_panel")] - if let Some(web_admin_panel_config) = &self.web_admin_panel_service_configuration { - cors_allowed_hosts.push(format!("http://{}", web_admin_panel_config.addr)); - } - - let cors_allowed_hosts_arc = Arc::new(cors_allowed_hosts); - - let app_data = Data::new(AppData { - agent_controller_pool: self.agent_controller_pool.clone(), - balancer_applicable_state_holder: self.balancer_applicable_state_holder.clone(), - buffered_request_manager: self.buffered_request_manager.clone(), - inference_service_configuration: self.configuration.clone(), - shutdown: shutdown.clone(), - }); - - #[expect(clippy::expect_used, reason = "server bind failure is unrecoverable")] - HttpServer::new(move || { - App::new() - .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) - .app_data(app_data.clone()) - .configure(common_http_route::get_health::register) - .configure(http_route::api::post_continue_from_conversation_history::register) - .configure(http_route::api::post_continue_from_raw_prompt::register) - .configure(http_route::api::post_generate_embedding_batch::register) - .configure(http_route::api::ws_inference_socket::register) - }) - .shutdown_signal(async move { - shutdown.cancelled().await; - }) - .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) - .disable_signals() - .bind(self.configuration.addr) - .expect("Unable to bind server to address") - .run() - .await?; - - Ok(()) - } -} diff --git a/paddler/src/balancer/management_service/http_route/api/get_agents.rs b/paddler/src/balancer/management_service/http_route/api/get_agents.rs deleted file mode 100644 index 76a29286..00000000 --- a/paddler/src/balancer/management_service/http_route/api/get_agents.rs +++ /dev/null @@ -1,22 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::error::ErrorInternalServerError; -use actix_web::get; -use actix_web::web; - -use crate::balancer::management_service::app_data::AppData; -use crate::produces_snapshot::ProducesSnapshot as _; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/api/v1/agents")] -async fn respond(app_data: web::Data) -> Result { - Ok(HttpResponse::Ok().json( - app_data - .agent_controller_pool - .make_snapshot() - .map_err(ErrorInternalServerError)?, - )) -} diff --git a/paddler/src/balancer/management_service/http_route/api/get_agents_stream.rs b/paddler/src/balancer/management_service/http_route/api/get_agents_stream.rs deleted file mode 100644 index 900e331e..00000000 --- a/paddler/src/balancer/management_service/http_route/api/get_agents_stream.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::convert::Infallible; -use std::time::Duration; - -use actix_web::Error; -use actix_web::Responder; -use actix_web::get; -use actix_web::web; -use actix_web_lab::sse; -use futures::StreamExt as _; -use log::error; - -use crate::balancer::management_service::app_data::AppData; -use crate::snapshots_stream::snapshots_stream; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/api/v1/agents/stream")] -async fn respond(app_data: web::Data) -> Result { - let event_stream = snapshots_stream( - app_data.agent_controller_pool.clone(), - app_data.shutdown.clone(), - ) - .filter_map(|snapshot| async move { - match serde_json::to_string(&snapshot) { - Ok(json) => Some(Ok::<_, Infallible>(sse::Event::Data(sse::Data::new(json)))), - Err(err) => { - error!("Failed to serialize agent controller pool snapshot: {err}"); - None - } - } - }); - - Ok(sse::Sse::from_stream(event_stream).with_keep_alive(Duration::from_secs(10))) -} diff --git a/paddler/src/balancer/management_service/http_route/api/get_balancer_applicable_state.rs b/paddler/src/balancer/management_service/http_route/api/get_balancer_applicable_state.rs deleted file mode 100644 index 8a0fff6b..00000000 --- a/paddler/src/balancer/management_service/http_route/api/get_balancer_applicable_state.rs +++ /dev/null @@ -1,20 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::get; -use actix_web::web; - -use crate::balancer::management_service::app_data::AppData; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/api/v1/balancer_applicable_state")] -async fn respond(app_data: web::Data) -> Result { - let applicable_state = app_data - .balancer_applicable_state_holder - .get_agent_desired_state(); - - Ok(HttpResponse::Ok().json(applicable_state)) -} diff --git a/paddler/src/balancer/management_service/http_route/api/get_balancer_desired_state.rs b/paddler/src/balancer/management_service/http_route/api/get_balancer_desired_state.rs deleted file mode 100644 index 727034fa..00000000 --- a/paddler/src/balancer/management_service/http_route/api/get_balancer_desired_state.rs +++ /dev/null @@ -1,23 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::error::ErrorInternalServerError; -use actix_web::get; -use actix_web::web; - -use crate::balancer::management_service::app_data::AppData; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/api/v1/balancer_desired_state")] -async fn respond(app_data: web::Data) -> Result { - let desired_state = app_data - .state_database - .read_balancer_desired_state() - .await - .map_err(ErrorInternalServerError)?; - - Ok(HttpResponse::Ok().json(desired_state)) -} diff --git a/paddler/src/balancer/management_service/http_route/api/get_buffered_requests.rs b/paddler/src/balancer/management_service/http_route/api/get_buffered_requests.rs deleted file mode 100644 index b30c3e80..00000000 --- a/paddler/src/balancer/management_service/http_route/api/get_buffered_requests.rs +++ /dev/null @@ -1,22 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::error::ErrorInternalServerError; -use actix_web::get; -use actix_web::web; - -use crate::balancer::management_service::app_data::AppData; -use crate::produces_snapshot::ProducesSnapshot as _; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/api/v1/buffered_requests")] -async fn respond(app_data: web::Data) -> Result { - Ok(HttpResponse::Ok().json( - app_data - .buffered_request_manager - .make_snapshot() - .map_err(ErrorInternalServerError)?, - )) -} diff --git a/paddler/src/balancer/management_service/http_route/api/put_balancer_desired_state.rs b/paddler/src/balancer/management_service/http_route/api/put_balancer_desired_state.rs deleted file mode 100644 index 07f20d84..00000000 --- a/paddler/src/balancer/management_service/http_route/api/put_balancer_desired_state.rs +++ /dev/null @@ -1,37 +0,0 @@ -use actix_web::Error; -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::error::ErrorBadRequest; -use actix_web::error::ErrorInternalServerError; -use actix_web::put; -use actix_web::web; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::validates::Validates; - -use crate::balancer::management_service::app_data::AppData; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[put("/api/v1/balancer_desired_state")] -async fn respond( - app_data: web::Data, - balancer_desired_state: web::Json, -) -> Result { - let balancer_desired_state_inner = balancer_desired_state.into_inner(); - - balancer_desired_state_inner - .inference_parameters - .clone() - .validate() - .map_err(ErrorBadRequest)?; - - app_data - .state_database - .store_balancer_desired_state(&balancer_desired_state_inner) - .await - .map_err(ErrorInternalServerError)?; - - Ok(HttpResponse::NoContent().finish()) -} diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs deleted file mode 100644 index 884e3c15..00000000 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::sync::Arc; - -use log::error; -use log::info; - -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; - -pub struct AgentSocketControllerContext { - pub agent_controller_pool: Arc, - pub agent_id: String, - pub balancer_applicable_state_holder: Arc, - pub chat_template_override_sender_collection: Arc, - pub embedding_sender_collection: Arc, - pub generate_tokens_sender_collection: Arc, - pub model_metadata_sender_collection: Arc, -} - -impl Drop for AgentSocketControllerContext { - fn drop(&mut self) { - if let Err(err) = self - .agent_controller_pool - .remove_agent_controller(&self.agent_id) - { - error!("Failed to remove agent: {err}"); - } - - info!("Removed agent: {}", self.agent_id); - } -} diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/message.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/message.rs deleted file mode 100644 index 5bea32ed..00000000 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/message.rs +++ /dev/null @@ -1,19 +0,0 @@ -use paddler_types::jsonrpc::Error; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::ResponseEnvelope; -use paddler_types::rpc_message::RpcMessage; -use serde::Deserialize; -use serde::Serialize; - -use super::Notification; -use crate::agent::jsonrpc::Response; - -#[derive(Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -pub enum Message { - Error(ErrorEnvelope), - Notification(Notification), - Response(ResponseEnvelope), -} - -impl RpcMessage for Message {} diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/mod.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/mod.rs deleted file mode 100644 index c75df1d3..00000000 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod message; -mod notification; -pub mod notification_params; - -pub use self::message::Message; -pub use self::notification::Notification; diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/mod.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/mod.rs deleted file mode 100644 index 9c97c44b..00000000 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod register_agent_params; -mod update_agent_status_params; - -pub use self::register_agent_params::RegisterAgentParams; -pub use self::update_agent_status_params::UpdateAgentStatusParams; diff --git a/paddler/src/balancer/management_service/mod.rs b/paddler/src/balancer/management_service/mod.rs deleted file mode 100644 index 211211af..00000000 --- a/paddler/src/balancer/management_service/mod.rs +++ /dev/null @@ -1,109 +0,0 @@ -pub mod app_data; -pub mod configuration; -pub mod http_route; - -use std::sync::Arc; - -use actix_web::App; -use actix_web::HttpServer; -use actix_web::web::Data; -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; -use trzcina::Service; -use trzcina::ServiceShutdownOptions; - -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use crate::balancer::http_route as common_http_route; -use crate::balancer::management_service::app_data::AppData; -use crate::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use crate::balancer::state_database::StateDatabase; -#[cfg(feature = "web_admin_panel")] -use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use crate::create_cors_middleware::create_cors_middleware; - -pub struct ManagementService { - pub agent_controller_pool: Arc, - pub balancer_applicable_state_holder: Arc, - pub buffered_request_manager: Arc, - pub chat_template_override_sender_collection: Arc, - pub configuration: ManagementServiceConfiguration, - pub embedding_sender_collection: Arc, - pub generate_tokens_sender_collection: Arc, - pub model_metadata_sender_collection: Arc, - pub shutdown_options: ServiceShutdownOptions, - pub state_database: Arc, - pub statsd_prefix: String, - #[cfg(feature = "web_admin_panel")] - pub web_admin_panel_service_configuration: Option, -} - -#[async_trait] -impl Service for ManagementService { - fn name(&self) -> &'static str { - "balancer::management_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - #[cfg_attr(not(feature = "web_admin_panel"), expect(unused_mut))] - let mut cors_allowed_hosts = self.configuration.cors_allowed_hosts.clone(); - - #[cfg(feature = "web_admin_panel")] - if let Some(web_admin_panel_config) = &self.web_admin_panel_service_configuration { - cors_allowed_hosts.push(format!("http://{}", web_admin_panel_config.addr)); - } - - let cors_allowed_hosts_arc = Arc::new(cors_allowed_hosts); - - let app_data = Data::new(AppData { - agent_controller_pool: self.agent_controller_pool.clone(), - balancer_applicable_state_holder: self.balancer_applicable_state_holder.clone(), - buffered_request_manager: self.buffered_request_manager.clone(), - chat_template_override_sender_collection: self - .chat_template_override_sender_collection - .clone(), - embedding_sender_collection: self.embedding_sender_collection.clone(), - generate_tokens_sender_collection: self.generate_tokens_sender_collection.clone(), - model_metadata_sender_collection: self.model_metadata_sender_collection.clone(), - shutdown: shutdown.clone(), - state_database: self.state_database.clone(), - statsd_prefix: self.statsd_prefix.clone(), - }); - - #[expect(clippy::expect_used, reason = "server bind failure is unrecoverable")] - HttpServer::new(move || { - App::new() - .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) - .app_data(app_data.clone()) - .configure(common_http_route::get_health::register) - .configure(http_route::api::get_agents::register) - .configure(http_route::api::get_agents_stream::register) - .configure(http_route::api::get_balancer_applicable_state::register) - .configure(http_route::api::get_balancer_desired_state::register) - .configure(http_route::api::get_buffered_requests::register) - .configure(http_route::api::get_buffered_requests_stream::register) - .configure(http_route::api::get_chat_template_override::register) - .configure(http_route::api::get_model_metadata::register) - .configure(http_route::api::put_balancer_desired_state::register) - .configure(http_route::api::ws_agent_socket::register) - .configure(http_route::get_metrics::register) - }) - .shutdown_signal(async move { - shutdown.cancelled().await; - }) - .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) - .disable_signals() - .bind(self.configuration.addr) - .expect("Unable to bind server to address") - .run() - .await?; - - Ok(()) - } -} diff --git a/paddler/src/balancer/manages_senders_controller.rs b/paddler/src/balancer/manages_senders_controller.rs deleted file mode 100644 index b85211be..00000000 --- a/paddler/src/balancer/manages_senders_controller.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use log::error; -use tokio::sync::mpsc; - -use crate::balancer::manages_senders::ManagesSenders; - -pub struct ManagesSendersController -where - TSenderCollection: ManagesSenders, -{ - pub request_id: String, - pub response_rx: mpsc::UnboundedReceiver, - pub response_sender_collection: Arc, -} - -impl ManagesSendersController -where - TSenderCollection: ManagesSenders, -{ - pub fn from_request_id( - request_id: String, - response_sender_collection: Arc, - ) -> Result { - let (response_tx, response_rx) = mpsc::unbounded_channel(); - - response_sender_collection.register_sender(request_id.clone(), response_tx)?; - - Ok(Self { - request_id, - response_rx, - response_sender_collection, - }) - } -} - -impl Drop for ManagesSendersController -where - TSenderCollection: ManagesSenders, -{ - fn drop(&mut self) { - self.response_sender_collection - .deregister_sender(self.request_id.clone()) - .unwrap_or_else(|err| { - error!( - "Failed to deregister sender for request_id {}: {err}", - self.request_id - ); - }); - } -} diff --git a/paddler/src/balancer/reconciliation_service.rs b/paddler/src/balancer/reconciliation_service.rs deleted file mode 100644 index 21aec788..00000000 --- a/paddler/src/balancer/reconciliation_service.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; -use log::error; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use tokio::sync::broadcast; -use tokio::time::Duration; -use tokio::time::MissedTickBehavior; -use tokio::time::interval; -use tokio_util::sync::CancellationToken; -use trzcina::Service; - -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use crate::converts_to_applicable_state::ConvertsToApplicableState as _; -use crate::sets_desired_state::SetsDesiredState as _; - -async fn convert_to_applicable_state( - balancer_desired_state: &BalancerDesiredState, - agent_controller_pool: &AgentControllerPool, - balancer_applicable_state_holder: &BalancerApplicableStateHolder, - is_converted_to_applicable_state: &mut bool, -) -> Result<()> { - let balancer_applicable_state = balancer_desired_state.to_applicable_state(()).await?; - - agent_controller_pool - .set_desired_state(balancer_applicable_state.agent_desired_state.clone()) - .await?; - balancer_applicable_state_holder - .set_balancer_applicable_state(Some(balancer_applicable_state)); - - *is_converted_to_applicable_state = true; - - Ok(()) -} - -async fn try_convert_to_applicable_state( - balancer_desired_state: &BalancerDesiredState, - agent_controller_pool: &AgentControllerPool, - balancer_applicable_state_holder: &BalancerApplicableStateHolder, - is_converted_to_applicable_state: &mut bool, -) { - if let Err(err) = convert_to_applicable_state( - balancer_desired_state, - agent_controller_pool, - balancer_applicable_state_holder, - is_converted_to_applicable_state, - ) - .await - { - error!("Failed to convert to applicable state: {err}"); - } -} - -pub struct ReconciliationService { - pub agent_controller_pool: Arc, - pub balancer_applicable_state_holder: Arc, - pub balancer_desired_state: BalancerDesiredState, - pub balancer_desired_state_rx: broadcast::Receiver, - pub is_converted_to_applicable_state: bool, -} - -#[async_trait] -impl Service for ReconciliationService { - fn name(&self) -> &'static str { - "balancer::reconciliation_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - let Self { - agent_controller_pool, - balancer_applicable_state_holder, - mut balancer_desired_state, - mut balancer_desired_state_rx, - mut is_converted_to_applicable_state, - } = *self; - - let mut ticker = interval(Duration::from_secs(1)); - - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - () = shutdown.cancelled() => break Ok(()), - _ = ticker.tick() => { - if !is_converted_to_applicable_state { - try_convert_to_applicable_state( - &balancer_desired_state, - &agent_controller_pool, - &balancer_applicable_state_holder, - &mut is_converted_to_applicable_state, - ).await; - } - }, - received_balancer_desired_state = balancer_desired_state_rx.recv() => { - is_converted_to_applicable_state = false; - balancer_desired_state = received_balancer_desired_state?; - try_convert_to_applicable_state( - &balancer_desired_state, - &agent_controller_pool, - &balancer_applicable_state_holder, - &mut is_converted_to_applicable_state, - ).await; - } - } - } - } -} diff --git a/paddler/src/balancer/request_from_agent.rs b/paddler/src/balancer/request_from_agent.rs deleted file mode 100644 index d66bad35..00000000 --- a/paddler/src/balancer/request_from_agent.rs +++ /dev/null @@ -1,310 +0,0 @@ -use std::fmt::Debug; -use std::sync::Arc; - -use anyhow::Result; -use log::debug; -use log::error; -use log::warn; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::inference_client::Response as OutgoingResponse; -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::ResponseEnvelope; -use paddler_types::streamable_result::StreamableResult; -use tokio::time::sleep; -use tokio_util::sync::CancellationToken; - -use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::dispatched_agent::DispatchedAgent; -use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::manages_senders_controller::ManagesSendersController; -use crate::controls_session::ControlsSession; - -pub async fn request_from_agent( - buffered_request_manager: Arc, - connection_close: CancellationToken, - inference_service_configuration: InferenceServiceConfiguration, - params: TParams, - request_id: String, - mut session_controller: TControlsSession, -) -> Result<()> -where - TControlsSession: ControlsSession, - TParams: Debug + Into + Send, - AgentController: HandlesAgentStreamingResponse, - <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, -{ - match wait_for_agent_controller( - buffered_request_manager.clone(), - connection_close.clone(), - request_id.clone(), - &mut session_controller, - ) - .await? - { - Some(dispatched_agent) => { - let receive_response_controller = match dispatched_agent - .agent_controller - .handle_streaming_response(request_id.clone(), params) - .await - { - Ok(receive_response_controller) => receive_response_controller, - Err(err) => { - error!("Failed to handle request {request_id:?}: {err}"); - - respond_with_error( - JsonRpcError { - code: 500, - description: "Failed to generate response".to_owned(), - }, - request_id.clone(), - &mut session_controller, - ) - .await; - - return Ok(()); - } - }; - - forward_responses_stream( - dispatched_agent.agent_controller.clone(), - connection_close, - inference_service_configuration, - receive_response_controller, - request_id, - session_controller, - ) - .await?; - - Ok(()) - } - None => Ok(()), - } -} - -pub async fn forward_responses_stream( - agent_controller: Arc, - connection_close: CancellationToken, - inference_service_configuration: InferenceServiceConfiguration, - mut receive_response_controller: ManagesSendersController, - request_id: String, - mut session_controller: TControlsSession, -) -> Result<()> -where - TControlsSession: ControlsSession, - TManagesSenders: ManagesSenders + Send + Sync, - TManagesSenders::Value: Debug + Into + Send + StreamableResult, -{ - debug!("Found available agent controller for request: {request_id:?}"); - - let agent_connection_close = agent_controller.connection_close.clone(); - - loop { - tokio::select! { - () = agent_connection_close.cancelled() => { - error!("Agent controller connection closed"); - - respond_with_error( - JsonRpcError { - code: 502, - description: "Agent controller connection closed".to_owned(), - }, - request_id, - &mut session_controller, - ).await; - - break; - } - () = connection_close.cancelled() => { - agent_controller.stop_responding_to(request_id.clone()).await.unwrap_or_else(|err| { - error!("Failed to stop request {request_id:?}: {err}"); - }); - - break; - } - () = sleep(inference_service_configuration.inference_item_timeout) => { - let timeout_ms = inference_service_configuration.inference_item_timeout.as_millis(); - - warn!( - "Timed out after {timeout_ms}ms waiting for next token for request {request_id:?}. \ - Consider increasing --inference-item-timeout if the model needs more time to process the prompt." - ); - - respond_with_error( - JsonRpcError { - code: 504, - description: format!( - "Inference timed out after {timeout_ms}ms waiting for next token. \ - Increase --inference-item-timeout if the prompt requires longer processing." - ), - }, - request_id.clone(), - &mut session_controller, - ).await; - - agent_controller.stop_responding_to(request_id.clone()).await.unwrap_or_else(|err| { - error!("Failed to stop responding to request {request_id:?}: {err}"); - }); - - break; - } - response = receive_response_controller.response_rx.recv() => { - if let Some(response) = response { - let is_done = response.is_done(); - - let send_succeeded = send_response_to_client( - agent_controller.clone(), - response, - request_id.clone(), - &mut session_controller, - ).await; - - if is_done || !send_succeeded { - break; - } - } else { - error!( - "Response channel closed before terminator for request {request_id:?}" - ); - - respond_with_error( - JsonRpcError { - code: 502, - description: - "Response channel closed before terminator".to_owned(), - }, - request_id, - &mut session_controller, - ).await; - - break; - } - } - } - } - - Ok(()) -} - -pub async fn respond_with_error( - error: JsonRpcError, - request_id: String, - session_controller: &mut TControlsSession, -) where - TControlsSession: ControlsSession, -{ - session_controller - .send_response(OutgoingMessage::Error(ErrorEnvelope { - request_id: request_id.clone(), - error, - })) - .await - .unwrap_or_else(|err| { - error!("Failed to send response for request {request_id:?}: {err}"); - }); -} - -async fn send_response_to_client( - agent_controller: Arc, - response: TResponse, - request_id: String, - session_controller: &mut TControlsSession, -) -> bool -where - TControlsSession: ControlsSession, - TResponse: Into + Send, -{ - if let Err(err) = session_controller - .send_response(OutgoingMessage::Response(ResponseEnvelope { - generated_by: agent_controller.name.clone(), - request_id: request_id.clone(), - response: response.into(), - })) - .await - { - error!("Failed to send response for request {request_id:?}: {err}"); - - agent_controller - .stop_responding_to(request_id.clone()) - .await - .unwrap_or_else(|err| { - error!("Failed to stop responding to request {request_id:?}: {err}"); - }); - - return false; - } - - true -} - -async fn wait_for_agent_controller( - buffered_request_manager: Arc, - connection_close: CancellationToken, - request_id: String, - session_controller: &mut TControlsSession, -) -> Result> -where - TControlsSession: ControlsSession, -{ - let buffered_request_manager = buffered_request_manager.clone(); - - tokio::select! { - () = connection_close.cancelled() => { - debug!("Connection close signal received, stopping GenerateTokens loop."); - - Ok(None) - }, - buffered_request_agent_wait_result = buffered_request_manager.wait_for_available_agent() => { - match buffered_request_agent_wait_result { - Ok(BufferedRequestAgentWaitResult::Found(dispatched_agent)) => Ok(Some(dispatched_agent)), - Ok(BufferedRequestAgentWaitResult::BufferOverflow) => { - warn!("Too many buffered requests, dropping request: {request_id:?}"); - - respond_with_error( - JsonRpcError { - code: 503, - description: "Buffered requests overflow".to_owned(), - }, - request_id.clone(), - session_controller, - ).await; - - Ok(None) - } - Ok(BufferedRequestAgentWaitResult::Timeout(err)) => { - warn!("Buffered request {request_id:?} timed out: {err:?}"); - - respond_with_error( - JsonRpcError { - code: 504, - description: "Waiting for available slot timed out".to_owned(), - }, - request_id.clone(), - session_controller, - ).await; - - Ok(None) - } - Err(err) => { - error!("Error while waiting for available agent controller for GenerateTokens request: {err}"); - - respond_with_error( - JsonRpcError { - code: 500, - description: "Internal server error".to_owned(), - }, - request_id.clone(), - session_controller, - ).await; - - Ok(None) - } - } - } - } -} diff --git a/paddler/src/balancer/response/view.rs b/paddler/src/balancer/response/view.rs deleted file mode 100644 index 4b327cdc..00000000 --- a/paddler/src/balancer/response/view.rs +++ /dev/null @@ -1,9 +0,0 @@ -use actix_web::HttpResponse; -use actix_web::Result; -use askama::Template; - -use super::view_from_http_response_builder::view_from_http_response_builder; - -pub fn view(template: TTemplate) -> Result { - view_from_http_response_builder(HttpResponse::Ok(), template) -} diff --git a/paddler/src/balancer/response/view_from_http_response_builder.rs b/paddler/src/balancer/response/view_from_http_response_builder.rs deleted file mode 100644 index 45d9a53d..00000000 --- a/paddler/src/balancer/response/view_from_http_response_builder.rs +++ /dev/null @@ -1,16 +0,0 @@ -use actix_web::HttpResponse; -use actix_web::HttpResponseBuilder; -use actix_web::Result; -use actix_web::error::ErrorInternalServerError; -use askama::Template; - -pub fn view_from_http_response_builder( - mut http_response_builder: HttpResponseBuilder, - template: TTemplate, -) -> Result { - let rendered = template.render().map_err(ErrorInternalServerError)?; - - Ok(http_response_builder - .content_type("text/html; charset=utf-8") - .body(rendered)) -} diff --git a/paddler/src/balancer/state_database/file/mod.rs b/paddler/src/balancer/state_database/file/mod.rs deleted file mode 100644 index b4272b81..00000000 --- a/paddler/src/balancer/state_database/file/mod.rs +++ /dev/null @@ -1,120 +0,0 @@ -mod schema; - -use std::path::PathBuf; - -use anyhow::Context; -use anyhow::Result; -use async_trait::async_trait; -use log::warn; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use tokio::fs; -use tokio::io::AsyncWriteExt; -use tokio::sync::RwLock; -use tokio::sync::broadcast; - -use self::schema::Schema; -use super::StateDatabase; - -pub struct File { - balancer_desired_state_notify_tx: broadcast::Sender, - path: PathBuf, - write_lock: RwLock<()>, -} - -impl File { - #[must_use] - pub fn new( - balancer_desired_state_notify_tx: broadcast::Sender, - path: PathBuf, - ) -> Self { - Self { - balancer_desired_state_notify_tx, - path, - write_lock: RwLock::new(()), - } - } - - async fn read_schema_from_file(&self) -> Result { - match fs::read_to_string(&self.path).await { - Ok(content) => { - if content.is_empty() { - return self.store_default_schema().await; - } - - let schema: Schema = serde_json::from_str(&content).context(format!("Unable to parse database file contents: '{}'. Either that is not a valid database file, or this version of Paddler is incompatible with it.", self.path.display()))?; - - Ok(schema) - } - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - warn!( - "State database file not found; trying to store the default state: '{}'", - self.path.display() - ); - - self.store_default_schema().await - } - Err(err) => Err(err.into()), - } - } - - async fn store_default_schema(&self) -> Result { - let schema = Schema::default(); - - self.store_schema(&schema) - .await - .context("Failed to store default state")?; - - Ok(schema) - } - - async fn store_schema(&self, schema: &Schema) -> Result<()> { - let balancer_desired_state = schema.balancer_desired_state.clone(); - let _lock = self.write_lock.write().await; - - let serialized_schema = serde_json::to_string_pretty(schema)?; - let mut file = fs::File::create(&self.path).await?; - - file.write_all(serialized_schema.as_bytes()).await?; - file.sync_all().await?; - - self.balancer_desired_state_notify_tx - .send(balancer_desired_state)?; - - Ok(()) - } - - async fn update_schema(&self, modifier: TModifier) -> Result<()> - where - TModifier: FnOnce(&mut Schema), - { - let mut schema = self - .read_schema_from_file() - .await - .context("Unable to read current state from file")?; - - modifier(&mut schema); - - self.store_schema(&schema).await - } -} - -#[async_trait] -impl StateDatabase for File { - async fn read_balancer_desired_state(&self) -> Result { - Ok(self - .read_schema_from_file() - .await - .context("Unable to read state from file")? - .balancer_desired_state) - } - - async fn store_balancer_desired_state( - &self, - balancer_desired_state: &BalancerDesiredState, - ) -> Result<()> { - self.update_schema(|schema| { - schema.balancer_desired_state = balancer_desired_state.clone(); - }) - .await - } -} diff --git a/paddler/src/balancer/statsd_service/mod.rs b/paddler/src/balancer/statsd_service/mod.rs deleted file mode 100644 index f0e3715d..00000000 --- a/paddler/src/balancer/statsd_service/mod.rs +++ /dev/null @@ -1,75 +0,0 @@ -pub mod configuration; - -use std::net::UdpSocket; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; -use cadence::Gauged; -use cadence::StatsdClient; -use cadence::UdpMetricSink; -use log::error; -use tokio::time::MissedTickBehavior; -use tokio::time::interval; -use tokio_util::sync::CancellationToken; -use trzcina::Service; - -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; - -pub struct StatsdService { - pub agent_controller_pool: Arc, - pub buffered_request_manager: Arc, - pub configuration: StatsdServiceConfiguration, -} - -impl StatsdService { - #[expect(clippy::cast_sign_loss, reason = "slot counts are always non-negative")] - fn report_metrics(&self, client: &StatsdClient) -> Result<()> { - let AgentControllerPoolTotalSlots { - slots_processing, - slots_total, - } = self.agent_controller_pool.total_slots(); - let requests_buffered = self.buffered_request_manager.buffered_request_counter.get(); - - client.gauge("slots_processing", slots_processing as u64)?; - client.gauge("slots_total", slots_total as u64)?; - client.gauge("requests_buffered", requests_buffered as u64)?; - client.flush()?; - - Ok(()) - } -} - -#[async_trait] -impl Service for StatsdService { - fn name(&self) -> &'static str { - "balancer::statsd_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - let statsd_sink_socket = UdpSocket::bind("0.0.0.0:0")?; - let statsd_sink = UdpMetricSink::from(self.configuration.statsd_addr, statsd_sink_socket)?; - - let client = StatsdClient::builder(&self.configuration.statsd_prefix.clone(), statsd_sink) - .with_error_handler(|err| error!("Statsd error: {err}")) - .build(); - - let mut ticker = interval(self.configuration.statsd_reporting_interval); - - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - () = shutdown.cancelled() => break Ok(()), - _ = ticker.tick() => { - if let Err(err) = self.report_metrics(&client) { - error!("Failed to report metrics: {err}"); - } - } - } - } - } -} diff --git a/paddler/src/balancer/unbounded_stream_from_agent.rs b/paddler/src/balancer/unbounded_stream_from_agent.rs deleted file mode 100644 index 89497c0c..00000000 --- a/paddler/src/balancer/unbounded_stream_from_agent.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::fmt::Debug; -use std::sync::Arc; - -use actix_web::rt; -use futures_util::Stream; -use log::error; -use nanoid::nanoid; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::inference_client::Response as OutgoingResponse; -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::streamable_result::StreamableResult; -use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_util::sync::CancellationToken; - -use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; -use crate::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use crate::balancer::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; -use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse; -use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::request_from_agent::request_from_agent; -use crate::cancellation_token_stream_guard::CancellationTokenStreamGuard; -use crate::controls_session::ControlsSession as _; - -pub fn unbounded_stream_from_agent( - buffered_request_manager: Arc, - inference_service_configuration: InferenceServiceConfiguration, - params: TParams, - transformer: TTransformsOutgoingMessage, -) -> impl Stream -where - TParams: Debug + Into + Send + 'static, - AgentController: HandlesAgentStreamingResponse, - <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, - TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, -{ - let request_id: String = nanoid!(); - let connection_close = CancellationToken::new(); - let (chunk_tx, chunk_rx) = mpsc::unbounded_channel::(); - - rt::spawn({ - let connection_close = connection_close.clone(); - - async move { - let mut session_controller = - ChunkForwardingSessionController::new(chunk_tx, transformer); - - if let Err(err) = request_from_agent( - buffered_request_manager.clone(), - connection_close, - inference_service_configuration.clone(), - params, - request_id.clone(), - session_controller.clone(), - ) - .await - { - error!("Failed to handle request: {err}"); - - session_controller - .send_response_safe(OutgoingMessage::Error(ErrorEnvelope { - request_id: request_id.clone(), - error: JsonRpcError { - code: 500, - description: format!("Request {request_id} failed: {err}"), - }, - })) - .await; - } - } - }); - - CancellationTokenStreamGuard::new(UnboundedReceiverStream::new(chunk_rx), connection_close) -} diff --git a/paddler/src/balancer/web_admin_panel_service/app_data.rs b/paddler/src/balancer/web_admin_panel_service/app_data.rs deleted file mode 100644 index 29522a84..00000000 --- a/paddler/src/balancer/web_admin_panel_service/app_data.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::balancer::web_admin_panel_service::template_data::TemplateData; - -pub struct AppData { - pub template_data: TemplateData, -} diff --git a/paddler/src/balancer/web_admin_panel_service/http_route/favicon.rs b/paddler/src/balancer/web_admin_panel_service/http_route/favicon.rs deleted file mode 100644 index 74a6746a..00000000 --- a/paddler/src/balancer/web_admin_panel_service/http_route/favicon.rs +++ /dev/null @@ -1,17 +0,0 @@ -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::get; -use actix_web::web; - -const FAVICON: &[u8] = include_bytes!("../../../../../resources/images/favicon.svg"); - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/favicon.ico")] -async fn respond() -> impl Responder { - HttpResponse::Ok() - .content_type("image/svg+xml") - .body(FAVICON) -} diff --git a/paddler/src/balancer/web_admin_panel_service/http_route/home.rs b/paddler/src/balancer/web_admin_panel_service/http_route/home.rs deleted file mode 100644 index 4717cc62..00000000 --- a/paddler/src/balancer/web_admin_panel_service/http_route/home.rs +++ /dev/null @@ -1,54 +0,0 @@ -use actix_web::Responder; -use actix_web::get; -use actix_web::web; -use askama::Template; -use esbuild_metafile::HttpPreloader; -use esbuild_metafile::filters; - -use crate::balancer::response::view; -use crate::balancer::web_admin_panel_service::app_data::AppData; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[derive(Template)] -#[template(path = "web_admin_panel.html")] -struct WebAdminPanelTemplate { - buffered_request_timeout_millis: u128, - compat_openai_addr: String, - inference_addr: String, - management_addr: String, - max_buffered_requests: i32, - preloads: HttpPreloader, - statsd_addr: String, - statsd_prefix: String, - statsd_reporting_interval_millis: u128, -} - -#[get("/{_:.*}")] -async fn respond(preloads: HttpPreloader, app_data: web::Data) -> impl Responder { - view(WebAdminPanelTemplate { - buffered_request_timeout_millis: app_data - .template_data - .buffered_request_timeout - .as_millis(), - compat_openai_addr: match app_data.template_data.compat_openai_addr.clone() { - Some(addr) => addr.input_addr, - None => String::new(), - }, - inference_addr: app_data.template_data.inference_addr.input_addr.clone(), - management_addr: app_data.template_data.management_addr.input_addr.clone(), - max_buffered_requests: app_data.template_data.max_buffered_requests, - preloads, - statsd_addr: match app_data.template_data.statsd_addr.clone() { - Some(addr) => addr.input_addr, - None => String::new(), - }, - statsd_prefix: app_data.template_data.statsd_prefix.clone(), - statsd_reporting_interval_millis: app_data - .template_data - .statsd_reporting_interval - .as_millis(), - }) -} diff --git a/paddler/src/balancer/web_admin_panel_service/http_route/static_files.rs b/paddler/src/balancer/web_admin_panel_service/http_route/static_files.rs deleted file mode 100644 index 34ac4d65..00000000 --- a/paddler/src/balancer/web_admin_panel_service/http_route/static_files.rs +++ /dev/null @@ -1,23 +0,0 @@ -use actix_web::HttpResponse; -use actix_web::Responder; -use actix_web::get; -use actix_web::web; -use mime_guess::from_path; - -use crate::static_files::StaticFiles; - -pub fn register(cfg: &mut web::ServiceConfig) { - cfg.service(respond); -} - -#[get("/static/{path:.*}")] -async fn respond(path: web::Path) -> impl Responder { - let path = path.into_inner(); - - match StaticFiles::get(path.as_str()) { - Some(content) => HttpResponse::Ok() - .content_type(from_path(path).first_or_octet_stream().as_ref()) - .body(content.data.into_owned()), - None => HttpResponse::NotFound().body("File not found"), - } -} diff --git a/paddler/src/balancer/web_admin_panel_service/mod.rs b/paddler/src/balancer/web_admin_panel_service/mod.rs deleted file mode 100644 index 0701e072..00000000 --- a/paddler/src/balancer/web_admin_panel_service/mod.rs +++ /dev/null @@ -1,54 +0,0 @@ -pub mod app_data; -pub mod configuration; -pub mod http_route; -pub mod template_data; - -use actix_web::App; -use actix_web::HttpServer; -use actix_web::web::Data; -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; -use trzcina::Service; -use trzcina::ServiceShutdownOptions; - -use crate::balancer::web_admin_panel_service::app_data::AppData; -use crate::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; - -pub struct WebAdminPanelService { - pub configuration: WebAdminPanelServiceConfiguration, - pub shutdown_options: ServiceShutdownOptions, -} - -#[async_trait] -impl Service for WebAdminPanelService { - fn name(&self) -> &'static str { - "balancer::web_admin_panel_service" - } - - async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { - let app_data: Data = Data::new(AppData { - template_data: self.configuration.template_data.clone(), - }); - - #[expect(clippy::expect_used, reason = "server bind failure is unrecoverable")] - HttpServer::new(move || { - App::new() - .app_data(app_data.clone()) - .configure(http_route::favicon::register) - .configure(http_route::static_files::register) - .configure(http_route::home::register) - }) - .shutdown_signal(async move { - shutdown.cancelled().await; - }) - .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) - .disable_signals() - .bind(self.configuration.addr) - .expect("Unable to bind server to address") - .run() - .await?; - - Ok(()) - } -} diff --git a/paddler/src/balancer_desired_state.rs b/paddler/src/balancer_desired_state.rs deleted file mode 100644 index d714027e..00000000 --- a/paddler/src/balancer_desired_state.rs +++ /dev/null @@ -1,18 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::balancer_applicable_state::BalancerApplicableState; -use crate::converts_to_applicable_state::ConvertsToApplicableState; - -#[async_trait] -impl ConvertsToApplicableState for BalancerDesiredState { - type ApplicableState = BalancerApplicableState; - type Context = (); - - async fn to_applicable_state(&self, _context: Self::Context) -> Result { - Ok(BalancerApplicableState { - agent_desired_state: self.to_agent_desired_state(), - }) - } -} diff --git a/paddler/src/chat_template_renderer/raise_exception.rs b/paddler/src/chat_template_renderer/raise_exception.rs deleted file mode 100644 index 30a4efe4..00000000 --- a/paddler/src/chat_template_renderer/raise_exception.rs +++ /dev/null @@ -1,46 +0,0 @@ -use minijinja::Error; -use minijinja::ErrorKind; - -// Surfaces errors raised explicitly inside a chat template. Known uses: -// https://huggingface.co/bartowski/Mistral-7B-Instruct-v0.3-GGUF -pub fn raise_exception(message: &str) -> Result { - Err(Error::new::( - ErrorKind::InvalidOperation, - format!("Model's chat template raised an exception: '{message}'"), - )) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use anyhow::anyhow; - - use super::raise_exception; - - #[test] - fn returns_err_with_supplied_message_quoted() -> Result<()> { - let err = raise_exception("template is invalid") - .err() - .ok_or_else(|| anyhow!("expected Err, got Ok"))?; - let rendered = err.to_string(); - - if !rendered.contains("template is invalid") { - return Err(anyhow!( - "error must include the supplied message; got: {rendered}" - )); - } - - Ok(()) - } - - #[test] - fn returns_err_with_invalid_operation_kind() -> Result<()> { - let err = raise_exception("anything") - .err() - .ok_or_else(|| anyhow!("expected Err, got Ok"))?; - - assert_eq!(err.kind(), minijinja::ErrorKind::InvalidOperation); - - Ok(()) - } -} diff --git a/paddler/src/controls_websocket_endpoint.rs b/paddler/src/controls_websocket_endpoint.rs deleted file mode 100644 index 8185053d..00000000 --- a/paddler/src/controls_websocket_endpoint.rs +++ /dev/null @@ -1,312 +0,0 @@ -use std::sync::Arc; - -use actix_web::Error; -use actix_web::HttpRequest; -use actix_web::HttpResponse; -use actix_web::rt; -use actix_web::web::Payload; -use actix_ws::AggregatedMessage; -use actix_ws::CloseCode; -use actix_ws::CloseReason; -use actix_ws::ProtocolError; -use actix_ws::Session; -use anyhow::Context as _; -use anyhow::Result; -use async_trait::async_trait; -use futures_util::StreamExt as _; -use log::debug; -use log::error; -use log::warn; -use paddler_types::rpc_message::RpcMessage; -use serde::de::DeserializeOwned; -use tokio::time::Duration; -use tokio::time::MissedTickBehavior; -use tokio::time::interval; -use tokio_util::sync::CancellationToken; - -use crate::continuation_decision::ContinuationDecision; -use crate::continuation_stop_parameters::ContinuationStopParameters; -use crate::websocket_session_controller::WebSocketSessionController; - -const MAX_FRAME_SIZE: usize = 50 * 1024 * 1024; -const MAX_CONTINUATION_SIZE: usize = 50 * 1024 * 1024; -const PING_INTERVAL: Duration = Duration::from_secs(3); - -#[async_trait] -pub trait ControlsWebSocketEndpoint: Send + Sync + 'static { - type Context: Send + Sync + 'static; - type IncomingMessage: DeserializeOwned + RpcMessage + Sync + 'static; - type OutgoingMessage: RpcMessage + Sync + 'static; - - fn create_context(&self) -> Self::Context; - - async fn handle_deserialized_message( - connection_close: CancellationToken, - context: Arc, - deserialized_message: Self::IncomingMessage, - websocket_session_controller: WebSocketSessionController, - ) -> Result; - - async fn handle_aggregated_message( - connection_close: CancellationToken, - context: Arc, - msg: Option>, - session: &mut Session, - ) -> Result { - match msg { - Some(Ok(AggregatedMessage::Binary(_))) => { - debug!("Received binary message, but only text messages are supported"); - - Ok(ContinuationDecision::Continue) - } - Some(Ok(AggregatedMessage::Close(_))) | None => { - return Ok(ContinuationDecision::Stop(ContinuationStopParameters { - close_reason: None, - })); - } - Some(Ok(AggregatedMessage::Ping(msg))) => { - if session.pong(&msg).await.is_err() { - return Ok(ContinuationDecision::Stop(ContinuationStopParameters { - close_reason: None, - })); - } - - Ok(ContinuationDecision::Continue) - } - Some(Ok(AggregatedMessage::Pong(_))) => { - // ignore pong messages - Ok(ContinuationDecision::Continue) - } - Some(Ok(AggregatedMessage::Text(text))) => { - match Self::handle_text_message( - connection_close, - context.clone(), - &text, - WebSocketSessionController::::new(session.clone()), - ) - .await - .context(format!("Text message: {text}")) - { - Ok(continuation_decision) => return Ok(continuation_decision), - Err(err) => { - error!("Error handling text message: {err:?}"); - - Ok(ContinuationDecision::Continue) - } - } - } - Some(Err(ProtocolError::Overflow)) => { - error!("Message exceeded the maximum allowed frame size of {MAX_FRAME_SIZE} bytes"); - - return Ok(ContinuationDecision::Stop(ContinuationStopParameters { - close_reason: Some(CloseReason { - code: CloseCode::Size, - description: Some(format!( - "Message exceeded the maximum allowed frame size of {MAX_FRAME_SIZE} bytes" - )), - }), - })); - } - Some(Err(ProtocolError::Io(ref io_err))) - if io_err - .to_string() - .contains("Exceeded maximum continuation size") => - { - error!( - "Message exceeded the maximum allowed continuation size of {MAX_CONTINUATION_SIZE} bytes" - ); - - return Ok(ContinuationDecision::Stop(ContinuationStopParameters { - close_reason: Some(CloseReason { - code: CloseCode::Size, - description: Some(format!( - "Message exceeded the maximum allowed continuation size of {MAX_CONTINUATION_SIZE} bytes" - )), - }), - })); - } - Some(Err(err)) => { - error!("Error receiving message: {err:?}"); - - return Ok(ContinuationDecision::Stop(ContinuationStopParameters { - close_reason: None, - })); - } - } - } - - async fn handle_serialization_error( - _connection_close: CancellationToken, - _context: Arc, - error: serde_json::Error, - _websocket_session_controller: WebSocketSessionController, - ) -> Result { - error!("Paddler-RPC serialization error: {error}"); - - Ok(ContinuationDecision::Continue) - } - - async fn handle_text_message( - connection_close: CancellationToken, - context: Arc, - text: &str, - websocket_session_controller: WebSocketSessionController, - ) -> Result { - match serde_json::from_str::(text) { - Ok(deserialized_message) => { - rt::spawn(async move { - match Self::handle_deserialized_message( - connection_close.clone(), - context, - deserialized_message, - websocket_session_controller, - ) - .await - { - Ok(ContinuationDecision::Continue) => { - // Continue processing messages - } - Ok(ContinuationDecision::Stop(_)) => connection_close.cancel(), - Err(err) => { - error!("Error handling deserialized message: {err:?}"); - - connection_close.cancel(); - } - } - }); - - Ok(ContinuationDecision::Continue) - } - Err(err @ serde_json::Error { .. }) if err.is_data() || err.is_syntax() => { - error!("JSON-RPC syntax error: {err:?}"); - - Self::handle_serialization_error( - connection_close, - context, - err, - websocket_session_controller, - ) - .await - } - Err(err) => { - error!("Error handling JSON-RPC request: {err:?}"); - - Self::handle_serialization_error( - connection_close, - context, - err, - websocket_session_controller, - ) - .await - } - } - } - - async fn on_connection_start( - _context: Arc, - _session: &mut Session, - ) -> Result { - Ok(ContinuationDecision::Continue) - } - - fn respond( - &self, - payload: Payload, - req: HttpRequest, - shutdown: CancellationToken, - ) -> Result { - let connection_close = CancellationToken::new(); - let context = Arc::new(self.create_context()); - let (res, mut session, msg_stream) = actix_ws::handle(&req, payload)?; - - let mut aggregated_msg_stream = msg_stream - .max_frame_size(MAX_FRAME_SIZE) - .aggregate_continuations() - .max_continuation_size(MAX_CONTINUATION_SIZE); - - rt::spawn(async move { - let mut close_reason: Option = None; - - match Self::on_connection_start(context.clone(), &mut session).await { - Ok(ContinuationDecision::Continue) => {} - Ok(ContinuationDecision::Stop(stop_parameters)) => { - close_reason = stop_parameters.close_reason; - - if let Err(close_err) = session.close(close_reason).await { - warn!( - "WebSocket session close failed after Stop decision (peer likely already disconnected): {close_err:?}" - ); - } - - return; - } - Err(err) => { - error!("Error in connection start handler: {err:?}"); - - if let Err(close_err) = session.close(close_reason).await { - warn!( - "WebSocket session close failed after start-handler error (peer likely already disconnected): {close_err:?}" - ); - } - - return; - } - } - let mut ping_ticker = interval(PING_INTERVAL); - - ping_ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - msg = aggregated_msg_stream.next() => { - match Self::handle_aggregated_message( - connection_close.clone(), - context.clone(), - msg, - &mut session, - ).await { - Ok(ContinuationDecision::Continue) => { - // continue processing messages - } - Ok(ContinuationDecision::Stop(stop_parameters)) => { - close_reason = stop_parameters.close_reason; - - break; - } - Err(err) => { - error!("Error handling aggregated message: {err:?}"); - - break; - }, - } - } - _ = ping_ticker.tick() => { - if session.ping(b"").await.is_err() { - break; - } - } - () = connection_close.cancelled() => { - break; - } - () = shutdown.cancelled() => { - close_reason = Some(CloseReason { - code: CloseCode::Away, - description: Some("Server shutting down".to_owned()), - }); - break; - } - } - } - - connection_close.cancel(); - - if let Err(close_err) = session.close(close_reason).await { - warn!( - "WebSocket session close failed at end of message loop (peer likely already disconnected): {close_err:?}" - ); - } - }); - - Ok(res) - } -} diff --git a/paddler/src/converts_to_applicable_state.rs b/paddler/src/converts_to_applicable_state.rs deleted file mode 100644 index 4e8a065d..00000000 --- a/paddler/src/converts_to_applicable_state.rs +++ /dev/null @@ -1,10 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; - -#[async_trait] -pub trait ConvertsToApplicableState { - type ApplicableState; - type Context; - - async fn to_applicable_state(&self, context: Self::Context) -> Result; -} diff --git a/paddler/src/converts_to_llama_kv_cache_dtype.rs b/paddler/src/converts_to_llama_kv_cache_dtype.rs deleted file mode 100644 index 8f42d8b5..00000000 --- a/paddler/src/converts_to_llama_kv_cache_dtype.rs +++ /dev/null @@ -1,22 +0,0 @@ -use llama_cpp_bindings::context::params::KvCacheType as LlamaKvCacheDtype; -use paddler_types::kv_cache_dtype::KvCacheDtype; - -pub trait ConvertsToLlamaKvCacheDtype { - fn to_llama_kv_cache_dtype(self) -> LlamaKvCacheDtype; -} - -impl ConvertsToLlamaKvCacheDtype for KvCacheDtype { - fn to_llama_kv_cache_dtype(self) -> LlamaKvCacheDtype { - match self { - Self::F32 => LlamaKvCacheDtype::F32, - Self::F16 => LlamaKvCacheDtype::F16, - Self::BF16 => LlamaKvCacheDtype::BF16, - Self::Q8_0 => LlamaKvCacheDtype::Q8_0, - Self::Q4_0 => LlamaKvCacheDtype::Q4_0, - Self::Q4_1 => LlamaKvCacheDtype::Q4_1, - Self::IQ4_NL => LlamaKvCacheDtype::IQ4_NL, - Self::Q5_0 => LlamaKvCacheDtype::Q5_0, - Self::Q5_1 => LlamaKvCacheDtype::Q5_1, - } - } -} diff --git a/paddler/src/converts_to_llama_pooling_type.rs b/paddler/src/converts_to_llama_pooling_type.rs deleted file mode 100644 index 2a8572c9..00000000 --- a/paddler/src/converts_to_llama_pooling_type.rs +++ /dev/null @@ -1,19 +0,0 @@ -use llama_cpp_bindings::context::params::LlamaPoolingType; -use paddler_types::pooling_type::PoolingType; - -pub trait ConvertsToLlamaPoolingType { - fn to_llama_pooling_type(self) -> LlamaPoolingType; -} - -impl ConvertsToLlamaPoolingType for PoolingType { - fn to_llama_pooling_type(self) -> LlamaPoolingType { - match self { - Self::Unspecified => LlamaPoolingType::Unspecified, - Self::None => LlamaPoolingType::None, - Self::Mean => LlamaPoolingType::Mean, - Self::Cls => LlamaPoolingType::Cls, - Self::Last => LlamaPoolingType::Last, - Self::Rank => LlamaPoolingType::Rank, - } - } -} diff --git a/paddler/src/lib.rs b/paddler/src/lib.rs deleted file mode 100644 index 5f97aef1..00000000 --- a/paddler/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -pub mod agent; -pub mod agent_applicable_state; -pub mod agent_applicable_state_holder; -pub mod agent_desired_state; -pub mod agent_issue_fix; -pub mod atomic_value; -pub mod balancer; -pub mod balancer_applicable_state; -pub mod balancer_applicable_state_holder; -pub mod balancer_desired_state; -pub mod cancellation_token_stream_guard; -pub mod chat_template_renderer; -pub mod continuation_decision; -pub mod continuation_stop_parameters; -pub mod controls_session; -pub mod controls_websocket_endpoint; -pub mod converts_to_applicable_state; -pub mod converts_to_llama_kv_cache_dtype; -pub mod converts_to_llama_pooling_type; -pub mod create_cors_middleware; -pub mod decoded_image; -pub mod decoded_image_error; -pub mod desired_model_resolution; -pub mod dispenses_slots; -pub mod embedding_input_tokenized; -pub mod model_source; -pub mod produces_snapshot; -pub mod resolve_desired_model; -pub mod resolved_socket_addr; -pub mod resolves_model_source; -pub mod sends_rpc_message; -pub mod sets_desired_state; -pub mod slot_aggregated_status; -pub mod slot_aggregated_status_download_progress; -pub mod slot_aggregated_status_manager; -pub mod snapshots_stream; -#[cfg(feature = "web_admin_panel")] -pub mod static_files; -pub mod subscribes_to_updates; -pub mod tool_call_buffer; -pub mod tool_call_event; -pub mod tool_call_pipeline; -pub mod tool_call_pipeline_error; -pub mod tool_call_validation_error; -pub mod tool_call_validator; -pub mod websocket_session_controller; diff --git a/paddler/src/static_files.rs b/paddler/src/static_files.rs deleted file mode 100644 index 9674e710..00000000 --- a/paddler/src/static_files.rs +++ /dev/null @@ -1,5 +0,0 @@ -use rust_embed::Embed; - -#[derive(Embed)] -#[folder = "../static"] -pub struct StaticFiles; diff --git a/paddler/Cargo.toml b/paddler_agent/Cargo.toml similarity index 70% rename from paddler/Cargo.toml rename to paddler_agent/Cargo.toml index a6277801..3301c6fb 100644 --- a/paddler/Cargo.toml +++ b/paddler_agent/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "paddler" +name = "paddler_agent" authors.workspace = true description.workspace = true edition.workspace = true @@ -9,44 +9,37 @@ repository.workspace = true version.workspace = true [dependencies] -actix = { workspace = true } -actix-cors = { workspace = true } -actix-rt = { workspace = true } -actix-web = { workspace = true } -actix-web-lab = { workspace = true } -actix-ws = { workspace = true } anyhow = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } +base64 = { workspace = true } bytes = { workspace = true } -cadence = { workspace = true } -clap = { workspace = true } dashmap = { workspace = true } encoding_rs = { workspace = true } -env_logger = { workspace = true } futures = { workspace = true } futures-util = { workspace = true } hf-hub = { workspace = true } image = { workspace = true } -indoc = { workspace = true } jsonschema = { workspace = true } llama-cpp-bindings = { workspace = true } llama-cpp-bindings-sys = { workspace = true } -base64 = { workspace = true } +llama-cpp-bindings-types = { workspace = true } log = { workspace = true } minijinja = { workspace = true } minijinja-contrib = { workspace = true } nanoid = { workspace = true } paddler_cache_dir = { workspace = true } paddler_download_manager = { workspace = true } -paddler_types = { workspace = true } -thiserror = { workspace = true } +paddler_messaging = { workspace = true } +paddler_state_conversion = { workspace = true } +parking_lot = { workspace = true } rand = { workspace = true } reqwest = { workspace = true } resvg = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } shellexpand = { workspace = true } +thiserror = { workspace = true } tokio = { workspace = true } tokio-stream = { workspace = true } tokio-tungstenite = { workspace = true } @@ -54,25 +47,14 @@ tokio-util = { workspace = true } trzcina = { workspace = true } url = { workspace = true } -# web dashboard deps -askama = { workspace = true, optional = true } -esbuild-metafile = { workspace = true, optional = true } -mime_guess = { workspace = true, optional = true } -rust-embed = { workspace = true, optional = true } - [features] default = [] cuda = ["llama-cpp-bindings/cuda"] metal = ["llama-cpp-bindings/metal"] -web_admin_panel = [ - "dep:askama", - "dep:esbuild-metafile", - "dep:mime_guess", - "dep:rust-embed", -] vulkan = ["llama-cpp-bindings/vulkan"] [dev-dependencies] +indoc = { workspace = true } tempfile = { workspace = true } tokio-test = { workspace = true } diff --git a/paddler/src/agent_applicable_state.rs b/paddler_agent/src/agent_applicable_state.rs similarity index 70% rename from paddler/src/agent_applicable_state.rs rename to paddler_agent/src/agent_applicable_state.rs index b6bb9db4..a91d5a74 100644 --- a/paddler/src/agent_applicable_state.rs +++ b/paddler_agent/src/agent_applicable_state.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; #[derive(Clone, Debug)] pub struct AgentApplicableState { diff --git a/paddler/src/agent_applicable_state_holder.rs b/paddler_agent/src/agent_applicable_state_holder.rs similarity index 69% rename from paddler/src/agent_applicable_state_holder.rs rename to paddler_agent/src/agent_applicable_state_holder.rs index 629ffbc4..9d33c329 100644 --- a/paddler/src/agent_applicable_state_holder.rs +++ b/paddler_agent/src/agent_applicable_state_holder.rs @@ -1,6 +1,5 @@ -use std::sync::RwLock; - use anyhow::Result; +use parking_lot::RwLock; use tokio::sync::watch; use crate::agent_applicable_state::AgentApplicableState; @@ -11,24 +10,16 @@ pub struct AgentApplicableStateHolder { } impl AgentApplicableStateHolder { - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_agent_applicable_state(&self) -> Option { - self.agent_applicable_state - .read() - .expect("Failed to acquire read lock") - .clone() + self.agent_applicable_state.read().clone() } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_agent_applicable_state( &self, agent_applicable_state: Option, ) -> Result<()> { { - let mut state = self - .agent_applicable_state - .write() - .expect("Failed to acquire write lock"); + let mut state = self.agent_applicable_state.write(); (*state).clone_from(&agent_applicable_state); } diff --git a/paddler/src/agent_desired_state.rs b/paddler_agent/src/agent_desired_state_converter.rs similarity index 61% rename from paddler/src/agent_desired_state.rs rename to paddler_agent/src/agent_desired_state_converter.rs index 5313dd59..b4796ab5 100644 --- a/paddler/src/agent_desired_state.rs +++ b/paddler_agent/src/agent_desired_state_converter.rs @@ -4,13 +4,14 @@ use std::sync::Arc; use anyhow::Result; use anyhow::anyhow; use async_trait::async_trait; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_desired_state::AgentDesiredState; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ModelPath; + +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_state_conversion::converts_to_applicable_state::ConvertsToApplicableState; use crate::agent_applicable_state::AgentApplicableState; -use crate::converts_to_applicable_state::ConvertsToApplicableState; use crate::desired_model_resolution::DesiredModelResolution; use crate::resolve_desired_model::resolve_desired_model; use crate::slot_aggregated_status::SlotAggregatedStatus; @@ -38,32 +39,36 @@ where } } +pub struct AgentDesiredStateConverter { + pub slot_aggregated_status: Arc, +} + #[async_trait] -impl ConvertsToApplicableState for AgentDesiredState { +impl ConvertsToApplicableState for AgentDesiredStateConverter { type ApplicableState = AgentApplicableState; - type Context = Arc; + type DesiredState = AgentDesiredState; async fn to_applicable_state( &self, - slot_aggregated_status: Self::Context, - ) -> Result { + desired_state: AgentDesiredState, + ) -> Result { let model_path = resolve_into_optional_path( - &self.model, - &slot_aggregated_status, + &desired_state.model, + &self.slot_aggregated_status, AgentIssue::ModelFileDoesNotExist, ) .await?; let multimodal_projection_path = resolve_into_optional_path( - &self.multimodal_projection, - &slot_aggregated_status, + &desired_state.multimodal_projection, + &self.slot_aggregated_status, AgentIssue::MultimodalProjectionCannotBeLoaded, ) .await?; Ok(AgentApplicableState { - chat_template_override: self.chat_template_override.clone(), - inference_parameters: self.inference_parameters.clone(), + chat_template_override: desired_state.chat_template_override, + inference_parameters: desired_state.inference_parameters, model_path, multimodal_projection_path, }) @@ -75,26 +80,35 @@ mod tests { use std::path::PathBuf; use std::sync::Arc; - use anyhow::Result; - use paddler_types::agent_desired_model::AgentDesiredModel; - use paddler_types::agent_desired_state::AgentDesiredState; - use paddler_types::agent_issue::AgentIssue; - use paddler_types::agent_issue_params::ModelPath; - use paddler_types::inference_parameters::InferenceParameters; use tempfile::TempDir; - use crate::converts_to_applicable_state::ConvertsToApplicableState; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::agent_desired_state::AgentDesiredState; + use paddler_messaging::agent_issue::AgentIssue; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::inference_parameters::InferenceParameters; + use paddler_state_conversion::converts_to_applicable_state::ConvertsToApplicableState as _; + + use crate::agent_desired_state_converter::AgentDesiredStateConverter; use crate::slot_aggregated_status::SlotAggregatedStatus; + struct MissingLocalModel { + _dir_guard: TempDir, + path: PathBuf, + } + fn fresh_status() -> Arc { Arc::new(SlotAggregatedStatus::new(1)) } - fn nonexistent_path_in_temp_dir(label: &str) -> Result<(TempDir, PathBuf)> { - let dir = tempfile::tempdir()?; - let path = dir.path().join(format!("missing-{label}.gguf")); + fn nonexistent_path_in_temp_dir(label: &str) -> MissingLocalModel { + let dir_guard = tempfile::tempdir().unwrap(); + let path = dir_guard.path().join(format!("missing-{label}.gguf")); - Ok((dir, path)) + MissingLocalModel { + _dir_guard: dir_guard, + path, + } } fn desired_state( @@ -110,19 +124,25 @@ mod tests { } #[tokio::test] - async fn local_missing_model_registers_model_file_does_not_exist_and_errs() -> Result<()> { + async fn local_missing_model_registers_model_file_does_not_exist_and_errs() { let status = fresh_status(); - let (_dir_guard, missing_path) = nonexistent_path_in_temp_dir("model")?; + let MissingLocalModel { + _dir_guard, + path: missing_path, + } = nonexistent_path_in_temp_dir("model"); let desired = desired_state( AgentDesiredModel::LocalToAgent(missing_path.display().to_string()), AgentDesiredModel::None, ); + let converter = AgentDesiredStateConverter { + slot_aggregated_status: status.clone(), + }; - let outcome = desired.to_applicable_state(status.clone()).await; + let outcome = converter.to_applicable_state(desired).await; assert!( outcome.is_err(), - "AgentDesiredState::to_applicable_state must Err when the model's local path is missing" + "AgentDesiredStateConverter must Err when the model's local path is missing" ); assert!( status.has_issue(&AgentIssue::ModelFileDoesNotExist(ModelPath { @@ -136,25 +156,29 @@ mod tests { })), "MultimodalProjectionCannotBeLoaded must NOT be registered for a missing model" ); - - Ok(()) } #[tokio::test] async fn local_missing_multimodal_projection_registers_multimodal_projection_cannot_be_loaded_and_errs() - -> Result<()> { + { let status = fresh_status(); - let (_dir_guard, missing_path) = nonexistent_path_in_temp_dir("projection")?; + let MissingLocalModel { + _dir_guard, + path: missing_path, + } = nonexistent_path_in_temp_dir("projection"); let desired = desired_state( AgentDesiredModel::None, AgentDesiredModel::LocalToAgent(missing_path.display().to_string()), ); + let converter = AgentDesiredStateConverter { + slot_aggregated_status: status.clone(), + }; - let outcome = desired.to_applicable_state(status.clone()).await; + let outcome = converter.to_applicable_state(desired).await; assert!( outcome.is_err(), - "AgentDesiredState::to_applicable_state must Err when the projection's local path is missing" + "AgentDesiredStateConverter must Err when the projection's local path is missing" ); assert!( status.has_issue(&AgentIssue::MultimodalProjectionCannotBeLoaded(ModelPath { @@ -168,7 +192,5 @@ mod tests { })), "ModelFileDoesNotExist must NOT be registered for a missing projection" ); - - Ok(()) } } diff --git a/paddler/src/agent_issue_fix.rs b/paddler_agent/src/agent_issue_fix.rs similarity index 68% rename from paddler/src/agent_issue_fix.rs rename to paddler_agent/src/agent_issue_fix.rs index ab4e2ec5..cf483d7e 100644 --- a/paddler/src/agent_issue_fix.rs +++ b/paddler_agent/src/agent_issue_fix.rs @@ -1,6 +1,6 @@ -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::agent_issue_params::SlotCannotStartParams; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_messaging::agent_issue_params::slot_cannot_start_params::SlotCannotStartParams; #[derive(Debug)] pub enum AgentIssueFix { @@ -86,9 +86,7 @@ impl AgentIssueFix { | AgentIssue::ModelCacheIsCorrupted(issue_model_path) | AgentIssue::ModelDoesNotExistAtUrl(issue_model_path) => match self { Self::ModelDownloadCompleted(fix_model_path) - | Self::ModelDownloadStarted(fix_model_path) => { - issue_model_path.eq(fix_model_path) - } + | Self::ModelDownloadStarted(fix_model_path) => issue_model_path.eq(fix_model_path), Self::ModelStateIsReconciled => true, _ => false, }, @@ -98,8 +96,9 @@ impl AgentIssueFix { #[cfg(test)] mod tests { - use paddler_types::agent_issue_params::ChatTemplateDoesNotCompileParams; - use paddler_types::agent_issue_params::SlotCannotStartParams; + use paddler_messaging::agent_issue_params::chat_template_does_not_compile_params::ChatTemplateDoesNotCompileParams; + use paddler_messaging::agent_issue_params::hugging_face_download_lock::HuggingFaceDownloadLock; + use paddler_messaging::agent_issue_params::slot_cannot_start_params::SlotCannotStartParams; use super::*; @@ -204,7 +203,8 @@ mod tests { #[test] fn model_download_started_fixes_download_server_denied_access() { let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); - let issue = AgentIssue::DownloadServerDeniedAccess(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::DownloadServerDeniedAccess(model_path("https://example.com/m.gguf")); assert!(fix.can_fix(&issue)); } @@ -262,8 +262,7 @@ mod tests { #[test] fn model_download_started_fixes_download_interrupted() { let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); - let issue = - AgentIssue::DownloadInterrupted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::DownloadInterrupted(model_path("https://example.com/m.gguf")); assert!(fix.can_fix(&issue)); } @@ -303,7 +302,116 @@ mod tests { #[test] fn model_download_started_does_not_fix_huggingface_issues() { let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); - let issue = AgentIssue::HuggingFaceModelDoesNotExist(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::HuggingFaceModelDoesNotExist(model_path("https://example.com/m.gguf")); + + assert!(!fix.can_fix(&issue)); + } + + #[test] + fn chat_template_does_not_compile_not_fixed_by_unrelated_fix() { + let fix = AgentIssueFix::ModelIsLoaded(model_path("model_a")); + let issue = AgentIssue::ChatTemplateDoesNotCompile(ChatTemplateDoesNotCompileParams { + error: "error".to_owned(), + model_path: model_path("model_a"), + template_content: "template".to_owned(), + }); + + assert!(!fix.can_fix(&issue)); + } + + #[test] + fn hugging_face_cannot_acquire_lock_fixes() { + let issue = AgentIssue::HuggingFaceCannotAcquireLock(HuggingFaceDownloadLock { + lock_path: "/tmp/lock".to_owned(), + model_path: model_path("model_a"), + }); + + assert!(AgentIssueFix::HuggingFaceDownloadedModel(model_path("model_a")).can_fix(&issue)); + assert!( + AgentIssueFix::HuggingFaceStartedDownloading(model_path("model_a")).can_fix(&issue) + ); + assert!(AgentIssueFix::ModelStateIsReconciled.can_fix(&issue)); + assert!( + !AgentIssueFix::HuggingFaceStartedDownloading(model_path("model_b")).can_fix(&issue) + ); + assert!(!AgentIssueFix::ModelIsLoaded(model_path("model_a")).can_fix(&issue)); + } + + #[test] + fn hugging_face_model_does_not_exist_fixes() { + let issue = AgentIssue::HuggingFaceModelDoesNotExist(model_path("model_a")); + + assert!(AgentIssueFix::HuggingFaceDownloadedModel(model_path("model_a")).can_fix(&issue)); + assert!(AgentIssueFix::MultimodalProjectionIsLoaded(model_path("model_a")).can_fix(&issue)); + assert!(AgentIssueFix::ModelStateIsReconciled.can_fix(&issue)); + assert!(!AgentIssueFix::ModelIsLoaded(model_path("model_a")).can_fix(&issue)); + } + + #[test] + fn hugging_face_permissions_fixed_by_started_downloading() { + let fix = AgentIssueFix::HuggingFaceStartedDownloading(model_path("model_a")); + let issue = AgentIssue::HuggingFacePermissions(model_path("model_a")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_cannot_be_loaded_not_fixed_by_unrelated_fix() { + let fix = AgentIssueFix::ModelFileExists(model_path("model_a")); + let issue = AgentIssue::ModelCannotBeLoaded(model_path("model_a")); + + assert!(!fix.can_fix(&issue)); + } + + #[test] + fn model_file_does_not_exist_fixed_by_multimodal_projection_and_not_others() { + let issue = AgentIssue::ModelFileDoesNotExist(model_path("model_a")); + + assert!(AgentIssueFix::MultimodalProjectionIsLoaded(model_path("model_a")).can_fix(&issue)); + assert!(!AgentIssueFix::ModelIsLoaded(model_path("model_a")).can_fix(&issue)); + } + + #[test] + fn multimodal_projection_cannot_be_loaded_fixed_only_by_multimodal_projection_loaded() { + let issue = AgentIssue::MultimodalProjectionCannotBeLoaded(model_path("model_a")); + + assert!(AgentIssueFix::MultimodalProjectionIsLoaded(model_path("model_a")).can_fix(&issue)); + assert!(!AgentIssueFix::ModelIsLoaded(model_path("model_a")).can_fix(&issue)); + } + + #[test] + fn slot_cannot_start_not_fixed_by_unrelated_fix() { + let fix = AgentIssueFix::ModelIsLoaded(model_path("model_a")); + let issue = AgentIssue::SlotCannotStart(SlotCannotStartParams { + error: "failed".to_owned(), + slot_index: 1, + }); + + assert!(!fix.can_fix(&issue)); + } + + #[test] + fn unable_to_find_chat_template_fixed_by_model_chat_template_loaded() { + let issue = AgentIssue::UnableToFindChatTemplate(model_path("model_a")); + + assert!(AgentIssueFix::ModelChatTemplateIsLoaded(model_path("model_a")).can_fix(&issue)); + assert!(!AgentIssueFix::ModelChatTemplateIsLoaded(model_path("model_b")).can_fix(&issue)); + } + + #[test] + fn download_server_rejected_request_fixed_by_model_download_started() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::DownloadServerRejectedRequest(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn download_issue_not_fixed_by_unrelated_fix() { + let fix = AgentIssueFix::ModelIsLoaded(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::DownloadInterrupted(model_path("https://example.com/m.gguf")); assert!(!fix.can_fix(&issue)); } diff --git a/paddler_agent/src/agent_kv_cache_dtype.rs b/paddler_agent/src/agent_kv_cache_dtype.rs new file mode 100644 index 00000000..3279b465 --- /dev/null +++ b/paddler_agent/src/agent_kv_cache_dtype.rs @@ -0,0 +1,71 @@ +use llama_cpp_bindings::context::params::KvCacheType as LlamaKvCacheDtype; +use paddler_messaging::kv_cache_dtype::KvCacheDtype; + +use crate::converts_to_llama_kv_cache_dtype::ConvertsToLlamaKvCacheDtype; + +pub struct AgentKvCacheDtype(pub KvCacheDtype); + +impl ConvertsToLlamaKvCacheDtype for AgentKvCacheDtype { + fn to_llama_kv_cache_dtype(self) -> LlamaKvCacheDtype { + match self.0 { + KvCacheDtype::F32 => LlamaKvCacheDtype::F32, + KvCacheDtype::F16 => LlamaKvCacheDtype::F16, + KvCacheDtype::Bf16 => LlamaKvCacheDtype::BF16, + KvCacheDtype::Q80 => LlamaKvCacheDtype::Q8_0, + KvCacheDtype::Q40 => LlamaKvCacheDtype::Q4_0, + KvCacheDtype::Q41 => LlamaKvCacheDtype::Q4_1, + KvCacheDtype::Iq4Nl => LlamaKvCacheDtype::IQ4_NL, + KvCacheDtype::Q50 => LlamaKvCacheDtype::Q5_0, + KvCacheDtype::Q51 => LlamaKvCacheDtype::Q5_1, + } + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::context::params::KvCacheType as LlamaKvCacheDtype; + use paddler_messaging::kv_cache_dtype::KvCacheDtype; + + use super::AgentKvCacheDtype; + use crate::converts_to_llama_kv_cache_dtype::ConvertsToLlamaKvCacheDtype; + + #[test] + fn maps_each_kv_cache_dtype_to_its_llama_counterpart() { + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::F32).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::F32 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::F16).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::F16 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Bf16).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::BF16 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Q80).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::Q8_0 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Q40).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::Q4_0 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Q41).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::Q4_1 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Iq4Nl).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::IQ4_NL + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Q50).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::Q5_0 + ); + assert_eq!( + AgentKvCacheDtype(KvCacheDtype::Q51).to_llama_kv_cache_dtype(), + LlamaKvCacheDtype::Q5_1 + ); + } +} diff --git a/paddler_agent/src/agent_pooling_type.rs b/paddler_agent/src/agent_pooling_type.rs new file mode 100644 index 00000000..7a011a9b --- /dev/null +++ b/paddler_agent/src/agent_pooling_type.rs @@ -0,0 +1,76 @@ +use llama_cpp_bindings::context::params::LlamaPoolingType; +use paddler_messaging::pooling_type::PoolingType; + +use crate::converts_to_llama_pooling_type::ConvertsToLlamaPoolingType; + +pub struct AgentPoolingType(pub PoolingType); + +impl ConvertsToLlamaPoolingType for AgentPoolingType { + fn to_llama_pooling_type(self) -> LlamaPoolingType { + match self.0 { + PoolingType::Unspecified => LlamaPoolingType::Unspecified, + PoolingType::None => LlamaPoolingType::None, + PoolingType::Mean => LlamaPoolingType::Mean, + PoolingType::Cls => LlamaPoolingType::Cls, + PoolingType::Last => LlamaPoolingType::Last, + PoolingType::Rank => LlamaPoolingType::Rank, + } + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::context::params::LlamaPoolingType; + use paddler_messaging::pooling_type::PoolingType; + + use super::AgentPoolingType; + use crate::converts_to_llama_pooling_type::ConvertsToLlamaPoolingType; + + #[test] + fn converts_unspecified() { + assert_eq!( + AgentPoolingType(PoolingType::Unspecified).to_llama_pooling_type(), + LlamaPoolingType::Unspecified + ); + } + + #[test] + fn converts_none() { + assert_eq!( + AgentPoolingType(PoolingType::None).to_llama_pooling_type(), + LlamaPoolingType::None + ); + } + + #[test] + fn converts_mean() { + assert_eq!( + AgentPoolingType(PoolingType::Mean).to_llama_pooling_type(), + LlamaPoolingType::Mean + ); + } + + #[test] + fn converts_cls() { + assert_eq!( + AgentPoolingType(PoolingType::Cls).to_llama_pooling_type(), + LlamaPoolingType::Cls + ); + } + + #[test] + fn converts_last() { + assert_eq!( + AgentPoolingType(PoolingType::Last).to_llama_pooling_type(), + LlamaPoolingType::Last + ); + } + + #[test] + fn converts_rank() { + assert_eq!( + AgentPoolingType(PoolingType::Rank).to_llama_pooling_type(), + LlamaPoolingType::Rank + ); + } +} diff --git a/paddler/src/chat_template_renderer/mod.rs b/paddler_agent/src/chat_template_renderer/mod.rs similarity index 59% rename from paddler/src/chat_template_renderer/mod.rs rename to paddler_agent/src/chat_template_renderer/mod.rs index 94f5445e..420b5512 100644 --- a/paddler/src/chat_template_renderer/mod.rs +++ b/paddler_agent/src/chat_template_renderer/mod.rs @@ -1,10 +1,11 @@ pub mod pyjinja_tojson; pub mod raise_exception; +use anyhow::Context as _; use anyhow::Result; use minijinja::Environment; use minijinja_contrib::pycompat::unknown_method_callback; -use paddler_types::chat_template::ChatTemplate; +use paddler_messaging::chat_template::ChatTemplate; use serde::ser::Serialize; use self::pyjinja_tojson::pyjinja_tojson; @@ -33,7 +34,8 @@ impl ChatTemplateRenderer { pub fn render(&self, context: TContext) -> Result { Ok(self .minijinja_env - .get_template(CHAT_TEMPLATE_NAME)? + .get_template(CHAT_TEMPLATE_NAME) + .context("chat template is not registered in the rendering environment")? .render(context)?) } } @@ -42,11 +44,10 @@ impl ChatTemplateRenderer { mod tests { use std::collections::HashMap; - use anyhow::Result; use minijinja::context; - use paddler_types::chat_template::ChatTemplate; - use paddler_types::chat_template_message::ChatTemplateMessage; - use paddler_types::chat_template_message_content::ChatTemplateMessageContent; + use paddler_messaging::chat_template::ChatTemplate; + use paddler_messaging::chat_template_message::ChatTemplateMessage; + use paddler_messaging::chat_template_message_content::ChatTemplateMessageContent; use crate::chat_template_renderer::ChatTemplateRenderer; @@ -69,27 +70,25 @@ mod tests { } #[test] - fn render_produces_expected_output() -> Result<()> { + fn render_produces_expected_output() { let template = ChatTemplate { content: "Hello {{ name }}!".to_owned(), }; - let renderer = ChatTemplateRenderer::new(template)?; + let renderer = ChatTemplateRenderer::new(template).unwrap(); let mut context = HashMap::new(); context.insert("name", "world"); - let result = renderer.render(context)?; + let result = renderer.render(context).unwrap(); assert_eq!(result, "Hello world!"); - - Ok(()) } #[test] - fn renders_messages_loop_with_roles() -> Result<()> { + fn renders_messages_loop_with_roles() { let template = ChatTemplate { content: "{% for message in messages %}{{ message.role }}:{{ message.content }}\n{% endfor %}".to_owned(), }; - let renderer = ChatTemplateRenderer::new(template)?; + let renderer = ChatTemplateRenderer::new(template).unwrap(); let messages = vec![ ChatTemplateMessage { content: ChatTemplateMessageContent::Text("hi".to_owned()), @@ -101,62 +100,78 @@ mod tests { }, ]; - let result = renderer.render(context! { messages => messages })?; + let result = renderer.render(context! { messages => messages }).unwrap(); assert_eq!(result, "user:hi\nassistant:hello\n"); - - Ok(()) } #[test] - fn add_generation_prompt_branch_changes_output() -> Result<()> { + fn add_generation_prompt_branch_changes_output() { let template = ChatTemplate { content: "A{% if add_generation_prompt %}B{% endif %}".to_owned(), }; - let renderer = ChatTemplateRenderer::new(template)?; + let renderer = ChatTemplateRenderer::new(template).unwrap(); - let with_prompt = renderer.render(context! { add_generation_prompt => true })?; - let without_prompt = renderer.render(context! { add_generation_prompt => false })?; + let with_prompt = renderer + .render(context! { add_generation_prompt => true }) + .unwrap(); + let without_prompt = renderer + .render(context! { add_generation_prompt => false }) + .unwrap(); assert_eq!(with_prompt, "AB"); assert_eq!(without_prompt, "A"); - - Ok(()) } #[test] - fn registers_pyjinja_tojson_filter() -> Result<()> { + fn registers_pyjinja_tojson_filter() { let template = ChatTemplate { content: "{{ value | tojson(ensure_ascii=False) }}".to_owned(), }; - let renderer = ChatTemplateRenderer::new(template)?; + let renderer = ChatTemplateRenderer::new(template).unwrap(); - let result = renderer.render(context! { value => "café" })?; + let result = renderer.render(context! { value => "café" }).unwrap(); assert_eq!(result, "\"café\""); - - Ok(()) } #[test] - fn registers_raise_exception_function() -> Result<()> { + fn registers_raise_exception_function() { let template = ChatTemplate { content: "{{ raise_exception('boom') }}".to_owned(), }; - let template_renderer = ChatTemplateRenderer::new(template)?; + let template_renderer = ChatTemplateRenderer::new(template).unwrap(); - let err = template_renderer + let render_error = template_renderer .render(context! {}) - .err() - .ok_or_else(|| anyhow::anyhow!("expected Err, got Ok"))?; - let error_message = err.to_string(); + .expect_err("raise_exception must turn rendering into an error"); + let error_message = render_error.to_string(); + + assert!( + error_message.contains("boom"), + "raise_exception must surface its message; got: {error_message}" + ); + } + + #[test] + fn render_fails_when_template_is_not_registered() { + let template = ChatTemplate { + content: "Hello {{ name }}!".to_owned(), + }; + let mut renderer = ChatTemplateRenderer::new(template).unwrap(); - if !error_message.contains("boom") { - return Err(anyhow::anyhow!( - "raise_exception must surface its message; got: {error_message}" - )); - } + renderer + .minijinja_env + .remove_template(super::CHAT_TEMPLATE_NAME); + + let render_error = renderer + .render(context! {}) + .expect_err("rendering must fail when the template is missing"); + let error_message = render_error.to_string(); - Ok(()) + assert_eq!( + error_message, + "chat template is not registered in the rendering environment" + ); } } diff --git a/paddler/src/chat_template_renderer/pyjinja_tojson.rs b/paddler_agent/src/chat_template_renderer/pyjinja_tojson.rs similarity index 52% rename from paddler/src/chat_template_renderer/pyjinja_tojson.rs rename to paddler_agent/src/chat_template_renderer/pyjinja_tojson.rs index cfcc8db6..12c0f04f 100644 --- a/paddler/src/chat_template_renderer/pyjinja_tojson.rs +++ b/paddler_agent/src/chat_template_renderer/pyjinja_tojson.rs @@ -31,8 +31,7 @@ pub fn pyjinja_tojson(value: &Value, kwargs: Kwargs) -> Result { )); } - let separators: Option = kwargs.get("separators")?; - if separators.is_some() { + if kwargs.has("separators") { return Err(Error::new( ErrorKind::InvalidOperation, "tojson(separators=...) is not supported by minijinja: separator strings are fixed.", @@ -48,163 +47,169 @@ pub fn pyjinja_tojson(value: &Value, kwargs: Kwargs) -> Result { #[cfg(test)] mod tests { - use anyhow::Result; - use anyhow::anyhow; use minijinja::Environment; use minijinja::context; use super::pyjinja_tojson; - fn render(template_source: &str, scope: minijinja::Value) -> Result { - let mut env = Environment::new(); - env.add_filter("tojson", pyjinja_tojson); - env.add_template_owned("t", template_source.to_owned())?; - Ok(env.get_template("t")?.render(scope)?) + fn render(template_source: &str, scope: minijinja::Value) -> String { + let mut environment = Environment::new(); + environment.add_filter("tojson", pyjinja_tojson); + environment + .add_template_owned("t", template_source.to_owned()) + .unwrap(); + + environment + .get_template("t") + .unwrap() + .render(scope) + .unwrap() } - fn render_expecting_error( - template_source: &str, - scope: minijinja::Value, - ) -> Result { - let mut env = Environment::new(); - env.add_filter("tojson", pyjinja_tojson); - env.add_template_owned("t", template_source.to_owned())?; - let outcome = env.get_template("t")?.render(scope); - - outcome.err().ok_or_else(|| anyhow!("expected Err, got Ok")) + fn render_error_message(template_source: &str, scope: minijinja::Value) -> String { + let mut environment = Environment::new(); + environment.add_filter("tojson", pyjinja_tojson); + environment + .add_template_owned("t", template_source.to_owned()) + .unwrap(); + + environment + .get_template("t") + .unwrap() + .render(scope) + .unwrap_err() + .to_string() } #[test] - fn no_kwargs_emits_quoted_json_string() -> Result<()> { - let result = render("{{ value | tojson }}", context! { value => "hello" })?; + fn no_kwargs_emits_quoted_json_string() { + let result = render("{{ value | tojson }}", context! { value => "hello" }); assert_eq!(result, "\"hello\""); - - Ok(()) } #[test] - fn ensure_ascii_false_matches_default_output() -> Result<()> { + fn ensure_ascii_false_matches_default_output() { let with_kwarg = render( "{{ value | tojson(ensure_ascii=False) }}", context! { value => "café" }, - )?; - let without_kwarg = render("{{ value | tojson }}", context! { value => "café" })?; + ); + let without_kwarg = render("{{ value | tojson }}", context! { value => "café" }); assert_eq!(with_kwarg, without_kwarg); assert_eq!(with_kwarg, "\"café\""); - - Ok(()) } #[test] - fn ensure_ascii_true_returns_error_naming_the_kwarg() -> Result<()> { - let err = render_expecting_error( + fn ensure_ascii_true_returns_error_naming_the_kwarg() { + let rendered = render_error_message( "{{ value | tojson(ensure_ascii=True) }}", context! { value => "x" }, - )?; - let rendered = err.to_string(); + ); - if !rendered.contains("ensure_ascii=True") { - return Err(anyhow!( - "error must name the rejected kwarg; got: {rendered}" - )); - } + assert!( + rendered.contains("ensure_ascii=True"), + "error must name the rejected kwarg; got: {rendered}" + ); + } - Ok(()) + #[test] + fn ensure_ascii_non_bool_propagates_kwargs_get_error() { + let rendered = render_error_message( + "{{ value | tojson(ensure_ascii='nope') }}", + context! { value => "x" }, + ); + + assert!( + !rendered.is_empty(), + "a type-mismatched ensure_ascii kwarg must surface an error" + ); } #[test] - fn sort_keys_false_matches_default_output() -> Result<()> { + fn sort_keys_false_matches_default_output() { let with_kwarg = render( "{{ value | tojson(sort_keys=False) }}", context! { value => "x" }, - )?; + ); assert_eq!(with_kwarg, "\"x\""); - - Ok(()) } #[test] - fn sort_keys_true_returns_error_naming_the_kwarg() -> Result<()> { - let err = render_expecting_error( + fn sort_keys_true_returns_error_naming_the_kwarg() { + let rendered = render_error_message( "{{ value | tojson(sort_keys=True) }}", context! { value => "x" }, - )?; - let rendered = err.to_string(); + ); + + assert!( + rendered.contains("sort_keys=True"), + "error must name the rejected kwarg; got: {rendered}" + ); + } - if !rendered.contains("sort_keys=True") { - return Err(anyhow!( - "error must name the rejected kwarg; got: {rendered}" - )); - } + #[test] + fn sort_keys_non_bool_propagates_kwargs_get_error() { + let rendered = render_error_message( + "{{ value | tojson(sort_keys='nope') }}", + context! { value => "x" }, + ); - Ok(()) + assert!( + !rendered.is_empty(), + "a type-mismatched sort_keys kwarg must surface an error" + ); } #[test] - fn separators_returns_error_naming_the_kwarg() -> Result<()> { - let err = render_expecting_error( + fn separators_returns_error_naming_the_kwarg() { + let rendered = render_error_message( "{{ value | tojson(separators=[',', ':']) }}", context! { value => "x" }, - )?; - let rendered = err.to_string(); + ); - if !rendered.contains("separators") { - return Err(anyhow!( - "error must name the rejected kwarg; got: {rendered}" - )); - } - - Ok(()) + assert!( + rendered.contains("separators"), + "error must name the rejected kwarg; got: {rendered}" + ); } #[test] - fn indent_kwarg_emits_pretty_printed_json() -> Result<()> { + fn indent_kwarg_emits_pretty_printed_json() { let result = render( "{{ value | tojson(indent=2) }}", context! { value => context! { k => "v" } }, - )?; + ); assert_eq!(result, "{\n \"k\": \"v\"\n}"); - - Ok(()) } #[test] - fn indent_kwarg_combines_with_ensure_ascii_false() -> Result<()> { + fn indent_kwarg_combines_with_ensure_ascii_false() { let result = render( "{{ value | tojson(ensure_ascii=False, indent=2) }}", context! { value => context! { k => "café" } }, - )?; + ); assert_eq!(result, "{\n \"k\": \"café\"\n}"); - - Ok(()) } #[test] - fn unknown_kwarg_returns_error() -> Result<()> { - let err = - render_expecting_error("{{ value | tojson(bogus=42) }}", context! { value => "x" })?; - let rendered = err.to_string(); - - if !rendered.contains("bogus") { - return Err(anyhow!( - "error must name the unknown kwarg; got: {rendered}" - )); - } - - Ok(()) + fn unknown_kwarg_returns_error() { + let rendered = + render_error_message("{{ value | tojson(bogus=42) }}", context! { value => "x" }); + + assert!( + rendered.contains("bogus"), + "error must name the unknown kwarg; got: {rendered}" + ); } #[test] - fn non_ascii_codepoints_emitted_unescaped() -> Result<()> { - let result = render("{{ value | tojson }}", context! { value => "日本語" })?; + fn non_ascii_codepoints_emitted_unescaped() { + let result = render("{{ value | tojson }}", context! { value => "日本語" }); assert_eq!(result, "\"日本語\""); - - Ok(()) } } diff --git a/paddler_agent/src/chat_template_renderer/raise_exception.rs b/paddler_agent/src/chat_template_renderer/raise_exception.rs new file mode 100644 index 00000000..ecf8c326 --- /dev/null +++ b/paddler_agent/src/chat_template_renderer/raise_exception.rs @@ -0,0 +1,38 @@ +use minijinja::Error; +use minijinja::ErrorKind; + +// Surfaces errors raised explicitly inside a chat template. Known uses: +// https://huggingface.co/bartowski/Mistral-7B-Instruct-v0.3-GGUF +pub fn raise_exception(message: &str) -> Result { + Err(Error::new::( + ErrorKind::InvalidOperation, + format!("Model's chat template raised an exception: '{message}'"), + )) +} + +#[cfg(test)] +mod tests { + use minijinja::ErrorKind; + + use super::raise_exception; + + #[test] + fn returns_err_with_supplied_message_quoted() { + let error = raise_exception("template is invalid") + .expect_err("raise_exception must always return Err"); + let rendered = error.to_string(); + + assert!( + rendered.contains("template is invalid"), + "error must include the supplied message; got: {rendered}" + ); + } + + #[test] + fn returns_err_with_invalid_operation_kind() { + let error = + raise_exception("anything").expect_err("raise_exception must always return Err"); + + assert_eq!(error.kind(), ErrorKind::InvalidOperation); + } +} diff --git a/paddler/src/agent/continue_from_conversation_history_request.rs b/paddler_agent/src/continue_from_conversation_history_request.rs similarity index 69% rename from paddler/src/agent/continue_from_conversation_history_request.rs rename to paddler_agent/src/continue_from_conversation_history_request.rs index 1b16aa6b..0d54bf9e 100644 --- a/paddler/src/agent/continue_from_conversation_history_request.rs +++ b/paddler_agent/src/continue_from_conversation_history_request.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use tokio::sync::mpsc; -use crate::agent::from_request_params::FromRequestParams; -use crate::agent::slot_guard::SlotGuard; +use crate::from_request_params::FromRequestParams; use crate::slot_aggregated_status::SlotAggregatedStatus; +use crate::slot_guard::SlotGuard; pub struct ContinueFromConversationHistoryRequest { pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, diff --git a/paddler/src/agent/continue_from_raw_prompt_request.rs b/paddler_agent/src/continue_from_raw_prompt_request.rs similarity index 79% rename from paddler/src/agent/continue_from_raw_prompt_request.rs rename to paddler_agent/src/continue_from_raw_prompt_request.rs index 9f573e74..2446f098 100644 --- a/paddler/src/agent/continue_from_raw_prompt_request.rs +++ b/paddler_agent/src/continue_from_raw_prompt_request.rs @@ -1,12 +1,12 @@ use std::sync::Arc; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; use tokio::sync::mpsc; -use crate::agent::from_request_params::FromRequestParams; -use crate::agent::slot_guard::SlotGuard; +use crate::from_request_params::FromRequestParams; use crate::slot_aggregated_status::SlotAggregatedStatus; +use crate::slot_guard::SlotGuard; pub struct ContinueFromRawPromptRequest { pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, diff --git a/paddler_agent/src/continuous_batch_active_request.rs b/paddler_agent/src/continuous_batch_active_request.rs new file mode 100644 index 00000000..0c5a4b1a --- /dev/null +++ b/paddler_agent/src/continuous_batch_active_request.rs @@ -0,0 +1,102 @@ +use llama_cpp_bindings::SampledTokenClassifier; +use llama_cpp_bindings::sampling::LlamaSampler; +use log::warn; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TryRecvError; + +use crate::continuous_batch_request_state::ContinuousBatchRequestState; +use crate::slot_guard::SlotGuard; +use crate::tool_call_pipeline::ToolCallPipeline; + +fn send_outcome_or_warn( + agent_name: Option<&str>, + sequence_id: i32, + generated_tokens_tx: &mpsc::UnboundedSender, + outcome: GeneratedTokenResult, +) { + if generated_tokens_tx.send(outcome).is_err() { + warn!( + "{agent_name:?}: sequence {sequence_id} failed to send result to client (receiver dropped)" + ); + } +} + +pub struct ContinuousBatchActiveRequest { + pub state: ContinuousBatchRequestState, + pub chain: LlamaSampler, + pub token_classifier: SampledTokenClassifier<'static>, + pub grammar_sampler: Option, + pub generated_tokens_tx: mpsc::UnboundedSender, + pub generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>, + pub slot_guard: SlotGuard, + pub tool_call_pipeline: Option, +} + +impl ContinuousBatchActiveRequest { + pub fn complete_with_outcome( + &mut self, + agent_name: Option<&str>, + outcome: GeneratedTokenResult, + ) { + send_outcome_or_warn( + agent_name, + self.state.sequence_id, + &self.generated_tokens_tx, + outcome, + ); + + self.state.mark_completed(); + } + + pub fn is_stop_requested(&mut self) -> bool { + match self.generate_tokens_stop_rx.try_recv() { + Ok(()) | Err(TryRecvError::Disconnected) => true, + Err(TryRecvError::Empty) => false, + } + } +} + +#[cfg(test)] +mod tests { + use log::LevelFilter; + use tokio::sync::mpsc; + + use super::send_outcome_or_warn; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + + #[test] + fn delivers_outcome_to_a_live_receiver() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + + send_outcome_or_warn( + Some("agent"), + 7, + &generated_tokens_tx, + GeneratedTokenResult::ContentToken("hello".to_owned()), + ); + + assert!(matches!( + generated_tokens_rx.try_recv(), + Ok(GeneratedTokenResult::ContentToken(token)) if token == "hello" + )); + } + + #[test] + fn warns_without_panicking_when_the_receiver_was_dropped() { + log::set_max_level(LevelFilter::Trace); + + let (generated_tokens_tx, generated_tokens_rx) = mpsc::unbounded_channel(); + + drop(generated_tokens_rx); + + send_outcome_or_warn( + None, + 42, + &generated_tokens_tx, + GeneratedTokenResult::ContentToken("dropped".to_owned()), + ); + + assert!(generated_tokens_tx.is_closed()); + } +} diff --git a/paddler/src/agent/continuous_batch_arbiter.rs b/paddler_agent/src/continuous_batch_arbiter.rs similarity index 89% rename from paddler/src/agent/continuous_batch_arbiter.rs rename to paddler_agent/src/continuous_batch_arbiter.rs index d662d0ec..23186baf 100644 --- a/paddler/src/agent/continuous_batch_arbiter.rs +++ b/paddler_agent/src/continuous_batch_arbiter.rs @@ -21,25 +21,27 @@ use llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_AUTO; use log::error; use log::info; use log::warn; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ChatTemplateDoesNotCompileParams; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::agent_issue_params::SlotCannotStartParams; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::model_metadata::ModelMetadata; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::chat_template_does_not_compile_params::ChatTemplateDoesNotCompileParams; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_messaging::agent_issue_params::slot_cannot_start_params::SlotCannotStartParams; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::model_metadata::ModelMetadata; use tokio::sync::oneshot; -use crate::agent::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; -use crate::agent::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; -use crate::agent::continuous_batch_scheduler::ContinuousBatchScheduler; -use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; -use crate::agent::model_metadata_holder::ModelMetadataHolder; use crate::agent_applicable_state::AgentApplicableState; use crate::agent_issue_fix::AgentIssueFix; +use crate::agent_kv_cache_dtype::AgentKvCacheDtype; +use crate::agent_pooling_type::AgentPoolingType; use crate::chat_template_renderer::ChatTemplateRenderer; +use crate::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; +use crate::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; +use crate::continuous_batch_scheduler::ContinuousBatchScheduler; +use crate::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; use crate::converts_to_llama_kv_cache_dtype::ConvertsToLlamaKvCacheDtype; use crate::converts_to_llama_pooling_type::ConvertsToLlamaPoolingType; +use crate::model_metadata_holder::ModelMetadataHolder; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; pub struct ContinuousBatchArbiter { @@ -111,17 +113,11 @@ impl ContinuousBatchArbiter { let llama_backend = Arc::new(LlamaBackend::init().context("Unable to initialize llama.cpp backend")?); - #[expect( - clippy::cast_sign_loss, - reason = "desired_slots_total is always positive" - )] - let n_seq_max = desired_slots_total as u32; + let n_seq_max = u32::try_from(desired_slots_total) + .context("desired_slots_total does not fit in u32")?; - #[expect( - clippy::cast_possible_truncation, - reason = "n_batch fits in u32 for llama.cpp FFI; usize is the internal type" - )] - let inference_parameters_n_batch_u32 = inference_parameters.n_batch as u32; + let inference_parameters_n_batch_u32 = u32::try_from(inference_parameters.n_batch) + .context("n_batch does not fit in u32")?; let context_params = LlamaContextParams::default() .with_embeddings(inference_parameters.enable_embeddings) @@ -132,21 +128,15 @@ impl ContinuousBatchArbiter { .with_n_threads(n_threads) .with_n_threads_batch(n_threads_batch) .with_pooling_type( - inference_parameters - .pooling_type - .clone() + AgentPoolingType(inference_parameters.pooling_type.clone()) .to_llama_pooling_type(), ) .with_type_k( - inference_parameters - .k_cache_dtype - .clone() + AgentKvCacheDtype(inference_parameters.k_cache_dtype.clone()) .to_llama_kv_cache_dtype(), ) .with_type_v( - inference_parameters - .v_cache_dtype - .clone() + AgentKvCacheDtype(inference_parameters.v_cache_dtype.clone()) .to_llama_kv_cache_dtype(), ); @@ -323,17 +313,13 @@ impl ContinuousBatchArbiter { { Ok(context) => context, Err(err) => { - for slot_index in 0..desired_slots_total { - #[expect( - clippy::cast_sign_loss, - reason = "slot_index is always non-negative" - )] + for slot_index in 0..n_seq_max { slot_aggregated_status_manager .slot_aggregated_status .register_issue(AgentIssue::SlotCannotStart( SlotCannotStartParams { error: format!("{err:#}"), - slot_index: slot_index as u32, + slot_index, }, )); } @@ -425,15 +411,17 @@ impl ContinuousBatchArbiter { "Scheduler thread did not signal agent-warm-and-scheduler-running before exiting", )?; - for slot_index in 0..self.desired_slots_total { + let desired_slots_total_u32 = u32::try_from(self.desired_slots_total) + .context("desired_slots_total does not fit in u32")?; + + for slot_index in 0..desired_slots_total_u32 { self.slot_aggregated_status_manager .slot_aggregated_status .increment_total_slots(); - #[expect(clippy::cast_sign_loss, reason = "slot_index is always non-negative")] self.slot_aggregated_status_manager .slot_aggregated_status - .register_fix(&AgentIssueFix::SlotStarted(slot_index as u32)); + .register_fix(&AgentIssueFix::SlotStarted(slot_index)); } Ok(ContinuousBatchArbiterHandle { diff --git a/paddler/src/agent/continuous_batch_arbiter_build_outcome.rs b/paddler_agent/src/continuous_batch_arbiter_build_outcome.rs similarity index 63% rename from paddler/src/agent/continuous_batch_arbiter_build_outcome.rs rename to paddler_agent/src/continuous_batch_arbiter_build_outcome.rs index 1abd91ba..8e32f4be 100644 --- a/paddler/src/agent/continuous_batch_arbiter_build_outcome.rs +++ b/paddler_agent/src/continuous_batch_arbiter_build_outcome.rs @@ -1,4 +1,4 @@ -use crate::agent::continuous_batch_arbiter::ContinuousBatchArbiter; +use crate::continuous_batch_arbiter::ContinuousBatchArbiter; pub enum ContinuousBatchArbiterBuildOutcome { NoModelConfigured, diff --git a/paddler_agent/src/continuous_batch_arbiter_handle.rs b/paddler_agent/src/continuous_batch_arbiter_handle.rs new file mode 100644 index 00000000..9ea9a994 --- /dev/null +++ b/paddler_agent/src/continuous_batch_arbiter_handle.rs @@ -0,0 +1,95 @@ +use std::sync::mpsc::SendError; +use std::sync::mpsc::Sender; +use std::thread; + +use anyhow::Result; +use anyhow::anyhow; + +use crate::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; + +pub struct ContinuousBatchArbiterHandle { + pub command_tx: Sender, + pub scheduler_thread_handle: thread::JoinHandle>, +} + +impl ContinuousBatchArbiterHandle { + pub fn shutdown(self) -> Result<()> { + if let Err(SendError(_unsent_command)) = self + .command_tx + .send(ContinuousBatchSchedulerCommand::Shutdown) + { + // Scheduler thread already dropped its receiver; join below is authoritative. + } + + self.scheduler_thread_handle + .join() + .map_err(|err| anyhow!("Failed to join scheduler thread: {err:?}"))? + } +} + +#[cfg(test)] +mod tests { + use std::mem::discriminant; + use std::sync::mpsc::channel; + use std::thread; + + use super::ContinuousBatchArbiterHandle; + use crate::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; + + #[test] + fn shutdown_sends_command_and_joins_successful_thread() { + let (command_tx, command_rx) = channel::(); + let scheduler_thread_handle = thread::spawn(move || { + let received_command = command_rx.recv().unwrap(); + + assert_eq!( + discriminant(&received_command), + discriminant(&ContinuousBatchSchedulerCommand::Shutdown) + ); + + Ok(()) + }); + let handle = ContinuousBatchArbiterHandle { + command_tx, + scheduler_thread_handle, + }; + + handle.shutdown().unwrap(); + } + + #[test] + fn shutdown_tolerates_dropped_receiver_and_joins() { + let (command_tx, command_rx) = channel::(); + let scheduler_thread_handle = thread::spawn(|| Ok(())); + + drop(command_rx); + + let handle = ContinuousBatchArbiterHandle { + command_tx, + scheduler_thread_handle, + }; + + handle.shutdown().unwrap(); + } + + #[test] + fn shutdown_reports_error_when_thread_panics() { + let (command_tx, command_rx) = channel::(); + let scheduler_thread_handle = thread::spawn(move || { + drop(command_rx.recv()); + + panic!("scheduler thread crashed"); + }); + let handle = ContinuousBatchArbiterHandle { + command_tx, + scheduler_thread_handle, + }; + + let shutdown_error = handle.shutdown().err().unwrap(); + + assert_eq!( + shutdown_error.to_string(), + "Failed to join scheduler thread: Any { .. }" + ); + } +} diff --git a/paddler/src/agent/continuous_batch_embedding_processor.rs b/paddler_agent/src/continuous_batch_embedding_processor.rs similarity index 76% rename from paddler/src/agent/continuous_batch_embedding_processor.rs rename to paddler_agent/src/continuous_batch_embedding_processor.rs index 9952047d..8c152903 100644 --- a/paddler/src/agent/continuous_batch_embedding_processor.rs +++ b/paddler_agent/src/continuous_batch_embedding_processor.rs @@ -7,17 +7,18 @@ use llama_cpp_bindings::context::LlamaContext; use llama_cpp_bindings::llama_batch::LlamaBatch; use llama_cpp_bindings::model::AddBos; use log::warn; -use paddler_types::embedding::Embedding; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; -use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_messaging::embedding::Embedding; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::embedding_result::EmbeddingResult; +use paddler_messaging::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; use tokio::sync::mpsc; -use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; -use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use crate::agent::plan_embedding_batches::plan_embedding_batches; +use crate::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; use crate::embedding_input_tokenized::EmbeddingInputTokenized; +use crate::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use crate::normalization::normalize_embedding::normalize_embedding; +use crate::plan_embedding_batches::plan_embedding_batches; pub struct ContinuousBatchEmbeddingProcessor<'context> { llama_context: &'context mut LlamaContext<'static>, @@ -48,11 +49,8 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { slot_guard, }: GenerateEmbeddingBatchRequest, ) -> Result<()> { - #[expect( - unused_variables, - reason = "slot_guard is held until function returns to release the slot via Drop" - )] - let slot_guard = slot_guard; + // Held until this function returns so the slot is released via `Drop`. + let _slot_guard = slot_guard; if !self .scheduler_context @@ -88,13 +86,10 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { let mut tokens_lines_list_within_batch: Vec = Vec::new(); for input in tokens_lines_list { if input.tokens.len() > n_batch { - #[expect( - clippy::cast_possible_truncation, - reason = "document token counts and n_batch are model-bounded and fit in u32" - )] let details = OversizedEmbeddingDocumentDetails { - document_tokens: input.tokens.len() as u32, - n_batch: n_batch as u32, + document_tokens: u32::try_from(input.tokens.len()) + .context("document token count does not fit in u32")?, + n_batch: u32::try_from(n_batch).context("n_batch does not fit in u32")?, source_document_id: input.id.clone(), }; @@ -122,11 +117,6 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { let mut embeddings_emitted: usize = 0; - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "sequence index within a planned batch is bounded by max_sequences_per_batch which fits in i32" - )] for planned_batch in planned_batches { if generate_embedding_stop_rx.try_recv().is_ok() { break; @@ -138,7 +128,11 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { .collect(); for (sequence_index, input) in batch_inputs.iter().enumerate() { - batch.add_sequence(&input.tokens, sequence_index as i32, true)?; + batch.add_sequence( + &input.tokens, + i32::try_from(sequence_index).context("sequence index does not fit in i32")?, + true, + )?; } self.embedding_batch_decode( @@ -170,18 +164,15 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { self.llama_context.clear_kv_cache(); self.llama_context.decode(batch)?; - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "embedding sequence index fits in i32 for llama.cpp FFI" - )] for (index, embedding_input_tokenized) in current_batch_embeddings.iter().enumerate() { let embedding = self .llama_context - .embeddings_seq_ith(index as i32) + .embeddings_seq_ith( + i32::try_from(index).context("embedding sequence index does not fit in i32")?, + ) .context("Failed to get embeddings")?; - generated_embedding_tx.send(EmbeddingResult::Embedding( + generated_embedding_tx.send(EmbeddingResult::Embedding(normalize_embedding( Embedding { embedding: embedding.to_vec(), normalization_method: EmbeddingNormalizationMethod::None, @@ -191,9 +182,9 @@ impl<'context> ContinuousBatchEmbeddingProcessor<'context> { .pooling_type .clone(), source_document_id: embedding_input_tokenized.id.clone(), - } - .normalize(normalization_method)?, - ))?; + }, + normalization_method, + )?))?; } batch.clear(); diff --git a/paddler/src/agent/continuous_batch_request_phase.rs b/paddler_agent/src/continuous_batch_request_phase.rs similarity index 100% rename from paddler/src/agent/continuous_batch_request_phase.rs rename to paddler_agent/src/continuous_batch_request_phase.rs diff --git a/paddler_agent/src/continuous_batch_request_state.rs b/paddler_agent/src/continuous_batch_request_state.rs new file mode 100644 index 00000000..81d08c52 --- /dev/null +++ b/paddler_agent/src/continuous_batch_request_state.rs @@ -0,0 +1,168 @@ +use anyhow::Context as _; +use anyhow::Result; +use llama_cpp_bindings::SampledToken; +use llama_cpp_bindings::token::LlamaToken; + +use crate::continuous_batch_request_phase::ContinuousBatchRequestPhase; + +pub struct ContinuousBatchRequestState { + pub current_token_position: i32, + pub i_batch: Option, + pub max_tokens: i32, + pub pending_sampled_token: Option, + pub phase: ContinuousBatchRequestPhase, + pub prompt_tokens: Vec, + pub prompt_tokens_ingested: usize, + pub sequence_id: i32, +} + +impl ContinuousBatchRequestState { + #[must_use] + pub fn remaining_prompt_tokens(&self) -> &[LlamaToken] { + &self.prompt_tokens[self.prompt_tokens_ingested..] + } + + pub const fn apply_generating_contribution(&mut self, batch_position: i32) { + self.pending_sampled_token = None; + self.i_batch = Some(batch_position); + self.current_token_position += 1; + } + + pub fn apply_ingesting_contribution( + &mut self, + chunk_size: usize, + is_last_chunk: bool, + last_batch_position: i32, + ) -> Result<()> { + self.prompt_tokens_ingested += chunk_size; + self.current_token_position += + i32::try_from(chunk_size).context("chunk size does not fit in i32")?; + + if is_last_chunk { + self.i_batch = Some(last_batch_position); + self.phase = ContinuousBatchRequestPhase::Generating; + } + + Ok(()) + } + + pub const fn store_pending_token(&mut self, token: SampledToken) { + self.pending_sampled_token = Some(token); + } + + pub const fn mark_completed(&mut self) { + self.i_batch = None; + self.phase = ContinuousBatchRequestPhase::Completed; + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings::SampledToken; + use llama_cpp_bindings::token::LlamaToken; + + use super::ContinuousBatchRequestState; + use crate::continuous_batch_request_phase::ContinuousBatchRequestPhase; + + fn ingesting_state(prompt_token_count: usize) -> ContinuousBatchRequestState { + ContinuousBatchRequestState { + current_token_position: 0, + i_batch: None, + max_tokens: 64, + pending_sampled_token: None, + phase: ContinuousBatchRequestPhase::Ingesting, + prompt_tokens: vec![LlamaToken::new(1); prompt_token_count], + prompt_tokens_ingested: 0, + sequence_id: 0, + } + } + + #[test] + fn remaining_prompt_tokens_skips_already_ingested_tokens() { + let mut state = ingesting_state(5); + state.prompt_tokens_ingested = 2; + + assert_eq!(state.remaining_prompt_tokens().len(), 3); + } + + #[test] + fn applying_a_generating_contribution_clears_pending_and_advances_position() { + let mut state = ingesting_state(0); + state.current_token_position = 7; + state.pending_sampled_token = Some(SampledToken::Content(LlamaToken::new(9))); + + state.apply_generating_contribution(3); + + assert!(state.pending_sampled_token.is_none()); + assert_eq!(state.i_batch, Some(3)); + assert_eq!(state.current_token_position, 8); + } + + #[test] + fn applying_a_non_final_ingesting_chunk_advances_without_transitioning() { + let mut state = ingesting_state(10); + + state.apply_ingesting_contribution(4, false, 99).unwrap(); + + assert_eq!(state.prompt_tokens_ingested, 4); + assert_eq!(state.current_token_position, 4); + assert_eq!(state.i_batch, None); + assert!(matches!( + state.phase, + ContinuousBatchRequestPhase::Ingesting + )); + } + + #[test] + fn applying_the_final_ingesting_chunk_transitions_to_generating() { + let mut state = ingesting_state(6); + state.prompt_tokens_ingested = 4; + state.current_token_position = 4; + + state.apply_ingesting_contribution(2, true, 41).unwrap(); + + assert_eq!(state.prompt_tokens_ingested, 6); + assert_eq!(state.current_token_position, 6); + assert_eq!(state.i_batch, Some(41)); + assert!(matches!( + state.phase, + ContinuousBatchRequestPhase::Generating + )); + } + + #[test] + fn applying_an_ingesting_chunk_too_large_for_i32_is_an_error() { + let mut state = ingesting_state(0); + + let result = state.apply_ingesting_contribution(usize::MAX, false, 0); + + assert!(result.is_err()); + } + + #[test] + fn storing_a_pending_token_records_it() { + let mut state = ingesting_state(0); + + state.store_pending_token(SampledToken::Content(LlamaToken::new(5))); + + assert!(matches!( + state.pending_sampled_token, + Some(SampledToken::Content(token)) if token == LlamaToken::new(5) + )); + } + + #[test] + fn marking_completed_clears_batch_index_and_sets_completed_phase() { + let mut state = ingesting_state(0); + state.i_batch = Some(2); + state.phase = ContinuousBatchRequestPhase::Generating; + + state.mark_completed(); + + assert_eq!(state.i_batch, None); + assert!(matches!( + state.phase, + ContinuousBatchRequestPhase::Completed + )); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs b/paddler_agent/src/continuous_batch_scheduler/advance_generating_phase.rs similarity index 78% rename from paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/advance_generating_phase.rs index 32c46840..a67b9aa3 100644 --- a/paddler/src/agent/continuous_batch_scheduler/advance_generating_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/advance_generating_phase.rs @@ -2,21 +2,21 @@ use llama_cpp_bindings::SampledToken; use llama_cpp_bindings::context::LlamaContext; use log::error; use log::warn; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::generation_summary::GenerationSummary; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; -use crate::agent::continuous_batch_scheduler::advance_outcome::AdvanceOutcome; -use crate::agent::continuous_batch_scheduler::classify_token_phase; -use crate::agent::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; -use crate::agent::continuous_batch_scheduler::completion_check_phase::CompletionCheckPhase; -use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; -use crate::agent::continuous_batch_scheduler::emit_token_phase; -use crate::agent::continuous_batch_scheduler::sample_outcome::SampleOutcome; -use crate::agent::continuous_batch_scheduler::sample_token_phase::SampleTokenPhase; -use crate::agent::continuous_batch_scheduler::tool_call_pass; -use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::continuous_batch_scheduler::advance_outcome::AdvanceOutcome; +use crate::continuous_batch_scheduler::classify_token_phase; +use crate::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; +use crate::continuous_batch_scheduler::completion_check_phase::CompletionCheckPhase; +use crate::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; +use crate::continuous_batch_scheduler::emit_token_phase; +use crate::continuous_batch_scheduler::sample_outcome::SampleOutcome; +use crate::continuous_batch_scheduler::sample_token_phase::SampleTokenPhase; +use crate::continuous_batch_scheduler::tool_call_pass; +use crate::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; pub struct AdvanceGeneratingPhase<'context> { pub scheduler_context: &'context ContinuousBatchSchedulerContext, @@ -32,15 +32,15 @@ impl AdvanceGeneratingPhase<'_> { } fn advance_one(&self, request: &mut ContinuousBatchActiveRequest) -> Option { - if !matches!(request.phase, ContinuousBatchRequestPhase::Generating) { + if !matches!(request.state.phase, ContinuousBatchRequestPhase::Generating) { return None; } - if request.pending_sampled_token.is_some() { + if request.state.pending_sampled_token.is_some() { return None; } - let batch_index = request.i_batch?; + let batch_index = request.state.i_batch?; let raw_token = match (SampleTokenPhase { context: self.llama_context, @@ -51,7 +51,7 @@ impl AdvanceGeneratingPhase<'_> { SampleOutcome::AllCandidatesEliminated => { error!( "{:?}: sequence {} sampling exhausted candidates", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::Completed( GeneratedTokenResult::SamplerError( @@ -62,7 +62,7 @@ impl AdvanceGeneratingPhase<'_> { SampleOutcome::GrammarRejected(message) => { error!( "{:?}: sequence {} grammar rejected sampled token: {message}", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::Completed( GeneratedTokenResult::GrammarRejectedModelOutput(message), @@ -71,7 +71,7 @@ impl AdvanceGeneratingPhase<'_> { SampleOutcome::Failed(message) => { error!( "{:?}: sequence {} sampling error: {message}", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::Completed( GeneratedTokenResult::SamplerError(message), @@ -97,7 +97,7 @@ impl AdvanceGeneratingPhase<'_> { { warn!( "{:?}: sequence {} client disconnected (receiver dropped) during EOG tool-call flush", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::ChannelDropped); } @@ -114,7 +114,7 @@ impl AdvanceGeneratingPhase<'_> { EmitTokenOutcome::ChannelDropped => { warn!( "{:?}: sequence {} client disconnected (receiver dropped)", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::ChannelDropped); } @@ -126,7 +126,7 @@ impl AdvanceGeneratingPhase<'_> { { warn!( "{:?}: sequence {} client disconnected (receiver dropped)", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::ChannelDropped); } @@ -141,7 +141,7 @@ impl AdvanceGeneratingPhase<'_> { { warn!( "{:?}: sequence {} client disconnected (receiver dropped) during tool-call EOG flush", - self.scheduler_context.agent_name, request.sequence_id + self.scheduler_context.agent_name, request.state.sequence_id ); return Some(AdvanceOutcome::ChannelDropped); } @@ -165,14 +165,13 @@ impl AdvanceGeneratingPhase<'_> { match outcome { None => {} Some(AdvanceOutcome::SampledAndStored(token)) => { - request.pending_sampled_token = Some(token); + request.state.store_pending_token(token); } Some(AdvanceOutcome::Completed(event)) => { - request.complete_with_outcome(&self.scheduler_context.agent_name, event); + request.complete_with_outcome(self.scheduler_context.agent_name.as_deref(), event); } Some(AdvanceOutcome::ChannelDropped) => { - request.i_batch = None; - request.phase = ContinuousBatchRequestPhase::Completed; + request.state.mark_completed(); } } } diff --git a/paddler_agent/src/continuous_batch_scheduler/advance_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/advance_outcome.rs new file mode 100644 index 00000000..88775049 --- /dev/null +++ b/paddler_agent/src/continuous_batch_scheduler/advance_outcome.rs @@ -0,0 +1,77 @@ +use llama_cpp_bindings::SampledToken; +use paddler_messaging::generated_token_result::GeneratedTokenResult; + +pub enum AdvanceOutcome { + SampledAndStored(SampledToken), + Completed(GeneratedTokenResult), + ChannelDropped, +} + +#[cfg(test)] +mod tests { + use std::mem::discriminant; + + use llama_cpp_bindings::SampledToken; + use llama_cpp_bindings::token::LlamaToken; + + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::generation_summary::GenerationSummary; + + use super::AdvanceOutcome; + + #[test] + fn sampled_and_stored_is_distinct_from_the_other_variants() { + let sampled_and_stored = + AdvanceOutcome::SampledAndStored(SampledToken::Content(LlamaToken::new(7))); + + assert_eq!( + discriminant(&sampled_and_stored), + discriminant(&AdvanceOutcome::SampledAndStored(SampledToken::Reasoning( + LlamaToken::new(0) + ))) + ); + assert_ne!( + discriminant(&sampled_and_stored), + discriminant(&AdvanceOutcome::Completed(GeneratedTokenResult::Done( + GenerationSummary::default() + ))) + ); + assert_ne!( + discriminant(&sampled_and_stored), + discriminant(&AdvanceOutcome::ChannelDropped) + ); + } + + #[test] + fn completed_is_distinct_from_the_other_variants() { + let completed = + AdvanceOutcome::Completed(GeneratedTokenResult::Done(GenerationSummary::default())); + + assert_eq!( + discriminant(&completed), + discriminant(&AdvanceOutcome::Completed( + GeneratedTokenResult::ContentToken("next".to_owned()) + )) + ); + assert_ne!( + discriminant(&completed), + discriminant(&AdvanceOutcome::ChannelDropped) + ); + } + + #[test] + fn channel_dropped_is_distinct_from_the_other_variants() { + let channel_dropped = AdvanceOutcome::ChannelDropped; + + assert_eq!( + discriminant(&channel_dropped), + discriminant(&AdvanceOutcome::ChannelDropped) + ); + assert_ne!( + discriminant(&channel_dropped), + discriminant(&AdvanceOutcome::SampledAndStored(SampledToken::ToolCall( + LlamaToken::new(0) + ))) + ); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs b/paddler_agent/src/continuous_batch_scheduler/assemble_batch_phase.rs similarity index 65% rename from paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/assemble_batch_phase.rs index 3255d18e..33c70b5b 100644 --- a/paddler/src/agent/continuous_batch_scheduler/assemble_batch_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/assemble_batch_phase.rs @@ -1,11 +1,12 @@ +use anyhow::Context as _; use anyhow::Result; use llama_cpp_bindings::SampledToken; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; -use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; -use crate::agent::continuous_batch_scheduler::generating_contribution::GeneratingContribution; -use crate::agent::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::continuous_batch_scheduler::batch_pass::BatchPass; +use crate::continuous_batch_scheduler::generating_contribution::GeneratingContribution; +use crate::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; pub struct AssembleBatchPhase { pub n_batch: usize, @@ -33,11 +34,11 @@ impl AssembleBatchPhase { let mut tokens_added: usize = 0; for (request_index, request) in requests.iter().enumerate() { - if !matches!(request.phase, ContinuousBatchRequestPhase::Generating) { + if !matches!(request.state.phase, ContinuousBatchRequestPhase::Generating) { continue; } - let Some(pending_token) = request.pending_sampled_token else { + let Some(pending_token) = request.state.pending_sampled_token else { continue; }; @@ -49,8 +50,8 @@ impl AssembleBatchPhase { pass.batch.add( &pending_token, - request.current_token_position, - &[request.sequence_id], + request.state.current_token_position, + &[request.state.sequence_id], true, )?; @@ -65,22 +66,17 @@ impl AssembleBatchPhase { Ok(tokens_added) } - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "token counts and positions fit in i32 for llama.cpp FFI" - )] fn fill_ingesting( &self, pass: &mut BatchPass, requests: &[ContinuousBatchActiveRequest], ) -> Result<()> { for (request_index, request) in requests.iter().enumerate() { - if !matches!(request.phase, ContinuousBatchRequestPhase::Ingesting) { + if !matches!(request.state.phase, ContinuousBatchRequestPhase::Ingesting) { continue; } - let remaining = request.remaining_prompt_tokens(); + let remaining = request.state.remaining_prompt_tokens(); let chunk_size = compute_ingesting_chunk_size( remaining.len(), self.n_batch, @@ -91,19 +87,20 @@ impl AssembleBatchPhase { continue; } - let chunk = &request.prompt_tokens - [request.prompt_tokens_ingested..request.prompt_tokens_ingested + chunk_size]; - let is_last_chunk = - request.prompt_tokens_ingested + chunk_size >= request.prompt_tokens.len(); + let chunk = &request.state.prompt_tokens[request.state.prompt_tokens_ingested + ..request.state.prompt_tokens_ingested + chunk_size]; + let is_last_chunk = request.state.prompt_tokens_ingested + chunk_size + >= request.state.prompt_tokens.len(); for (offset, token) in chunk.iter().enumerate() { - let position = request.current_token_position + offset as i32; + let position = request.state.current_token_position + + i32::try_from(offset).context("token offset does not fit in i32")?; let is_last_token_of_prompt = is_last_chunk && offset == chunk_size - 1; pass.batch.add( &SampledToken::Content(*token), position, - &[request.sequence_id], + &[request.state.sequence_id], is_last_token_of_prompt, )?; } @@ -133,7 +130,23 @@ fn compute_ingesting_chunk_size( #[cfg(test)] mod tests { + use super::AssembleBatchPhase; use super::compute_ingesting_chunk_size; + use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; + use crate::continuous_batch_scheduler::batch_pass::BatchPass; + + #[test] + fn run_over_empty_requests_leaves_batch_untouched() { + let assemble_phase = AssembleBatchPhase { n_batch: 16 }; + let mut pass = BatchPass::new(16, 1).unwrap(); + let mut requests: [ContinuousBatchActiveRequest; 0] = []; + + assemble_phase.run(&mut pass, &mut requests).unwrap(); + + assert_eq!(pass.contributions.current_batch_token_count, 0); + assert_eq!(pass.batch.n_tokens(), 0); + assert!(pass.is_empty()); + } #[test] fn chunk_size_is_min_of_remaining_and_available_space() { diff --git a/paddler/src/agent/continuous_batch_scheduler/batch_pass.rs b/paddler_agent/src/continuous_batch_scheduler/batch_pass.rs similarity index 51% rename from paddler/src/agent/continuous_batch_scheduler/batch_pass.rs rename to paddler_agent/src/continuous_batch_scheduler/batch_pass.rs index e5bfd0cb..31fe50c1 100644 --- a/paddler/src/agent/continuous_batch_scheduler/batch_pass.rs +++ b/paddler_agent/src/continuous_batch_scheduler/batch_pass.rs @@ -1,7 +1,7 @@ use anyhow::Result; use llama_cpp_bindings::llama_batch::LlamaBatch; -use crate::agent::continuous_batch_scheduler::contributions::Contributions; +use crate::continuous_batch_scheduler::contributions::Contributions; pub struct BatchPass<'tokens> { pub batch: LlamaBatch<'tokens>, @@ -23,3 +23,25 @@ impl BatchPass<'_> { self.contributions.is_empty() } } + +#[cfg(test)] +mod tests { + use super::BatchPass; + + #[test] + fn new_creates_empty_batch_pass() { + let batch_pass = BatchPass::new(16, 1).unwrap(); + + assert_eq!(batch_pass.batch.n_tokens(), 0); + assert!(batch_pass.is_empty()); + } + + #[test] + fn new_forwards_llama_batch_error_for_oversized_n_batch() { + let result = BatchPass::new(usize::MAX, 1); + + let error = result.err().unwrap(); + + assert!(error.to_string().contains("overflow")); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/classified_token.rs b/paddler_agent/src/continuous_batch_scheduler/classified_token.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/classified_token.rs rename to paddler_agent/src/continuous_batch_scheduler/classified_token.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs b/paddler_agent/src/continuous_batch_scheduler/classify_token_phase.rs similarity index 97% rename from paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/classify_token_phase.rs index 39bc589f..148b112b 100644 --- a/paddler/src/agent/continuous_batch_scheduler/classify_token_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/classify_token_phase.rs @@ -3,8 +3,8 @@ use llama_cpp_bindings::sampled_token_classifier::IngestOutcome; use llama_cpp_bindings::sampled_token_classifier::SampledTokenSection; use llama_cpp_bindings::token::LlamaToken; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_scheduler::classified_token::ClassifiedToken; pub fn run( request: &mut ContinuousBatchActiveRequest, diff --git a/paddler_agent/src/continuous_batch_scheduler/commit_phase.rs b/paddler_agent/src/continuous_batch_scheduler/commit_phase.rs new file mode 100644 index 00000000..4d2d1e7a --- /dev/null +++ b/paddler_agent/src/continuous_batch_scheduler/commit_phase.rs @@ -0,0 +1,24 @@ +use anyhow::Result; + +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_scheduler::batch_pass::BatchPass; + +pub fn run(pass: BatchPass, requests: &mut [ContinuousBatchActiveRequest]) -> Result<()> { + for contribution in pass.contributions.generating { + requests[contribution.request_index] + .state + .apply_generating_contribution(contribution.batch_position); + } + + for contribution in pass.contributions.ingesting { + requests[contribution.request_index] + .state + .apply_ingesting_contribution( + contribution.chunk_size, + contribution.is_last_chunk, + contribution.last_batch_position, + )?; + } + + Ok(()) +} diff --git a/paddler/src/agent/continuous_batch_scheduler/completion_check_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/completion_check_outcome.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/completion_check_outcome.rs rename to paddler_agent/src/continuous_batch_scheduler/completion_check_outcome.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs b/paddler_agent/src/continuous_batch_scheduler/completion_check_phase.rs similarity index 56% rename from paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/completion_check_phase.rs index e2855e50..645acc7b 100644 --- a/paddler/src/agent/continuous_batch_scheduler/completion_check_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/completion_check_phase.rs @@ -2,8 +2,8 @@ use llama_cpp_bindings::SampledToken; use llama_cpp_bindings::TokenUsage; use llama_cpp_bindings::model::LlamaModel; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; pub struct CompletionCheckPhase<'model> { pub model: &'model LlamaModel, @@ -20,17 +20,19 @@ impl CompletionCheckPhase<'_> { return CompletionCheckOutcome::ReachedEog; } - #[expect( - clippy::cast_sign_loss, - reason = "max_tokens is non-negative by API contract" - )] - let max_tokens_u64 = request.max_tokens as u64; + max_tokens_outcome(request.state.max_tokens, request.token_classifier.usage()) + } +} - if completion_token_count(request.token_classifier.usage()) >= max_tokens_u64 { - CompletionCheckOutcome::ReachedMaxTokens - } else { - CompletionCheckOutcome::Continue - } +fn max_tokens_outcome(max_tokens: i32, usage: &TokenUsage) -> CompletionCheckOutcome { + let Ok(max_tokens_u64) = u64::try_from(max_tokens) else { + return CompletionCheckOutcome::ReachedMaxTokens; + }; + + if completion_token_count(usage) >= max_tokens_u64 { + CompletionCheckOutcome::ReachedMaxTokens + } else { + CompletionCheckOutcome::Continue } } @@ -43,6 +45,44 @@ mod tests { use llama_cpp_bindings::TokenUsage; use super::completion_token_count; + use super::max_tokens_outcome; + use crate::continuous_batch_scheduler::completion_check_outcome::CompletionCheckOutcome; + + #[test] + fn negative_max_tokens_reaches_max_tokens() { + let usage = TokenUsage::new(); + + assert!(matches!( + max_tokens_outcome(-1, &usage), + CompletionCheckOutcome::ReachedMaxTokens + )); + } + + #[test] + fn reaching_the_max_token_budget_reports_reached_max_tokens() { + let usage = TokenUsage { + content_tokens: 4, + ..TokenUsage::new() + }; + + assert!(matches!( + max_tokens_outcome(4, &usage), + CompletionCheckOutcome::ReachedMaxTokens + )); + } + + #[test] + fn staying_under_the_max_token_budget_continues() { + let usage = TokenUsage { + content_tokens: 3, + ..TokenUsage::new() + }; + + assert!(matches!( + max_tokens_outcome(4, &usage), + CompletionCheckOutcome::Continue + )); + } #[test] fn completion_token_count_sums_content_reasoning_and_undeterminable() { diff --git a/paddler/src/agent/continuous_batch_scheduler/contributions.rs b/paddler_agent/src/continuous_batch_scheduler/contributions.rs similarity index 88% rename from paddler/src/agent/continuous_batch_scheduler/contributions.rs rename to paddler_agent/src/continuous_batch_scheduler/contributions.rs index 4ce8805b..bb9f6067 100644 --- a/paddler/src/agent/continuous_batch_scheduler/contributions.rs +++ b/paddler_agent/src/continuous_batch_scheduler/contributions.rs @@ -1,5 +1,5 @@ -use crate::agent::continuous_batch_scheduler::generating_contribution::GeneratingContribution; -use crate::agent::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; +use crate::continuous_batch_scheduler::generating_contribution::GeneratingContribution; +use crate::continuous_batch_scheduler::ingesting_contribution::IngestingContribution; #[derive(Default)] pub struct Contributions { diff --git a/paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs b/paddler_agent/src/continuous_batch_scheduler/decode_batch_phase.rs similarity index 58% rename from paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/decode_batch_phase.rs index 10214fe8..65c7ee73 100644 --- a/paddler/src/agent/continuous_batch_scheduler/decode_batch_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/decode_batch_phase.rs @@ -1,7 +1,7 @@ use llama_cpp_bindings::context::LlamaContext; -use crate::agent::continuous_batch_scheduler::batch_pass::BatchPass; -use crate::agent::continuous_batch_scheduler::decode_outcome::DecodeOutcome; +use crate::continuous_batch_scheduler::batch_pass::BatchPass; +use crate::continuous_batch_scheduler::decode_outcome::DecodeOutcome; pub fn run(pass: &mut BatchPass, context: &mut LlamaContext) -> DecodeOutcome { DecodeOutcome::from_decode_result(context.decode(&mut pass.batch)) diff --git a/paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/decode_outcome.rs similarity index 60% rename from paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs rename to paddler_agent/src/continuous_batch_scheduler/decode_outcome.rs index edb40528..d414e4d8 100644 --- a/paddler/src/agent/continuous_batch_scheduler/decode_outcome.rs +++ b/paddler_agent/src/continuous_batch_scheduler/decode_outcome.rs @@ -22,40 +22,48 @@ impl DecodeOutcome { #[cfg(test)] mod tests { + use std::mem::discriminant; + use llama_cpp_bindings::error::DecodeError; use super::DecodeOutcome; #[test] fn ok_maps_to_decoded() { - assert!(matches!( - DecodeOutcome::from_decode_result(Ok(())), - DecodeOutcome::Decoded - )); + assert_eq!( + discriminant(&DecodeOutcome::from_decode_result(Ok(()))), + discriminant(&DecodeOutcome::Decoded) + ); } #[test] fn no_kv_cache_slot_maps_to_needs_eviction() { - assert!(matches!( - DecodeOutcome::from_decode_result(Err(DecodeError::NoKvCacheSlot)), - DecodeOutcome::NeedsEviction - )); + assert_eq!( + discriminant(&DecodeOutcome::from_decode_result(Err( + DecodeError::NoKvCacheSlot + ))), + discriminant(&DecodeOutcome::NeedsEviction) + ); } #[test] fn aborted_maps_to_aborted() { - assert!(matches!( - DecodeOutcome::from_decode_result(Err(DecodeError::Aborted)), - DecodeOutcome::Aborted - )); + assert_eq!( + discriminant(&DecodeOutcome::from_decode_result(Err( + DecodeError::Aborted + ))), + discriminant(&DecodeOutcome::Aborted) + ); } #[test] fn batch_invalid_maps_to_aborted() { - assert!(matches!( - DecodeOutcome::from_decode_result(Err(DecodeError::BatchInvalid)), - DecodeOutcome::Aborted - )); + assert_eq!( + discriminant(&DecodeOutcome::from_decode_result(Err( + DecodeError::BatchInvalid + ))), + discriminant(&DecodeOutcome::Aborted) + ); } #[test] diff --git a/paddler/src/agent/continuous_batch_scheduler/emit_token_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/emit_token_outcome.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/emit_token_outcome.rs rename to paddler_agent/src/continuous_batch_scheduler/emit_token_outcome.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs b/paddler_agent/src/continuous_batch_scheduler/emit_token_phase.rs similarity index 55% rename from paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/emit_token_phase.rs index a1af128d..601ac3be 100644 --- a/paddler/src/agent/continuous_batch_scheduler/emit_token_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/emit_token_phase.rs @@ -1,10 +1,10 @@ use llama_cpp_bindings::SampledToken; -use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generated_token_result::GeneratedTokenResult; use tokio::sync::mpsc; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; -use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_scheduler::classified_token::ClassifiedToken; +use crate::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; pub fn run( request: &mut ContinuousBatchActiveRequest, @@ -42,16 +42,16 @@ const fn token_to_event(sampled_token: SampledToken, piece: String) -> Generated #[cfg(test)] mod tests { - use anyhow::Result; - use anyhow::bail; + use std::mem::discriminant; + use llama_cpp_bindings::SampledToken; use llama_cpp_bindings::token::LlamaToken; - use paddler_types::generated_token_result::GeneratedTokenResult; use tokio::sync::mpsc; use super::emit_classified; - use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; - use crate::agent::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; + use crate::continuous_batch_scheduler::classified_token::ClassifiedToken; + use crate::continuous_batch_scheduler::emit_token_outcome::EmitTokenOutcome; + use paddler_messaging::generated_token_result::GeneratedTokenResult; fn classified_with_piece(sampled: SampledToken, piece: &str) -> ClassifiedToken { ClassifiedToken { @@ -64,86 +64,107 @@ mod tests { } #[test] - fn empty_visible_piece_emits_empty_string_without_sending() -> Result<()> { + fn empty_visible_piece_emits_empty_string_without_sending() { let (tx, mut rx) = mpsc::unbounded_channel::(); let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(1)), ""); - match emit_classified(&classified, &tx) { - EmitTokenOutcome::Emitted(piece) if piece.is_empty() => {} - other => bail!("expected Emitted(\"\"), got {other:?}"), - } + let outcome = emit_classified(&classified, &tx); - match rx.try_recv() { - Err(mpsc::error::TryRecvError::Empty) => Ok(()), - other => bail!("expected empty channel, got {other:?}"), - } + assert_eq!( + discriminant(&outcome), + discriminant(&EmitTokenOutcome::Emitted(String::new())), + ); + + let receive_error = rx.try_recv().err().unwrap(); + + assert_eq!( + discriminant(&receive_error), + discriminant(&mpsc::error::TryRecvError::Empty), + ); } #[test] - fn content_token_emits_content_event() -> Result<()> { + fn content_token_emits_content_event() { let (tx, mut rx) = mpsc::unbounded_channel::(); let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(2)), "hi"); - emit_classified(&classified, &tx); + let outcome = emit_classified(&classified, &tx); - match rx.try_recv() { - Ok(GeneratedTokenResult::ContentToken(text)) if text == "hi" => Ok(()), - other => bail!("expected ContentToken(\"hi\"), got {other:?}"), - } + assert_eq!( + discriminant(&outcome), + discriminant(&EmitTokenOutcome::Emitted(String::new())), + ); + + let event = rx.try_recv().unwrap(); + + assert_eq!( + discriminant(&event), + discriminant(&GeneratedTokenResult::ContentToken(String::new())), + ); + assert_eq!(event.token_text().unwrap(), "hi"); } #[test] - fn reasoning_token_emits_reasoning_event() -> Result<()> { + fn reasoning_token_emits_reasoning_event() { let (tx, mut rx) = mpsc::unbounded_channel::(); let classified = classified_with_piece(SampledToken::Reasoning(LlamaToken::new(3)), "think"); emit_classified(&classified, &tx); - match rx.try_recv() { - Ok(GeneratedTokenResult::ReasoningToken(text)) if text == "think" => Ok(()), - other => bail!("expected ReasoningToken(\"think\"), got {other:?}"), - } + let event = rx.try_recv().unwrap(); + + assert_eq!( + discriminant(&event), + discriminant(&GeneratedTokenResult::ReasoningToken(String::new())), + ); + assert_eq!(event.token_text().unwrap(), "think"); } #[test] - fn tool_call_token_emits_tool_call_event() -> Result<()> { + fn tool_call_token_emits_tool_call_event() { let (tx, mut rx) = mpsc::unbounded_channel::(); let classified = classified_with_piece(SampledToken::ToolCall(LlamaToken::new(4)), "{"); emit_classified(&classified, &tx); - match rx.try_recv() { - Ok(GeneratedTokenResult::ToolCallToken(text)) if text == "{" => Ok(()), - other => bail!("expected ToolCallToken(\"{{\"), got {other:?}"), - } + let event = rx.try_recv().unwrap(); + + assert_eq!( + discriminant(&event), + discriminant(&GeneratedTokenResult::ToolCallToken(String::new())), + ); + assert_eq!(event.token_text().unwrap(), "{"); } #[test] - fn undeterminable_token_emits_undeterminable_event() -> Result<()> { + fn undeterminable_token_emits_undeterminable_event() { let (tx, mut rx) = mpsc::unbounded_channel::(); let classified = classified_with_piece(SampledToken::Undeterminable(LlamaToken::new(5)), "?"); emit_classified(&classified, &tx); - match rx.try_recv() { - Ok(GeneratedTokenResult::UndeterminableToken(text)) if text == "?" => Ok(()), - other => bail!("expected UndeterminableToken(\"?\"), got {other:?}"), - } + let event = rx.try_recv().unwrap(); + + assert_eq!( + discriminant(&event), + discriminant(&GeneratedTokenResult::UndeterminableToken(String::new())), + ); + assert_eq!(event.token_text().unwrap(), "?"); } #[test] - fn dropped_receiver_returns_channel_dropped() -> Result<()> { + fn dropped_receiver_returns_channel_dropped() { let (tx, rx) = mpsc::unbounded_channel::(); drop(rx); let classified = classified_with_piece(SampledToken::Content(LlamaToken::new(6)), "hi"); - match emit_classified(&classified, &tx) { - EmitTokenOutcome::ChannelDropped => Ok(()), - EmitTokenOutcome::Emitted(piece) => { - bail!("expected ChannelDropped on dropped receiver, got Emitted({piece:?})") - } - } + let outcome = emit_classified(&classified, &tx); + + assert_eq!( + discriminant(&outcome), + discriminant(&EmitTokenOutcome::ChannelDropped), + ); } } diff --git a/paddler/src/agent/continuous_batch_scheduler/generating_contribution.rs b/paddler_agent/src/continuous_batch_scheduler/generating_contribution.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/generating_contribution.rs rename to paddler_agent/src/continuous_batch_scheduler/generating_contribution.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/ingesting_contribution.rs b/paddler_agent/src/continuous_batch_scheduler/ingesting_contribution.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/ingesting_contribution.rs rename to paddler_agent/src/continuous_batch_scheduler/ingesting_contribution.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/mod.rs b/paddler_agent/src/continuous_batch_scheduler/mod.rs similarity index 75% rename from paddler/src/agent/continuous_batch_scheduler/mod.rs rename to paddler_agent/src/continuous_batch_scheduler/mod.rs index 28809909..9d89dcbe 100644 --- a/paddler/src/agent/continuous_batch_scheduler/mod.rs +++ b/paddler_agent/src/continuous_batch_scheduler/mod.rs @@ -40,13 +40,13 @@ use log::debug; use log::error; use log::info; use log::warn; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::generation_summary::GenerationSummary; -use paddler_types::oversized_image_details::OversizedImageDetails; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::embedding_result::EmbeddingResult; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::oversized_image_details::OversizedImageDetails; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use rand::Rng as _; use rand::rngs::ThreadRng; use tokio::sync::mpsc; @@ -56,26 +56,37 @@ use self::assemble_batch_phase::AssembleBatchPhase; use self::batch_pass::BatchPass; use self::decode_outcome::DecodeOutcome; use self::tool_call_pipeline_build_outcome::ToolCallPipelineBuildOutcome; -use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_embedding_processor::ContinuousBatchEmbeddingProcessor; -use crate::agent::continuous_batch_request_phase::ContinuousBatchRequestPhase; -use crate::agent::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; -use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; -use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use crate::agent::grammar_sampler::GrammarSampler; -use crate::agent::prepare_conversation_history_request::prepare_conversation_history_request; -use crate::agent::prepared_conversation_history_request::PreparedConversationHistoryRequest; -use crate::agent::resolve_grammar::resolve_grammar; -use crate::agent::sample_token_at_batch_index::sample_token_at_batch_index; -use crate::agent::sampling_outcome::SamplingOutcome; -use crate::agent::sequence_id_pool::SequenceIdPool; -use crate::agent::slot_guard::SlotGuard; +use crate::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use crate::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_embedding_processor::ContinuousBatchEmbeddingProcessor; +use crate::continuous_batch_request_phase::ContinuousBatchRequestPhase; +use crate::continuous_batch_request_state::ContinuousBatchRequestState; +use crate::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; +use crate::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; use crate::decoded_image::DecodedImage; +use crate::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use crate::grammar_sampler::GrammarSampler; +use crate::prepare_conversation_history_request::prepare_conversation_history_request; +use crate::prepared_conversation_history_request::PreparedConversationHistoryRequest; +use crate::resolve_grammar::resolve_grammar; +use crate::sample_token_at_batch_index::sample_token_at_batch_index; +use crate::sampling_outcome::SamplingOutcome; +use crate::sequence_id_pool::SequenceIdPool; +use crate::slot_guard::SlotGuard; use crate::tool_call_pipeline::ToolCallPipeline; use crate::tool_call_validator::ToolCallValidator; -use crate::tool_call_validator::ValidatorBuildError; +use crate::validator_build_error::ValidatorBuildError; + +fn send_generated_token_result_or_warn( + agent_name: Option<&str>, + generated_tokens_tx: &mpsc::UnboundedSender, + result: GeneratedTokenResult, +) { + if generated_tokens_tx.send(result).is_err() { + warn!("{agent_name:?}: failed to send result to client (receiver dropped)"); + } +} pub struct ContinuousBatchScheduler { active_requests: Vec, @@ -356,17 +367,11 @@ impl ContinuousBatchScheduler { error!("{message}"); - if generated_tokens_tx - .send(GeneratedTokenResult::GrammarInitializationFailed( - message.clone(), - )) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + generated_tokens_tx, + GeneratedTokenResult::GrammarInitializationFailed(message.clone()), + ); Err(anyhow!(message)) } @@ -428,7 +433,7 @@ impl ContinuousBatchScheduler { #[expect( clippy::too_many_arguments, - reason = "text prompt acceptance genuinely needs all these parameters from the caller" + reason = "these are distinct concerns (the prompt, the generation config, the output channel, the stop signal, the slot guard) that do not form a cohesive value object; bundling them would violate single-responsibility grouping" )] fn accept_text_prompt( &mut self, @@ -453,15 +458,11 @@ impl ContinuousBatchScheduler { self.scheduler_context.agent_name ); - if generated_tokens_tx - .send(GeneratedTokenResult::ToolSchemaInvalid(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::ToolSchemaInvalid(message), + ); return Ok(()); } @@ -482,15 +483,11 @@ impl ContinuousBatchScheduler { error!("{message}"); - if generated_tokens_tx - .send(GeneratedTokenResult::SamplerError(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError(message), + ); return Ok(()); }; @@ -510,15 +507,11 @@ impl ContinuousBatchScheduler { error!("{message}"); self.sequence_id_pool.release(sequence_id); - if generated_tokens_tx - .send(GeneratedTokenResult::SamplerError(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError(message), + ); return Ok(()); } @@ -539,19 +532,7 @@ impl ContinuousBatchScheduler { token_classifier.record_prompt_tokens(prompt_tokens.len() as u64); token_classifier.ingest_prompt_tokens(&prompt_tokens); - #[expect( - clippy::cast_sign_loss, - reason = "sequence IDs are always non-negative" - )] - if let Err(err) = - self.llama_context - .clear_kv_cache_seq(Some(sequence_id as u32), None, None) - { - error!( - "{:?}: failed to clear KV cache for sequence {sequence_id}: {err}", - self.scheduler_context.agent_name - ); - } + self.clear_kv_cache_for_sequence(sequence_id); debug!( "{:?}: accepted text prompt request on sequence {sequence_id} ({} tokens)", @@ -560,19 +541,21 @@ impl ContinuousBatchScheduler { ); self.active_requests.push(ContinuousBatchActiveRequest { + state: ContinuousBatchRequestState { + current_token_position: 0, + i_batch: None, + max_tokens, + pending_sampled_token: None, + phase: ContinuousBatchRequestPhase::Ingesting, + prompt_tokens, + prompt_tokens_ingested: 0, + sequence_id, + }, chain, token_classifier, - current_token_position: 0, grammar_sampler: llama_grammar_sampler, generated_tokens_tx, generate_tokens_stop_rx, - i_batch: None, - max_tokens, - pending_sampled_token: None, - phase: ContinuousBatchRequestPhase::Ingesting, - prompt_tokens, - prompt_tokens_ingested: 0, - sequence_id, slot_guard, tool_call_pipeline, }); @@ -582,7 +565,7 @@ impl ContinuousBatchScheduler { #[expect( clippy::too_many_arguments, - reason = "multimodal request handling genuinely requires all these parameters from the caller" + reason = "these are distinct concerns (the multimodal context, prompt, images, generation config, the output channel, the stop signal, the slot guard) that do not form a cohesive value object; bundling them would violate single-responsibility grouping" )] fn accept_multimodal_request( &mut self, @@ -609,15 +592,11 @@ impl ContinuousBatchScheduler { self.scheduler_context.agent_name ); - if generated_tokens_tx - .send(GeneratedTokenResult::ToolSchemaInvalid(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::ToolSchemaInvalid(message), + ); return Ok(()); } @@ -631,15 +610,11 @@ impl ContinuousBatchScheduler { error!("{message}"); - if generated_tokens_tx - .send(GeneratedTokenResult::SamplerError(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError(message), + ); return Ok(()); }; @@ -662,15 +637,11 @@ impl ContinuousBatchScheduler { error!("{message}"); self.sequence_id_pool.release(sequence_id); - if generated_tokens_tx - .send(GeneratedTokenResult::ImageDecodingFailed(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::ImageDecodingFailed(message), + ); return Ok(()); } @@ -698,15 +669,11 @@ impl ContinuousBatchScheduler { error!("{message}"); self.sequence_id_pool.release(sequence_id); - if generated_tokens_tx - .send(GeneratedTokenResult::SamplerError(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError(message), + ); return Ok(()); } @@ -714,36 +681,21 @@ impl ContinuousBatchScheduler { let batch_size = self.scheduler_context.inference_parameters.n_batch; - #[expect( - clippy::cast_sign_loss, - reason = "sequence IDs are always non-negative" - )] - if let Err(err) = - self.llama_context - .clear_kv_cache_seq(Some(sequence_id as u32), None, None) - { - error!( - "{:?}: failed to clear KV cache for sequence {sequence_id}: {err}", - self.scheduler_context.agent_name - ); - } + self.clear_kv_cache_for_sequence(sequence_id); self.harvest_pending_samples_before_external_decode(); let mut token_classifier = self.build_token_classifier_for_active_request(); - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "batch_size fits in i32 for llama.cpp FFI" - )] + let batch_size_i32 = i32::try_from(batch_size).context("batch_size does not fit in i32")?; + let eval_outcome = token_classifier.eval_multimodal_chunks( &input_chunks, multimodal_context, &self.llama_context, 0, sequence_id, - batch_size as i32, + batch_size_i32, true, ); @@ -759,20 +711,14 @@ impl ContinuousBatchScheduler { self.sequence_id_pool.release(sequence_id); - if generated_tokens_tx - .send(GeneratedTokenResult::ImageExceedsBatchSize( - OversizedImageDetails { - image_tokens: mismatch.image_tokens, - n_batch: mismatch.n_batch, - }, - )) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::ImageExceedsBatchSize(OversizedImageDetails { + image_tokens: mismatch.image_tokens, + n_batch: mismatch.n_batch, + }), + ); return Ok(()); } @@ -785,15 +731,11 @@ impl ContinuousBatchScheduler { error!("{message}"); self.sequence_id_pool.release(sequence_id); - if generated_tokens_tx - .send(GeneratedTokenResult::SamplerError(message)) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError(message), + ); return Ok(()); } @@ -817,19 +759,21 @@ impl ContinuousBatchScheduler { ); self.active_requests.push(ContinuousBatchActiveRequest { + state: ContinuousBatchRequestState { + current_token_position: tokens_ingested, + i_batch: Some(-1), + max_tokens, + pending_sampled_token: None, + phase: ContinuousBatchRequestPhase::Generating, + prompt_tokens: Vec::new(), + prompt_tokens_ingested: 0, + sequence_id, + }, chain, token_classifier, - current_token_position: tokens_ingested, grammar_sampler: llama_grammar_sampler, generated_tokens_tx, generate_tokens_stop_rx, - i_batch: Some(-1), - max_tokens, - pending_sampled_token: None, - phase: ContinuousBatchRequestPhase::Generating, - prompt_tokens: Vec::new(), - prompt_tokens_ingested: 0, - sequence_id, slot_guard, tool_call_pipeline, }); @@ -840,17 +784,17 @@ impl ContinuousBatchScheduler { fn harvest_pending_samples_before_external_decode(&mut self) { for active_request in &mut self.active_requests { if !matches!( - active_request.phase, + active_request.state.phase, ContinuousBatchRequestPhase::Generating ) { continue; } - if active_request.pending_sampled_token.is_some() { + if active_request.state.pending_sampled_token.is_some() { continue; } - let Some(batch_index) = active_request.i_batch else { + let Some(batch_index) = active_request.state.i_batch else { continue; }; @@ -867,17 +811,17 @@ impl ContinuousBatchScheduler { // happens in `advance_generating_phase` after the next decode, // not here. let _ = active_request.token_classifier.ingest(raw_token); - active_request.pending_sampled_token = + active_request.state.pending_sampled_token = Some(llama_cpp_bindings::SampledToken::Content(raw_token)); - active_request.i_batch = None; + active_request.state.i_batch = None; } Ok(SamplingOutcome::AllCandidatesEliminated) => { error!( "{:?}: sequence {} pre-eval harvest exhausted candidates", - self.scheduler_context.agent_name, active_request.sequence_id + self.scheduler_context.agent_name, active_request.state.sequence_id ); active_request.complete_with_outcome( - &self.scheduler_context.agent_name, + self.scheduler_context.agent_name.as_deref(), GeneratedTokenResult::SamplerError( "all token candidates were eliminated during sampling".to_owned(), ), @@ -886,20 +830,20 @@ impl ContinuousBatchScheduler { Ok(SamplingOutcome::GrammarRejectedModelOutput(message)) => { error!( "{:?}: sequence {} pre-eval harvest grammar rejected: {message}", - self.scheduler_context.agent_name, active_request.sequence_id + self.scheduler_context.agent_name, active_request.state.sequence_id ); active_request.complete_with_outcome( - &self.scheduler_context.agent_name, + self.scheduler_context.agent_name.as_deref(), GeneratedTokenResult::GrammarRejectedModelOutput(message), ); } Err(err) => { error!( "{:?}: sequence {} pre-eval harvest sampling error: {err:#}", - self.scheduler_context.agent_name, active_request.sequence_id + self.scheduler_context.agent_name, active_request.state.sequence_id ); active_request.complete_with_outcome( - &self.scheduler_context.agent_name, + self.scheduler_context.agent_name.as_deref(), GeneratedTokenResult::SamplerError(err.to_string()), ); } @@ -915,7 +859,7 @@ impl ContinuousBatchScheduler { }; active_request.complete_with_outcome( - &self.scheduler_context.agent_name, + self.scheduler_context.agent_name.as_deref(), GeneratedTokenResult::Done(summary), ); } @@ -958,7 +902,7 @@ impl ContinuousBatchScheduler { fn has_active_requests(&self) -> bool { self.active_requests .iter() - .any(|request| !matches!(request.phase, ContinuousBatchRequestPhase::Completed)) + .any(|request| !matches!(request.state.phase, ContinuousBatchRequestPhase::Completed)) } fn execute_one_iteration(&mut self) -> Result<()> { @@ -970,12 +914,9 @@ impl ContinuousBatchScheduler { loop { let max_sequences = self.active_requests.len(); - #[expect( - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - reason = "token counts and positions fit in i32 for llama.cpp FFI" - )] - let mut pass = BatchPass::new(n_batch, max_sequences.max(1) as i32)?; + let max_sequences_i32 = i32::try_from(max_sequences.max(1)) + .context("max sequence count does not fit in i32")?; + let mut pass = BatchPass::new(n_batch, max_sequences_i32)?; assemble_phase.run(&mut pass, &mut self.active_requests)?; @@ -992,7 +933,7 @@ impl ContinuousBatchScheduler { match decode_batch_phase::run(&mut pass, &mut self.llama_context) { DecodeOutcome::Decoded => { - commit_phase::run(pass, &mut self.active_requests); + commit_phase::run(pass, &mut self.active_requests)?; return Ok(()); } @@ -1026,12 +967,15 @@ impl ContinuousBatchScheduler { let mut largest_position: i32 = -1; for (index, active_request) in self.active_requests.iter().enumerate() { - if matches!(active_request.phase, ContinuousBatchRequestPhase::Completed) { + if matches!( + active_request.state.phase, + ContinuousBatchRequestPhase::Completed + ) { continue; } - if active_request.current_token_position > largest_position { - largest_position = active_request.current_token_position; + if active_request.state.current_token_position > largest_position { + largest_position = active_request.state.current_token_position; largest_seq_index = Some(index); } } @@ -1042,24 +986,19 @@ impl ContinuousBatchScheduler { warn!( "{:?}: evicting sequence {} (position {}) due to KV cache pressure", self.scheduler_context.agent_name, - evicted_request.sequence_id, - evicted_request.current_token_position + evicted_request.state.sequence_id, + evicted_request.state.current_token_position ); - if evicted_request - .generated_tokens_tx - .send(GeneratedTokenResult::SamplerError( + send_generated_token_result_or_warn( + self.scheduler_context.agent_name.as_deref(), + &evicted_request.generated_tokens_tx, + GeneratedTokenResult::SamplerError( "Request evicted due to KV cache pressure".to_owned(), - )) - .is_err() - { - warn!( - "{:?}: failed to send result to client (receiver dropped)", - self.scheduler_context.agent_name - ); - } + ), + ); - evicted_request.phase = ContinuousBatchRequestPhase::Completed; + evicted_request.state.phase = ContinuousBatchRequestPhase::Completed; self.cleanup_completed_request(eviction_index); } @@ -1070,7 +1009,7 @@ impl ContinuousBatchScheduler { while removal_index < self.active_requests.len() { if matches!( - self.active_requests[removal_index].phase, + self.active_requests[removal_index].state.phase, ContinuousBatchRequestPhase::Completed ) { self.cleanup_completed_request(removal_index); @@ -1080,33 +1019,87 @@ impl ContinuousBatchScheduler { } } - fn cleanup_completed_request(&mut self, index: usize) { - let removed_request = self.active_requests.swap_remove(index); + fn clear_kv_cache_for_sequence(&mut self, sequence_id: i32) { + let sequence_id_u32 = match u32::try_from(sequence_id) { + Ok(sequence_id_u32) => sequence_id_u32, + Err(err) => { + error!( + "{:?}: sequence id {sequence_id} does not fit in u32: {err}", + self.scheduler_context.agent_name + ); - #[expect( - clippy::cast_sign_loss, - reason = "sequence IDs are always non-negative" - )] - if let Err(err) = self.llama_context.clear_kv_cache_seq( - Some(removed_request.sequence_id as u32), - None, - None, - ) { + return; + } + }; + + if let Err(err) = self + .llama_context + .clear_kv_cache_seq(Some(sequence_id_u32), None, None) + { error!( - "{:?}: failed to clear KV cache for sequence {}: {err}", - self.scheduler_context.agent_name, removed_request.sequence_id + "{:?}: failed to clear KV cache for sequence {sequence_id}: {err}", + self.scheduler_context.agent_name ); } + } + + fn cleanup_completed_request(&mut self, index: usize) { + let removed_request = self.active_requests.swap_remove(index); + + self.clear_kv_cache_for_sequence(removed_request.state.sequence_id); - self.sequence_id_pool.release(removed_request.sequence_id); + self.sequence_id_pool + .release(removed_request.state.sequence_id); let usage = removed_request.token_classifier.usage(); debug!( "{:?}: cleaned up sequence {} ({} completion tokens generated)", self.scheduler_context.agent_name, - removed_request.sequence_id, + removed_request.state.sequence_id, usage.content_tokens + usage.reasoning_tokens + usage.undeterminable_tokens, ); } } + +#[cfg(test)] +mod tests { + use log::LevelFilter; + use tokio::sync::mpsc; + + use super::send_generated_token_result_or_warn; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + + #[test] + fn delivers_result_to_a_live_receiver() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + + send_generated_token_result_or_warn( + Some("agent"), + &generated_tokens_tx, + GeneratedTokenResult::SamplerError("boom".to_owned()), + ); + + assert!(matches!( + generated_tokens_rx.try_recv(), + Ok(GeneratedTokenResult::SamplerError(message)) if message == "boom" + )); + } + + #[test] + fn warns_without_panicking_when_the_receiver_was_dropped() { + log::set_max_level(LevelFilter::Trace); + + let (generated_tokens_tx, generated_tokens_rx) = mpsc::unbounded_channel(); + + drop(generated_tokens_rx); + + send_generated_token_result_or_warn( + None, + &generated_tokens_tx, + GeneratedTokenResult::SamplerError("boom".to_owned()), + ); + + assert!(generated_tokens_tx.is_closed()); + } +} diff --git a/paddler/src/agent/continuous_batch_scheduler/sample_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/sample_outcome.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/sample_outcome.rs rename to paddler_agent/src/continuous_batch_scheduler/sample_outcome.rs diff --git a/paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs b/paddler_agent/src/continuous_batch_scheduler/sample_token_phase.rs similarity index 75% rename from paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs rename to paddler_agent/src/continuous_batch_scheduler/sample_token_phase.rs index 14f9c921..fe896c9f 100644 --- a/paddler/src/agent/continuous_batch_scheduler/sample_token_phase.rs +++ b/paddler_agent/src/continuous_batch_scheduler/sample_token_phase.rs @@ -1,9 +1,9 @@ use llama_cpp_bindings::context::LlamaContext; -use crate::agent::continuous_batch_active_request::ContinuousBatchActiveRequest; -use crate::agent::continuous_batch_scheduler::sample_outcome::SampleOutcome; -use crate::agent::sample_token_at_batch_index::sample_token_at_batch_index; -use crate::agent::sampling_outcome::SamplingOutcome; +use crate::continuous_batch_active_request::ContinuousBatchActiveRequest; +use crate::continuous_batch_scheduler::sample_outcome::SampleOutcome; +use crate::sample_token_at_batch_index::sample_token_at_batch_index; +use crate::sampling_outcome::SamplingOutcome; pub struct SampleTokenPhase<'context> { pub context: &'context LlamaContext<'context>, diff --git a/paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs b/paddler_agent/src/continuous_batch_scheduler/tool_call_pass.rs similarity index 88% rename from paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs rename to paddler_agent/src/continuous_batch_scheduler/tool_call_pass.rs index 4373c1c7..c7e52e38 100644 --- a/paddler/src/agent/continuous_batch_scheduler/tool_call_pass.rs +++ b/paddler_agent/src/continuous_batch_scheduler/tool_call_pass.rs @@ -1,7 +1,7 @@ use llama_cpp_bindings::SampledToken; -use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generated_token_result::GeneratedTokenResult; -use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; +use crate::continuous_batch_scheduler::classified_token::ClassifiedToken; use crate::tool_call_pipeline::ToolCallPipeline; #[must_use] @@ -28,7 +28,7 @@ mod tests { use llama_cpp_bindings::token::LlamaToken; use super::run; - use crate::agent::continuous_batch_scheduler::classified_token::ClassifiedToken; + use crate::continuous_batch_scheduler::classified_token::ClassifiedToken; fn classified(was: bool, is: bool, sampled: SampledToken) -> ClassifiedToken { ClassifiedToken { diff --git a/paddler/src/agent/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs b/paddler_agent/src/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs similarity index 100% rename from paddler/src/agent/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs rename to paddler_agent/src/continuous_batch_scheduler/tool_call_pipeline_build_outcome.rs diff --git a/paddler_agent/src/continuous_batch_scheduler_command.rs b/paddler_agent/src/continuous_batch_scheduler_command.rs new file mode 100644 index 00000000..71001153 --- /dev/null +++ b/paddler_agent/src/continuous_batch_scheduler_command.rs @@ -0,0 +1,10 @@ +use crate::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use crate::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use crate::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; + +pub enum ContinuousBatchSchedulerCommand { + ContinueFromConversationHistory(ContinueFromConversationHistoryRequest), + ContinueFromRawPrompt(ContinueFromRawPromptRequest), + GenerateEmbeddingBatch(GenerateEmbeddingBatchRequest), + Shutdown, +} diff --git a/paddler/src/agent/continuous_batch_scheduler_context.rs b/paddler_agent/src/continuous_batch_scheduler_context.rs similarity index 90% rename from paddler/src/agent/continuous_batch_scheduler_context.rs rename to paddler_agent/src/continuous_batch_scheduler_context.rs index a73aad87..535223ff 100644 --- a/paddler/src/agent/continuous_batch_scheduler_context.rs +++ b/paddler_agent/src/continuous_batch_scheduler_context.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use llama_cpp_bindings::model::LlamaModel; use llama_cpp_bindings::mtmd::MtmdContext; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::inference_parameters::InferenceParameters; use crate::chat_template_renderer::ChatTemplateRenderer; diff --git a/paddler_agent/src/converts_to_llama_kv_cache_dtype.rs b/paddler_agent/src/converts_to_llama_kv_cache_dtype.rs new file mode 100644 index 00000000..71211c01 --- /dev/null +++ b/paddler_agent/src/converts_to_llama_kv_cache_dtype.rs @@ -0,0 +1,5 @@ +use llama_cpp_bindings::context::params::KvCacheType as LlamaKvCacheDtype; + +pub trait ConvertsToLlamaKvCacheDtype { + fn to_llama_kv_cache_dtype(self) -> LlamaKvCacheDtype; +} diff --git a/paddler_agent/src/converts_to_llama_pooling_type.rs b/paddler_agent/src/converts_to_llama_pooling_type.rs new file mode 100644 index 00000000..83b260c7 --- /dev/null +++ b/paddler_agent/src/converts_to_llama_pooling_type.rs @@ -0,0 +1,5 @@ +use llama_cpp_bindings::context::params::LlamaPoolingType; + +pub trait ConvertsToLlamaPoolingType { + fn to_llama_pooling_type(self) -> LlamaPoolingType; +} diff --git a/paddler/src/decoded_image.rs b/paddler_agent/src/decoded_image.rs similarity index 51% rename from paddler/src/decoded_image.rs rename to paddler_agent/src/decoded_image.rs index 1d1d5b74..3122ada6 100644 --- a/paddler/src/decoded_image.rs +++ b/paddler_agent/src/decoded_image.rs @@ -6,7 +6,7 @@ use image::DynamicImage; use image::ImageFormat; use image::imageops::FilterType; use log::info; -use paddler_types::image_url::ImageUrl; +use paddler_messaging::image_url::ImageUrl; use resvg::tiny_skia::Pixmap; use resvg::usvg::Options; use resvg::usvg::Tree as SvgTree; @@ -79,7 +79,7 @@ fn rasterize_svg_to_dynamic_image( let rgba = image::RgbaImage::from_raw(target_width, target_height, pixmap.data().to_vec()) .ok_or_else(|| DecodedImageError::ConversionFailed { - message: "rasterized SVG buffer did not match target dimensions".to_owned(), + message: "rasterized pixmap buffer length did not match target dimensions".to_owned(), })?; Ok(DynamicImage::ImageRgba8(rgba)) @@ -190,6 +190,7 @@ impl DecodedImage { if needs_resize { let resized = image.resize(max_dimension, max_dimension, FilterType::Lanczos3); + return Ok(Self { data: encode_jpeg(&resized)?, }); @@ -207,98 +208,115 @@ impl DecodedImage { #[cfg(test)] mod tests { use std::io::Cursor; + use std::mem::discriminant; - use anyhow::Result; use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use image::ImageFormat; - use paddler_types::image_url::ImageUrl; + use paddler_messaging::image_url::ImageUrl; use crate::decoded_image::DecodedImage; + use crate::decoded_image::compute_target_dimension; use crate::decoded_image_error::DecodedImageError; - fn create_test_jpeg(width: u32, height: u32) -> Result> { + fn create_test_jpeg(width: u32, height: u32) -> Vec { use image::RgbImage; let image_buffer = RgbImage::new(width, height); let mut output_buffer = Cursor::new(Vec::new()); image::DynamicImage::ImageRgb8(image_buffer) - .write_to(&mut output_buffer, ImageFormat::Jpeg)?; + .write_to(&mut output_buffer, ImageFormat::Jpeg) + .unwrap(); - Ok(output_buffer.into_inner()) + output_buffer.into_inner() } - fn create_test_tiff(width: u32, height: u32) -> Result> { + fn create_test_tiff(width: u32, height: u32) -> Vec { use image::RgbImage; let image_buffer = RgbImage::new(width, height); let mut output_buffer = Cursor::new(Vec::new()); image::DynamicImage::ImageRgb8(image_buffer) - .write_to(&mut output_buffer, ImageFormat::Tiff)?; + .write_to(&mut output_buffer, ImageFormat::Tiff) + .unwrap(); - Ok(output_buffer.into_inner()) + output_buffer.into_inner() } - fn create_test_png(width: u32, height: u32) -> Result> { + fn create_test_png(width: u32, height: u32) -> Vec { use image::RgbImage; let image_buffer = RgbImage::new(width, height); let mut output_buffer = Cursor::new(Vec::new()); image::DynamicImage::ImageRgb8(image_buffer) - .write_to(&mut output_buffer, ImageFormat::Png)?; + .write_to(&mut output_buffer, ImageFormat::Png) + .unwrap(); - Ok(output_buffer.into_inner()) + output_buffer.into_inner() } - fn create_test_gif(width: u32, height: u32) -> Result> { + fn create_test_gif(width: u32, height: u32) -> Vec { use image::RgbaImage; let image_buffer = RgbaImage::new(width, height); let mut output_buffer = Cursor::new(Vec::new()); image::DynamicImage::ImageRgba8(image_buffer) - .write_to(&mut output_buffer, ImageFormat::Gif)?; + .write_to(&mut output_buffer, ImageFormat::Gif) + .unwrap(); - Ok(output_buffer.into_inner()) + output_buffer.into_inner() } - fn create_test_bmp(width: u32, height: u32) -> Result> { + fn create_test_bmp(width: u32, height: u32) -> Vec { use image::RgbImage; let image_buffer = RgbImage::new(width, height); let mut output_buffer = Cursor::new(Vec::new()); image::DynamicImage::ImageRgb8(image_buffer) - .write_to(&mut output_buffer, ImageFormat::Bmp)?; + .write_to(&mut output_buffer, ImageFormat::Bmp) + .unwrap(); + + output_buffer.into_inner() + } + + fn create_test_openexr(width: u32, height: u32) -> Vec { + use image::Rgb; + use image::Rgb32FImage; + + let image_buffer = Rgb32FImage::from_pixel(width, height, Rgb([0.25f32, 0.5f32, 0.75f32])); + let mut output_buffer = Cursor::new(Vec::new()); - Ok(output_buffer.into_inner()) + image::DynamicImage::ImageRgb32F(image_buffer) + .write_to(&mut output_buffer, ImageFormat::OpenExr) + .unwrap(); + + output_buffer.into_inner() } - fn load_fixture(filename: &str) -> Result> { - let data = std::fs::read(format!( + fn load_fixture(filename: &str) -> Vec { + std::fs::read(format!( "{}/../fixtures/{filename}", env!("CARGO_MANIFEST_DIR"), - ))?; - - Ok(data) + )) + .unwrap() } #[test] - fn test_decodes_valid_png_data_uri() -> Result<()> { + fn test_decodes_valid_png_data_uri() { let png_bytes: Vec = vec![0x89, 0x50, 0x4E, 0x47]; let encoded = BASE64_STANDARD.encode(&png_bytes); let image_url = ImageUrl { url: format!("data:image/png;base64,{encoded}"), }; - let result = DecodedImage::from_data_uri(&image_url)?; + let result = DecodedImage::from_data_uri(&image_url).unwrap(); assert_eq!(result.data, png_bytes); - - Ok(()) } #[test] @@ -307,12 +325,12 @@ mod tests { url: "https://example.com/image.png".to_owned(), }; - let result = DecodedImage::from_data_uri(&image_url); + let error = DecodedImage::from_data_uri(&image_url).err().unwrap(); - assert!(matches!( - result, - Err(DecodedImageError::RemoteUrlNotSupported) - )); + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::RemoteUrlNotSupported), + ); } #[test] @@ -321,12 +339,12 @@ mod tests { url: "data:image/png;base64".to_owned(), }; - let result = DecodedImage::from_data_uri(&image_url); + let error = DecodedImage::from_data_uri(&image_url).err().unwrap(); - assert!(matches!( - result, - Err(DecodedImageError::MissingCommaSeparator) - )); + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::MissingCommaSeparator), + ); } #[test] @@ -335,92 +353,88 @@ mod tests { url: "data:image/png;base64,!!!not-valid-base64!!!".to_owned(), }; - let result = DecodedImage::from_data_uri(&image_url); + let error = DecodedImage::from_data_uri(&image_url).err().unwrap(); - assert!(matches!( - result, - Err(DecodedImageError::InvalidBase64Payload { .. }) - )); + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::InvalidBase64Payload { + message: String::new(), + }), + ); } #[test] - fn test_prepared_passes_through_small_jpeg() -> Result<()> { - let jpeg_data = create_test_jpeg(100, 100)?; + fn test_prepared_passes_through_small_jpeg() { + let jpeg_data = create_test_jpeg(100, 100); let original_len = jpeg_data.len(); let decoded_image = DecodedImage { data: jpeg_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_prepared_passes_through_small_png() -> Result<()> { - let png_data = create_test_png(100, 100)?; + fn test_prepared_passes_through_small_png() { + let png_data = create_test_png(100, 100); let original_len = png_data.len(); let decoded_image = DecodedImage { data: png_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_prepared_passes_through_small_gif() -> Result<()> { - let gif_data = create_test_gif(100, 100)?; + fn test_prepared_passes_through_small_gif() { + let gif_data = create_test_gif(100, 100); let original_len = gif_data.len(); let decoded_image = DecodedImage { data: gif_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_prepared_passes_through_small_bmp() -> Result<()> { - let bmp_data = create_test_bmp(100, 100)?; + fn test_prepared_passes_through_small_bmp() { + let bmp_data = create_test_bmp(100, 100); let original_len = bmp_data.len(); let decoded_image = DecodedImage { data: bmp_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); assert_eq!(result.data.len(), original_len); - Ok(()) } #[test] - fn test_prepared_converts_small_tiff_to_png() -> Result<()> { - let tiff_data = create_test_tiff(100, 100)?; + fn test_prepared_converts_small_tiff_to_png() { + let tiff_data = create_test_tiff(100, 100); let decoded_image = DecodedImage { data: tiff_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); - let result_format = image::guess_format(&result.data)?; + let result_format = image::guess_format(&result.data).unwrap(); assert_eq!(result_format, ImageFormat::Png); - Ok(()) } #[test] - fn test_prepared_converts_small_webp_fixture_to_png() -> Result<()> { - let webp_data = load_fixture("llamas.webp")?; + fn test_prepared_converts_small_webp_fixture_to_png() { + let webp_data = load_fixture("llamas.webp"); let decoded_image = DecodedImage { data: webp_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); - let result_format = image::guess_format(&result.data)?; + let result_format = image::guess_format(&result.data).unwrap(); assert_eq!(result_format, ImageFormat::Png); - let result_image = image::load_from_memory(&result.data)?; + let result_image = image::load_from_memory(&result.data).unwrap(); assert_eq!(result_image.width(), 640); assert_eq!(result_image.height(), 427); - Ok(()) } #[test] - fn test_prepared_rasterizes_small_svg() -> Result<()> { + fn test_prepared_rasterizes_small_svg() { let svg_data = br#" "#; @@ -428,97 +442,85 @@ mod tests { data: svg_data.to_vec(), }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); - let result_format = image::guess_format(&result.data)?; + let result_format = image::guess_format(&result.data).unwrap(); assert_eq!(result_format, ImageFormat::Png); - let result_image = image::load_from_memory(&result.data)?; + let result_image = image::load_from_memory(&result.data).unwrap(); assert_eq!(result_image.width(), 50); assert_eq!(result_image.height(), 50); - Ok(()) } #[test] - fn test_prepared_rasterizes_svg_fixture_within_bound() -> Result<()> { - let svg_data = load_fixture("llamas.svg")?; + fn test_prepared_rasterizes_svg_fixture_within_bound() { + let svg_data = load_fixture("llamas.svg"); let decoded_image = DecodedImage { data: svg_data }; - let result = decoded_image.prepared_for_inference(320)?; + let result = decoded_image.prepared_for_inference(320).unwrap(); - let result_format = image::guess_format(&result.data)?; - let result_image = image::load_from_memory(&result.data)?; + let result_format = image::guess_format(&result.data).unwrap(); + let result_image = image::load_from_memory(&result.data).unwrap(); assert!(result_image.width() <= 320); assert!(result_image.height() <= 320); - assert!(matches!( - result_format, - ImageFormat::Png | ImageFormat::Jpeg - )); - Ok(()) + assert_eq!(result_format, ImageFormat::Png); } #[test] - fn test_prepared_resizes_oversized_jpeg_to_jpeg() -> Result<()> { - let jpeg_data = create_test_jpeg(2000, 1500)?; + fn test_prepared_resizes_oversized_jpeg_to_jpeg() { + let jpeg_data = create_test_jpeg(2000, 1500); let decoded_image = DecodedImage { data: jpeg_data }; - let result = decoded_image.prepared_for_inference(1024)?; + let result = decoded_image.prepared_for_inference(1024).unwrap(); - let result_format = image::guess_format(&result.data)?; + let result_format = image::guess_format(&result.data).unwrap(); assert_eq!(result_format, ImageFormat::Jpeg); - let result_image = image::load_from_memory(&result.data)?; + let result_image = image::load_from_memory(&result.data).unwrap(); assert!(result_image.width() <= 1024); assert!(result_image.height() <= 1024); - Ok(()) } #[test] - fn test_prepared_preserves_aspect_ratio_on_resize() -> Result<()> { - let jpeg_data = create_test_jpeg(2000, 1000)?; + fn test_prepared_preserves_aspect_ratio_on_resize() { + let jpeg_data = create_test_jpeg(2000, 1000); let decoded_image = DecodedImage { data: jpeg_data }; - let result = decoded_image.prepared_for_inference(1000)?; + let result = decoded_image.prepared_for_inference(1000).unwrap(); - let result_image = image::load_from_memory(&result.data)?; + let result_image = image::load_from_memory(&result.data).unwrap(); assert_eq!(result_image.width(), 1000); assert_eq!(result_image.height(), 500); - Ok(()) } #[test] - fn test_prepared_with_jpg_fixture_within_bound() -> Result<()> { - let fixture_data = std::fs::read(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../fixtures/llamas.jpg" - ))?; + fn test_prepared_with_jpg_fixture_within_bound() { + let fixture_data = load_fixture("llamas.jpg"); - let original_image = image::load_from_memory(&fixture_data)?; + let original_image = image::load_from_memory(&fixture_data).unwrap(); assert_eq!(original_image.width(), 640); assert_eq!(original_image.height(), 427); let decoded_image = DecodedImage { data: fixture_data }; - let result = decoded_image.prepared_for_inference(320)?; + let result = decoded_image.prepared_for_inference(320).unwrap(); - let result_image = image::load_from_memory(&result.data)?; + let result_image = image::load_from_memory(&result.data).unwrap(); assert_eq!(result_image.width(), 320); assert_eq!(result_image.height(), 214); - Ok(()) } #[test] - fn test_prepared_rejects_zero_max_dimension() -> Result<()> { - let png_data = create_test_png(50, 50)?; + fn test_prepared_rejects_zero_max_dimension() { + let png_data = create_test_png(50, 50); let decoded_image = DecodedImage { data: png_data }; - let result = decoded_image.prepared_for_inference(0); + let error = decoded_image.prepared_for_inference(0).err().unwrap(); - assert!(matches!( - result, - Err(DecodedImageError::InvalidMaxDimension) - )); - Ok(()) + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::InvalidMaxDimension), + ); } #[test] @@ -530,11 +532,198 @@ mod tests { data: svg_data.to_vec(), }; - let result = decoded_image.prepared_for_inference(1024); + let error = decoded_image.prepared_for_inference(1024).err().unwrap(); - assert!(matches!( - result, - Err(DecodedImageError::ConversionFailed { .. }) - )); + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_format_without_reading_support() { + let dds_header: Vec = b"DDS \x00\x00\x00\x00".to_vec(); + let decoded_image = DecodedImage { data: dds_header }; + + let error = decoded_image.prepared_for_inference(1024).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::UnsupportedFormat { + format: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_unrecognized_format_bytes() { + let unrecognized_bytes: Vec = b"this is plain text and not any image format".to_vec(); + let decoded_image = DecodedImage { + data: unrecognized_bytes, + }; + + let error = decoded_image.prepared_for_inference(1024).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_corrupt_png_body() { + let mut corrupt_png: Vec = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]; + corrupt_png.extend_from_slice(b"not a real PNG chunk stream"); + let decoded_image = DecodedImage { data: corrupt_png }; + + let error = decoded_image.prepared_for_inference(1024).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_compute_target_dimension_rounds_up_within_range() { + let target = compute_target_dimension(49.2, 1.0).unwrap(); + + assert_eq!(target, 50); + } + + #[test] + fn test_compute_target_dimension_rejects_below_one() { + let error = compute_target_dimension(0.0, 1.0).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_compute_target_dimension_rejects_non_finite() { + let error = compute_target_dimension(f64::INFINITY, 1.0).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_compute_target_dimension_rejects_above_u32_max() { + let above_u32_max = f64::from(u32::MAX) + 1.0; + + let error = compute_target_dimension(above_u32_max, 1.0).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_svg_whose_scaled_width_exceeds_u32_max() { + let svg_data = + br#""#; + let decoded_image = DecodedImage { + data: svg_data.to_vec(), + }; + + let error = decoded_image + .prepared_for_inference(u32::MAX) + .err() + .unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_svg_whose_scaled_height_exceeds_u32_max() { + let svg_data = + br#""#; + let decoded_image = DecodedImage { + data: svg_data.to_vec(), + }; + + let error = decoded_image + .prepared_for_inference(u32::MAX) + .err() + .unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_svg_whose_target_pixmap_overflows() { + let svg_data = br#""#; + let decoded_image = DecodedImage { + data: svg_data.to_vec(), + }; + + let error = decoded_image + .prepared_for_inference(600_000_000) + .err() + .unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_rejects_float_pixel_format_when_reencoding_to_png() { + let openexr_data = create_test_openexr(4, 4); + let decoded_image = DecodedImage { data: openexr_data }; + + let error = decoded_image.prepared_for_inference(1024).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ConversionFailed { + message: String::new(), + }), + ); + } + + #[test] + fn test_prepared_fails_when_resized_dimension_exceeds_jpeg_limit() { + let png_data = create_test_png(70_001, 2); + let decoded_image = DecodedImage { data: png_data }; + + let error = decoded_image.prepared_for_inference(70_000).err().unwrap(); + + assert_eq!( + discriminant(&error), + discriminant(&DecodedImageError::ResizeFailed { + message: String::new(), + }), + ); } } diff --git a/paddler/src/decoded_image_error.rs b/paddler_agent/src/decoded_image_error.rs similarity index 100% rename from paddler/src/decoded_image_error.rs rename to paddler_agent/src/decoded_image_error.rs diff --git a/paddler/src/desired_model_resolution.rs b/paddler_agent/src/desired_model_resolution.rs similarity index 100% rename from paddler/src/desired_model_resolution.rs rename to paddler_agent/src/desired_model_resolution.rs diff --git a/paddler/src/dispenses_slots.rs b/paddler_agent/src/dispenses_slots.rs similarity index 100% rename from paddler/src/dispenses_slots.rs rename to paddler_agent/src/dispenses_slots.rs diff --git a/paddler/src/agent/drain_in_flight_requests.rs b/paddler_agent/src/drain_in_flight_requests.rs similarity index 86% rename from paddler/src/agent/drain_in_flight_requests.rs rename to paddler_agent/src/drain_in_flight_requests.rs index bd7aa5f2..649fa8ae 100644 --- a/paddler/src/agent/drain_in_flight_requests.rs +++ b/paddler_agent/src/drain_in_flight_requests.rs @@ -6,7 +6,7 @@ use log::info; use tokio_util::sync::CancellationToken; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; -use crate::subscribes_to_updates::SubscribesToUpdates as _; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates as _; pub async fn drain_in_flight_requests( slot_aggregated_status_manager: &Arc, @@ -53,17 +53,17 @@ mod tests { } #[tokio::test] - async fn returns_immediately_when_no_slots_processing() -> anyhow::Result<()> { + async fn returns_immediately_when_no_slots_processing() { let slot_aggregated_status_manager = create_status_manager(4); let shutdown = CancellationToken::new(); - drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown).await?; - - Ok(()) + drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown) + .await + .unwrap(); } #[tokio::test] - async fn waits_for_processing_slots_to_reach_zero() -> anyhow::Result<()> { + async fn waits_for_processing_slots_to_reach_zero() { let slot_aggregated_status_manager = create_status_manager(4); let shutdown = CancellationToken::new(); @@ -79,7 +79,9 @@ mod tests { status.release_slot(); }); - drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown).await?; + drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown) + .await + .unwrap(); assert_eq!( slot_aggregated_status_manager @@ -88,13 +90,11 @@ mod tests { 0 ); - release_handle.await?; - - Ok(()) + release_handle.await.unwrap(); } #[tokio::test] - async fn aborts_on_shutdown_signal() -> anyhow::Result<()> { + async fn aborts_on_shutdown_signal() { let slot_aggregated_status_manager = create_status_manager(4); let shutdown = CancellationToken::new(); @@ -108,7 +108,9 @@ mod tests { shutdown_trigger.cancel(); }); - drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown).await?; + drain_in_flight_requests(&slot_aggregated_status_manager, &shutdown) + .await + .unwrap(); assert_eq!( slot_aggregated_status_manager @@ -117,8 +119,6 @@ mod tests { 1, ); - shutdown_handle.await?; - - Ok(()) + shutdown_handle.await.unwrap(); } } diff --git a/paddler/src/embedding_input_tokenized.rs b/paddler_agent/src/embedding_input_tokenized.rs similarity index 100% rename from paddler/src/embedding_input_tokenized.rs rename to paddler_agent/src/embedding_input_tokenized.rs diff --git a/paddler/src/agent/from_request_params.rs b/paddler_agent/src/from_request_params.rs similarity index 86% rename from paddler/src/agent/from_request_params.rs rename to paddler_agent/src/from_request_params.rs index db9371d2..6af18eba 100644 --- a/paddler/src/agent/from_request_params.rs +++ b/paddler_agent/src/from_request_params.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use tokio::sync::mpsc; -use crate::agent::jsonrpc::response::Response; use crate::slot_aggregated_status::SlotAggregatedStatus; +use paddler_messaging::management_socket::agent::response::Response; pub trait FromRequestParams: Send + Sync { type RequestParams; diff --git a/paddler/src/agent/generate_embedding_batch_request.rs b/paddler_agent/src/generate_embedding_batch_request.rs similarity index 80% rename from paddler/src/agent/generate_embedding_batch_request.rs rename to paddler_agent/src/generate_embedding_batch_request.rs index f7139022..660dd495 100644 --- a/paddler/src/agent/generate_embedding_batch_request.rs +++ b/paddler_agent/src/generate_embedding_batch_request.rs @@ -1,12 +1,12 @@ use std::sync::Arc; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_messaging::embedding_result::EmbeddingResult; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; use tokio::sync::mpsc; -use crate::agent::from_request_params::FromRequestParams; -use crate::agent::slot_guard::SlotGuard; +use crate::from_request_params::FromRequestParams; use crate::slot_aggregated_status::SlotAggregatedStatus; +use crate::slot_guard::SlotGuard; pub struct GenerateEmbeddingBatchRequest { pub generate_embedding_stop_rx: mpsc::UnboundedReceiver<()>, diff --git a/paddler/src/agent/grammar_sampler.rs b/paddler_agent/src/grammar_sampler.rs similarity index 85% rename from paddler/src/agent/grammar_sampler.rs rename to paddler_agent/src/grammar_sampler.rs index 1df52018..e0d0a7cb 100644 --- a/paddler/src/agent/grammar_sampler.rs +++ b/paddler_agent/src/grammar_sampler.rs @@ -2,9 +2,9 @@ use anyhow::Result; use anyhow::anyhow; use llama_cpp_bindings::model::LlamaModel; use llama_cpp_bindings::sampling::LlamaSampler; -use paddler_types::grammar_constraint::GrammarConstraint; +use paddler_messaging::grammar_constraint::GrammarConstraint; -use crate::agent::resolve_grammar_to_gbnf::resolve_grammar_to_gbnf; +use crate::resolve_grammar_to_gbnf::resolve_grammar_to_gbnf; pub struct GrammarSampler { grammar_string: String, diff --git a/paddler/src/agent/mod.rs b/paddler_agent/src/lib.rs similarity index 55% rename from paddler/src/agent/mod.rs rename to paddler_agent/src/lib.rs index 44bcaa90..9d8ca6ca 100644 --- a/paddler/src/agent/mod.rs +++ b/paddler_agent/src/lib.rs @@ -1,3 +1,10 @@ +pub mod agent_applicable_state; +pub mod agent_applicable_state_holder; +pub mod agent_desired_state_converter; +pub mod agent_issue_fix; +pub mod agent_kv_cache_dtype; +pub mod agent_pooling_type; +pub mod chat_template_renderer; pub mod continue_from_conversation_history_request; pub mod continue_from_raw_prompt_request; pub mod continuous_batch_active_request; @@ -6,27 +13,47 @@ pub mod continuous_batch_arbiter_build_outcome; pub mod continuous_batch_arbiter_handle; pub mod continuous_batch_embedding_processor; pub mod continuous_batch_request_phase; +pub mod continuous_batch_request_state; pub mod continuous_batch_scheduler; pub mod continuous_batch_scheduler_command; pub mod continuous_batch_scheduler_context; +pub mod converts_to_llama_kv_cache_dtype; +pub mod converts_to_llama_pooling_type; +pub mod decoded_image; +pub mod decoded_image_error; +pub mod desired_model_resolution; +pub mod dispenses_slots; pub mod drain_in_flight_requests; +pub mod embedding_input_tokenized; mod from_request_params; pub mod generate_embedding_batch_request; pub mod grammar_sampler; -pub mod jsonrpc; pub mod llamacpp_arbiter_service; pub mod management_socket_client_service; pub mod model_metadata_holder; +pub mod model_source; +pub mod normalization; pub mod plan_embedding_batches; pub mod prepare_conversation_history_request; pub mod prepared_conversation_history_request; pub mod receive_stream_stopper_collection; mod receive_stream_stopper_drop_guard; pub mod reconciliation_service; +pub mod resolve_desired_model; pub mod resolve_grammar; pub mod resolve_grammar_to_gbnf; pub mod resolved_grammar; +pub mod resolves_model_source; pub mod sample_token_at_batch_index; pub mod sampling_outcome; pub mod sequence_id_pool; +pub mod slot_aggregated_status; +pub mod slot_aggregated_status_download_progress; +pub mod slot_aggregated_status_manager; pub mod slot_guard; +pub mod tool_call_buffer; +pub mod tool_call_event; +pub mod tool_call_pipeline; +pub mod tool_call_pipeline_error; +pub mod tool_call_validator; +pub mod validator_build_error; diff --git a/paddler/src/agent/llamacpp_arbiter_service.rs b/paddler_agent/src/llamacpp_arbiter_service.rs similarity index 68% rename from paddler/src/agent/llamacpp_arbiter_service.rs rename to paddler_agent/src/llamacpp_arbiter_service.rs index c867d966..e1b6e3bf 100644 --- a/paddler/src/agent/llamacpp_arbiter_service.rs +++ b/paddler_agent/src/llamacpp_arbiter_service.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; use log::error; use log::info; use log::warn; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::MissedTickBehavior; @@ -14,17 +14,17 @@ use tokio::time::interval; use tokio_util::sync::CancellationToken; use trzcina::Service; -use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use crate::agent::continuous_batch_arbiter::ContinuousBatchArbiter; -use crate::agent::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; -use crate::agent::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; -use crate::agent::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; -use crate::agent::drain_in_flight_requests::drain_in_flight_requests; -use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use crate::agent::model_metadata_holder::ModelMetadataHolder; use crate::agent_applicable_state::AgentApplicableState; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; +use crate::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use crate::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use crate::continuous_batch_arbiter::ContinuousBatchArbiter; +use crate::continuous_batch_arbiter_build_outcome::ContinuousBatchArbiterBuildOutcome; +use crate::continuous_batch_arbiter_handle::ContinuousBatchArbiterHandle; +use crate::continuous_batch_scheduler_command::ContinuousBatchSchedulerCommand; +use crate::drain_in_flight_requests::drain_in_flight_requests; +use crate::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use crate::model_metadata_holder::ModelMetadataHolder; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; async fn apply_state( @@ -59,9 +59,7 @@ async fn apply_state( info!("Reconciled state change applied successfully"); } ContinuousBatchArbiterBuildOutcome::NoModelConfigured => { - warn!( - "No model configured in applicable state; skipping llama.cpp initialization" - ); + warn!("No model configured in applicable state; skipping llama.cpp initialization"); } } } @@ -250,10 +248,149 @@ impl Service for LlamaCppArbiterService { #[cfg(test)] mod tests { + use std::mem::discriminant; + use std::sync::mpsc::channel as std_channel; + use std::thread; + use anyhow::bail; use super::*; + fn spawn_arbiter_handle_with_live_receiver() -> ( + ContinuousBatchArbiterHandle, + std::sync::mpsc::Receiver, + ) { + let (command_tx, command_rx) = std_channel(); + let scheduler_thread_handle = thread::spawn(|| Ok(())); + + ( + ContinuousBatchArbiterHandle { + command_tx, + scheduler_thread_handle, + }, + command_rx, + ) + } + + #[test] + fn forward_command_delivers_command_when_handle_present() { + let (arbiter_handle, command_rx) = spawn_arbiter_handle_with_live_receiver(); + + forward_command( + Some(&arbiter_handle), + ContinuousBatchSchedulerCommand::Shutdown, + ); + + let delivered = command_rx.recv().unwrap(); + + assert_eq!( + discriminant(&delivered), + discriminant(&ContinuousBatchSchedulerCommand::Shutdown), + ); + } + + #[test] + fn forward_command_logs_error_when_receiver_dropped() { + let (arbiter_handle, command_rx) = spawn_arbiter_handle_with_live_receiver(); + + drop(command_rx); + + forward_command( + Some(&arbiter_handle), + ContinuousBatchSchedulerCommand::Shutdown, + ); + } + + #[test] + fn forward_command_logs_error_when_handle_absent() { + forward_command(None, ContinuousBatchSchedulerCommand::Shutdown); + } + + #[tokio::test] + async fn wait_for_in_flight_requests_drains_when_handle_present() { + let (arbiter_handle, _command_rx) = spawn_arbiter_handle_with_live_receiver(); + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(1)); + let shutdown = CancellationToken::new(); + + wait_for_in_flight_requests_to_finish( + &shutdown, + Some(&arbiter_handle), + &slot_aggregated_status_manager, + ) + .await + .unwrap(); + + arbiter_handle.shutdown().unwrap(); + } + + #[tokio::test] + async fn wait_for_in_flight_requests_returns_immediately_without_handle() { + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(1)); + let shutdown = CancellationToken::new(); + + wait_for_in_flight_requests_to_finish(&shutdown, None, &slot_aggregated_status_manager) + .await + .unwrap(); + } + + #[tokio::test] + async fn apply_state_without_model_marks_status_applied() { + let model_metadata_holder = Arc::new(ModelMetadataHolder::default()); + let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(1)); + let shutdown = CancellationToken::new(); + let mut continuous_batch_arbiter_handle: Option = None; + + apply_state( + &shutdown, + None, + None, + 1, + &model_metadata_holder, + &slot_aggregated_status_manager, + &mut continuous_batch_arbiter_handle, + ) + .await + .unwrap(); + + assert_eq!( + slot_aggregated_status_manager + .slot_aggregated_status + .get_state_application_status() + .unwrap(), + AgentStateApplicationStatus::Applied, + ); + } + + #[tokio::test] + async fn shutdown_arbiter_handle_returns_ok_when_handle_absent() { + let mut continuous_batch_arbiter_handle: Option = None; + + shutdown_arbiter_handle(&mut continuous_batch_arbiter_handle) + .await + .unwrap(); + + assert!(continuous_batch_arbiter_handle.is_none()); + } + + #[tokio::test] + async fn shutdown_arbiter_handle_joins_and_clears_present_handle() { + let (arbiter_handle, command_rx) = spawn_arbiter_handle_with_live_receiver(); + let mut continuous_batch_arbiter_handle = Some(arbiter_handle); + + shutdown_arbiter_handle(&mut continuous_batch_arbiter_handle) + .await + .unwrap(); + + assert!(continuous_batch_arbiter_handle.is_none()); + + let delivered = command_rx.recv().unwrap(); + + assert_eq!( + discriminant(&delivered), + discriminant(&ContinuousBatchSchedulerCommand::Shutdown), + ); + } + #[tokio::test(flavor = "multi_thread")] async fn does_not_exit_when_request_channels_close_without_shutdown() -> Result<()> { let observation_window = Duration::from_millis(500); @@ -284,8 +421,7 @@ mod tests { let shutdown = CancellationToken::new(); let task_token = shutdown.clone(); - let mut join_handle = - tokio::spawn(async move { Box::new(service).run(task_token).await }); + let mut join_handle = tokio::spawn(async move { Box::new(service).run(task_token).await }); drop(continue_from_conversation_history_request_tx); drop(continue_from_raw_prompt_request_tx); diff --git a/paddler_agent/src/management_socket_client_service.rs b/paddler_agent/src/management_socket_client_service.rs new file mode 100644 index 00000000..0171ced1 --- /dev/null +++ b/paddler_agent/src/management_socket_client_service.rs @@ -0,0 +1,1188 @@ +use std::sync::Arc; + +use anyhow::Context; +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::SinkExt as _; +use futures_util::StreamExt; +use log::debug; +use log::error; +use log::info; +use log::warn; +use tokio::sync::mpsc; +use tokio::time::Duration; +use tokio::time::MissedTickBehavior; +use tokio::time::interval; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::protocol::Message; +use tokio_util::sync::CancellationToken; +use trzcina::Service; + +use paddler_messaging::agent_desired_state::AgentDesiredState; +use paddler_messaging::jsonrpc::error::Error as JsonRpcError; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; +use paddler_messaging::jsonrpc::request_envelope::RequestEnvelope; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + +use crate::agent_applicable_state_holder::AgentApplicableStateHolder; +use crate::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use crate::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use crate::from_request_params::FromRequestParams; +use crate::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use crate::model_metadata_holder::ModelMetadataHolder; +use crate::receive_stream_stopper_collection::ReceiveStreamStopperCollection; +use crate::slot_aggregated_status::SlotAggregatedStatus; +use paddler_messaging::management_socket::agent::message::Message as JsonRpcMessage; +use paddler_messaging::management_socket::agent::notification::Notification as JsonRpcNotification; +use paddler_messaging::management_socket::agent::request::Request as JsonRpcRequest; +use paddler_messaging::management_socket::agent::response::Response as JsonRpcResponse; +use paddler_messaging::management_socket::agent::notification_params::version_params::VersionParams; +use paddler_messaging::management_socket::balancer::message::Message as ManagementJsonRpcMessage; +use paddler_messaging::management_socket::balancer::notification::Notification as ManagementJsonRpcNotification; +use paddler_messaging::management_socket::balancer::notification_params::register_agent_params::RegisterAgentParams; +use paddler_messaging::management_socket::balancer::notification_params::update_agent_status_params::UpdateAgentStatusParams; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates as _; + +struct IncomingMessageContext { + agent_applicable_state_holder: Arc, + agent_desired_state_tx: mpsc::UnboundedSender, + connection_close: CancellationToken, + continue_from_conversation_history_request_tx: + mpsc::UnboundedSender, + continue_from_raw_prompt_request_tx: mpsc::UnboundedSender, + generate_embedding_batch_request_tx: mpsc::UnboundedSender, + model_metadata_holder: Arc, + receive_stream_stopper_collection: Arc, + message_tx: mpsc::UnboundedSender, + slot_aggregated_status: Arc, +} + +pub struct ManagementSocketClientService { + pub agent_applicable_state_holder: Arc, + pub agent_desired_state_tx: mpsc::UnboundedSender, + pub continue_from_conversation_history_request_tx: + mpsc::UnboundedSender, + pub continue_from_raw_prompt_request_tx: mpsc::UnboundedSender, + pub generate_embedding_batch_request_tx: mpsc::UnboundedSender, + pub model_metadata_holder: Arc, + pub name: Option, + pub receive_stream_stopper_collection: Arc, + pub slot_aggregated_status: Arc, + pub socket_url: String, +} + +impl ManagementSocketClientService { + async fn generate_responses( + connection_close: CancellationToken, + id: String, + message_tx: mpsc::UnboundedSender, + request_params: TRequest::RequestParams, + receive_stream_stopper_collection: Arc, + request_tx: mpsc::UnboundedSender, + slot_aggregated_status: Arc, + ) -> Result<()> { + let (response_tx, mut response_rx) = mpsc::unbounded_channel::(); + let (stop_tx, stop_rx) = mpsc::unbounded_channel::<()>(); + + let _guard = receive_stream_stopper_collection + .register_stopper_with_guard(id.clone(), stop_tx) + .context(format!("Failed to register stopper for request: {id}"))?; + + request_tx.send(TRequest::from_request_params( + request_params, + response_tx, + stop_rx, + slot_aggregated_status, + ))?; + + loop { + tokio::select! { + () = connection_close.cancelled() => break, + response = response_rx.recv() => { + match response { + Some(response) => { + message_tx.send( + ManagementJsonRpcMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: id.clone(), + response: response.into(), + } + ), + )?; + } + None => break, + } + } + } + } + + Ok(()) + } + + async fn handle_deserialized_message( + IncomingMessageContext { + agent_applicable_state_holder, + agent_desired_state_tx, + connection_close, + continue_from_conversation_history_request_tx, + continue_from_raw_prompt_request_tx, + generate_embedding_batch_request_tx, + message_tx, + model_metadata_holder, + receive_stream_stopper_collection, + slot_aggregated_status, + }: IncomingMessageContext, + deserialized_message: JsonRpcMessage, + ) -> Result<()> { + match deserialized_message { + JsonRpcMessage::Error(ErrorEnvelope { + request_id, + error: JsonRpcError { code, description }, + }) => { + error!( + "Received error from server: code: {code}, description: {description:?}, request_id: {request_id:?}" + ); + + Ok(()) + } + JsonRpcMessage::Notification(JsonRpcNotification::SetState(set_state_params)) => { + agent_desired_state_tx.send(set_state_params.desired_state)?; + + Ok(()) + } + JsonRpcMessage::Notification(JsonRpcNotification::StopRespondingTo(request_id)) => { + debug!("Received StopGeneratingTokens notification for request ID: {request_id:?}"); + receive_stream_stopper_collection + .stop(&request_id) + .context(format!( + "Failed to stop generating tokens for request ID: {request_id}" + ))?; + + Ok(()) + } + JsonRpcMessage::Notification(JsonRpcNotification::Version(VersionParams { + version, + })) => { + if version != env!("CARGO_PKG_VERSION") { + warn!( + "Version mismatch: server version is {version}, client version is {}", + env!("CARGO_PKG_VERSION") + ); + } + + Ok(()) + } + JsonRpcMessage::Request(RequestEnvelope { + id, + request: + JsonRpcRequest::ContinueFromConversationHistory( + continue_from_conversation_history_params, + ), + }) => { + Self::generate_responses( + connection_close, + id, + message_tx, + continue_from_conversation_history_params, + receive_stream_stopper_collection, + continue_from_conversation_history_request_tx, + slot_aggregated_status, + ) + .await + } + JsonRpcMessage::Request(RequestEnvelope { + id, + request: JsonRpcRequest::ContinueFromRawPrompt(generate_tokens_params), + }) => { + Self::generate_responses( + connection_close, + id, + message_tx, + generate_tokens_params, + receive_stream_stopper_collection, + continue_from_raw_prompt_request_tx, + slot_aggregated_status, + ) + .await + } + JsonRpcMessage::Request(RequestEnvelope { + id, + request: JsonRpcRequest::GenerateEmbeddingBatch(generate_embedding_batch_params), + }) => { + Self::generate_responses( + connection_close, + id, + message_tx, + generate_embedding_batch_params, + receive_stream_stopper_collection, + generate_embedding_batch_request_tx, + slot_aggregated_status, + ) + .await + } + JsonRpcMessage::Request(RequestEnvelope { + id, + request: JsonRpcRequest::GetChatTemplateOverride, + }) => Ok( + message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: id, + response: JsonRpcResponse::ChatTemplateOverride( + if let Some(agent_applicable_state) = + agent_applicable_state_holder.get_agent_applicable_state() + { + agent_applicable_state.chat_template_override + } else { + None + }, + ), + }))?, + ), + JsonRpcMessage::Request(RequestEnvelope { + id, + request: JsonRpcRequest::GetModelMetadata, + }) => Ok( + message_tx.send(ManagementJsonRpcMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: id, + response: JsonRpcResponse::ModelMetadata( + model_metadata_holder.get_model_metadata(), + ), + }))?, + ), + } + } + + fn handle_incoming_message( + incoming_message_context: IncomingMessageContext, + msg: Message, + pong_tx: &mpsc::UnboundedSender, + ) -> Result<()> { + match msg { + Message::Text(text) => { + let connection_close = incoming_message_context.connection_close.clone(); + + tokio::spawn(async move { + tokio::select! { + () = connection_close.cancelled() => { + info!("Connection close signal received, shutting down"); + } + result = Self::handle_deserialized_message( + incoming_message_context, + match serde_json::from_str::(&text).context(format!("Failed to parse JSON-RPC message: {text}")) { + Ok(message) => message, + Err(err) => { + error!("Failed to deserialize message: {err}"); + + return; + } + }, + ) => if let Err(err) = result { + error!("Error handling incoming message: {err}"); + } + } + }); + + Ok(()) + } + Message::Binary(_) => { + error!("Received binary message, which is not expected"); + + Ok(()) + } + Message::Close(_) => { + info!("Connection closed by server"); + + Ok(()) + } + Message::Frame(_) => { + error!("Received a frame message, which is not expected"); + + Ok(()) + } + Message::Ping(payload) => Ok(pong_tx.send(payload)?), + Message::Pong(_) => { + // Pong received, no action needed + Ok(()) + } + } + } + + async fn keep_connection_alive(&self, shutdown: CancellationToken) -> Result<()> { + info!("Connecting to management server at {}", self.socket_url); + + let (ws_stream, _response) = connect_async(self.socket_url.clone()).await?; + + info!("Connected to management server"); + + let connection_close = CancellationToken::new(); + let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (mut write, mut read) = ws_stream.split(); + + let forward_connection_close = connection_close.clone(); + let forward_shutdown = shutdown.clone(); + + let message_forward_handle = tokio::spawn(async move { + loop { + tokio::select! { + () = forward_connection_close.cancelled() => { + break; + } + () = forward_shutdown.cancelled() => { + info!("Shutdown signal received, deregistering agent"); + + write.send(Message::Text(match serde_json::to_string( + &ManagementJsonRpcMessage::Notification( + ManagementJsonRpcNotification::DeregisterAgent, + ) + ) { + Ok(serialized_message) => serialized_message.into(), + Err(err) => { + error!("Failed to serialize deregister agent notification: {err}"); + return; + } + })).await.unwrap_or_else(|err| { + error!("Failed to send deregister agent notification: {err}"); + }); + + break; + } + message = message_rx.recv() => { + match message { + Some(msg) => { + match serde_json::to_string(&msg) { + Ok(serialized_message) => { + let message = Message::Text(serialized_message.into()); + + if let Err(err) = write.send(message).await { + error!("Failed to send message: {err}"); + break; + } + }, + Err(err) => { + error!("Failed to serialize message: {err}"); + } + } + } + None => break, + } + } + payload = pong_rx.recv() => { + match payload { + Some(payload) => { + write.send(Message::Pong(payload)).await.unwrap_or_else(|err| { + error!("Failed to send pong message: {err}"); + }); + } + None => break, + } + } + } + } + }); + + match self.slot_aggregated_status.make_snapshot() { + Ok(slot_aggregated_status_snapshot) => { + message_tx + .send(ManagementJsonRpcMessage::Notification( + ManagementJsonRpcNotification::RegisterAgent(RegisterAgentParams { + name: self.name.clone(), + slot_aggregated_status_snapshot, + }), + )) + .unwrap_or_else(|err| { + error!("Failed to send register agent notification: {err}"); + }); + } + Err(err) => { + error!("Failed to create slot aggregated status snapshot: {err}"); + + return Err(err); + } + } + + let do_send_status_update = || match self.slot_aggregated_status.make_snapshot() { + Ok(slot_aggregated_status_snapshot) => { + message_tx + .send(ManagementJsonRpcMessage::Notification( + ManagementJsonRpcNotification::UpdateAgentStatus(UpdateAgentStatusParams { + slot_aggregated_status_snapshot, + }), + )) + .unwrap_or_else(|err| { + error!("Failed to send status update notification: {err}"); + }); + } + Err(err) => error!("Failed to create slot aggregated status snapshot: {err}"), + }; + + let mut ticker = interval(Duration::from_secs(1)); + let mut update_rx = self.slot_aggregated_status.subscribe_to_updates(); + + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + () = connection_close.cancelled() => { + info!("Connection close signal received, shutting down"); + + break; + } + () = shutdown.cancelled() => break, + changed = update_rx.changed() => { + if changed.is_err() { + break; + } + do_send_status_update(); + } + _ = ticker.tick() => do_send_status_update(), + msg = read.next() => { + let should_close = match msg { + Some(Ok(msg)) => { + if let Err(err) = Self::handle_incoming_message( + IncomingMessageContext { + agent_applicable_state_holder: self.agent_applicable_state_holder.clone(), + agent_desired_state_tx: self.agent_desired_state_tx.clone(), + connection_close: connection_close.clone(), + continue_from_conversation_history_request_tx: self.continue_from_conversation_history_request_tx.clone(), + continue_from_raw_prompt_request_tx: self.continue_from_raw_prompt_request_tx.clone(), + generate_embedding_batch_request_tx: self.generate_embedding_batch_request_tx.clone(), + model_metadata_holder: self.model_metadata_holder.clone(), + receive_stream_stopper_collection: self.receive_stream_stopper_collection.clone(), + message_tx: message_tx.clone(), + slot_aggregated_status: self.slot_aggregated_status.clone(), + }, + msg, + &pong_tx, + ) + .context("Failed to handle incoming message") + { + error!("Error handling incoming message: {err}"); + } + + false + } + Some(Err(err)) => { + error!("Error reading message: {err}"); + + true + } + None => true, + }; + + if should_close { + connection_close.cancel(); + + break; + } + } + } + } + + message_forward_handle + .await + .context("Failed to join message forwarding task")?; + + Ok(()) + } +} + +#[async_trait] +impl Service for ManagementSocketClientService { + fn name(&self) -> &'static str { + "agent::management_socket_client_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let mut ticker = interval(Duration::from_secs(1)); + + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + () = shutdown.cancelled() => break Ok(()), + _ = ticker.tick() => { + match self.keep_connection_alive(shutdown.clone()).await { + Err(err) => { + error!("Failed to keep the connection alive: {err:?}"); + } + Ok(()) => { + info!("Gracefully closed connection to management server"); + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use tokio_tungstenite::tungstenite::protocol::frame::Frame; + use tokio_tungstenite::tungstenite::protocol::frame::coding::Data; + use tokio_tungstenite::tungstenite::protocol::frame::coding::OpCode; + + use paddler_messaging::management_socket::agent::notification_params::set_state_params::SetStateParams; + use paddler_messaging::model_metadata::ModelMetadata; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + use super::*; + + fn build_incoming_message_context( + agent_applicable_state_holder: Arc, + agent_desired_state_tx: mpsc::UnboundedSender, + connection_close: CancellationToken, + model_metadata_holder: Arc, + receive_stream_stopper_collection: Arc, + message_tx: mpsc::UnboundedSender, + slot_aggregated_status: Arc, + ) -> IncomingMessageContext { + let (continue_from_conversation_history_request_tx, _continue_history_rx) = + mpsc::unbounded_channel::(); + let (continue_from_raw_prompt_request_tx, _continue_raw_rx) = + mpsc::unbounded_channel::(); + let (generate_embedding_batch_request_tx, _embedding_rx) = + mpsc::unbounded_channel::(); + + IncomingMessageContext { + agent_applicable_state_holder, + agent_desired_state_tx, + connection_close, + continue_from_conversation_history_request_tx, + continue_from_raw_prompt_request_tx, + generate_embedding_batch_request_tx, + model_metadata_holder, + receive_stream_stopper_collection, + message_tx, + slot_aggregated_status, + } + } + + #[tokio::test] + async fn error_message_is_acknowledged_without_side_effects() { + let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, mut agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Error(ErrorEnvelope { + request_id: "req_error".to_owned(), + error: JsonRpcError { + code: -32_600, + description: "Invalid Request".to_owned(), + }, + }), + ) + .await; + + assert!(result.is_ok()); + assert!(message_rx.try_recv().is_err()); + assert!(agent_desired_state_rx.try_recv().is_err()); + } + + #[tokio::test] + async fn set_state_notification_forwards_desired_state() { + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, mut agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Notification(JsonRpcNotification::SetState(Box::new(SetStateParams { + desired_state: AgentDesiredState::default(), + }))), + ) + .await; + + assert!(result.is_ok()); + assert_eq!( + agent_desired_state_rx.try_recv().unwrap(), + AgentDesiredState::default() + ); + } + + #[tokio::test] + async fn set_state_notification_errors_when_receiver_dropped() { + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, agent_desired_state_rx) = + mpsc::unbounded_channel::(); + + drop(agent_desired_state_rx); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Notification(JsonRpcNotification::SetState(Box::new(SetStateParams { + desired_state: AgentDesiredState::default(), + }))), + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn stop_responding_to_unknown_request_returns_error() { + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Notification(JsonRpcNotification::StopRespondingTo( + "missing_request".to_owned(), + )), + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn stop_responding_to_registered_request_signals_stopper() { + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (stop_tx, mut stop_rx) = mpsc::unbounded_channel::<()>(); + + receive_stream_stopper_collection + .register_stopper("active_request".to_owned(), stop_tx) + .unwrap(); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + receive_stream_stopper_collection, + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Notification(JsonRpcNotification::StopRespondingTo( + "active_request".to_owned(), + )), + ) + .await; + + assert!(result.is_ok()); + assert!(stop_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn mismatched_version_notification_is_acknowledged() { + let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Notification(JsonRpcNotification::Version(VersionParams { + version: "0.0.0-mismatch".to_owned(), + })), + ) + .await; + + assert!(result.is_ok()); + assert!(message_rx.try_recv().is_err()); + } + + #[tokio::test] + async fn get_chat_template_override_without_applicable_state_responds_with_none() { + let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Request(RequestEnvelope { + id: "req_template".to_owned(), + request: JsonRpcRequest::GetChatTemplateOverride, + }), + ) + .await; + + assert!(result.is_ok()); + + let sent_message = message_rx.try_recv().unwrap(); + + assert!(matches!( + sent_message, + ManagementJsonRpcMessage::Response(ResponseEnvelope { + request_id, + response: JsonRpcResponse::ChatTemplateOverride(None), + .. + }) if request_id == "req_template" + )); + } + + #[tokio::test] + async fn get_chat_template_override_errors_when_message_receiver_dropped() { + let (message_tx, message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + + drop(message_rx); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Request(RequestEnvelope { + id: "req_template".to_owned(), + request: JsonRpcRequest::GetChatTemplateOverride, + }), + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn get_model_metadata_responds_with_stored_metadata() { + let (message_tx, mut message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let model_metadata_holder = Arc::new(ModelMetadataHolder::new()); + let mut metadata = BTreeMap::new(); + + metadata.insert("architecture".to_owned(), "llama".to_owned()); + model_metadata_holder.set_model_metadata(ModelMetadata { + metadata: metadata.clone(), + }); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + model_metadata_holder, + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Request(RequestEnvelope { + id: "req_metadata".to_owned(), + request: JsonRpcRequest::GetModelMetadata, + }), + ) + .await; + + assert!(result.is_ok()); + + let sent_message = message_rx.try_recv().unwrap(); + + assert!(matches!( + sent_message, + ManagementJsonRpcMessage::Response(ResponseEnvelope { + response: JsonRpcResponse::ModelMetadata(Some(returned_metadata)), + .. + }) if returned_metadata.metadata == metadata + )); + } + + #[tokio::test] + async fn get_model_metadata_errors_when_message_receiver_dropped() { + let (message_tx, message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + + drop(message_rx); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_deserialized_message( + context, + JsonRpcMessage::Request(RequestEnvelope { + id: "req_metadata".to_owned(), + request: JsonRpcRequest::GetModelMetadata, + }), + ) + .await; + + assert!(result.is_err()); + } + + #[test] + fn binary_message_is_acknowledged_without_pong() { + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Binary(Bytes::from_static(b"unexpected")), + &pong_tx, + ); + + assert!(result.is_ok()); + assert!(pong_rx.try_recv().is_err()); + } + + #[test] + fn close_message_is_acknowledged() { + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Close(None), + &pong_tx, + ); + + assert!(result.is_ok()); + assert!(pong_rx.try_recv().is_err()); + } + + #[test] + fn frame_message_is_acknowledged_without_pong() { + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Frame(Frame::message( + Bytes::from_static(b"frame"), + OpCode::Data(Data::Text), + true, + )), + &pong_tx, + ); + + assert!(result.is_ok()); + assert!(pong_rx.try_recv().is_err()); + } + + #[test] + fn ping_message_forwards_payload_to_pong_channel() { + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Ping(Bytes::from_static(b"ping_payload")), + &pong_tx, + ); + + assert!(result.is_ok()); + assert_eq!( + pong_rx.try_recv().unwrap(), + Bytes::from_static(b"ping_payload") + ); + } + + #[test] + fn ping_message_errors_when_pong_receiver_dropped() { + let (pong_tx, pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + + drop(pong_rx); + + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Ping(Bytes::from_static(b"ping_payload")), + &pong_tx, + ); + + assert!(result.is_err()); + } + + #[test] + fn pong_message_is_acknowledged_without_forwarding() { + let (pong_tx, mut pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, _agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Pong(Bytes::from_static(b"pong_payload")), + &pong_tx, + ); + + assert!(result.is_ok()); + assert!(pong_rx.try_recv().is_err()); + } + + #[tokio::test] + async fn text_message_dispatches_deserialized_set_state() { + let (pong_tx, _pong_rx) = mpsc::unbounded_channel::(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let (agent_desired_state_tx, mut agent_desired_state_rx) = + mpsc::unbounded_channel::(); + let context = build_incoming_message_context( + Arc::new(AgentApplicableStateHolder::default()), + agent_desired_state_tx, + CancellationToken::new(), + Arc::new(ModelMetadataHolder::new()), + Arc::new(ReceiveStreamStopperCollection::default()), + message_tx, + Arc::new(SlotAggregatedStatus::new(2)), + ); + + let serialized_set_state = serde_json::to_string(&JsonRpcMessage::Notification( + JsonRpcNotification::SetState(Box::new(SetStateParams { + desired_state: AgentDesiredState::default(), + })), + )) + .unwrap(); + + let result = ManagementSocketClientService::handle_incoming_message( + context, + Message::Text(serialized_set_state.into()), + &pong_tx, + ); + + assert!(result.is_ok()); + assert_eq!( + agent_desired_state_rx.recv().await.unwrap(), + AgentDesiredState::default() + ); + } + + #[tokio::test] + async fn generate_responses_breaks_when_connection_closes() { + let connection_close = CancellationToken::new(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (request_tx, mut request_rx) = + mpsc::unbounded_channel::(); + let slot_aggregated_status = Arc::new(SlotAggregatedStatus::new(2)); + + connection_close.cancel(); + + let result = + ManagementSocketClientService::generate_responses::( + connection_close, + "req_generate".to_owned(), + message_tx, + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 8, + raw_prompt: "hello".to_owned(), + }, + receive_stream_stopper_collection.clone(), + request_tx, + slot_aggregated_status, + ) + .await; + + assert!(result.is_ok()); + + let dispatched_request = request_rx.try_recv().unwrap(); + + assert_eq!(dispatched_request.params.raw_prompt, "hello"); + assert!( + receive_stream_stopper_collection + .deregister_stopper("req_generate") + .is_err() + ); + } + + #[tokio::test] + async fn generate_responses_errors_when_request_receiver_dropped() { + let connection_close = CancellationToken::new(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (request_tx, request_rx) = mpsc::unbounded_channel::(); + let slot_aggregated_status = Arc::new(SlotAggregatedStatus::new(2)); + + drop(request_rx); + + let result = + ManagementSocketClientService::generate_responses::( + connection_close, + "req_generate".to_owned(), + message_tx, + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 8, + raw_prompt: "hello".to_owned(), + }, + receive_stream_stopper_collection, + request_tx, + slot_aggregated_status, + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn generate_responses_errors_when_stopper_already_registered() { + let connection_close = CancellationToken::new(); + let (message_tx, _message_rx) = mpsc::unbounded_channel::(); + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (existing_stop_tx, _existing_stop_rx) = mpsc::unbounded_channel::<()>(); + let (request_tx, _request_rx) = mpsc::unbounded_channel::(); + let slot_aggregated_status = Arc::new(SlotAggregatedStatus::new(2)); + + receive_stream_stopper_collection + .register_stopper("req_generate".to_owned(), existing_stop_tx) + .unwrap(); + + let result = + ManagementSocketClientService::generate_responses::( + connection_close, + "req_generate".to_owned(), + message_tx, + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 8, + raw_prompt: "hello".to_owned(), + }, + receive_stream_stopper_collection, + request_tx, + slot_aggregated_status, + ) + .await; + + assert!(result.is_err()); + } +} diff --git a/paddler_agent/src/model_metadata_holder.rs b/paddler_agent/src/model_metadata_holder.rs new file mode 100644 index 00000000..a72cc52d --- /dev/null +++ b/paddler_agent/src/model_metadata_holder.rs @@ -0,0 +1,61 @@ +use parking_lot::RwLock; + +use paddler_messaging::model_metadata::ModelMetadata; + +pub struct ModelMetadataHolder { + model_metadata: RwLock>, +} + +impl ModelMetadataHolder { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn set_model_metadata(&self, metadata: ModelMetadata) { + let mut lock = self.model_metadata.write(); + + *lock = Some(metadata); + } + + pub fn get_model_metadata(&self) -> Option { + let lock = self.model_metadata.read(); + + lock.clone() + } +} + +impl Default for ModelMetadataHolder { + fn default() -> Self { + Self { + model_metadata: RwLock::new(None), + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use super::*; + + #[test] + fn new_holder_starts_empty() { + let holder = ModelMetadataHolder::new(); + + assert!(holder.get_model_metadata().is_none()); + } + + #[test] + fn stored_metadata_is_returned() { + let holder = ModelMetadataHolder::new(); + let mut metadata = BTreeMap::new(); + metadata.insert("architecture".to_owned(), "llama".to_owned()); + + holder.set_model_metadata(ModelMetadata { metadata }); + + let stored = holder.get_model_metadata().unwrap(); + + assert_eq!(stored.metadata.get("architecture").unwrap(), "llama"); + } +} diff --git a/paddler/src/model_source/huggingface.rs b/paddler_agent/src/model_source/huggingface.rs similarity index 90% rename from paddler/src/model_source/huggingface.rs rename to paddler_agent/src/model_source/huggingface.rs index 5abdcf4f..f9a8ffb5 100644 --- a/paddler/src/model_source/huggingface.rs +++ b/paddler_agent/src/model_source/huggingface.rs @@ -12,10 +12,10 @@ use log::warn; use tokio::time::Duration; use tokio::time::sleep; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::HuggingFaceDownloadLock; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::hugging_face_download_lock::HuggingFaceDownloadLock; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::agent_issue_fix::AgentIssueFix; use crate::desired_model_resolution::DesiredModelResolution; @@ -25,17 +25,19 @@ use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloa const LOCK_RETRY_TIMEOUT: Duration = Duration::from_secs(10); +pub struct HuggingFaceModelSource(pub HuggingFaceModelReference); + #[async_trait] -impl ResolvesModelSource for HuggingFaceModelReference { +impl ResolvesModelSource for HuggingFaceModelSource { async fn resolve( &self, slot_aggregated_status: Arc, ) -> Result { - let Self { + let HuggingFaceModelReference { filename, repo_id, revision, - } = self; + } = &self.0; let model_path = format!("{repo_id}/{revision}/{filename}"); if slot_aggregated_status.has_issue(&AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { diff --git a/paddler/src/model_source/local.rs b/paddler_agent/src/model_source/local.rs similarity index 100% rename from paddler/src/model_source/local.rs rename to paddler_agent/src/model_source/local.rs diff --git a/paddler/src/model_source/mod.rs b/paddler_agent/src/model_source/mod.rs similarity index 100% rename from paddler/src/model_source/mod.rs rename to paddler_agent/src/model_source/mod.rs diff --git a/paddler/src/model_source/url.rs b/paddler_agent/src/model_source/url.rs similarity index 57% rename from paddler/src/model_source/url.rs rename to paddler_agent/src/model_source/url.rs index c5bacc07..68d234fa 100644 --- a/paddler/src/model_source/url.rs +++ b/paddler_agent/src/model_source/url.rs @@ -6,15 +6,15 @@ use anyhow::anyhow; use async_trait::async_trait; use url::Url; -use paddler_cache_dir::CacheDir; -use paddler_cache_dir::CachedDownloadedModel; -use paddler_cache_dir::DownloadLockAcquisitionError; +use paddler_cache_dir::cache_dir::CacheDir; +use paddler_cache_dir::cached_downloaded_model::CachedDownloadedModel; +use paddler_cache_dir::download_lock_acquisition_error::DownloadLockAcquisitionError; use paddler_download_manager::download_error::DownloadError; use paddler_download_manager::download_manager::DownloadManager; use paddler_download_manager::progress_sink::ProgressSink; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_messaging::url_model_reference::UrlModelReference; use crate::agent_issue_fix::AgentIssueFix; use crate::desired_model_resolution::DesiredModelResolution; @@ -124,8 +124,9 @@ async fn resolve_url_into_cache( model_path: url_string.to_owned(), })); - return Err(anyhow::Error::new(parse_error) - .context(format!("Invalid URL '{url_string}'"))); + return Err( + anyhow::Error::new(parse_error).context(format!("Invalid URL '{url_string}'")) + ); } }; @@ -214,15 +215,17 @@ async fn resolve_url_into_cache( } } +pub struct UrlModelSource(pub UrlModelReference); + #[async_trait] -impl ResolvesModelSource for UrlModelReference { +impl ResolvesModelSource for UrlModelSource { async fn resolve( &self, slot_aggregated_status: Arc, ) -> Result { let cache_dir = CacheDir::from_process_env(); - resolve_url_into_cache(&self.url, &cache_dir, slot_aggregated_status).await + resolve_url_into_cache(&self.0.url, &cache_dir, slot_aggregated_status).await } } @@ -232,22 +235,28 @@ mod tests { use std::path::PathBuf; use std::sync::Arc; - use anyhow::Context as _; - use anyhow::Result; use anyhow::anyhow; - use paddler_cache_dir::CacheDir; - use paddler_cache_dir::CachedDownloadedModel; + use paddler_cache_dir::cache_dir::CacheDir; + use paddler_cache_dir::cached_downloaded_model::CachedDownloadedModel; use paddler_download_manager::download_error::DownloadError; - use paddler_types::agent_issue::AgentIssue; + use paddler_messaging::agent_issue::AgentIssue; use reqwest::StatusCode; use tempfile::TempDir; + use tokio::io::AsyncBufReadExt as _; + use tokio::io::AsyncWriteExt as _; + use tokio::io::BufReader; + use tokio::net::TcpListener; use url::Url; use crate::desired_model_resolution::DesiredModelResolution; + use crate::model_source::url::SlotAggregatedStatusSink; use crate::model_source::url::agent_issue_for; use crate::model_source::url::classify_cache_io_error; use crate::model_source::url::resolve_url_into_cache; use crate::slot_aggregated_status::SlotAggregatedStatus; + use paddler_download_manager::progress_sink::ProgressSink; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::produces_snapshot::ProducesSnapshot; const TEST_URL: &str = "https://example.com/m.gguf"; @@ -275,30 +284,29 @@ mod tests { } #[tokio::test] - async fn cache_hit_returns_path_without_calling_download_manager() -> Result<()> { - let directory = TempDir::new()?; + async fn cache_hit_returns_path_without_calling_download_manager() { + let directory = TempDir::new().unwrap(); let cache_dir = cache_dir_at(directory.path()); let url_string = "https://host.example/cached.gguf"; - let cached = CachedDownloadedModel::new(&cache_dir, url_string)?; - cached.ensure_cache_subdir_exists().await?; - tokio::fs::write(&cached.cache_file_path, b"cached content").await?; - - let resolution = - resolve_url_into_cache(url_string, &cache_dir, fresh_status()).await?; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::write(&cached.cache_file_path, b"cached content") + .await + .unwrap(); - match resolution { - DesiredModelResolution::Resolved(path) => { - assert_eq!(path, cached.cache_file_path); - } - other => return Err(anyhow!("expected Resolved, got {other:?}")), - } + let resolution = resolve_url_into_cache(url_string, &cache_dir, fresh_status()) + .await + .unwrap(); - Ok(()) + assert!(matches!( + resolution, + DesiredModelResolution::Resolved(resolved_path) if resolved_path == cached.cache_file_path + )); } #[tokio::test] - async fn malformed_url_registers_download_url_is_malformed() -> Result<()> { - let directory = TempDir::new()?; + async fn malformed_url_registers_download_url_is_malformed() { + let directory = TempDir::new().unwrap(); let cache_dir = cache_dir_at(directory.path()); let url_string = "not a url"; @@ -306,19 +314,16 @@ mod tests { let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; assert!(result.is_err(), "malformed URL must produce an Err"); - assert!(status.has_issue(&AgentIssue::DownloadUrlIsMalformed( - paddler_types::agent_issue_params::ModelPath { + assert!( + status.has_issue(&AgentIssue::DownloadUrlIsMalformed(ModelPath { model_path: url_string.to_owned(), - }, - ))); - - Ok(()) + })) + ); } #[tokio::test] - async fn unsupported_scheme_registers_download_url_is_malformed_without_creating_cache_state() - -> Result<()> { - let directory = TempDir::new()?; + async fn unsupported_scheme_registers_download_url_is_malformed_without_creating_cache_state() { + let directory = TempDir::new().unwrap(); let cache_dir = cache_dir_at(directory.path()); let url_string = "ftp://example.invalid/m.gguf"; @@ -326,58 +331,101 @@ mod tests { let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; assert!(result.is_err(), "unsupported scheme must produce an Err"); - assert!(status.has_issue(&AgentIssue::DownloadUrlIsMalformed( - paddler_types::agent_issue_params::ModelPath { + assert!( + status.has_issue(&AgentIssue::DownloadUrlIsMalformed(ModelPath { model_path: url_string.to_owned(), - }, - ))); + })) + ); assert!( !directory.path().join("downloaded-models").exists(), "no cache subdirectory must be created for an unsupported scheme" ); - - Ok(()) } #[tokio::test] - async fn lock_contention_registers_cache_cannot_acquire_lock() -> Result<()> { - let directory = TempDir::new()?; + async fn lock_contention_registers_cache_cannot_acquire_lock() { + let directory = TempDir::new().unwrap(); let cache_dir = cache_dir_at(directory.path()); let url_string = "https://host.example/contended.gguf"; - let cached = CachedDownloadedModel::new(&cache_dir, url_string)?; - cached.ensure_cache_subdir_exists().await?; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); - let _blocker = cached.try_acquire_download_lock()?; + let _blocker = cached.try_acquire_download_lock().unwrap(); let status = fresh_status(); let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; assert!(result.is_err(), "lock contention must produce an Err"); - assert!(status.has_issue(&AgentIssue::CacheCannotAcquireLock( - paddler_types::agent_issue_params::ModelPath { + assert!( + status.has_issue(&AgentIssue::CacheCannotAcquireLock(ModelPath { model_path: url_string.to_owned(), - }, - ))); + })) + ); + } + + #[cfg(unix)] + #[tokio::test] + async fn cache_subdir_creation_failure_registers_model_cache_is_corrupted() { + use std::os::unix::fs::symlink; + + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/subdir-blocked.gguf"; + let subdir_path = directory.path().join("downloaded-models"); + symlink(directory.path().join("missing-target"), &subdir_path).unwrap(); - Ok(()) + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!( + result.is_err(), + "a non-directory at the cache subdir path must produce an Err" + ); + assert!( + status.has_issue_like(|issue| matches!(issue, AgentIssue::ModelCacheIsCorrupted(_))) + ); + } + + #[cfg(unix)] + #[tokio::test] + async fn lock_open_io_error_registers_model_cache_is_corrupted() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/lock-as-directory.gguf"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::create_dir(&cached.lock_file_path).await.unwrap(); + + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!( + result.is_err(), + "an unopenable lock path must produce an Err" + ); + assert!( + status.has_issue_like(|issue| matches!(issue, AgentIssue::ModelCacheIsCorrupted(_))) + ); + } + + fn test_model_path() -> ModelPath { + ModelPath { + model_path: TEST_URL.to_owned(), + } } #[test] - fn invalid_url_maps_to_download_url_is_malformed() -> Result<()> { - let parse_error = Url::parse("not a url") - .err() - .context("'not a url' should not parse")?; + fn invalid_url_maps_to_download_url_is_malformed() { + let parse_error = Url::parse("not a url").err().unwrap(); let error = DownloadError::InvalidUrl { url: "not a url".to_owned(), source: parse_error, }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadUrlIsMalformed(_) - )); - - Ok(()) + AgentIssue::DownloadUrlIsMalformed(test_model_path()) + ); } #[test] @@ -387,10 +435,10 @@ mod tests { scheme: "ftp".to_owned(), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadUrlIsMalformed(_) - )); + AgentIssue::DownloadUrlIsMalformed(test_model_path()) + ); } #[test] @@ -399,10 +447,10 @@ mod tests { url: TEST_URL.to_owned(), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::ModelDoesNotExistAtUrl(_) - )); + AgentIssue::ModelDoesNotExistAtUrl(test_model_path()) + ); } #[test] @@ -412,10 +460,10 @@ mod tests { status: StatusCode::FORBIDDEN, }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadServerDeniedAccess(_) - )); + AgentIssue::DownloadServerDeniedAccess(test_model_path()) + ); } #[test] @@ -425,10 +473,10 @@ mod tests { partial_path: PathBuf::from("/tmp/stale.partial"), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::ModelCacheIsCorrupted(_) - )); + AgentIssue::ModelCacheIsCorrupted(test_model_path()) + ); } #[test] @@ -438,10 +486,10 @@ mod tests { source: anyhow!("connection refused"), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadServerIsUnreachable(_) - )); + AgentIssue::DownloadServerIsUnreachable(test_model_path()) + ); } #[test] @@ -451,10 +499,10 @@ mod tests { status: StatusCode::INTERNAL_SERVER_ERROR, }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadServerErrored(_) - )); + AgentIssue::DownloadServerErrored(test_model_path()) + ); } #[test] @@ -464,10 +512,10 @@ mod tests { status: StatusCode::BAD_REQUEST, }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadServerRejectedRequest(_) - )); + AgentIssue::DownloadServerRejectedRequest(test_model_path()) + ); } #[test] @@ -477,10 +525,10 @@ mod tests { source: anyhow!("stream dropped"), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::DownloadInterrupted(_) - )); + AgentIssue::DownloadInterrupted(test_model_path()) + ); } #[test] @@ -490,10 +538,10 @@ mod tests { source: io::Error::from(io::ErrorKind::PermissionDenied), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::CacheDirectoryIsNotWritable(_) - )); + AgentIssue::CacheDirectoryIsNotWritable(test_model_path()) + ); } #[test] @@ -503,10 +551,10 @@ mod tests { source: io::Error::from_raw_os_error(28), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::CacheStorageIsFull(_) - )); + AgentIssue::CacheStorageIsFull(test_model_path()) + ); } #[test] @@ -516,20 +564,20 @@ mod tests { source: io::Error::from(io::ErrorKind::NotFound), }; - assert!(matches!( + assert_eq!( agent_issue_for(&error, TEST_URL), - AgentIssue::ModelCacheIsCorrupted(_) - )); + AgentIssue::ModelCacheIsCorrupted(test_model_path()) + ); } #[test] fn classify_cache_io_error_maps_permission_denied_to_cache_directory_is_not_writable() { let error = io::Error::from(io::ErrorKind::PermissionDenied); - assert!(matches!( + assert_eq!( classify_cache_io_error(TEST_URL, &error), - AgentIssue::CacheDirectoryIsNotWritable(_) - )); + AgentIssue::CacheDirectoryIsNotWritable(test_model_path()) + ); } #[test] @@ -541,26 +589,28 @@ mod tests { let error = io::Error::from_raw_os_error(DISK_FULL_ERRNO); - assert!(matches!( + assert_eq!( classify_cache_io_error(TEST_URL, &error), - AgentIssue::CacheStorageIsFull(_) - )); + AgentIssue::CacheStorageIsFull(test_model_path()) + ); } #[test] fn classify_cache_io_error_falls_back_to_model_cache_is_corrupted() { let error = io::Error::from(io::ErrorKind::NotFound); - assert!(matches!( + assert_eq!( classify_cache_io_error(TEST_URL, &error), - AgentIssue::ModelCacheIsCorrupted(_) - )); + AgentIssue::ModelCacheIsCorrupted(test_model_path()) + ); } #[tokio::test] - async fn ensure_cache_subdir_failure_registers_model_cache_is_corrupted() -> Result<()> { - let directory = TempDir::new()?; - tokio::fs::write(directory.path().join("downloaded-models"), b"blocker").await?; + async fn ensure_cache_subdir_failure_registers_model_cache_is_corrupted() { + let directory = TempDir::new().unwrap(); + tokio::fs::write(directory.path().join("downloaded-models"), b"blocker") + .await + .unwrap(); let cache_dir = cache_dir_at(directory.path()); let url_string = "https://host.example/blocked.gguf"; @@ -568,11 +618,161 @@ mod tests { let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; assert!(result.is_err(), "blocked cache subdir must produce an Err"); - assert!(status.has_issue_like(|issue| matches!( - issue, - AgentIssue::ModelCacheIsCorrupted(_) - ))); + assert!( + status.has_issue_like(|issue| matches!(issue, AgentIssue::ModelCacheIsCorrupted(_))) + ); + } + + #[test] + fn sink_on_started_sets_download_status_and_clears_matching_download_issue() { + let status = fresh_status(); + status.register_issue(AgentIssue::DownloadInterrupted(ModelPath { + model_path: TEST_URL.to_owned(), + })); + + let sink = SlotAggregatedStatusSink { + basename: Some("m.gguf".to_owned()), + slot_aggregated_status: status.clone(), + url: TEST_URL.to_owned(), + }; + + sink.on_started(Some(500), 100); + + let snapshot = status.make_snapshot().unwrap(); + assert_eq!(snapshot.download_current, 100); + assert_eq!(snapshot.download_total, 500); + assert!(!snapshot.download_indeterminate); + assert_eq!(snapshot.download_filename, Some("m.gguf".to_owned())); + assert!( + !status.has_issue(&AgentIssue::DownloadInterrupted(ModelPath { + model_path: TEST_URL.to_owned(), + })) + ); + } + + #[test] + fn sink_on_chunk_increments_download_current() { + let status = fresh_status(); + let sink = SlotAggregatedStatusSink { + basename: None, + slot_aggregated_status: status.clone(), + url: TEST_URL.to_owned(), + }; + + sink.on_started(Some(1000), 0); + sink.on_chunk(250); + sink.on_chunk(125); + + let snapshot = status.make_snapshot().unwrap(); + assert_eq!(snapshot.download_current, 375); + } + + #[test] + fn sink_on_finished_resets_download_and_clears_matching_download_issue() { + let status = fresh_status(); + status.register_issue(AgentIssue::DownloadInterrupted(ModelPath { + model_path: TEST_URL.to_owned(), + })); + + let sink = SlotAggregatedStatusSink { + basename: Some("m.gguf".to_owned()), + slot_aggregated_status: status.clone(), + url: TEST_URL.to_owned(), + }; + + sink.on_started(Some(500), 200); + sink.on_finished(); + + let snapshot = status.make_snapshot().unwrap(); + assert_eq!(snapshot.download_current, 0); + assert_eq!(snapshot.download_total, 0); + assert!(snapshot.download_indeterminate); + assert_eq!(snapshot.download_filename, None); + assert!( + !status.has_issue(&AgentIssue::DownloadInterrupted(ModelPath { + model_path: TEST_URL.to_owned(), + })) + ); + } + + fn unresolvable_cache_dir() -> CacheDir { + #[cfg(unix)] + { + CacheDir { + explicit: None, + home: None, + xdg: None, + } + } + #[cfg(windows)] + { + CacheDir { + explicit: None, + localappdata: None, + userprofile: None, + } + } + } + + #[tokio::test] + async fn cache_path_resolution_failure_propagates_error() { + let url_string = "https://host.example/unresolvable.gguf"; + + let result = + resolve_url_into_cache(url_string, &unresolvable_cache_dir(), fresh_status()).await; - Ok(()) + assert!( + result.is_err(), + "an unresolvable cache directory must produce an Err" + ); + } + + async fn serve_single_ok_response(listener: TcpListener, body: Vec) { + let (mut socket, _peer) = listener.accept().await.unwrap(); + let (reader_half, mut writer_half) = socket.split(); + let mut reader = BufReader::new(reader_half); + + loop { + let mut header_line = String::new(); + let bytes_read = reader.read_line(&mut header_line).await.unwrap(); + if bytes_read == 0 || header_line == "\r\n" { + break; + } + } + + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + writer_half.write_all(header.as_bytes()).await.unwrap(); + writer_half.write_all(&body).await.unwrap(); + writer_half.shutdown().await.unwrap(); + } + + #[tokio::test] + async fn successful_download_resolves_to_cache_file_with_downloaded_contents() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let url_string = format!("http://127.0.0.1:{port}/model.gguf"); + let body = b"downloaded model bytes".to_vec(); + let server = tokio::spawn(serve_single_ok_response(listener, body.clone())); + + let cached = CachedDownloadedModel::new(&cache_dir, &url_string).unwrap(); + let expected_path = cached.cache_file_path.clone(); + + let resolution = resolve_url_into_cache(&url_string, &cache_dir, fresh_status()) + .await + .unwrap(); + + server.await.unwrap(); + + assert!(matches!( + resolution, + DesiredModelResolution::Resolved(resolved_path) if resolved_path == expected_path + )); + assert_eq!(tokio::fs::read(&expected_path).await.unwrap(), body); } } diff --git a/paddler_types/src/normalization/l2.rs b/paddler_agent/src/normalization/l2.rs similarity index 100% rename from paddler_types/src/normalization/l2.rs rename to paddler_agent/src/normalization/l2.rs diff --git a/paddler_agent/src/normalization/mod.rs b/paddler_agent/src/normalization/mod.rs new file mode 100644 index 00000000..a9dc3af9 --- /dev/null +++ b/paddler_agent/src/normalization/mod.rs @@ -0,0 +1,3 @@ +pub mod l2; +pub mod normalize_embedding; +pub mod rms_norm; diff --git a/paddler_agent/src/normalization/normalize_embedding.rs b/paddler_agent/src/normalization/normalize_embedding.rs new file mode 100644 index 00000000..9371b786 --- /dev/null +++ b/paddler_agent/src/normalization/normalize_embedding.rs @@ -0,0 +1,158 @@ +use anyhow::Result; +use anyhow::anyhow; + +use paddler_messaging::embedding::Embedding; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; + +use crate::normalization::l2::l2; +use crate::normalization::rms_norm::rms_norm; + +pub fn normalize_embedding( + embedding: Embedding, + normalization_method: &EmbeddingNormalizationMethod, +) -> Result { + if !embedding + .normalization_method + .can_transform_to(normalization_method) + { + return Err(anyhow!( + "Cannot transform from {:?} to {normalization_method:?}", + embedding.normalization_method + )); + } + + if !embedding + .normalization_method + .needs_transformation_to(normalization_method) + { + return Ok(embedding); + } + + let normalized = match normalization_method { + EmbeddingNormalizationMethod::None => embedding.embedding, + EmbeddingNormalizationMethod::L2 => l2(&embedding.embedding), + EmbeddingNormalizationMethod::RmsNorm { epsilon } => { + rms_norm(&embedding.embedding, *epsilon)? + } + }; + + Ok(Embedding { + embedding: normalized, + normalization_method: normalization_method.clone(), + pooling_type: embedding.pooling_type, + source_document_id: embedding.source_document_id, + }) +} + +#[cfg(test)] +mod tests { + use paddler_messaging::embedding::Embedding; + use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; + use paddler_messaging::pooling_type::PoolingType; + + use super::normalize_embedding; + + fn make_embedding(values: Vec, method: EmbeddingNormalizationMethod) -> Embedding { + Embedding { + embedding: values, + normalization_method: method, + pooling_type: PoolingType::Mean, + source_document_id: "test".to_owned(), + } + } + + #[test] + fn normalize_from_none_to_l2() { + let embedding = make_embedding(vec![3.0, 4.0], EmbeddingNormalizationMethod::None); + let result = normalize_embedding(embedding, &EmbeddingNormalizationMethod::L2).unwrap(); + + assert_eq!(result.embedding, vec![0.6, 0.8]); + assert!( + !result + .normalization_method + .needs_transformation_to(&EmbeddingNormalizationMethod::L2) + ); + } + + #[test] + fn normalize_from_none_to_rms_norm() { + let embedding = + make_embedding(vec![2.0, 2.0, 2.0, 2.0], EmbeddingNormalizationMethod::None); + let result = normalize_embedding( + embedding, + &EmbeddingNormalizationMethod::RmsNorm { epsilon: 0.0 }, + ) + .unwrap(); + + for value in &result.embedding { + assert!((value - 1.0).abs() < 1e-6); + } + } + + #[test] + fn normalize_none_to_none_is_noop() { + let embedding = make_embedding(vec![1.0, 2.0, 3.0], EmbeddingNormalizationMethod::None); + let result = normalize_embedding(embedding, &EmbeddingNormalizationMethod::None).unwrap(); + + assert_eq!(result.embedding, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn normalize_rejects_l2_to_rms_norm() { + let embedding = make_embedding(vec![0.6, 0.8], EmbeddingNormalizationMethod::L2); + let result = normalize_embedding( + embedding, + &EmbeddingNormalizationMethod::RmsNorm { epsilon: 1e-6 }, + ); + + assert!(result.is_err()); + } + + #[test] + fn normalize_rejects_l2_to_none() { + let embedding = make_embedding(vec![0.6, 0.8], EmbeddingNormalizationMethod::L2); + let result = normalize_embedding(embedding, &EmbeddingNormalizationMethod::None); + + assert!(result.is_err()); + } + + #[test] + fn normalize_rejects_rms_norm_to_l2() { + let embedding = make_embedding( + vec![1.0, 1.0], + EmbeddingNormalizationMethod::RmsNorm { epsilon: 1e-6 }, + ); + let result = normalize_embedding(embedding, &EmbeddingNormalizationMethod::L2); + + assert!(result.is_err()); + } + + #[test] + fn normalize_to_rms_norm_propagates_oversized_embedding_error() { + let oversized_length = usize::from(u16::MAX) + 1; + let embedding = make_embedding( + vec![1.0; oversized_length], + EmbeddingNormalizationMethod::None, + ); + let result = normalize_embedding( + embedding, + &EmbeddingNormalizationMethod::RmsNorm { epsilon: 0.0 }, + ); + + assert!(result.is_err()); + } + + #[test] + fn normalize_preserves_metadata() { + let embedding = Embedding { + embedding: vec![3.0, 4.0], + normalization_method: EmbeddingNormalizationMethod::None, + pooling_type: PoolingType::Cls, + source_document_id: "doc-42".to_owned(), + }; + let result = normalize_embedding(embedding, &EmbeddingNormalizationMethod::L2).unwrap(); + + assert_eq!(result.pooling_type, PoolingType::Cls); + assert_eq!(result.source_document_id, "doc-42"); + } +} diff --git a/paddler_types/src/normalization/rms_norm.rs b/paddler_agent/src/normalization/rms_norm.rs similarity index 66% rename from paddler_types/src/normalization/rms_norm.rs rename to paddler_agent/src/normalization/rms_norm.rs index 7012cb18..041a08d6 100644 --- a/paddler_types/src/normalization/rms_norm.rs +++ b/paddler_agent/src/normalization/rms_norm.rs @@ -1,25 +1,26 @@ -#[must_use] -#[expect( - clippy::cast_precision_loss, - reason = "embedding length precision loss is acceptable for normalization math" -)] -pub fn rms_norm(embedding: &[f32], eps: f32) -> Vec { +use anyhow::Context as _; +use anyhow::Result; + +pub fn rms_norm(embedding: &[f32], eps: f32) -> Result> { if embedding.is_empty() { - return Vec::new(); + return Ok(Vec::new()); } + let embedding_length = u16::try_from(embedding.len()) + .context("embedding length exceeds the supported maximum for normalization")?; + let mean_square = embedding .iter() .fold(0.0, |acc, &val| val.mul_add(val, acc)) - / embedding.len() as f32; + / f32::from(embedding_length); let rms = (mean_square + eps).sqrt(); if rms == 0.0 { - return vec![0.0; embedding.len()]; + return Ok(vec![0.0; embedding.len()]); } - embedding.iter().map(|&val| val / rms).collect() + Ok(embedding.iter().map(|&val| val / rms).collect()) } #[cfg(test)] @@ -29,7 +30,7 @@ mod tests { #[test] fn test_rms_norm_uniform_values() { let embedding = vec![2.0, 2.0, 2.0, 2.0]; - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); // mean_square = (4+4+4+4)/4 = 4, rms = 2.0 // each value / 2.0 = 1.0 @@ -41,7 +42,7 @@ mod tests { #[test] fn test_rms_norm_mixed_values() { let embedding = vec![1.0, 3.0]; - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); // mean_square = (1+9)/2 = 5, rms = sqrt(5) let expected_rms = 5.0_f32.sqrt(); @@ -53,7 +54,7 @@ mod tests { #[test] fn test_rms_norm_zero_vector_with_zero_epsilon() { let embedding = vec![0.0, 0.0, 0.0]; - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); assert_eq!(result, vec![0.0, 0.0, 0.0]); } @@ -61,7 +62,7 @@ mod tests { #[test] fn test_rms_norm_zero_vector_with_nonzero_epsilon() { let embedding = vec![0.0, 0.0]; - let result = rms_norm(&embedding, 1e-6); + let result = rms_norm(&embedding, 1e-6).unwrap(); // mean_square = 0, rms = sqrt(1e-6), so values = 0 / rms = 0 for val in &result { @@ -72,8 +73,8 @@ mod tests { #[test] fn test_rms_norm_epsilon_prevents_division_instability() { let embedding = vec![1e-10, 1e-10]; - let without_eps = rms_norm(&embedding, 0.0); - let with_eps = rms_norm(&embedding, 1e-6); + let without_eps = rms_norm(&embedding, 0.0).unwrap(); + let with_eps = rms_norm(&embedding, 1e-6).unwrap(); // With epsilon, the denominator is larger, so normalized values are smaller assert!(with_eps[0].abs() < without_eps[0].abs()); @@ -82,7 +83,7 @@ mod tests { #[test] fn test_rms_norm_single_element() { let embedding = vec![5.0]; - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); // mean_square = 25/1 = 25, rms = 5.0, result = 5/5 = 1.0 assert!((result[0] - 1.0).abs() < 1e-6); @@ -91,15 +92,23 @@ mod tests { #[test] fn test_rms_norm_empty_embedding() { let embedding: Vec = Vec::new(); - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); assert!(result.is_empty()); } + #[test] + fn test_rms_norm_length_exceeding_u16_max_returns_error() { + let embedding = vec![1.0_f32; usize::from(u16::MAX) + 1]; + let result = rms_norm(&embedding, 0.0); + + assert!(result.is_err()); + } + #[test] fn test_rms_norm_negative_values() { let embedding = vec![-3.0, 4.0]; - let result = rms_norm(&embedding, 0.0); + let result = rms_norm(&embedding, 0.0).unwrap(); // mean_square = (9+16)/2 = 12.5, rms = sqrt(12.5) let expected_rms = 12.5_f32.sqrt(); diff --git a/paddler/src/agent/plan_embedding_batches.rs b/paddler_agent/src/plan_embedding_batches.rs similarity index 100% rename from paddler/src/agent/plan_embedding_batches.rs rename to paddler_agent/src/plan_embedding_batches.rs diff --git a/paddler/src/agent/prepare_conversation_history_request.rs b/paddler_agent/src/prepare_conversation_history_request.rs similarity index 86% rename from paddler/src/agent/prepare_conversation_history_request.rs rename to paddler_agent/src/prepare_conversation_history_request.rs index 87be81d9..84e2e23e 100644 --- a/paddler/src/agent/prepare_conversation_history_request.rs +++ b/paddler_agent/src/prepare_conversation_history_request.rs @@ -4,17 +4,17 @@ use llama_cpp_bindings::mtmd::mtmd_default_marker; use log::error; use log::warn; use minijinja::context; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::media_marker::MediaMarker; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::media_marker::MediaMarker; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use tokio::sync::mpsc; -use crate::agent::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; -use crate::agent::prepared_conversation_history_request::PreparedConversationHistoryRequest; -use crate::agent::resolve_grammar::resolve_grammar; +use crate::continuous_batch_scheduler_context::ContinuousBatchSchedulerContext; use crate::decoded_image::DecodedImage; use crate::decoded_image_error::DecodedImageError; +use crate::prepared_conversation_history_request::PreparedConversationHistoryRequest; +use crate::resolve_grammar::resolve_grammar; pub fn prepare_conversation_history_request( ContinueFromConversationHistoryParams { diff --git a/paddler/src/agent/prepared_conversation_history_request.rs b/paddler_agent/src/prepared_conversation_history_request.rs similarity index 62% rename from paddler/src/agent/prepared_conversation_history_request.rs rename to paddler_agent/src/prepared_conversation_history_request.rs index 61092da5..4b7f59e4 100644 --- a/paddler/src/agent/prepared_conversation_history_request.rs +++ b/paddler_agent/src/prepared_conversation_history_request.rs @@ -1,8 +1,8 @@ -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use crate::agent::grammar_sampler::GrammarSampler; use crate::decoded_image::DecodedImage; +use crate::grammar_sampler::GrammarSampler; pub enum PreparedConversationHistoryRequest { TextPrompt { diff --git a/paddler/src/agent/receive_stream_stopper_collection.rs b/paddler_agent/src/receive_stream_stopper_collection.rs similarity index 69% rename from paddler/src/agent/receive_stream_stopper_collection.rs rename to paddler_agent/src/receive_stream_stopper_collection.rs index eb16f1d5..3ffb07c8 100644 --- a/paddler/src/agent/receive_stream_stopper_collection.rs +++ b/paddler_agent/src/receive_stream_stopper_collection.rs @@ -5,7 +5,7 @@ use anyhow::anyhow; use dashmap::DashMap; use tokio::sync::mpsc; -use crate::agent::receive_stream_stopper_drop_guard::ReceiveStreamStopperDropGuard; +use crate::receive_stream_stopper_drop_guard::ReceiveStreamStopperDropGuard; pub struct ReceiveStreamStopperCollection { receive_stoppers: DashMap>, @@ -73,8 +73,6 @@ impl Default for ReceiveStreamStopperCollection { #[cfg(test)] mod tests { - use anyhow::Result; - use super::*; #[test] @@ -90,32 +88,32 @@ mod tests { } #[test] - fn register_duplicate_stopper_fails() -> Result<()> { + fn register_duplicate_stopper_fails() { let collection = ReceiveStreamStopperCollection::default(); let (sender_1, _receiver_1) = mpsc::unbounded_channel(); let (sender_2, _receiver_2) = mpsc::unbounded_channel(); - collection.register_stopper("req_1".to_owned(), sender_1)?; + collection + .register_stopper("req_1".to_owned(), sender_1) + .unwrap(); assert!( collection .register_stopper("req_1".to_owned(), sender_2) .is_err() ); - - Ok(()) } #[test] - fn deregister_stopper_succeeds() -> Result<()> { + fn deregister_stopper_succeeds() { let collection = ReceiveStreamStopperCollection::default(); let (sender, _receiver) = mpsc::unbounded_channel(); - collection.register_stopper("req_1".to_owned(), sender)?; + collection + .register_stopper("req_1".to_owned(), sender) + .unwrap(); assert!(collection.deregister_stopper("req_1").is_ok()); - - Ok(()) } #[test] @@ -126,16 +124,16 @@ mod tests { } #[test] - fn stop_sends_signal() -> Result<()> { + fn stop_sends_signal() { let collection = ReceiveStreamStopperCollection::default(); let (sender, mut receiver) = mpsc::unbounded_channel(); - collection.register_stopper("req_1".to_owned(), sender)?; + collection + .register_stopper("req_1".to_owned(), sender) + .unwrap(); assert!(collection.stop("req_1").is_ok()); assert!(receiver.try_recv().is_ok()); - - Ok(()) } #[test] @@ -146,17 +144,47 @@ mod tests { } #[test] - fn register_stopper_with_guard_auto_deregisters_on_drop() -> Result<()> { + fn stop_fails_when_receiver_dropped() { + let collection = ReceiveStreamStopperCollection::default(); + let (sender, receiver) = mpsc::unbounded_channel(); + + collection + .register_stopper("req_1".to_owned(), sender) + .unwrap(); + + drop(receiver); + + assert!(collection.stop("req_1").is_err()); + } + + #[test] + fn register_stopper_with_guard_fails_on_duplicate() { + let collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (sender_1, _receiver_1) = mpsc::unbounded_channel(); + let (sender_2, _receiver_2) = mpsc::unbounded_channel(); + + collection + .register_stopper("req_1".to_owned(), sender_1) + .unwrap(); + + assert!( + collection + .register_stopper_with_guard("req_1".to_owned(), sender_2) + .is_err() + ); + } + + #[test] + fn register_stopper_with_guard_auto_deregisters_on_drop() { let collection = Arc::new(ReceiveStreamStopperCollection::default()); let (sender, _receiver) = mpsc::unbounded_channel(); - let guard = collection.register_stopper_with_guard("req_1".to_owned(), sender)?; + let guard = collection + .register_stopper_with_guard("req_1".to_owned(), sender) + .unwrap(); drop(guard); - // After drop, the stopper should be deregistered assert!(collection.deregister_stopper("req_1").is_err()); - - Ok(()) } } diff --git a/paddler_agent/src/receive_stream_stopper_drop_guard.rs b/paddler_agent/src/receive_stream_stopper_drop_guard.rs new file mode 100644 index 00000000..88f143ef --- /dev/null +++ b/paddler_agent/src/receive_stream_stopper_drop_guard.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use log::error; + +use crate::receive_stream_stopper_collection::ReceiveStreamStopperCollection; + +pub struct ReceiveStreamStopperDropGuard { + pub receive_stream_stopper_collection: Arc, + pub request_id: String, +} + +impl Drop for ReceiveStreamStopperDropGuard { + fn drop(&mut self) { + if let Err(err) = self + .receive_stream_stopper_collection + .deregister_stopper(&self.request_id) + { + error!( + "Failed to deregister stopper for request_id {}: {}", + self.request_id, err + ); + } + } +} + +#[cfg(test)] +mod tests { + use tokio::sync::mpsc; + + use super::*; + + #[test] + fn drop_deregisters_registered_stopper() { + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (sender, _receiver) = mpsc::unbounded_channel(); + let guard = ReceiveStreamStopperDropGuard { + receive_stream_stopper_collection: receive_stream_stopper_collection.clone(), + request_id: "req_1".to_owned(), + }; + + receive_stream_stopper_collection + .register_stopper("req_1".to_owned(), sender) + .unwrap(); + + drop(guard); + + assert!( + receive_stream_stopper_collection + .deregister_stopper("req_1") + .is_err() + ); + } + + #[test] + fn drop_handles_already_deregistered_stopper() { + let receive_stream_stopper_collection = Arc::new(ReceiveStreamStopperCollection::default()); + let (sender, _receiver) = mpsc::unbounded_channel(); + let guard = ReceiveStreamStopperDropGuard { + receive_stream_stopper_collection: receive_stream_stopper_collection.clone(), + request_id: "req_1".to_owned(), + }; + + receive_stream_stopper_collection + .register_stopper("req_1".to_owned(), sender) + .unwrap(); + receive_stream_stopper_collection + .deregister_stopper("req_1") + .unwrap(); + + drop(guard); + + assert!( + receive_stream_stopper_collection + .deregister_stopper("req_1") + .is_err() + ); + } +} diff --git a/paddler/src/agent/reconciliation_service.rs b/paddler_agent/src/reconciliation_service.rs similarity index 91% rename from paddler/src/agent/reconciliation_service.rs rename to paddler_agent/src/reconciliation_service.rs index c4de3101..ecbac9b8 100644 --- a/paddler/src/agent/reconciliation_service.rs +++ b/paddler_agent/src/reconciliation_service.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use log::error; -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_desired_state::AgentDesiredState; use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::MissedTickBehavior; @@ -12,9 +12,10 @@ use tokio_util::sync::CancellationToken; use trzcina::Service; use crate::agent_applicable_state_holder::AgentApplicableStateHolder; +use crate::agent_desired_state_converter::AgentDesiredStateConverter; use crate::agent_issue_fix::AgentIssueFix; -use crate::converts_to_applicable_state::ConvertsToApplicableState as _; use crate::slot_aggregated_status::SlotAggregatedStatus; +use paddler_state_conversion::converts_to_applicable_state::ConvertsToApplicableState as _; async fn convert_to_applicable_state( agent_desired_state: Option<&AgentDesiredState>, @@ -25,9 +26,11 @@ async fn convert_to_applicable_state( let applicable_state = match agent_desired_state { None => None, Some(agent_desired_state) => Some( - agent_desired_state - .to_applicable_state(slot_aggregated_status.clone()) - .await?, + AgentDesiredStateConverter { + slot_aggregated_status: slot_aggregated_status.clone(), + } + .to_applicable_state(agent_desired_state.clone()) + .await?, ), }; diff --git a/paddler/src/resolve_desired_model.rs b/paddler_agent/src/resolve_desired_model.rs similarity index 54% rename from paddler/src/resolve_desired_model.rs rename to paddler_agent/src/resolve_desired_model.rs index 51fde0d7..da87657c 100644 --- a/paddler/src/resolve_desired_model.rs +++ b/paddler_agent/src/resolve_desired_model.rs @@ -1,10 +1,12 @@ use std::sync::Arc; use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_desired_model::AgentDesiredModel; use crate::desired_model_resolution::DesiredModelResolution; +use crate::model_source::huggingface::HuggingFaceModelSource; use crate::model_source::local::LocalModelPath; +use crate::model_source::url::UrlModelSource; use crate::resolves_model_source::ResolvesModelSource; use crate::slot_aggregated_status::SlotAggregatedStatus; @@ -14,85 +16,101 @@ pub async fn resolve_desired_model( ) -> Result { match desired { AgentDesiredModel::HuggingFace(reference) => { - reference.resolve(slot_aggregated_status).await + HuggingFaceModelSource(reference.clone()) + .resolve(slot_aggregated_status) + .await } AgentDesiredModel::LocalToAgent(path) => { LocalModelPath::new(path.clone()) .resolve(slot_aggregated_status) .await } - AgentDesiredModel::Url(reference) => reference.resolve(slot_aggregated_status).await, + AgentDesiredModel::Url(reference) => { + UrlModelSource(reference.clone()) + .resolve(slot_aggregated_status) + .await + } AgentDesiredModel::None => Ok(DesiredModelResolution::NotConfigured), } } #[cfg(test)] mod tests { - use std::path::PathBuf; + use std::mem; use std::sync::Arc; - use anyhow::Result; - use paddler_types::agent_desired_model::AgentDesiredModel; + use paddler_messaging::agent_desired_model::AgentDesiredModel; use tempfile::NamedTempFile; - use tempfile::TempDir; use crate::desired_model_resolution::DesiredModelResolution; use crate::resolve_desired_model::resolve_desired_model; use crate::slot_aggregated_status::SlotAggregatedStatus; + use paddler_messaging::agent_issue::AgentIssue; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; fn fresh_status() -> Arc { Arc::new(SlotAggregatedStatus::new(1)) } - fn nonexistent_path_in_temp_dir(label: &str) -> Result<(TempDir, PathBuf)> { - let dir = tempfile::tempdir()?; - let path = dir.path().join(format!("missing-{label}.gguf")); - - Ok((dir, path)) - } - #[tokio::test] - async fn local_existing_file_resolves_to_resolved_with_that_path() -> Result<()> { + async fn local_existing_file_resolves_to_resolved_with_that_path() { let status = fresh_status(); - let temp_file = NamedTempFile::new()?; + let temp_file = NamedTempFile::new().unwrap(); let path = temp_file.path().to_path_buf(); let desired = AgentDesiredModel::LocalToAgent(path.display().to_string()); - let resolution = resolve_desired_model(&desired, status).await?; + let resolution = resolve_desired_model(&desired, status).await.unwrap(); assert!(matches!( resolution, DesiredModelResolution::Resolved(ref resolved) if *resolved == path )); - - Ok(()) } #[tokio::test] - async fn local_missing_file_resolves_to_local_file_missing_with_that_path() -> Result<()> { + async fn local_missing_file_resolves_to_local_file_missing_with_that_path() { let status = fresh_status(); - let (_dir_guard, path) = nonexistent_path_in_temp_dir("desired")?; + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("missing-desired.gguf"); let desired = AgentDesiredModel::LocalToAgent(path.display().to_string()); - let resolution = resolve_desired_model(&desired, status).await?; + let resolution = resolve_desired_model(&desired, status).await.unwrap(); assert!(matches!( resolution, DesiredModelResolution::LocalFileMissing(ref missing) if *missing == path )); + } - Ok(()) + #[tokio::test] + async fn huggingface_already_marked_missing_resolves_to_error_without_network() { + let status = fresh_status(); + let reference = HuggingFaceModelReference { + filename: "model.gguf".to_owned(), + repo_id: "owner/repo".to_owned(), + revision: "main".to_owned(), + }; + status.register_issue(AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { + model_path: "owner/repo/main/model.gguf".to_owned(), + })); + let desired = AgentDesiredModel::HuggingFace(reference); + + let resolution = resolve_desired_model(&desired, status).await; + + assert!(resolution.is_err()); } #[tokio::test] - async fn none_variant_resolves_to_not_configured() -> Result<()> { + async fn none_variant_resolves_to_not_configured() { let status = fresh_status(); let desired = AgentDesiredModel::None; - let resolution = resolve_desired_model(&desired, status).await?; - - assert!(matches!(resolution, DesiredModelResolution::NotConfigured)); + let resolution = resolve_desired_model(&desired, status).await.unwrap(); - Ok(()) + assert_eq!( + mem::discriminant(&resolution), + mem::discriminant(&DesiredModelResolution::NotConfigured) + ); } } diff --git a/paddler_agent/src/resolve_grammar.rs b/paddler_agent/src/resolve_grammar.rs new file mode 100644 index 00000000..2186c0c7 --- /dev/null +++ b/paddler_agent/src/resolve_grammar.rs @@ -0,0 +1,152 @@ +use anyhow::Result; +use anyhow::anyhow; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use tokio::sync::mpsc; + +use crate::grammar_sampler::GrammarSampler; + +pub fn resolve_grammar( + grammar: Option<&GrammarConstraint>, + enable_thinking: bool, + generated_tokens_tx: &mpsc::UnboundedSender, +) -> Result> { + let Some(grammar_constraint) = grammar else { + return Ok(None); + }; + + if enable_thinking { + let message = "Grammar constraints are incompatible with thinking mode".to_owned(); + + generated_tokens_tx + .send(GeneratedTokenResult::GrammarIncompatibleWithThinking( + message.clone(), + )) + .map_err(|err| anyhow!("Failed to send grammar incompatibility error: {err}"))?; + + return Err(anyhow!(message)); + } + + match GrammarSampler::new(grammar_constraint) { + Ok(sampler) => Ok(Some(sampler)), + Err(err) => { + let message = format!("Failed to create grammar sampler: {err}"); + + generated_tokens_tx + .send(GeneratedTokenResult::GrammarSyntaxError(message.clone())) + .map_err(|send_err| anyhow!("Failed to send grammar syntax error: {send_err}"))?; + + Err(anyhow!(message)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_none_when_grammar_is_absent() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + + let resolved = resolve_grammar(None, false, &generated_tokens_tx).unwrap(); + + assert!(resolved.is_none()); + assert!(generated_tokens_rx.try_recv().is_err()); + } + + #[test] + fn emits_incompatibility_event_and_errors_when_thinking_is_enabled() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + let grammar = GrammarConstraint::Gbnf { + grammar: "root ::= \"yes\" | \"no\"".to_owned(), + root: "root".to_owned(), + }; + + let result = resolve_grammar(Some(&grammar), true, &generated_tokens_tx); + + assert!(result.is_err()); + + let event = generated_tokens_rx.try_recv().unwrap(); + + assert!( + matches!(event, GeneratedTokenResult::GrammarIncompatibleWithThinking(message) if message == "Grammar constraints are incompatible with thinking mode") + ); + } + + #[test] + fn errors_when_incompatibility_event_cannot_be_sent() { + let (generated_tokens_tx, generated_tokens_rx) = mpsc::unbounded_channel(); + + drop(generated_tokens_rx); + + let grammar = GrammarConstraint::Gbnf { + grammar: "root ::= \"yes\" | \"no\"".to_owned(), + root: "root".to_owned(), + }; + + let result = resolve_grammar(Some(&grammar), true, &generated_tokens_tx); + + assert_eq!( + result.err().unwrap().to_string(), + "Failed to send grammar incompatibility error: channel closed" + ); + } + + #[test] + fn returns_sampler_for_valid_grammar() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + let grammar = GrammarConstraint::Gbnf { + grammar: "root ::= \"yes\" | \"no\"".to_owned(), + root: "root".to_owned(), + }; + + let resolved = resolve_grammar(Some(&grammar), false, &generated_tokens_tx).unwrap(); + + assert!(resolved.is_some()); + assert!(generated_tokens_rx.try_recv().is_err()); + } + + #[test] + fn emits_syntax_error_event_and_errors_for_invalid_grammar() { + let (generated_tokens_tx, mut generated_tokens_rx) = mpsc::unbounded_channel(); + let grammar = GrammarConstraint::JsonSchema { + schema: "not valid json at all".to_owned(), + }; + + let result = resolve_grammar(Some(&grammar), false, &generated_tokens_tx); + + assert!(result.is_err()); + assert!( + result + .err() + .unwrap() + .to_string() + .starts_with("Failed to create grammar sampler:") + ); + + let event = generated_tokens_rx.try_recv().unwrap(); + + assert!( + matches!(event, GeneratedTokenResult::GrammarSyntaxError(message) if message.starts_with("Failed to create grammar sampler:")) + ); + } + + #[test] + fn errors_when_syntax_error_event_cannot_be_sent() { + let (generated_tokens_tx, generated_tokens_rx) = mpsc::unbounded_channel(); + + drop(generated_tokens_rx); + + let grammar = GrammarConstraint::JsonSchema { + schema: "not valid json at all".to_owned(), + }; + + let result = resolve_grammar(Some(&grammar), false, &generated_tokens_tx); + + assert_eq!( + result.err().unwrap().to_string(), + "Failed to send grammar syntax error: channel closed" + ); + } +} diff --git a/paddler/src/agent/resolve_grammar_to_gbnf.rs b/paddler_agent/src/resolve_grammar_to_gbnf.rs similarity index 80% rename from paddler/src/agent/resolve_grammar_to_gbnf.rs rename to paddler_agent/src/resolve_grammar_to_gbnf.rs index f957b643..07ba5fa2 100644 --- a/paddler/src/agent/resolve_grammar_to_gbnf.rs +++ b/paddler_agent/src/resolve_grammar_to_gbnf.rs @@ -1,9 +1,9 @@ use anyhow::Result; use anyhow::anyhow; use llama_cpp_bindings::json_schema_to_grammar; -use paddler_types::grammar_constraint::GrammarConstraint; +use paddler_messaging::grammar_constraint::GrammarConstraint; -use crate::agent::resolved_grammar::ResolvedGrammar; +use crate::resolved_grammar::ResolvedGrammar; pub fn resolve_grammar_to_gbnf(grammar_constraint: &GrammarConstraint) -> Result { match grammar_constraint { @@ -25,37 +25,31 @@ pub fn resolve_grammar_to_gbnf(grammar_constraint: &GrammarConstraint) -> Result #[cfg(test)] mod tests { - use anyhow::Result; - use super::*; #[test] - fn resolves_gbnf_variant() -> Result<()> { + fn resolves_gbnf_variant() { let constraint = GrammarConstraint::Gbnf { grammar: "root ::= \"yes\" | \"no\"".to_owned(), root: "root".to_owned(), }; - let resolved = resolve_grammar_to_gbnf(&constraint)?; + let resolved = resolve_grammar_to_gbnf(&constraint).unwrap(); assert_eq!(resolved.grammar_string, "root ::= \"yes\" | \"no\""); assert_eq!(resolved.root_rule, "root"); - - Ok(()) } #[test] - fn resolves_json_schema_variant() -> Result<()> { + fn resolves_json_schema_variant() { let constraint = GrammarConstraint::JsonSchema { schema: r#"{"type": "object", "properties": {"name": {"type": "string"}}}"#.to_owned(), }; - let resolved = resolve_grammar_to_gbnf(&constraint)?; + let resolved = resolve_grammar_to_gbnf(&constraint).unwrap(); assert!(!resolved.grammar_string.is_empty()); assert_eq!(resolved.root_rule, "root"); - - Ok(()) } #[test] diff --git a/paddler/src/agent/resolved_grammar.rs b/paddler_agent/src/resolved_grammar.rs similarity index 100% rename from paddler/src/agent/resolved_grammar.rs rename to paddler_agent/src/resolved_grammar.rs diff --git a/paddler/src/resolves_model_source.rs b/paddler_agent/src/resolves_model_source.rs similarity index 100% rename from paddler/src/resolves_model_source.rs rename to paddler_agent/src/resolves_model_source.rs diff --git a/paddler/src/agent/sample_token_at_batch_index.rs b/paddler_agent/src/sample_token_at_batch_index.rs similarity index 95% rename from paddler/src/agent/sample_token_at_batch_index.rs rename to paddler_agent/src/sample_token_at_batch_index.rs index 0739efd2..69e72dd1 100644 --- a/paddler/src/agent/sample_token_at_batch_index.rs +++ b/paddler_agent/src/sample_token_at_batch_index.rs @@ -3,7 +3,7 @@ use anyhow::Result; use llama_cpp_bindings::context::LlamaContext; use llama_cpp_bindings::sampling::LlamaSampler; -use crate::agent::sampling_outcome::SamplingOutcome; +use crate::sampling_outcome::SamplingOutcome; pub fn sample_token_at_batch_index( llama_context: &LlamaContext, diff --git a/paddler/src/agent/sampling_outcome.rs b/paddler_agent/src/sampling_outcome.rs similarity index 100% rename from paddler/src/agent/sampling_outcome.rs rename to paddler_agent/src/sampling_outcome.rs diff --git a/paddler/src/agent/sequence_id_pool.rs b/paddler_agent/src/sequence_id_pool.rs similarity index 100% rename from paddler/src/agent/sequence_id_pool.rs rename to paddler_agent/src/sequence_id_pool.rs diff --git a/paddler/src/slot_aggregated_status.rs b/paddler_agent/src/slot_aggregated_status.rs similarity index 69% rename from paddler/src/slot_aggregated_status.rs rename to paddler_agent/src/slot_aggregated_status.rs index 0ad9d2ff..a66ce0a1 100644 --- a/paddler/src/slot_aggregated_status.rs +++ b/paddler_agent/src/slot_aggregated_status.rs @@ -1,20 +1,20 @@ -use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; use anyhow::Result; use dashmap::DashSet; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use parking_lot::RwLock; use tokio::sync::watch; use crate::agent_issue_fix::AgentIssueFix; -use crate::atomic_value::AtomicValue; use crate::dispenses_slots::DispensesSlots; -use crate::produces_snapshot::ProducesSnapshot; -use crate::subscribes_to_updates::SubscribesToUpdates; +use paddler_messaging::atomic_value::AtomicValue; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates; pub struct SlotAggregatedStatus { desired_slots_total: i32, @@ -125,12 +125,7 @@ impl SlotAggregatedStatus { self.update_tx.send_replace(()); } - pub fn set_download_status( - &self, - current: u64, - total: Option, - filename: Option, - ) { + pub fn set_download_status(&self, current: u64, total: Option, filename: Option) { self.download_current.set(current); if let Some(value) = total { self.download_total.set(value); @@ -142,13 +137,9 @@ impl SlotAggregatedStatus { self.set_download_filename(filename); } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_download_filename(&self, filename: Option) { { - let mut filename_lock = self - .download_filename - .write() - .expect("Lock poisoned when setting download filename"); + let mut filename_lock = self.download_filename.write(); *filename_lock = filename; } @@ -157,13 +148,9 @@ impl SlotAggregatedStatus { self.update_tx.send_replace(()); } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_model_path(&self, model_path: Option) { { - let mut path_lock = self - .model_path - .write() - .expect("Lock poisoned when setting model path"); + let mut path_lock = self.model_path.write(); *path_lock = model_path; } @@ -212,24 +199,15 @@ impl SubscribesToUpdates for SlotAggregatedStatus { impl ProducesSnapshot for SlotAggregatedStatus { type Snapshot = SlotAggregatedStatusSnapshot; - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] fn make_snapshot(&self) -> Result { Ok(SlotAggregatedStatusSnapshot { issues: self.issues.iter().map(|item| item.clone()).collect(), desired_slots_total: self.desired_slots_total, download_current: self.download_current.get(), - download_filename: self - .download_filename - .read() - .expect("Lock poisoned when getting download filename") - .clone(), + download_filename: self.download_filename.read().clone(), download_indeterminate: self.download_indeterminate.get(), download_total: self.download_total.get(), - model_path: self - .model_path - .read() - .expect("Lock poisoned when getting model path") - .clone(), + model_path: self.model_path.read().clone(), slots_processing: self.slots_processing.get(), slots_total: self.slots_total.get(), state_application_status: self.state_application_status_code.get().try_into()?, @@ -243,15 +221,14 @@ impl ProducesSnapshot for SlotAggregatedStatus { mod tests { use std::time::Duration; - use anyhow::Result; - use paddler_types::agent_issue_params::ModelPath; - use paddler_types::agent_issue_params::SlotCannotStartParams; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::agent_issue_params::slot_cannot_start_params::SlotCannotStartParams; use tokio::time::timeout; use super::*; #[tokio::test] - async fn take_slot_wakes_subscribed_waiter() -> Result<()> { + async fn take_slot_wakes_subscribed_waiter() { let status = SlotAggregatedStatus::new(2); let mut update_rx = status.subscribe_to_updates(); @@ -259,10 +236,8 @@ mod tests { timeout(Duration::from_secs(1), update_rx.changed()) .await - .map_err(|err| anyhow::anyhow!("subscriber did not observe within deadline: {err}"))? - .map_err(|err| anyhow::anyhow!("watch sender dropped: {err}"))?; - - Ok(()) + .unwrap() + .unwrap(); } fn model_path(path: &str) -> ModelPath { @@ -294,66 +269,67 @@ mod tests { assert!(!status.has_issue(&issue)); } + fn is_slot_cannot_start(agent_issue: &AgentIssue) -> bool { + matches!(agent_issue, AgentIssue::SlotCannotStart(_)) + } + #[test] fn has_issue_like_matches_with_predicate() { let status = SlotAggregatedStatus::new(2); - let issue = AgentIssue::SlotCannotStart(SlotCannotStartParams { - error: "failed".to_owned(), - slot_index: 3, - }); - status.register_issue(issue); + status.register_issue(AgentIssue::ModelFileDoesNotExist(model_path("model_test"))); + + assert!(!status.has_issue_like(is_slot_cannot_start)); - assert!(status.has_issue_like(|agent_issue| { - matches!(agent_issue, AgentIssue::SlotCannotStart(_)) + status.register_issue(AgentIssue::SlotCannotStart(SlotCannotStartParams { + error: "failed".to_owned(), + slot_index: 3, })); + assert!(status.has_issue_like(is_slot_cannot_start)); + assert!(!status.has_issue_like(|agent_issue| { matches!(agent_issue, AgentIssue::ModelCannotBeLoaded(_)) })); } #[test] - fn increment_and_decrement_total_slots() -> Result<()> { + fn increment_and_decrement_total_slots() { let status = SlotAggregatedStatus::new(2); status.increment_total_slots(); status.increment_total_slots(); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.slots_total, 2); status.decrement_total_slots(); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.slots_total, 1); - - Ok(()) } #[test] - fn version_increments_on_slot_changes() -> Result<()> { + fn version_increments_on_slot_changes() { let status = SlotAggregatedStatus::new(2); - let initial_version = status.make_snapshot()?.version; + let initial_version = status.make_snapshot().unwrap().version; status.increment_total_slots(); - let updated_version = status.make_snapshot()?.version; + let updated_version = status.make_snapshot().unwrap().version; assert!(updated_version > initial_version); - - Ok(()) } #[test] - fn make_snapshot_returns_correct_values() -> Result<()> { + fn make_snapshot_returns_correct_values() { let status = SlotAggregatedStatus::new(4); status.set_model_path(Some("test_model".to_owned())); status.increment_total_slots(); status.increment_total_slots(); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.desired_slots_total, 4); assert_eq!(snapshot.model_path, Some("test_model".to_owned())); @@ -363,12 +339,84 @@ mod tests { snapshot.state_application_status, AgentStateApplicationStatus::Fresh ); + } - Ok(()) + #[test] + fn get_state_application_status_reflects_set_value() { + let status = SlotAggregatedStatus::new(2); + + assert_eq!( + status.get_state_application_status().unwrap(), + AgentStateApplicationStatus::Fresh + ); + + status.set_state_application_status(AgentStateApplicationStatus::Applied); + + assert_eq!( + status.get_state_application_status().unwrap(), + AgentStateApplicationStatus::Applied + ); + + let snapshot = status.make_snapshot().unwrap(); + assert_eq!( + snapshot.state_application_status, + AgentStateApplicationStatus::Applied + ); } #[test] - fn reset_clears_state() -> Result<()> { + fn make_snapshot_propagates_invalid_state_application_status() { + let status = SlotAggregatedStatus::new(2); + + status + .state_application_status_code + .set(AgentStateApplicationStatus::Stuck as i32 + 1); + + let snapshot_result = status.make_snapshot(); + + assert!(snapshot_result.is_err()); + } + + #[test] + fn register_issue_twice_keeps_single_entry() { + let status = SlotAggregatedStatus::new(2); + let issue = AgentIssue::ModelFileDoesNotExist(model_path("model_test")); + + status.register_issue(issue.clone()); + status.register_issue(issue); + + let snapshot = status.make_snapshot().unwrap(); + assert_eq!(snapshot.issues.len(), 1); + } + + #[test] + fn register_fix_without_matching_issue_keeps_issues() { + let status = SlotAggregatedStatus::new(2); + let issue = AgentIssue::ModelFileDoesNotExist(model_path("model_test")); + + status.register_issue(issue.clone()); + status.register_fix(&AgentIssueFix::ModelFileExists(model_path("other_model"))); + + assert!(status.has_issue(&issue)); + } + + #[test] + fn slots_processing_count_tracks_taken_slots() { + let status = SlotAggregatedStatus::new(2); + + assert_eq!(status.slots_processing_count(), 0); + + status.take_slot(); + + assert_eq!(status.slots_processing_count(), 1); + + status.release_slot(); + + assert_eq!(status.slots_processing_count(), 0); + } + + #[test] + fn reset_clears_state() { let status = SlotAggregatedStatus::new(2); status.set_model_path(Some("test_model".to_owned())); @@ -377,141 +425,123 @@ mod tests { status.reset(); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.slots_total, 0); assert_eq!(snapshot.slots_processing, 0); assert_eq!(snapshot.model_path, None); assert!(snapshot.issues.is_empty()); - - Ok(()) } #[test] - fn take_slot_and_release_slot() -> Result<()> { + fn take_slot_and_release_slot() { let status = SlotAggregatedStatus::new(2); status.take_slot(); - assert_eq!(status.make_snapshot()?.slots_processing, 1); + assert_eq!(status.make_snapshot().unwrap().slots_processing, 1); status.take_slot(); - assert_eq!(status.make_snapshot()?.slots_processing, 2); + assert_eq!(status.make_snapshot().unwrap().slots_processing, 2); status.release_slot(); - assert_eq!(status.make_snapshot()?.slots_processing, 1); - - Ok(()) + assert_eq!(status.make_snapshot().unwrap().slots_processing, 1); } #[test] - fn set_download_status_updates_all_fields() -> Result<()> { + fn set_download_status_updates_all_fields() { let status = SlotAggregatedStatus::new(2); status.set_download_status(100, Some(500), Some("model.gguf".to_owned())); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 100); assert_eq!(snapshot.download_total, 500); assert_eq!(snapshot.download_filename, Some("model.gguf".to_owned())); - - Ok(()) } #[test] - fn set_download_status_with_indeterminate_total_keeps_flag_true() -> Result<()> { + fn set_download_status_with_indeterminate_total_keeps_flag_true() { let status = SlotAggregatedStatus::new(2); status.set_download_status(123, None, Some("model.gguf".to_owned())); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 123); assert_eq!(snapshot.download_total, 0); assert!(snapshot.download_indeterminate); - - Ok(()) } #[test] - fn set_download_status_indeterminate_after_known_total_resets_download_total() -> Result<()> { + fn set_download_status_indeterminate_after_known_total_resets_download_total() { let status = SlotAggregatedStatus::new(2); status.set_download_status(0, Some(5000), Some("model.gguf".to_owned())); status.set_download_status(10, None, Some("model.gguf".to_owned())); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_total, 0); assert!(snapshot.download_indeterminate); - - Ok(()) } #[test] - fn set_download_status_with_known_total_flips_indeterminate_false() -> Result<()> { + fn set_download_status_with_known_total_flips_indeterminate_false() { let status = SlotAggregatedStatus::new(2); status.set_download_status(0, Some(5000), Some("model.gguf".to_owned())); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_total, 5000); assert!(!snapshot.download_indeterminate); - - Ok(()) } #[test] - fn increment_download_current_accumulates() -> Result<()> { + fn increment_download_current_accumulates() { let status = SlotAggregatedStatus::new(2); status.set_download_status(0, Some(1000), Some("model.gguf".to_owned())); status.increment_download_current(100); status.increment_download_current(200); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 300); assert_eq!(snapshot.download_total, 1000); - - Ok(()) } #[test] - fn reset_download_clears_download_fields() -> Result<()> { + fn reset_download_clears_download_fields() { let status = SlotAggregatedStatus::new(2); status.set_download_status(500, Some(1000), Some("model.gguf".to_owned())); status.reset_download(); - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 0); assert_eq!(snapshot.download_total, 0); assert!(snapshot.download_indeterminate); assert_eq!(snapshot.download_filename, None); - - Ok(()) } #[test] - fn set_uses_chat_template_override() -> Result<()> { + fn set_uses_chat_template_override() { let status = SlotAggregatedStatus::new(2); - assert!(!status.make_snapshot()?.uses_chat_template_override); + assert!(!status.make_snapshot().unwrap().uses_chat_template_override); status.set_uses_chat_template_override(true); - assert!(status.make_snapshot()?.uses_chat_template_override); + assert!(status.make_snapshot().unwrap().uses_chat_template_override); status.set_uses_chat_template_override(false); - assert!(!status.make_snapshot()?.uses_chat_template_override); - - Ok(()) + assert!(!status.make_snapshot().unwrap().uses_chat_template_override); } } diff --git a/paddler/src/slot_aggregated_status_download_progress.rs b/paddler_agent/src/slot_aggregated_status_download_progress.rs similarity index 80% rename from paddler/src/slot_aggregated_status_download_progress.rs rename to paddler_agent/src/slot_aggregated_status_download_progress.rs index a97ffe27..9467d07d 100644 --- a/paddler/src/slot_aggregated_status_download_progress.rs +++ b/paddler_agent/src/slot_aggregated_status_download_progress.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use hf_hub::api::tokio::Progress; -use paddler_types::agent_issue_params::ModelPath; +use paddler_messaging::agent_issue_params::model_path::ModelPath; use crate::agent_issue_fix::AgentIssueFix; use crate::slot_aggregated_status::SlotAggregatedStatus; @@ -26,8 +26,11 @@ impl Progress for SlotAggregatedStatusDownloadProgress { model_path: filename.to_owned(), })); - self.slot_aggregated_status - .set_download_status(0, Some(size as u64), Some(filename.to_owned())); + self.slot_aggregated_status.set_download_status( + 0, + Some(size as u64), + Some(filename.to_owned()), + ); } async fn update(&mut self, size: usize) { @@ -44,18 +47,17 @@ impl Progress for SlotAggregatedStatusDownloadProgress { mod tests { use std::sync::Arc; - use anyhow::Result; use hf_hub::api::tokio::Progress; - use paddler_types::agent_issue::AgentIssue; - use paddler_types::agent_issue_params::HuggingFaceDownloadLock; - use paddler_types::agent_issue_params::ModelPath; + use paddler_messaging::agent_issue::AgentIssue; + use paddler_messaging::agent_issue_params::hugging_face_download_lock::HuggingFaceDownloadLock; + use paddler_messaging::agent_issue_params::model_path::ModelPath; - use crate::produces_snapshot::ProducesSnapshot; use crate::slot_aggregated_status::SlotAggregatedStatus; use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloadProgress; + use paddler_messaging::produces_snapshot::ProducesSnapshot; #[tokio::test] - async fn test_init_sets_download_status_and_registers_fix() -> Result<()> { + async fn test_init_sets_download_status_and_registers_fix() { let status = Arc::new(SlotAggregatedStatus::new(2)); status.register_issue(AgentIssue::HuggingFaceCannotAcquireLock( @@ -71,7 +73,7 @@ mod tests { progress.init(1000, "model.gguf").await; - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_total, 1000); assert_eq!(snapshot.download_current, 0); @@ -84,12 +86,10 @@ mod tests { }, }, ))); - - Ok(()) } #[tokio::test] - async fn test_update_increments_download_current() -> Result<()> { + async fn test_update_increments_download_current() { let status = Arc::new(SlotAggregatedStatus::new(2)); let mut progress = SlotAggregatedStatusDownloadProgress::new(Arc::clone(&status)); @@ -97,16 +97,14 @@ mod tests { progress.update(300).await; progress.update(200).await; - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 500); assert_eq!(snapshot.download_total, 1000); - - Ok(()) } #[tokio::test] - async fn test_finish_resets_download() -> Result<()> { + async fn test_finish_resets_download() { let status = Arc::new(SlotAggregatedStatus::new(2)); let mut progress = SlotAggregatedStatusDownloadProgress::new(Arc::clone(&status)); @@ -114,12 +112,10 @@ mod tests { progress.update(1000).await; progress.finish().await; - let snapshot = status.make_snapshot()?; + let snapshot = status.make_snapshot().unwrap(); assert_eq!(snapshot.download_current, 0); assert_eq!(snapshot.download_total, 0); assert_eq!(snapshot.download_filename, None); - - Ok(()) } } diff --git a/paddler/src/slot_aggregated_status_manager.rs b/paddler_agent/src/slot_aggregated_status_manager.rs similarity index 100% rename from paddler/src/slot_aggregated_status_manager.rs rename to paddler_agent/src/slot_aggregated_status_manager.rs diff --git a/paddler/src/agent/slot_guard.rs b/paddler_agent/src/slot_guard.rs similarity index 90% rename from paddler/src/agent/slot_guard.rs rename to paddler_agent/src/slot_guard.rs index 2503739a..b202b0a3 100644 --- a/paddler/src/agent/slot_guard.rs +++ b/paddler_agent/src/slot_guard.rs @@ -29,15 +29,14 @@ mod tests { use std::sync::Arc; use std::time::Duration; - use anyhow::Result; use tokio_util::sync::CancellationToken; - use crate::agent::drain_in_flight_requests::drain_in_flight_requests; - use crate::agent::slot_guard::SlotGuard; + use crate::drain_in_flight_requests::drain_in_flight_requests; use crate::slot_aggregated_status_manager::SlotAggregatedStatusManager; + use crate::slot_guard::SlotGuard; #[tokio::test] - async fn increments_slot_on_construct_and_releases_on_drop() -> Result<()> { + async fn increments_slot_on_construct_and_releases_on_drop() { let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(4)); assert_eq!( @@ -68,12 +67,10 @@ mod tests { .slots_processing_count(), 0 ); - - Ok(()) } #[tokio::test] - async fn drain_in_flight_requests_blocks_until_guard_dropped() -> Result<()> { + async fn drain_in_flight_requests_blocks_until_guard_dropped() { let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(4)); let shutdown = CancellationToken::new(); @@ -99,8 +96,10 @@ mod tests { drop(guard); let unblock_window = Duration::from_millis(500); - tokio::time::timeout(unblock_window, drain_task).await???; - - Ok(()) + tokio::time::timeout(unblock_window, drain_task) + .await + .unwrap() + .unwrap() + .unwrap(); } } diff --git a/paddler/src/tool_call_buffer.rs b/paddler_agent/src/tool_call_buffer.rs similarity index 100% rename from paddler/src/tool_call_buffer.rs rename to paddler_agent/src/tool_call_buffer.rs diff --git a/paddler/src/tool_call_event.rs b/paddler_agent/src/tool_call_event.rs similarity index 70% rename from paddler/src/tool_call_event.rs rename to paddler_agent/src/tool_call_event.rs index 54a692c1..37f39f48 100644 --- a/paddler/src/tool_call_event.rs +++ b/paddler_agent/src/tool_call_event.rs @@ -1,9 +1,9 @@ use llama_cpp_bindings::ParsedToolCall; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::raw_tool_call_tokens::RawToolCallTokens; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::raw_tool_call_tokens::RawToolCallTokens; use crate::tool_call_pipeline_error::ToolCallPipelineError; -use crate::tool_call_validation_error::ToolCallValidationError; +use paddler_messaging::tool_call_validation_error::ToolCallValidationError; #[derive(Debug)] pub enum ToolCallEvent { @@ -50,17 +50,15 @@ impl ToolCallEvent { #[cfg(test)] mod tests { - use anyhow::Result; - use anyhow::bail; use llama_cpp_bindings::ParsedToolCall; use llama_cpp_bindings::ToolCallArguments; - use paddler_types::generated_token_result::GeneratedTokenResult; - use paddler_types::raw_tool_call_tokens::RawToolCallTokens; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::raw_tool_call_tokens::RawToolCallTokens; use serde_json::json; use super::ToolCallEvent; use crate::tool_call_pipeline_error::ToolCallPipelineError; - use crate::tool_call_validation_error::ToolCallValidationError; + use paddler_messaging::tool_call_validation_error::ToolCallValidationError; #[test] fn pending_classifies_as_pending() { @@ -110,7 +108,7 @@ mod tests { } #[test] - fn resolved_converts_to_tool_call_parsed() -> Result<()> { + fn resolved_converts_to_tool_call_parsed() { let parsed = ParsedToolCall::new( "id".to_owned(), "tool".to_owned(), @@ -118,39 +116,46 @@ mod tests { ); let event = ToolCallEvent::Resolved(vec![parsed.clone()]); - match event.into_generated_token_result() { - Some(GeneratedTokenResult::ToolCallParsed(calls)) if calls == vec![parsed] => Ok(()), - other => bail!("expected ToolCallParsed with one call, got {other:?}"), - } + let result = event + .into_generated_token_result() + .expect("Resolved must convert to Some"); + + assert!( + matches!(result, GeneratedTokenResult::ToolCallParsed(calls) if calls == vec![parsed]) + ); } #[test] - fn parse_failed_converts_to_tool_call_parse_failed_with_message() -> Result<()> { + fn parse_failed_converts_to_tool_call_parse_failed_with_message() { let event = ToolCallEvent::ParseFailed(ToolCallPipelineError::EmptyBuffer); - match event.into_generated_token_result() { - Some(GeneratedTokenResult::ToolCallParseFailed(message)) if !message.is_empty() => { - Ok(()) - } - other => bail!("expected ToolCallParseFailed with non-empty message, got {other:?}"), - } + let result = event + .into_generated_token_result() + .expect("ParseFailed must convert to Some"); + + assert!(matches!( + result, + GeneratedTokenResult::ToolCallParseFailed(message) + if message == ToolCallPipelineError::EmptyBuffer.to_string() + )); } #[test] - fn validation_failed_converts_to_tool_call_validation_failed_with_messages() -> Result<()> { + fn validation_failed_converts_to_tool_call_validation_failed_with_messages() { let event = ToolCallEvent::ValidationFailed(vec![ToolCallValidationError::UnknownToolName( "missing".to_owned(), )]); - match event.into_generated_token_result() { - Some(GeneratedTokenResult::ToolCallValidationFailed(messages)) - if messages.len() == 1 && messages[0].contains("missing") => - { - Ok(()) - } - other => bail!("expected ToolCallValidationFailed mentioning 'missing', got {other:?}"), - } + let result = event + .into_generated_token_result() + .expect("ValidationFailed must convert to Some"); + + assert!(matches!( + result, + GeneratedTokenResult::ToolCallValidationFailed(messages) + if messages.len() == 1 && messages[0].contains("missing") + )); } #[test] @@ -166,20 +171,20 @@ mod tests { } #[test] - fn unrecognized_format_converts_to_unrecognized_tool_call_format_preserving_payload() - -> Result<()> { + fn unrecognized_format_converts_to_unrecognized_tool_call_format_preserving_payload() { let event = ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { text: "raw output".to_owned(), ffi_error_message: "parser bailed".to_owned(), }); - match event.into_generated_token_result() { - Some(GeneratedTokenResult::UnrecognizedToolCallFormat(raw)) => { - assert_eq!(raw.text, "raw output"); - assert_eq!(raw.ffi_error_message, "parser bailed"); - Ok(()) - } - other => bail!("expected UnrecognizedToolCallFormat preserving payload, got {other:?}"), - } + let result = event + .into_generated_token_result() + .expect("UnrecognizedFormat must convert to Some"); + + assert!(matches!( + result, + GeneratedTokenResult::UnrecognizedToolCallFormat(raw) + if raw.text == "raw output" && raw.ffi_error_message == "parser bailed" + )); } } diff --git a/paddler/src/tool_call_pipeline.rs b/paddler_agent/src/tool_call_pipeline.rs similarity index 69% rename from paddler/src/tool_call_pipeline.rs rename to paddler_agent/src/tool_call_pipeline.rs index 016de45f..2b3c5ea9 100644 --- a/paddler/src/tool_call_pipeline.rs +++ b/paddler_agent/src/tool_call_pipeline.rs @@ -4,8 +4,8 @@ use llama_cpp_bindings::ChatMessageParseOutcome; use llama_cpp_bindings::ParsedToolCall; use llama_cpp_bindings::RawChatMessage; use llama_cpp_bindings::model::LlamaModel; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::raw_tool_call_tokens::RawToolCallTokens; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::raw_tool_call_tokens::RawToolCallTokens; use crate::tool_call_buffer::ToolCallBuffer; use crate::tool_call_event::ToolCallEvent; @@ -68,36 +68,6 @@ impl ToolCallPipeline { self.finalize().into_generated_token_result() } - #[must_use] - pub fn try_partial(&self) -> ToolCallEvent { - let input = self.buffer.as_str(); - if input.is_empty() { - return ToolCallEvent::Pending; - } - - match self.model.parse_chat_message(&self.tools_json, input, true) { - Ok(ChatMessageParseOutcome::Recognized(parsed)) if parsed.tool_calls.is_empty() => { - ToolCallEvent::Pending - } - Ok(ChatMessageParseOutcome::Recognized(parsed)) => { - self.validate_resolved(parsed.tool_calls) - } - Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage { - text, - ffi_error_message, - .. - })) => ToolCallEvent::UnrecognizedFormat(RawToolCallTokens { - text, - ffi_error_message, - }), - Err(err) => ToolCallEvent::ParseFailed(ToolCallPipelineError::Bindings(err)), - } - } - - pub fn reset(&mut self) { - self.buffer.clear(); - } - #[must_use] pub const fn buffer_is_empty(&self) -> bool { self.buffer.is_empty() diff --git a/paddler/src/tool_call_pipeline_error.rs b/paddler_agent/src/tool_call_pipeline_error.rs similarity index 100% rename from paddler/src/tool_call_pipeline_error.rs rename to paddler_agent/src/tool_call_pipeline_error.rs diff --git a/paddler/src/tool_call_validator.rs b/paddler_agent/src/tool_call_validator.rs similarity index 65% rename from paddler/src/tool_call_validator.rs rename to paddler_agent/src/tool_call_validator.rs index 63f48a80..e065ad94 100644 --- a/paddler/src/tool_call_validator.rs +++ b/paddler_agent/src/tool_call_validator.rs @@ -3,19 +3,13 @@ use std::collections::HashMap; use jsonschema::Validator; use llama_cpp_bindings::ParsedToolCall; use llama_cpp_bindings::ToolCallArguments; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; - -use crate::tool_call_validation_error::ToolCallValidationError; - -#[derive(Debug, thiserror::Error)] -pub enum ValidatorBuildError { - #[error("could not serialize tool {tool_name:?} parameters to JSON: {message}")] - SerializationFailed { tool_name: String, message: String }, - #[error("tool {tool_name:?} parameters are not a valid JSON Schema: {message}")] - InvalidSchema { tool_name: String, message: String }, -} +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + +use paddler_messaging::tool_call_validation_error::ToolCallValidationError; + +use crate::validator_build_error::ValidatorBuildError; enum ValidationStrategy { JsonObjectOnly, @@ -100,21 +94,20 @@ impl ToolCallValidator { #[cfg(test)] mod tests { - use anyhow::Result; - use anyhow::bail; use llama_cpp_bindings::ParsedToolCall; use llama_cpp_bindings::ToolCallArguments; - use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; - use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; - use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; - use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; - use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; use serde_json::json; use super::ToolCallValidator; - use crate::tool_call_validation_error::ToolCallValidationError; + use crate::validator_build_error::ValidatorBuildError; + use paddler_messaging::tool_call_validation_error::ToolCallValidationError; fn valid_json_arguments(value: Value) -> ToolCallArguments { ToolCallArguments::ValidJson(value) @@ -152,126 +145,119 @@ mod tests { } #[test] - fn schema_validator_accepts_matching_arguments() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + fn schema_validator_accepts_matching_arguments() { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "get_weather".to_owned(), valid_json_arguments(json!({"location": "Paris"})), ); - validator.validate(&parsed)?; - - Ok(()) + assert!(validator.validate(&parsed).is_ok()); } #[test] - fn schema_validator_rejects_missing_required_field() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + fn schema_validator_rejects_missing_required_field() { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "get_weather".to_owned(), valid_json_arguments(json!({})), ); - match validator.validate(&parsed) { - Err(ToolCallValidationError::SchemaMismatch { tool_name, .. }) => { - assert_eq!(tool_name, "get_weather"); - Ok(()) - } - other => bail!("expected SchemaMismatch, got {other:?}"), - } + let validation_error = validator.validate(&parsed).err().unwrap(); + + assert!(matches!( + validation_error, + ToolCallValidationError::SchemaMismatch { tool_name, .. } if tool_name == "get_weather" + )); } #[test] - fn schema_validator_rejects_wrong_type() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + fn schema_validator_rejects_wrong_type() { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "get_weather".to_owned(), valid_json_arguments(json!({"location": 42})), ); - match validator.validate(&parsed) { - Err(ToolCallValidationError::SchemaMismatch { .. }) => Ok(()), - other => bail!("expected SchemaMismatch, got {other:?}"), - } + let validation_error = validator.validate(&parsed).err().unwrap(); + + assert!(matches!( + validation_error, + ToolCallValidationError::SchemaMismatch { tool_name, .. } if tool_name == "get_weather" + )); } #[test] - fn unknown_tool_name_returns_error() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + fn unknown_tool_name_returns_error() { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "set_thermostat".to_owned(), valid_json_arguments(json!({"value": 21})), ); - match validator.validate(&parsed) { - Err(ToolCallValidationError::UnknownToolName(name)) => { - assert_eq!(name, "set_thermostat"); - Ok(()) - } - other => bail!("expected UnknownToolName, got {other:?}"), - } + let validation_error = validator.validate(&parsed).err().unwrap(); + + assert!(matches!( + validation_error, + ToolCallValidationError::UnknownToolName(name) if name == "set_thermostat" + )); } #[test] - fn invalid_json_arguments_pass_validation_silently() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()])?; + fn invalid_json_arguments_pass_validation_silently() { + let validator = ToolCallValidator::from_tools(&[weather_tool_with_schema()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "get_weather".to_owned(), ToolCallArguments::InvalidJson("not json".to_owned()), ); - validator.validate(&parsed)?; - - Ok(()) + assert!(validator.validate(&parsed).is_ok()); } #[test] - fn schemaless_tool_accepts_any_object() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[schemaless_tool()])?; + fn schemaless_tool_accepts_any_object() { + let validator = ToolCallValidator::from_tools(&[schemaless_tool()]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "freeform".to_owned(), valid_json_arguments(json!({"x": 1, "y": 2})), ); - validator.validate(&parsed)?; - - Ok(()) + assert!(validator.validate(&parsed).is_ok()); } #[test] - fn known_tool_names_returns_all_registered_names() -> Result<()> { + fn known_tool_names_returns_all_registered_names() { let validator = - ToolCallValidator::from_tools(&[weather_tool_with_schema(), schemaless_tool()])?; + ToolCallValidator::from_tools(&[weather_tool_with_schema(), schemaless_tool()]) + .unwrap(); let mut names = validator.known_tool_names(); names.sort_unstable(); assert_eq!(names, vec!["freeform", "get_weather"]); - - Ok(()) } #[test] - fn empty_tools_yields_validator_that_rejects_any_call() -> Result<()> { - let validator = ToolCallValidator::from_tools(&[])?; + fn empty_tools_yields_validator_that_rejects_any_call() { + let validator = ToolCallValidator::from_tools(&[]).unwrap(); let parsed = ParsedToolCall::new( "id".to_owned(), "anything".to_owned(), valid_json_arguments(json!({})), ); + let validation_error = validator.validate(&parsed).err().unwrap(); + assert!(matches!( - validator.validate(&parsed), - Err(ToolCallValidationError::UnknownToolName(_)) + validation_error, + ToolCallValidationError::UnknownToolName(name) if name == "anything" )); - - Ok(()) } fn tool_with_invalid_property_schema() -> Tool { @@ -293,20 +279,15 @@ mod tests { } #[test] - fn invalid_property_schema_rejects_validator_build() -> Result<()> { - let error = ToolCallValidator::from_tools(&[tool_with_invalid_property_schema()]) + fn invalid_property_schema_rejects_validator_build() { + let build_error = ToolCallValidator::from_tools(&[tool_with_invalid_property_schema()]) .err() - .ok_or_else(|| anyhow::anyhow!("expected ValidatorBuildError, got Ok"))?; + .unwrap(); - match error { - super::ValidatorBuildError::InvalidSchema { tool_name, .. } => { - assert_eq!(tool_name, "broken_tool"); - Ok(()) - } - super::ValidatorBuildError::SerializationFailed { .. } => { - bail!("expected InvalidSchema, got SerializationFailed: {error:?}"); - } - } + assert!(matches!( + build_error, + ValidatorBuildError::InvalidSchema { tool_name, .. } if tool_name == "broken_tool" + )); } fn tool_with_invalid_additional_properties_schema() -> Tool { @@ -325,20 +306,15 @@ mod tests { } #[test] - fn invalid_additional_properties_schema_rejects_validator_build() -> Result<()> { - let error = + fn invalid_additional_properties_schema_rejects_validator_build() { + let build_error = ToolCallValidator::from_tools(&[tool_with_invalid_additional_properties_schema()]) .err() - .ok_or_else(|| anyhow::anyhow!("expected ValidatorBuildError, got Ok"))?; + .unwrap(); - match error { - super::ValidatorBuildError::InvalidSchema { tool_name, .. } => { - assert_eq!(tool_name, "broken_additional"); - Ok(()) - } - super::ValidatorBuildError::SerializationFailed { .. } => { - bail!("expected InvalidSchema, got SerializationFailed: {error:?}"); - } - } + assert!(matches!( + build_error, + ValidatorBuildError::InvalidSchema { tool_name, .. } if tool_name == "broken_additional" + )); } } diff --git a/paddler_agent/src/validator_build_error.rs b/paddler_agent/src/validator_build_error.rs new file mode 100644 index 00000000..23876a43 --- /dev/null +++ b/paddler_agent/src/validator_build_error.rs @@ -0,0 +1,7 @@ +#[derive(Debug, thiserror::Error)] +pub enum ValidatorBuildError { + #[error("could not serialize tool {tool_name:?} parameters to JSON: {message}")] + SerializationFailed { tool_name: String, message: String }, + #[error("tool {tool_name:?} parameters are not a valid JSON Schema: {message}")] + InvalidSchema { tool_name: String, message: String }, +} diff --git a/paddler_balancer/Cargo.toml b/paddler_balancer/Cargo.toml new file mode 100644 index 00000000..0be34360 --- /dev/null +++ b/paddler_balancer/Cargo.toml @@ -0,0 +1,62 @@ +[package] +name = "paddler_balancer" +authors.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +actix-cors = { workspace = true } +actix-web = { workspace = true } +actix-web-lab = { workspace = true } +actix-ws = { workspace = true } +anyhow = { workspace = true } +async-stream = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +cadence = { workspace = true } +dashmap = { workspace = true } +futures = { workspace = true } +futures-util = { workspace = true } +indoc = { workspace = true } +llama-cpp-bindings-types = { workspace = true } +log = { workspace = true } +nanoid = { workspace = true } +parking_lot = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +shellexpand = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +tokio-util = { workspace = true } +trzcina = { workspace = true } +url = { workspace = true } +paddler_messaging = { workspace = true } +paddler_state_conversion = { workspace = true } + +# web dashboard deps +askama = { workspace = true, optional = true } +esbuild-metafile = { workspace = true, optional = true } +mime_guess = { workspace = true, optional = true } +rust-embed = { workspace = true, optional = true } + +[features] +default = [] +web_admin_panel = [ + "dep:askama", + "dep:esbuild-metafile", + "dep:mime_guess", + "dep:rust-embed", +] + +[dev-dependencies] +paddler_openai_response_format_validator = { workspace = true } +tempfile = { workspace = true } +tokio-test = { workspace = true } + +[lints] +workspace = true diff --git a/paddler/src/balancer/agent_controller.rs b/paddler_balancer/src/agent_controller.rs similarity index 55% rename from paddler/src/balancer/agent_controller.rs rename to paddler_balancer/src/agent_controller.rs index 27c8e3d7..5771c827 100644 --- a/paddler/src/balancer/agent_controller.rs +++ b/paddler_balancer/src/agent_controller.rs @@ -1,6 +1,5 @@ use std::collections::BTreeSet; use std::sync::Arc; -use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; @@ -9,35 +8,36 @@ use anyhow::Result; use async_trait::async_trait; use log::debug; use nanoid::nanoid; +use parking_lot::RwLock; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::agent_desired_state::AgentDesiredState; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; - -use crate::agent::jsonrpc::Message as AgentJsonRpcMessage; -use crate::agent::jsonrpc::Notification as AgentJsonRpcNotification; -use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; -use crate::agent::jsonrpc::notification_params::SetStateParams; -use crate::atomic_value::AtomicValue; -use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::manages_senders_controller::ManagesSendersController; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use crate::produces_snapshot::ProducesSnapshot; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::jsonrpc::request_envelope::RequestEnvelope; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; + +use crate::agent_controller_update_result::AgentControllerUpdateResult; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use crate::embedding_sender_collection::EmbeddingSenderCollection; +use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::manages_senders::ManagesSenders; +use crate::manages_senders_controller::ManagesSendersController; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; use crate::sends_rpc_message::SendsRpcMessage; use crate::sets_desired_state::SetsDesiredState; +use paddler_messaging::atomic_value::AtomicValue; +use paddler_messaging::management_socket::agent::message::Message as AgentJsonRpcMessage; +use paddler_messaging::management_socket::agent::notification::Notification as AgentJsonRpcNotification; +use paddler_messaging::management_socket::agent::notification_params::set_state_params::SetStateParams; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; +use paddler_messaging::produces_snapshot::ProducesSnapshot; pub struct AgentController { pub agent_message_tx: mpsc::UnboundedSender, @@ -73,17 +73,12 @@ impl AgentController { .await } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_download_filename(&self) -> Option { - self.download_filename - .read() - .expect("Poisoned lock on download filename") - .clone() + self.download_filename.read().clone() } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_issues(&self) -> BTreeSet { - self.issues.read().expect("Poisoned lock on issues").clone() + self.issues.read().clone() } pub async fn get_model_metadata( @@ -96,37 +91,24 @@ impl AgentController { .await } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_model_path(&self) -> Option { - self.model_path - .read() - .expect("Poisoned lock on model path") - .clone() + self.model_path.read().clone() } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_download_filename(&self, filename: Option) { - let mut locked_filename = self - .download_filename - .write() - .expect("Poisoned lock on download filename"); + let mut locked_filename = self.download_filename.write(); *locked_filename = filename; } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_issues(&self, issues: BTreeSet) { - let mut locked_issues = self.issues.write().expect("Poisoned lock on issues"); + let mut locked_issues = self.issues.write(); *locked_issues = issues; } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn set_model_path(&self, model_path: Option) { - let mut locked_path = self - .model_path - .write() - .expect("Poisoned lock on model path"); + let mut locked_path = self.model_path.write(); *locked_path = model_path; } @@ -355,10 +337,15 @@ impl SetsDesiredState for AgentController { #[cfg(test)] mod tests { - use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use super::*; + fn is_updated(result: &AgentControllerUpdateResult) -> bool { + matches!(result, AgentControllerUpdateResult::Updated) + } + fn fresh_agent_controller() -> AgentController { let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); @@ -391,7 +378,7 @@ mod tests { } #[test] - fn multi_field_update_stores_all_changed_atomic_fields() -> Result<()> { + fn multi_field_update_stores_all_changed_atomic_fields() { let agent_controller = fresh_agent_controller(); let snapshot = SlotAggregatedStatusSnapshot { @@ -411,38 +398,220 @@ mod tests { let result = agent_controller.update_from_slot_aggregated_status_snapshot(snapshot); - if !matches!(result, AgentControllerUpdateResult::Updated) { - anyhow::bail!("update with multiple changed fields must return Updated"); - } + assert!(is_updated(&result)); + assert_eq!(agent_controller.desired_slots_total.get(), 4); + assert_eq!(agent_controller.download_current.get(), 10); + assert_eq!(agent_controller.download_total.get(), 100); + assert_eq!(agent_controller.slots_total.get(), 4); + assert!(agent_controller.uses_chat_template_override.get()); + } - if agent_controller.desired_slots_total.get() != 4 { - anyhow::bail!( - "desired_slots_total must be stored: expected 4, got {}", - agent_controller.desired_slots_total.get() - ); - } - if agent_controller.download_current.get() != 10 { - anyhow::bail!( - "download_current must be stored: expected 10, got {}", - agent_controller.download_current.get() - ); - } - if agent_controller.download_total.get() != 100 { - anyhow::bail!( - "download_total must be stored: expected 100, got {}", - agent_controller.download_total.get() - ); - } - if agent_controller.slots_total.get() != 4 { - anyhow::bail!( - "slots_total must be stored: expected 4, got {}", - agent_controller.slots_total.get() - ); - } - if !agent_controller.uses_chat_template_override.get() { - anyhow::bail!("uses_chat_template_override must be stored: expected true, got false"); - } + #[test] + fn update_with_older_version_is_discarded() { + let agent_controller = fresh_agent_controller(); - Ok(()) + agent_controller.newest_update_version.set(5); + + let snapshot = SlotAggregatedStatusSnapshot { + desired_slots_total: 9, + download_current: 0, + download_filename: None, + download_indeterminate: true, + download_total: 0, + issues: BTreeSet::new(), + model_path: None, + slots_processing: 0, + slots_total: 0, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: false, + version: 1, + }; + + let result = agent_controller.update_from_slot_aggregated_status_snapshot(snapshot); + + assert!(!is_updated(&result)); + assert_eq!(agent_controller.desired_slots_total.get(), 0); + } + + #[test] + fn update_stores_new_download_filename_model_path_and_issues() { + let agent_controller = fresh_agent_controller(); + + let mut issues = BTreeSet::new(); + issues.insert(AgentIssue::ModelFileDoesNotExist(ModelPath { + model_path: "/models/test.gguf".to_owned(), + })); + + let snapshot = SlotAggregatedStatusSnapshot { + desired_slots_total: 0, + download_current: 0, + download_filename: Some("weights.gguf".to_owned()), + download_indeterminate: true, + download_total: 0, + issues: issues.clone(), + model_path: Some("/models/test.gguf".to_owned()), + slots_processing: 0, + slots_total: 0, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: false, + version: 1, + }; + + let result = agent_controller.update_from_slot_aggregated_status_snapshot(snapshot); + + assert!(is_updated(&result)); + assert_eq!( + agent_controller.get_download_filename(), + Some("weights.gguf".to_owned()) + ); + assert_eq!( + agent_controller.get_model_path(), + Some("/models/test.gguf".to_owned()) + ); + assert_eq!(agent_controller.get_issues(), issues); + } + + #[test] + fn update_with_identical_values_reports_no_meaningful_changes() { + let agent_controller = fresh_agent_controller(); + + let snapshot = SlotAggregatedStatusSnapshot { + desired_slots_total: 0, + download_current: 0, + download_filename: None, + download_indeterminate: true, + download_total: 0, + issues: BTreeSet::new(), + model_path: None, + slots_processing: 0, + slots_total: 0, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: false, + version: 1, + }; + + let result = agent_controller.update_from_slot_aggregated_status_snapshot(snapshot); + + assert!(!is_updated(&result)); + } + + #[test] + fn make_snapshot_fails_for_invalid_state_application_status() { + let agent_controller = fresh_agent_controller(); + + agent_controller.state_application_status_code.set(99); + + let result = agent_controller.make_snapshot(); + let error = result.err().unwrap(); + + assert!( + error + .to_string() + .contains("Invalid value for AgentStateApplicationStatus") + ); + } + + #[tokio::test] + async fn get_chat_template_override_registers_pending_request() { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent_controller = AgentController { + agent_message_tx, + ..fresh_agent_controller() + }; + + let controller = agent_controller.get_chat_template_override().await.unwrap(); + + assert!( + controller + .response_sender_collection + .get_sender_collection() + .contains_key(&controller.request_id) + ); + } + + #[tokio::test] + async fn handle_raw_prompt_streaming_response_registers_sender() { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent_controller = AgentController { + agent_message_tx, + ..fresh_agent_controller() + }; + + let controller = HandlesAgentStreamingResponse::::handle_streaming_response( + &agent_controller, + "raw-prompt-request".to_owned(), + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: "hello".to_owned(), + }, + ) + .await + .unwrap(); + + assert_eq!(controller.request_id, "raw-prompt-request"); + assert!( + controller + .response_sender_collection + .get_sender_collection() + .contains_key("raw-prompt-request") + ); + } + + #[tokio::test] + async fn handle_streaming_response_fails_when_request_id_already_registered() { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent_controller = AgentController { + agent_message_tx, + ..fresh_agent_controller() + }; + + let _first_controller = + HandlesAgentStreamingResponse::::handle_streaming_response( + &agent_controller, + "duplicate-request".to_owned(), + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: "first".to_owned(), + }, + ) + .await + .unwrap(); + + let result = + HandlesAgentStreamingResponse::::handle_streaming_response( + &agent_controller, + "duplicate-request".to_owned(), + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: "second".to_owned(), + }, + ) + .await; + + let error = result.err().unwrap(); + + assert_eq!( + error.to_string(), + "Sender for request_id duplicate-request already exists" + ); + } + + #[tokio::test] + async fn send_rpc_message_fails_when_agent_message_receiver_dropped() { + let (agent_message_tx, agent_message_rx) = mpsc::unbounded_channel(); + + drop(agent_message_rx); + + let agent_controller = AgentController { + agent_message_tx, + ..fresh_agent_controller() + }; + + let result = agent_controller.get_chat_template_override().await; + + assert!(result.is_err()); } } diff --git a/paddler/src/balancer/agent_controller_pool.rs b/paddler_balancer/src/agent_controller_pool.rs similarity index 51% rename from paddler/src/balancer/agent_controller_pool.rs rename to paddler_balancer/src/agent_controller_pool.rs index 8c4188ba..795175ca 100644 --- a/paddler/src/balancer/agent_controller_pool.rs +++ b/paddler_balancer/src/agent_controller_pool.rs @@ -3,19 +3,19 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use dashmap::DashMap; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::agent_desired_state::AgentDesiredState; use tokio::sync::watch; use super::agent_controller::AgentController; use super::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; -use crate::balancer::agent_controller_slot_guard::AgentControllerSlotGuard; -use crate::balancer::dispatch_candidate::DispatchCandidate; -use crate::balancer::dispatched_agent::DispatchedAgent; -use crate::produces_snapshot::ProducesSnapshot; +use crate::agent_controller_slot_guard::AgentControllerSlotGuard; +use crate::dispatch_candidate::DispatchCandidate; +use crate::dispatched_agent::DispatchedAgent; use crate::sets_desired_state::SetsDesiredState; -use crate::subscribes_to_updates::SubscribesToUpdates; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates; pub struct AgentControllerPool { pub agents: DashMap>, @@ -182,10 +182,62 @@ impl SetsDesiredState for AgentControllerPool { #[cfg(test)] mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; use std::time::Duration; + use tokio::sync::mpsc; use tokio::sync::watch; use tokio::time::timeout; + use tokio_util::sync::CancellationToken; + + use super::AgentControllerPool; + use crate::agent_controller::AgentController; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::produces_snapshot::ProducesSnapshot; + + fn agent_controller_with_slots( + slots_processing: i32, + slots_total: i32, + ) -> Arc { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-test".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(slots_processing), + slots_total: AtomicValue::::new(slots_total), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }) + } #[tokio::test] async fn watch_receiver_observes_send_fired_before_changed_await() { @@ -200,4 +252,61 @@ mod tests { "watch::Receiver must observe a send fired before .changed() is awaited" ); } + + #[test] + fn register_agent_controller_rejects_duplicate_id() { + let pool = AgentControllerPool::default(); + + assert!( + pool.register_agent_controller( + "duplicate".to_owned(), + agent_controller_with_slots(0, 1), + ) + .is_ok() + ); + + let duplicate_result = pool + .register_agent_controller("duplicate".to_owned(), agent_controller_with_slots(0, 1)); + + assert_eq!( + duplicate_result.err().unwrap().to_string(), + "AgentController already registered" + ); + } + + #[test] + fn remove_agent_controller_returns_false_for_unknown_id() { + let pool = AgentControllerPool::default(); + + assert!(!pool.remove_agent_controller("never-registered").unwrap()); + } + + #[test] + fn total_slots_sums_processing_and_total_across_agents() { + let pool = AgentControllerPool::default(); + + pool.register_agent_controller("first".to_owned(), agent_controller_with_slots(1, 4)) + .unwrap(); + pool.register_agent_controller("second".to_owned(), agent_controller_with_slots(2, 8)) + .unwrap(); + + let total_slots = pool.total_slots(); + + assert_eq!(total_slots.slots_processing, 3); + assert_eq!(total_slots.slots_total, 12); + } + + #[test] + fn make_snapshot_includes_each_registered_agent() { + let pool = AgentControllerPool::default(); + + pool.register_agent_controller("only".to_owned(), agent_controller_with_slots(2, 5)) + .unwrap(); + + let snapshot = pool.make_snapshot().unwrap(); + + assert_eq!(snapshot.agents.len(), 1); + assert_eq!(snapshot.agents[0].slots_processing, 2); + assert_eq!(snapshot.agents[0].slots_total, 5); + } } diff --git a/paddler/src/balancer/agent_controller_pool_total_slots.rs b/paddler_balancer/src/agent_controller_pool_total_slots.rs similarity index 100% rename from paddler/src/balancer/agent_controller_pool_total_slots.rs rename to paddler_balancer/src/agent_controller_pool_total_slots.rs diff --git a/paddler/src/balancer/agent_controller_slot_guard.rs b/paddler_balancer/src/agent_controller_slot_guard.rs similarity index 91% rename from paddler/src/balancer/agent_controller_slot_guard.rs rename to paddler_balancer/src/agent_controller_slot_guard.rs index 263b823d..4612db7a 100644 --- a/paddler/src/balancer/agent_controller_slot_guard.rs +++ b/paddler_balancer/src/agent_controller_slot_guard.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::sync::watch; -use crate::balancer::agent_controller::AgentController; +use crate::agent_controller::AgentController; pub struct AgentControllerSlotGuard { agent_controller: Arc, diff --git a/paddler/src/balancer/agent_controller_update_result.rs b/paddler_balancer/src/agent_controller_update_result.rs similarity index 100% rename from paddler/src/balancer/agent_controller_update_result.rs rename to paddler_balancer/src/agent_controller_update_result.rs diff --git a/paddler/src/balancer_applicable_state.rs b/paddler_balancer/src/balancer_applicable_state.rs similarity index 64% rename from paddler/src/balancer_applicable_state.rs rename to paddler_balancer/src/balancer_applicable_state.rs index eafc71ac..6efaa8b0 100644 --- a/paddler/src/balancer_applicable_state.rs +++ b/paddler_balancer/src/balancer_applicable_state.rs @@ -1,4 +1,4 @@ -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_desired_state::AgentDesiredState; #[derive(Clone, Debug)] pub struct BalancerApplicableState { diff --git a/paddler/src/balancer_applicable_state_holder.rs b/paddler_balancer/src/balancer_applicable_state_holder.rs similarity index 57% rename from paddler/src/balancer_applicable_state_holder.rs rename to paddler_balancer/src/balancer_applicable_state_holder.rs index 94902204..332fc22f 100644 --- a/paddler/src/balancer_applicable_state_holder.rs +++ b/paddler_balancer/src/balancer_applicable_state_holder.rs @@ -1,10 +1,9 @@ -use std::sync::RwLock; - -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_desired_state::AgentDesiredState; +use parking_lot::RwLock; use tokio::sync::watch; use crate::balancer_applicable_state::BalancerApplicableState; -use crate::subscribes_to_updates::SubscribesToUpdates; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates; pub struct BalancerApplicableStateHolder { update_tx: watch::Sender<()>, @@ -12,21 +11,15 @@ pub struct BalancerApplicableStateHolder { } impl BalancerApplicableStateHolder { - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_agent_desired_state(&self) -> Option { self.balancer_applicable_state .read() - .expect("Failed to get balancer state lock") .as_ref() .map(|state| state.agent_desired_state.clone()) } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] pub fn get_balancer_applicable_state(&self) -> Option { - self.balancer_applicable_state - .read() - .expect("Failed to get balancer state lock") - .clone() + self.balancer_applicable_state.read().clone() } pub fn set_balancer_applicable_state( @@ -34,11 +27,7 @@ impl BalancerApplicableStateHolder { balancer_applicable_state: Option, ) { { - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] - let mut lock = self - .balancer_applicable_state - .write() - .expect("Failed to get balancer state lock"); + let mut lock = self.balancer_applicable_state.write(); *lock = balancer_applicable_state; } @@ -66,9 +55,8 @@ impl SubscribesToUpdates for BalancerApplicableStateHolder { #[cfg(test)] mod tests { - use anyhow::Result; - use paddler_types::agent_desired_model::AgentDesiredModel; - use paddler_types::inference_parameters::InferenceParameters; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::inference_parameters::InferenceParameters; use tokio::time::Duration; use tokio::time::timeout; @@ -79,14 +67,14 @@ mod tests { agent_desired_state: AgentDesiredState { chat_template_override: None, inference_parameters: InferenceParameters::default(), - model: AgentDesiredModel::None, + model: AgentDesiredModel::LocalToAgent("model.gguf".to_owned()), multimodal_projection: AgentDesiredModel::None, }, } } #[tokio::test] - async fn set_balancer_applicable_state_wakes_subscribed_waiter() -> Result<()> { + async fn set_balancer_applicable_state_wakes_subscribed_waiter() { let holder = BalancerApplicableStateHolder::default(); let mut update_rx = holder.subscribe_to_updates(); @@ -94,14 +82,12 @@ mod tests { timeout(Duration::from_secs(1), update_rx.changed()) .await - .map_err(|error| anyhow::anyhow!("waiter did not awaken: {error}"))? - .map_err(|error| anyhow::anyhow!("watch sender dropped: {error}"))?; - - Ok(()) + .expect("waiter did not awaken before the timeout elapsed") + .expect("watch sender was dropped before the change arrived"); } #[test] - fn get_balancer_applicable_state_returns_stored_value() -> Result<()> { + fn get_balancer_applicable_state_returns_stored_value() { let holder = BalancerApplicableStateHolder::default(); assert!(holder.get_balancer_applicable_state().is_none()); @@ -112,13 +98,43 @@ mod tests { let stored = holder .get_balancer_applicable_state() - .ok_or_else(|| anyhow::anyhow!("state should be present after set"))?; + .expect("state should be present after set"); assert_eq!( stored.agent_desired_state.model, applicable_state.agent_desired_state.model ); + } + + #[test] + fn get_agent_desired_state_returns_none_before_any_state_is_set() { + let holder = BalancerApplicableStateHolder::default(); + + assert!(holder.get_agent_desired_state().is_none()); + } + + #[test] + fn get_agent_desired_state_returns_stored_agent_desired_state() { + let holder = BalancerApplicableStateHolder::default(); + let applicable_state = make_applicable_state(); + + holder.set_balancer_applicable_state(Some(applicable_state.clone())); + + let stored = holder + .get_agent_desired_state() + .expect("agent desired state should be present after set"); + + assert_eq!(stored.model, applicable_state.agent_desired_state.model); + } - Ok(()) + #[test] + fn set_balancer_applicable_state_can_clear_back_to_none() { + let holder = BalancerApplicableStateHolder::default(); + + holder.set_balancer_applicable_state(Some(make_applicable_state())); + holder.set_balancer_applicable_state(None); + + assert!(holder.get_balancer_applicable_state().is_none()); + assert!(holder.get_agent_desired_state().is_none()); } } diff --git a/paddler_balancer/src/balancer_desired_state_converter.rs b/paddler_balancer/src/balancer_desired_state_converter.rs new file mode 100644 index 00000000..9102e978 --- /dev/null +++ b/paddler_balancer/src/balancer_desired_state_converter.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use async_trait::async_trait; + +use paddler_messaging::agent_desired_state::AgentDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_state_conversion::converts_to_applicable_state::ConvertsToApplicableState; +use paddler_state_conversion::converts_to_desired_state::ConvertsToDesiredState; + +use crate::balancer_applicable_state::BalancerApplicableState; + +pub struct BalancerDesiredStateConverter; + +impl ConvertsToDesiredState for BalancerDesiredStateConverter { + type DesiredState = AgentDesiredState; + type Source = BalancerDesiredState; + + fn to_desired_state( + &self, + BalancerDesiredState { + chat_template_override, + inference_parameters, + model, + multimodal_projection, + use_chat_template_override, + }: BalancerDesiredState, + ) -> AgentDesiredState { + AgentDesiredState { + chat_template_override: if use_chat_template_override { + chat_template_override + } else { + None + }, + inference_parameters, + model, + multimodal_projection, + } + } +} + +#[async_trait] +impl ConvertsToApplicableState for BalancerDesiredStateConverter { + type ApplicableState = BalancerApplicableState; + type DesiredState = BalancerDesiredState; + + async fn to_applicable_state( + &self, + desired_state: BalancerDesiredState, + ) -> Result { + Ok(BalancerApplicableState { + agent_desired_state: self.to_desired_state(desired_state), + }) + } +} diff --git a/paddler/src/balancer/buffered_request_agent_wait_result.rs b/paddler_balancer/src/buffered_request_agent_wait_result.rs similarity index 70% rename from paddler/src/balancer/buffered_request_agent_wait_result.rs rename to paddler_balancer/src/buffered_request_agent_wait_result.rs index 6ab213d2..1a16f164 100644 --- a/paddler/src/balancer/buffered_request_agent_wait_result.rs +++ b/paddler_balancer/src/buffered_request_agent_wait_result.rs @@ -1,6 +1,6 @@ use anyhow::Error; -use crate::balancer::dispatched_agent::DispatchedAgent; +use crate::dispatched_agent::DispatchedAgent; pub enum BufferedRequestAgentWaitResult { BufferOverflow, diff --git a/paddler/src/balancer/buffered_request_count_guard.rs b/paddler_balancer/src/buffered_request_count_guard.rs similarity index 86% rename from paddler/src/balancer/buffered_request_count_guard.rs rename to paddler_balancer/src/buffered_request_count_guard.rs index a8215595..cd9e3e3b 100644 --- a/paddler/src/balancer/buffered_request_count_guard.rs +++ b/paddler_balancer/src/buffered_request_count_guard.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::balancer::buffered_request_counter::BufferedRequestCounter; +use crate::buffered_request_counter::BufferedRequestCounter; pub struct BufferedRequestCountGuard { buffered_requests_counter: Arc, diff --git a/paddler/src/balancer/buffered_request_counter.rs b/paddler_balancer/src/buffered_request_counter.rs similarity index 94% rename from paddler/src/balancer/buffered_request_counter.rs rename to paddler_balancer/src/buffered_request_counter.rs index 5d53f9c8..68494702 100644 --- a/paddler/src/balancer/buffered_request_counter.rs +++ b/paddler_balancer/src/buffered_request_counter.rs @@ -3,8 +3,8 @@ use std::sync::atomic::AtomicI32; use tokio::sync::watch; -use crate::atomic_value::AtomicValue; -use crate::balancer::buffered_request_count_guard::BufferedRequestCountGuard; +use crate::buffered_request_count_guard::BufferedRequestCountGuard; +use paddler_messaging::atomic_value::AtomicValue; pub struct BufferedRequestCounter { count: Arc>, diff --git a/paddler/src/balancer/buffered_request_manager.rs b/paddler_balancer/src/buffered_request_manager.rs similarity index 66% rename from paddler/src/balancer/buffered_request_manager.rs rename to paddler_balancer/src/buffered_request_manager.rs index 7fc812f5..85f11740 100644 --- a/paddler/src/balancer/buffered_request_manager.rs +++ b/paddler_balancer/src/buffered_request_manager.rs @@ -2,15 +2,15 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use paddler_types::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; use tokio::sync::watch; use tokio::time::timeout; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; -use crate::balancer::buffered_request_counter::BufferedRequestCounter; -use crate::produces_snapshot::ProducesSnapshot; -use crate::subscribes_to_updates::SubscribesToUpdates; +use crate::agent_controller_pool::AgentControllerPool; +use crate::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; +use crate::buffered_request_counter::BufferedRequestCounter; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates; pub struct BufferedRequestManager { agent_controller_pool: Arc, @@ -96,28 +96,66 @@ impl SubscribesToUpdates for BufferedRequestManager { #[cfg(test)] mod tests { + use parking_lot::RwLock; use std::collections::BTreeSet; - use std::sync::RwLock; + use std::mem::Discriminant; + use std::mem::discriminant; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; - use std::task::Poll; - use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use super::*; - use crate::atomic_value::AtomicValue; - use crate::balancer::agent_controller::AgentController; - use crate::balancer::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; - use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; - use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; - use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; - use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::agent_controller::AgentController; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::atomic_value::AtomicValue; + + fn found_result_discriminant() -> Discriminant { + let pool = AgentControllerPool::default(); + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent = Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-discriminant".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }); + + pool.register_agent_controller("agent-discriminant".to_owned(), agent) + .unwrap(); + + let dispatched_agent = pool.take_least_busy_agent_controller().unwrap(); + + discriminant(&BufferedRequestAgentWaitResult::Found(dispatched_agent)) + } #[tokio::test] - async fn counter_increment_wakes_subscribed_waiter() -> Result<()> { + async fn counter_increment_wakes_subscribed_waiter() { let pool = Arc::new(AgentControllerPool::default()); let manager = Arc::new(BufferedRequestManager::new( pool, @@ -129,20 +167,22 @@ mod tests { manager.buffered_request_counter.increment(); - timeout(Duration::from_secs(1), update_rx.changed()) + let observed_within_deadline = timeout(Duration::from_secs(1), update_rx.changed()) .await - .map_err(|err| anyhow::anyhow!("subscriber did not observe within deadline: {err}"))? - .map_err(|err| anyhow::anyhow!("watch sender dropped: {err}"))?; + .unwrap(); - Ok(()) + assert!( + observed_within_deadline.is_ok(), + "watch sender must stay alive while the manager holds it" + ); } #[tokio::test(flavor = "current_thread")] - async fn waiter_returns_found_after_agent_registration_with_no_initial_agents() -> Result<()> { + async fn waiter_returns_found_after_agent_registration_with_no_initial_agents() { let pool = Arc::new(AgentControllerPool::default()); let manager = Arc::new(BufferedRequestManager::new( pool.clone(), - Duration::from_secs(60), + Duration::from_mins(1), 10, )); @@ -182,26 +222,25 @@ mod tests { uses_chat_template_override: AtomicValue::::new(false), }); - pool.register_agent_controller("agent-1".to_owned(), agent)?; + pool.register_agent_controller("agent-1".to_owned(), agent) + .unwrap(); assert!( waiter.is_woken(), "register_agent_controller must wake the subscribed waiter" ); - let Poll::Ready(result) = waiter.poll() else { - anyhow::bail!("waiter must be Ready after register_agent_controller, got Pending"); - }; - - if !matches!(result?, BufferedRequestAgentWaitResult::Found(_)) { - anyhow::bail!("waiter must return Found after register_agent_controller"); - } + let result = waiter.await.unwrap(); - Ok(()) + assert_eq!( + discriminant(&result), + found_result_discriminant(), + "waiter must return Found after register_agent_controller" + ); } #[tokio::test(flavor = "current_thread")] - async fn waiter_returns_found_when_agent_was_registered_before_call() -> Result<()> { + async fn waiter_returns_found_when_agent_was_registered_before_call() { let pool = Arc::new(AgentControllerPool::default()); let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); @@ -232,27 +271,21 @@ mod tests { uses_chat_template_override: AtomicValue::::new(false), }); - pool.register_agent_controller("agent-pre".to_owned(), agent)?; + pool.register_agent_controller("agent-pre".to_owned(), agent) + .unwrap(); let manager = Arc::new(BufferedRequestManager::new( pool, - Duration::from_secs(60), + Duration::from_mins(1), 10, )); - let mut waiter = - tokio_test::task::spawn(async move { manager.wait_for_available_agent().await }); - - let Poll::Ready(result) = waiter.poll() else { - anyhow::bail!( - "waiter must be Ready on first poll when agent was registered before call" - ); - }; + let result = manager.wait_for_available_agent().await.unwrap(); - if !matches!(result?, BufferedRequestAgentWaitResult::Found(_)) { - anyhow::bail!("waiter must return Found when an agent is already in the pool"); - } - - Ok(()) + assert_eq!( + discriminant(&result), + found_result_discriminant(), + "waiter must return Found when an agent is already in the pool" + ); } } diff --git a/paddler/src/cancellation_token_stream_guard.rs b/paddler_balancer/src/cancellation_token_stream_guard.rs similarity index 100% rename from paddler/src/cancellation_token_stream_guard.rs rename to paddler_balancer/src/cancellation_token_stream_guard.rs diff --git a/paddler/src/balancer/chat_template_override_sender_collection.rs b/paddler_balancer/src/chat_template_override_sender_collection.rs similarity index 85% rename from paddler/src/balancer/chat_template_override_sender_collection.rs rename to paddler_balancer/src/chat_template_override_sender_collection.rs index abd88b7f..abbb7478 100644 --- a/paddler/src/balancer/chat_template_override_sender_collection.rs +++ b/paddler_balancer/src/chat_template_override_sender_collection.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use dashmap::DashMap; -use paddler_types::chat_template::ChatTemplate; +use paddler_messaging::chat_template::ChatTemplate; use tokio::sync::mpsc; -use crate::balancer::manages_senders::ManagesSenders; +use crate::manages_senders::ManagesSenders; pub struct ChatTemplateOverrideSenderCollection { senders: DashMap>>, diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs b/paddler_balancer/src/chunk_forwarding_session_controller/identity_transformer.rs similarity index 84% rename from paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs rename to paddler_balancer/src/chunk_forwarding_session_controller/identity_transformer.rs index f9d964af..ee52faf5 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/identity_transformer.rs +++ b/paddler_balancer/src/chunk_forwarding_session_controller/identity_transformer.rs @@ -1,6 +1,6 @@ use anyhow::Result; use async_trait::async_trait; -use paddler_types::inference_client::Message as OutgoingMessage; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; use super::transform_result::TransformResult; use super::transforms_outgoing_message::TransformsOutgoingMessage; @@ -17,6 +17,8 @@ impl IdentityTransformer { #[async_trait] impl TransformsOutgoingMessage for IdentityTransformer { + type Output = TransformResult; + async fn transform(&self, message: OutgoingMessage) -> Result> { let serialized = serde_json::to_string(&message)?; diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs b/paddler_balancer/src/chunk_forwarding_session_controller/mod.rs similarity index 68% rename from paddler/src/balancer/chunk_forwarding_session_controller/mod.rs rename to paddler_balancer/src/chunk_forwarding_session_controller/mod.rs index 3b147333..0dbfc118 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/mod.rs +++ b/paddler_balancer/src/chunk_forwarding_session_controller/mod.rs @@ -3,10 +3,9 @@ pub mod transform_result; pub mod transforms_outgoing_message; use async_trait::async_trait; -use paddler_types::inference_client::Message as OutgoingMessage; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; use tokio::sync::mpsc; -use self::transform_result::TransformResult; use self::transforms_outgoing_message::TransformsOutgoingMessage; use crate::controls_session::ControlsSession; @@ -15,7 +14,7 @@ pub struct ChunkForwardingSessionController where TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync, { - chunk_tx: mpsc::UnboundedSender, + chunk_tx: mpsc::UnboundedSender, transformer: TTransformsOutgoingMessage, } @@ -24,7 +23,7 @@ where TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync, { pub const fn new( - chunk_tx: mpsc::UnboundedSender, + chunk_tx: mpsc::UnboundedSender, transformer: TTransformsOutgoingMessage, ) -> Self { Self { @@ -41,13 +40,8 @@ where TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync, { async fn send_response(&mut self, message: OutgoingMessage) -> anyhow::Result<()> { - for transform_result in self.transformer.transform(message).await? { - match transform_result { - TransformResult::Discard => {} - forwarded @ (TransformResult::Chunk(_) | TransformResult::Error(_)) => { - self.chunk_tx.send(forwarded)?; - } - } + for output in self.transformer.transform(message).await? { + self.chunk_tx.send(output)?; } Ok(()) diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/transform_result.rs b/paddler_balancer/src/chunk_forwarding_session_controller/transform_result.rs similarity index 100% rename from paddler/src/balancer/chunk_forwarding_session_controller/transform_result.rs rename to paddler_balancer/src/chunk_forwarding_session_controller/transform_result.rs diff --git a/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs b/paddler_balancer/src/chunk_forwarding_session_controller/transforms_outgoing_message.rs similarity index 54% rename from paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs rename to paddler_balancer/src/chunk_forwarding_session_controller/transforms_outgoing_message.rs index ccfc438e..e3cb0be4 100644 --- a/paddler/src/balancer/chunk_forwarding_session_controller/transforms_outgoing_message.rs +++ b/paddler_balancer/src/chunk_forwarding_session_controller/transforms_outgoing_message.rs @@ -1,10 +1,10 @@ use anyhow::Result; use async_trait::async_trait; -use paddler_types::inference_client::Message as OutgoingMessage; - -use super::transform_result::TransformResult; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; #[async_trait] pub trait TransformsOutgoingMessage { - async fn transform(&self, message: OutgoingMessage) -> Result>; + type Output: Send + Sync + 'static; + + async fn transform(&self, message: OutgoingMessage) -> Result>; } diff --git a/paddler/src/balancer/compatibility/mod.rs b/paddler_balancer/src/compatibility/mod.rs similarity index 100% rename from paddler/src/balancer/compatibility/mod.rs rename to paddler_balancer/src/compatibility/mod.rs diff --git a/paddler_balancer/src/compatibility/openai_service/app_data.rs b/paddler_balancer/src/compatibility/openai_service/app_data.rs new file mode 100644 index 00000000..3c992f35 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/app_data.rs @@ -0,0 +1,12 @@ +use std::sync::Arc; + +use tokio_util::sync::CancellationToken; + +use crate::buffered_request_manager::BufferedRequestManager; +use crate::inference_service::configuration::Configuration; + +pub struct AppData { + pub buffered_request_manager: Arc, + pub inference_service_configuration: Configuration, + pub shutdown: CancellationToken, +} diff --git a/paddler_balancer/src/compatibility/openai_service/arguments_to_tool_call_string.rs b/paddler_balancer/src/compatibility/openai_service/arguments_to_tool_call_string.rs new file mode 100644 index 00000000..0a831454 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/arguments_to_tool_call_string.rs @@ -0,0 +1,39 @@ +use anyhow::Context as _; +use anyhow::Result; +use llama_cpp_bindings_types::ToolCallArguments; + +pub fn arguments_to_tool_call_string(arguments: &ToolCallArguments) -> Result { + match arguments { + ToolCallArguments::ValidJson(value) => { + serde_json::to_string(value).context("serializing tool-call arguments to OpenAI string") + } + ToolCallArguments::InvalidJson(raw) => Ok(raw.clone()), + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings_types::ToolCallArguments; + + use super::arguments_to_tool_call_string; + + #[test] + fn serializes_valid_json_arguments() { + let serialized = + arguments_to_tool_call_string(&ToolCallArguments::ValidJson(serde_json::json!({ + "location": "Paris" + }))) + .unwrap(); + + assert_eq!(serialized, "{\"location\":\"Paris\"}"); + } + + #[test] + fn passes_invalid_json_through_verbatim() { + let serialized = + arguments_to_tool_call_string(&ToolCallArguments::InvalidJson("{not valid".to_owned())) + .unwrap(); + + assert_eq!(serialized, "{not valid"); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/chat_completions_sse_response.rs b/paddler_balancer/src/compatibility/openai_service/chat_completions_sse_response.rs new file mode 100644 index 00000000..41ea5dea --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/chat_completions_sse_response.rs @@ -0,0 +1,138 @@ +use std::convert::Infallible; +use std::fmt::Debug; +use std::sync::Arc; + +use actix_web::HttpResponse; +use actix_web::http::header; +use actix_web_lab::sse; +use futures::stream::StreamExt as _; +use futures::stream::once; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; +use paddler_messaging::streamable_result::StreamableResult; +use tokio_util::sync::CancellationToken; + +use crate::agent_controller::AgentController; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::manages_senders::ManagesSenders; +use crate::unbounded_stream_from_agent::unbounded_stream_from_agent; + +pub fn chat_completions_sse_response( + buffered_request_manager: Arc, + inference_service_configuration: InferenceServiceConfiguration, + params: TParams, + transformer: TTransformsOutgoingMessage, + shutdown: CancellationToken, +) -> HttpResponse +where + TParams: Debug + Into + Send + 'static, + AgentController: HandlesAgentStreamingResponse, + <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, + TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, +{ + let event_stream = unbounded_stream_from_agent( + buffered_request_manager, + inference_service_configuration, + params, + transformer, + shutdown, + ) + .filter_map(|transform_result| async move { + match transform_result { + TransformResult::Chunk(chunk) | TransformResult::Error(chunk) => { + Some(Ok::(sse::Event::Data( + sse::Data::new(chunk), + ))) + } + TransformResult::Discard => None, + } + }) + .chain(once(async { + Ok::(sse::Event::Data(sse::Data::new("[DONE]"))) + })); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header((header::CACHE_CONTROL, "no-cache")) + .body(sse::Sse::from_stream(event_stream)) +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + + use actix_web::body; + use tokio_util::sync::CancellationToken; + + use super::chat_completions_sse_response; + use crate::agent_controller_pool::AgentControllerPool; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; + use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + fn empty_pool_manager() -> Arc { + Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )) + } + + fn inference_service_configuration() -> InferenceServiceConfiguration { + InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: Duration::from_secs(1), + } + } + + fn raw_prompt_params() -> ContinueFromRawPromptParams { + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "hello".to_owned(), + } + } + + #[actix_web::test] + async fn frames_each_chunk_as_an_sse_data_event_and_terminates_with_done() { + let shutdown = CancellationToken::new(); + shutdown.cancel(); + + let response = chat_completions_sse_response( + empty_pool_manager(), + inference_service_configuration(), + raw_prompt_params(), + IdentityTransformer::new(), + shutdown, + ); + + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + let body_bytes = body::to_bytes(response.into_body()).await.unwrap(); + let body_text = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert!( + body_text.contains("data: "), + "each chunk must be framed as an SSE data event: {body_text}" + ); + assert!( + body_text.contains("balancer is shutting down"), + "the shutdown chunk must be framed into the SSE body: {body_text}" + ); + assert!( + body_text.ends_with("data: [DONE]\n\n"), + "the stream must terminate with the OpenAI [DONE] sentinel: {body_text:?}" + ); + } +} diff --git a/paddler/src/balancer/compatibility/openai_service/configuration.rs b/paddler_balancer/src/compatibility/openai_service/configuration.rs similarity index 100% rename from paddler/src/balancer/compatibility/openai_service/configuration.rs rename to paddler_balancer/src/compatibility/openai_service/configuration.rs diff --git a/paddler_balancer/src/compatibility/openai_service/content_part_event.rs b/paddler_balancer/src/compatibility/openai_service/content_part_event.rs new file mode 100644 index 00000000..a5abba17 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/content_part_event.rs @@ -0,0 +1,10 @@ +use serde_json::Value; + +#[derive(Clone, Debug)] +pub struct ContentPartEvent { + pub sequence_number: u64, + pub item_id: String, + pub output_index: usize, + pub content_index: usize, + pub part: Value, +} diff --git a/paddler_balancer/src/compatibility/openai_service/function_call_arguments_delta_event.rs b/paddler_balancer/src/compatibility/openai_service/function_call_arguments_delta_event.rs new file mode 100644 index 00000000..edb2ed2a --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/function_call_arguments_delta_event.rs @@ -0,0 +1,7 @@ +#[derive(Clone, Debug)] +pub struct FunctionCallArgumentsDeltaEvent { + pub sequence_number: u64, + pub item_id: String, + pub output_index: usize, + pub delta: String, +} diff --git a/paddler_balancer/src/compatibility/openai_service/function_call_arguments_done_event.rs b/paddler_balancer/src/compatibility/openai_service/function_call_arguments_done_event.rs new file mode 100644 index 00000000..2eb7f40c --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/function_call_arguments_done_event.rs @@ -0,0 +1,8 @@ +#[derive(Clone, Debug)] +pub struct FunctionCallArgumentsDoneEvent { + pub sequence_number: u64, + pub item_id: String, + pub output_index: usize, + pub name: String, + pub arguments: String, +} diff --git a/paddler_balancer/src/compatibility/openai_service/function_call_item.rs b/paddler_balancer/src/compatibility/openai_service/function_call_item.rs new file mode 100644 index 00000000..3397ceb4 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/function_call_item.rs @@ -0,0 +1,20 @@ +use serde_json::Value; +use serde_json::json; + +#[must_use] +pub fn function_call_item( + item_id: &str, + call_id: &str, + name: &str, + arguments: &str, + status: &str, +) -> Value { + json!({ + "type": "function_call", + "id": item_id, + "call_id": call_id, + "name": name, + "arguments": arguments, + "status": status + }) +} diff --git a/paddler/src/balancer/compatibility/openai_service/http_route/mod.rs b/paddler_balancer/src/compatibility/openai_service/http_route/mod.rs similarity index 56% rename from paddler/src/balancer/compatibility/openai_service/http_route/mod.rs rename to paddler_balancer/src/compatibility/openai_service/http_route/mod.rs index 3d1ca76d..62c52c9f 100644 --- a/paddler/src/balancer/compatibility/openai_service/http_route/mod.rs +++ b/paddler_balancer/src/compatibility/openai_service/http_route/mod.rs @@ -1 +1,2 @@ pub mod post_chat_completions; +pub mod post_responses; diff --git a/paddler_balancer/src/compatibility/openai_service/http_route/post_chat_completions.rs b/paddler_balancer/src/compatibility/openai_service/http_route/post_chat_completions.rs new file mode 100644 index 00000000..1865bdb1 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/http_route/post_chat_completions.rs @@ -0,0 +1,267 @@ +use std::sync::Arc; +use std::time::SystemTime; + +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::post; +use actix_web::web; +use nanoid::nanoid; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::validates::Validates; +use parking_lot::Mutex; +use tokio_stream::StreamExt as _; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::compatibility::openai_service::app_data::AppData; +use crate::compatibility::openai_service::chat_completions_sse_response::chat_completions_sse_response; +use crate::compatibility::openai_service::openai_completion_request_params::OpenAICompletionRequestParams; +use crate::compatibility::openai_service::openai_error::OpenAIError; +use crate::compatibility::openai_service::openai_message::OpenAIMessage; +use crate::compatibility::openai_service::openai_non_streaming_response_transformer::OpenAINonStreamingResponseTransformer; +use crate::compatibility::openai_service::openai_non_streaming_state::OpenAINonStreamingState; +use crate::compatibility::openai_service::openai_streaming_response_transformer::OpenAIStreamingResponseTransformer; +use crate::compatibility::openai_service::openai_streaming_state::OpenAIStreamingState; +use crate::compatibility::openai_service::timestamp_from::timestamp_from; +use crate::unbounded_stream_from_agent::unbounded_stream_from_agent; + +#[post("/v1/chat/completions")] +async fn respond( + app_data: web::Data, + openai_params: web::Json, +) -> Result { + let openai_params = openai_params.into_inner(); + + let validated_tools = match openai_params + .tools + .into_iter() + .map(Validates::validate) + .collect::, _>>() + { + Ok(tools) => tools, + Err(err) => { + return Ok(HttpResponse::BadRequest() + .content_type("application/json") + .body( + OpenAIError { + error_type: "invalid_request_error", + message: err.to_string(), + } + .to_envelope() + .to_string(), + )); + } + }; + + let parse_tool_calls = !validated_tools.is_empty(); + let paddler_params = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new( + openai_params + .messages + .iter() + .map(OpenAIMessage::to_conversation_message) + .collect(), + ), + enable_thinking: true, + grammar: None, + max_tokens: openai_params.max_completion_tokens.unwrap_or(2000), + parse_tool_calls, + tools: validated_tools, + }; + + let created = + timestamp_from(SystemTime::now()).map_err(actix_web::error::ErrorInternalServerError)?; + + if openai_params.stream.unwrap_or(false) { + let include_usage = openai_params + .stream_options + .as_ref() + .is_some_and(|options| options.include_usage); + + Ok(chat_completions_sse_response( + app_data.buffered_request_manager.clone(), + app_data.inference_service_configuration.clone(), + paddler_params, + OpenAIStreamingResponseTransformer { + created, + include_usage, + model: openai_params.model.clone(), + state: Arc::new(Mutex::new(OpenAIStreamingState::default())), + system_fingerprint: nanoid!(), + }, + app_data.shutdown.clone(), + )) + } else { + let results: Vec = unbounded_stream_from_agent( + app_data.buffered_request_manager.clone(), + app_data.inference_service_configuration.clone(), + paddler_params, + OpenAINonStreamingResponseTransformer { + created, + model: openai_params.model.clone(), + state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), + }, + app_data.shutdown.clone(), + ) + .collect() + .await; + + if let Some(TransformResult::Error(error_json)) = results + .iter() + .find(|result| matches!(result, TransformResult::Error(_))) + { + return Ok(HttpResponse::InternalServerError() + .content_type("application/json") + .body(error_json.clone())); + } + + let body = results.into_iter().find_map(|result| match result { + TransformResult::Chunk(content) => Some(content), + TransformResult::Discard | TransformResult::Error(_) => None, + }); + + Ok(body.map_or_else( + || { + HttpResponse::InternalServerError() + .content_type("application/json") + .body( + OpenAIError { + error_type: "server_error", + message: "no completion produced".to_owned(), + } + .to_envelope() + .to_string(), + ) + }, + |json_body| { + HttpResponse::Ok() + .content_type("application/json") + .body(json_body) + }, + )) + } +} + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test::TestRequest; + use actix_web::test::call_service; + use actix_web::test::init_service; + use actix_web::test::read_body; + use actix_web::web::Data; + use tokio_util::sync::CancellationToken; + + use super::AppData; + use super::register; + use crate::agent_controller_pool::AgentControllerPool; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + + fn app_data_without_agents(max_buffered_requests: i32) -> AppData { + AppData { + buffered_request_manager: Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::ZERO, + max_buffered_requests, + )), + inference_service_configuration: InferenceServiceConfiguration { + addr: SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: Duration::ZERO, + }, + shutdown: CancellationToken::new(), + } + } + + #[actix_web::test] + async fn invalid_tool_schema_returns_bad_request() { + let app = init_service( + App::new() + .app_data(Data::new(app_data_without_agents(0))) + .configure(register), + ) + .await; + + let request = TestRequest::post() + .uri("/v1/chat/completions") + .set_json(serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": [ + { + "type": "function", + "function": { + "name": "broken", + "description": "tool with an unsatisfiable required field", + "parameters": { + "type": "object", + "properties": {"present": {"type": "string"}}, + "required": ["absent"] + } + } + } + ] + })) + .to_request(); + + let response = call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = read_body(response).await; + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(parsed["error"]["type"], "invalid_request_error"); + assert!( + parsed["error"]["message"] + .as_str() + .unwrap() + .contains("absent") + ); + } + + #[actix_web::test] + async fn non_streaming_request_without_available_agent_returns_internal_server_error() { + let app = init_service( + App::new() + .app_data(Data::new(app_data_without_agents(0))) + .configure(register), + ) + .await; + + let request = TestRequest::post() + .uri("/v1/chat/completions") + .set_json(serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}] + })) + .to_request(); + + let response = call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let body = read_body(response).await; + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(parsed["error"]["type"], "server_error"); + assert!( + parsed["error"]["message"] + .as_str() + .unwrap() + .contains("Buffered requests overflow") + ); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/http_route/post_responses.rs b/paddler_balancer/src/compatibility/openai_service/http_route/post_responses.rs new file mode 100644 index 00000000..21cb12fc --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/http_route/post_responses.rs @@ -0,0 +1,119 @@ +use std::sync::Arc; +use std::time::SystemTime; + +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::post; +use actix_web::web; +use nanoid::nanoid; +use parking_lot::Mutex; +use tokio_stream::StreamExt as _; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::compatibility::openai_service::app_data::AppData; +use crate::compatibility::openai_service::openai_error::OpenAIError; +use crate::compatibility::openai_service::openai_responses_request_params::OpenAIResponsesRequestParams; +use crate::compatibility::openai_service::responses_non_streaming_response_transformer::ResponsesNonStreamingResponseTransformer; +use crate::compatibility::openai_service::responses_non_streaming_state::ResponsesNonStreamingState; +use crate::compatibility::openai_service::responses_response_builder::ResponsesResponseBuilder; +use crate::compatibility::openai_service::responses_streaming_response_transformer::ResponsesStreamingResponseTransformer; +use crate::compatibility::openai_service::responses_streaming_state::ResponsesStreamingState; +use crate::compatibility::openai_service::sse_response_from_agent::sse_response_from_agent; +use crate::compatibility::openai_service::timestamp_from::timestamp_from; +use crate::unbounded_stream_from_agent::unbounded_stream_from_agent; + +#[post("/v1/responses")] +async fn respond( + app_data: web::Data, + openai_params: web::Json, +) -> Result { + let prepared = match openai_params.into_inner().into_prepared() { + Ok(prepared) => prepared, + Err(err) => { + return Ok(HttpResponse::BadRequest() + .content_type("application/json") + .body( + OpenAIError { + error_type: "invalid_request_error", + message: err.to_string(), + } + .to_envelope() + .to_string(), + )); + } + }; + + let created_at = + timestamp_from(SystemTime::now()).map_err(actix_web::error::ErrorInternalServerError)?; + + let builder = ResponsesResponseBuilder { + id: format!("resp_{}", nanoid!()), + created_at, + model: prepared.model, + instructions: prepared.instructions, + }; + + if prepared.stream { + Ok(sse_response_from_agent( + app_data.buffered_request_manager.clone(), + app_data.inference_service_configuration.clone(), + prepared.paddler_params, + ResponsesStreamingResponseTransformer { + builder, + state: Arc::new(Mutex::new(ResponsesStreamingState::default())), + }, + app_data.shutdown.clone(), + )) + } else { + let results: Vec = unbounded_stream_from_agent( + app_data.buffered_request_manager.clone(), + app_data.inference_service_configuration.clone(), + prepared.paddler_params, + ResponsesNonStreamingResponseTransformer { + builder, + state: Arc::new(Mutex::new(ResponsesNonStreamingState::default())), + }, + app_data.shutdown.clone(), + ) + .collect() + .await; + + if let Some(TransformResult::Error(error_json)) = results + .iter() + .find(|result| matches!(result, TransformResult::Error(_))) + { + return Ok(HttpResponse::InternalServerError() + .content_type("application/json") + .body(error_json.clone())); + } + + let body = results.into_iter().find_map(|result| match result { + TransformResult::Chunk(content) => Some(content), + TransformResult::Discard | TransformResult::Error(_) => None, + }); + + Ok(body.map_or_else( + || { + HttpResponse::InternalServerError() + .content_type("application/json") + .body( + OpenAIError { + error_type: "server_error", + message: "no response produced".to_owned(), + } + .to_envelope() + .to_string(), + ) + }, + |json_body| { + HttpResponse::Ok() + .content_type("application/json") + .body(json_body) + }, + )) + } +} + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} diff --git a/paddler_balancer/src/compatibility/openai_service/message_item_done.rs b/paddler_balancer/src/compatibility/openai_service/message_item_done.rs new file mode 100644 index 00000000..20f96b79 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/message_item_done.rs @@ -0,0 +1,15 @@ +use serde_json::Value; +use serde_json::json; + +use crate::compatibility::openai_service::output_text_part::output_text_part; + +#[must_use] +pub fn message_item_done(item_id: &str, text: &str) -> Value { + json!({ + "id": item_id, + "type": "message", + "role": "assistant", + "status": "completed", + "content": [output_text_part(text)] + }) +} diff --git a/paddler_balancer/src/compatibility/openai_service/mod.rs b/paddler_balancer/src/compatibility/openai_service/mod.rs new file mode 100644 index 00000000..d26fefe1 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/mod.rs @@ -0,0 +1,172 @@ +pub mod app_data; +pub mod arguments_to_tool_call_string; +pub mod chat_completions_sse_response; +pub mod configuration; +pub mod content_part_event; +pub mod function_call_arguments_delta_event; +pub mod function_call_arguments_done_event; +pub mod function_call_item; +pub mod http_route; +pub mod message_item_done; +pub mod open_item; +pub mod openai_completion_request_params; +pub mod openai_error; +pub mod openai_message; +pub mod openai_non_streaming_response_transformer; +pub mod openai_non_streaming_state; +pub mod openai_responses_function_call_item; +pub mod openai_responses_function_call_output_item; +pub mod openai_responses_function_output; +pub mod openai_responses_function_tool; +pub mod openai_responses_input; +pub mod openai_responses_input_content_part; +pub mod openai_responses_input_item; +pub mod openai_responses_message_content; +pub mod openai_responses_message_item; +pub mod openai_responses_reasoning; +pub mod openai_responses_request_params; +pub mod openai_responses_tagged_item; +pub mod openai_responses_text_format; +pub mod openai_responses_text_param; +pub mod openai_responses_tool; +pub mod openai_streaming_response_transformer; +pub mod openai_streaming_state; +pub mod openai_usage_json; +pub mod output_item_event; +pub mod output_text_part; +pub mod reasoning_item_done; +pub mod response_snapshot_event; +pub mod responses_error; +pub mod responses_non_streaming_response_transformer; +pub mod responses_non_streaming_state; +pub mod responses_prepared_request; +pub mod responses_response_builder; +pub mod responses_stream_event; +pub mod responses_streaming_response_transformer; +pub mod responses_streaming_state; +pub mod sse_response_from_agent; +pub mod stream_options; +pub mod text_delta_event; +pub mod text_done_event; +pub mod timestamp_from; +pub mod try_universal_error_chunk; + +use std::sync::Arc; + +use actix_web::App; +use actix_web::HttpServer; +use actix_web::web::Data; +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; + +use crate::buffered_request_manager::BufferedRequestManager; +use crate::compatibility::openai_service::app_data::AppData; +use crate::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use crate::create_cors_middleware::create_cors_middleware; +use crate::http_route as common_http_route; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + +pub struct OpenAIService { + pub buffered_request_manager: Arc, + pub inference_service_configuration: InferenceServiceConfiguration, + pub openai_service_configuration: OpenAIServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, +} + +#[async_trait] +impl Service for OpenAIService { + fn name(&self) -> &'static str { + "balancer::compatibility::openai_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let cors_allowed_hosts = self + .inference_service_configuration + .cors_allowed_hosts + .clone(); + let cors_allowed_hosts_arc = Arc::new(cors_allowed_hosts); + + let app_data = Data::new(AppData { + buffered_request_manager: self.buffered_request_manager.clone(), + inference_service_configuration: self.inference_service_configuration.clone(), + shutdown: shutdown.clone(), + }); + + let bind_addr = self.openai_service_configuration.addr; + + let server = HttpServer::new(move || { + App::new() + .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) + .app_data(app_data.clone()) + .configure(common_http_route::get_health::register) + .configure(http_route::post_chat_completions::register) + .configure(http_route::post_responses::register) + }) + .shutdown_signal(async move { + shutdown.cancelled().await; + }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) + .disable_signals() + .bind(bind_addr) + .with_context(|| format!("Unable to bind balancer OpenAI-compat service to {bind_addr}"))?; + + server.run().await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::net::TcpListener; + use std::sync::Arc; + use std::time::Duration; + + use tokio_util::sync::CancellationToken; + use trzcina::Service as _; + use trzcina::ServiceShutdownOptions; + + use super::OpenAIService; + use crate::agent_controller_pool::AgentControllerPool; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; + use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + + fn build_service(addr: SocketAddr) -> OpenAIService { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + OpenAIService { + buffered_request_manager: Arc::new(BufferedRequestManager::new( + agent_controller_pool, + Duration::from_secs(30), + 32, + )), + inference_service_configuration: InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: vec!["http://127.0.0.1:8080".to_owned()], + inference_item_timeout: Duration::from_secs(30), + }, + openai_service_configuration: OpenAIServiceConfiguration { addr }, + shutdown_options: ServiceShutdownOptions::default(), + } + } + + #[actix_web::test] + async fn run_returns_error_when_address_is_already_in_use() { + let occupied_listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap(); + let occupied_addr = occupied_listener.local_addr().unwrap(); + + let service = Box::new(build_service(occupied_addr)); + let result = service.run(CancellationToken::new()).await; + + let error_message = result.unwrap_err().to_string(); + let expected_addr_fragment = occupied_addr.to_string(); + + assert!(error_message.contains(&expected_addr_fragment)); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/open_item.rs b/paddler_balancer/src/compatibility/openai_service/open_item.rs new file mode 100644 index 00000000..d8b23021 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/open_item.rs @@ -0,0 +1,7 @@ +#[derive(Default, Eq, PartialEq)] +pub enum OpenItem { + #[default] + None, + Reasoning, + Message, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_completion_request_params.rs b/paddler_balancer/src/compatibility/openai_service/openai_completion_request_params.rs new file mode 100644 index 00000000..57d30774 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_completion_request_params.rs @@ -0,0 +1,115 @@ +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_message::OpenAIMessage; +use crate::compatibility::openai_service::stream_options::StreamOptions; + +#[derive(Deserialize)] +pub struct OpenAICompletionRequestParams { + pub max_completion_tokens: Option, + pub messages: Vec, + /// This parameter is ignored here, but is required by the `OpenAI` API. + pub model: String, + pub stream: Option, + pub stream_options: Option, + #[serde(default)] + pub tools: Vec>, +} + +#[cfg(test)] +mod tests { + use super::OpenAICompletionRequestParams; + + #[test] + fn deserialize_text_only_request() { + let input = serde_json::json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "hello"} + ] + }); + + let params: OpenAICompletionRequestParams = serde_json::from_value(input).unwrap(); + + assert_eq!(params.model, "test-model"); + assert_eq!(params.messages.len(), 1); + assert_eq!(params.messages[0].role, "user"); + assert_eq!(params.messages[0].content.text_content(), "hello"); + } + + #[test] + fn deserialize_request_with_stream_options_include_usage_true() { + let input = serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": true, + "stream_options": {"include_usage": true} + }); + + let params: OpenAICompletionRequestParams = serde_json::from_value(input).unwrap(); + + let stream_options = params.stream_options.unwrap(); + + assert!(stream_options.include_usage); + } + + #[test] + fn deserialize_request_without_stream_options_defaults_to_none() { + let input = serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": true + }); + + let params: OpenAICompletionRequestParams = serde_json::from_value(input).unwrap(); + + assert!(params.stream_options.is_none()); + } + + #[test] + fn deserialize_multimodal_request_with_image() { + let input = serde_json::json!({ + "model": "vision-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this image"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ"}} + ] + } + ] + }); + + let params: OpenAICompletionRequestParams = serde_json::from_value(input).unwrap(); + + assert_eq!(params.messages.len(), 1); + assert_eq!( + params.messages[0].content.text_content(), + "describe this image" + ); + + let image_urls = params.messages[0].content.image_urls(); + + assert_eq!(image_urls.len(), 1); + assert_eq!(image_urls[0].url, "data:image/jpeg;base64,/9j/4AAQ"); + } + + #[test] + fn deserialize_multi_turn_conversation() { + let input = serde_json::json!({ + "model": "test-model", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "And 3+3?"} + ] + }); + + let params: OpenAICompletionRequestParams = serde_json::from_value(input).unwrap(); + + assert_eq!(params.messages.len(), 4); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_error.rs b/paddler_balancer/src/compatibility/openai_service/openai_error.rs new file mode 100644 index 00000000..0e00a4c2 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_error.rs @@ -0,0 +1,264 @@ +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::error::Error as JsonRpcError; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use paddler_messaging::oversized_image_details::OversizedImageDetails; +use paddler_messaging::raw_tool_call_tokens::RawToolCallTokens; +use serde_json::Value; +use serde_json::json; + +fn validation_failure_message(errors: &[String]) -> String { + errors + .first() + .cloned() + .unwrap_or_else(|| "tool call failed validation".to_owned()) +} + +fn unrecognized_tool_call_format_message(raw: &RawToolCallTokens) -> String { + format!( + "model produced output the parser did not recognise as any registered tool-call format; \ + FFI error: {}; raw text: {}", + raw.ffi_error_message, raw.text, + ) +} + +fn image_exceeds_batch_size_message(details: &OversizedImageDetails) -> String { + format!( + "image required {} tokens but agent n_batch is {}; rerun with a larger n_batch", + details.image_tokens, details.n_batch, + ) +} + +fn description_from_error_token(token: &GeneratedTokenResult) -> Option<&str> { + match token { + GeneratedTokenResult::ChatTemplateError(description) + | GeneratedTokenResult::GrammarIncompatibleWithThinking(description) + | GeneratedTokenResult::GrammarRejectedModelOutput(description) + | GeneratedTokenResult::GrammarInitializationFailed(description) + | GeneratedTokenResult::GrammarSyntaxError(description) + | GeneratedTokenResult::ImageDecodingFailed(description) + | GeneratedTokenResult::MultimodalNotSupported(description) + | GeneratedTokenResult::SamplerError(description) + | GeneratedTokenResult::ToolCallParseFailed(description) + | GeneratedTokenResult::ToolSchemaInvalid(description) => Some(description), + _ => None, + } +} + +fn server_error_from_token(token: &GeneratedTokenResult) -> Option { + match token { + GeneratedTokenResult::ImageExceedsBatchSize(details) => Some(OpenAIError { + error_type: "server_error", + message: image_exceeds_batch_size_message(details), + }), + GeneratedTokenResult::ToolCallValidationFailed(errors) => Some(OpenAIError { + error_type: "server_error", + message: validation_failure_message(errors), + }), + GeneratedTokenResult::UnrecognizedToolCallFormat(raw) => Some(OpenAIError { + error_type: "server_error", + message: unrecognized_tool_call_format_message(raw), + }), + other => description_from_error_token(other).map(|description| OpenAIError { + error_type: "server_error", + message: description.to_owned(), + }), + } +} + +pub struct OpenAIError { + pub error_type: &'static str, + pub message: String, +} + +impl OpenAIError { + #[must_use] + pub fn classify(message: &OutgoingMessage) -> Option { + match message { + OutgoingMessage::Error(ErrorEnvelope { + error: JsonRpcError { description, .. }, + .. + }) => Some(Self { + error_type: "server_error", + message: description.clone(), + }), + OutgoingMessage::Response(ResponseEnvelope { response, .. }) => match response { + OutgoingResponse::GeneratedToken(token) => server_error_from_token(token), + OutgoingResponse::Timeout => Some(Self { + error_type: "timeout", + message: "request timed out".to_owned(), + }), + OutgoingResponse::TooManyBufferedRequests => Some(Self { + error_type: "rate_limit_error", + message: "too many buffered requests".to_owned(), + }), + OutgoingResponse::Embedding(_) => None, + }, + } + } + + #[must_use] + pub fn to_envelope(&self) -> Value { + json!({ + "error": { + "message": self.message, + "type": self.error_type, + "param": null, + "code": null + } + }) + } +} + +#[cfg(test)] +mod tests { + use llama_cpp_bindings_types::ToolCallArguments; + use paddler_messaging::embedding_result::EmbeddingResult; + use paddler_messaging::generation_summary::GenerationSummary; + + use super::OpenAIError; + use super::OutgoingMessage; + use super::OutgoingResponse; + use super::ResponseEnvelope; + use super::validation_failure_message; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::jsonrpc::error::Error as JsonRpcError; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + + fn token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::GeneratedToken(token_result), + }) + } + + #[test] + fn to_envelope_has_the_openai_error_shape() { + let envelope = OpenAIError { + error_type: "server_error", + message: "something went wrong".to_owned(), + } + .to_envelope(); + + assert_eq!(envelope["error"]["type"], "server_error"); + assert_eq!(envelope["error"]["message"], "something went wrong"); + assert!(envelope["error"]["param"].is_null()); + assert!(envelope["error"]["code"].is_null()); + } + + #[test] + fn validation_failure_message_returns_first_error() { + let message = + validation_failure_message(&["first issue".to_owned(), "second issue".to_owned()]); + + assert_eq!(message, "first issue"); + } + + #[test] + fn validation_failure_message_falls_back_when_no_errors() { + let message = validation_failure_message(&[]); + + assert!(message.contains("validation")); + } + + #[test] + fn classifies_jsonrpc_error_as_server_error() { + let message = OutgoingMessage::Error(ErrorEnvelope { + request_id: "test-request".to_owned(), + error: JsonRpcError { + code: 500, + description: "internal failure".to_owned(), + }, + }); + + let classified = OpenAIError::classify(&message).unwrap(); + + assert_eq!(classified.error_type, "server_error"); + assert_eq!(classified.message, "internal failure"); + } + + #[test] + fn classifies_timeout_as_timeout() { + let message = OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::Timeout, + }); + + let classified = OpenAIError::classify(&message).unwrap(); + + assert_eq!(classified.error_type, "timeout"); + } + + #[test] + fn classifies_too_many_buffered_requests_as_rate_limit() { + let message = OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::TooManyBufferedRequests, + }); + + let classified = OpenAIError::classify(&message).unwrap(); + + assert_eq!(classified.error_type, "rate_limit_error"); + } + + #[test] + fn classifies_validation_failure_with_first_message() { + let classified = OpenAIError::classify(&token_message( + GeneratedTokenResult::ToolCallValidationFailed(vec!["missing field x".to_owned()]), + )) + .unwrap(); + + assert_eq!(classified.error_type, "server_error"); + assert_eq!(classified.message, "missing field x"); + } + + #[test] + fn does_not_classify_a_content_token() { + assert!( + OpenAIError::classify(&token_message(GeneratedTokenResult::ContentToken( + "hello".to_owned() + ))) + .is_none() + ); + } + + #[test] + fn does_not_classify_a_done_summary() { + assert!( + OpenAIError::classify(&token_message(GeneratedTokenResult::Done( + GenerationSummary::default() + ))) + .is_none() + ); + } + + #[test] + fn does_not_classify_an_embedding_response() { + let message = OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::Embedding(EmbeddingResult::Done), + }); + + assert!(OpenAIError::classify(&message).is_none()); + } + + #[test] + fn classifies_a_tool_call_with_arguments_is_unrelated_to_errors() { + let parsed = vec![llama_cpp_bindings_types::ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(serde_json::json!({"location": "Paris"})), + )]; + + assert!( + OpenAIError::classify(&token_message(GeneratedTokenResult::ToolCallParsed(parsed))) + .is_none() + ); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_message.rs b/paddler_balancer/src/compatibility/openai_service/openai_message.rs new file mode 100644 index 00000000..a8cab097 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_message.rs @@ -0,0 +1,42 @@ +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct OpenAIMessage { + pub content: ConversationMessageContent, + pub role: String, +} + +impl OpenAIMessage { + #[must_use] + pub fn to_conversation_message(&self) -> ConversationMessage { + ConversationMessage { + content: self.content.clone(), + role: self.role.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::OpenAIMessage; + + #[test] + fn openai_message_converts_to_conversation_message() { + let input = serde_json::json!({ + "role": "user", + "content": [ + {"type": "text", "text": "OCR this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}} + ] + }); + + let openai_message: OpenAIMessage = serde_json::from_value(input).unwrap(); + let conversation_message = openai_message.to_conversation_message(); + + assert_eq!(conversation_message.role, "user"); + assert_eq!(conversation_message.content.text_content(), "OCR this"); + assert_eq!(conversation_message.content.image_urls().len(), 1); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_response_transformer.rs b/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_response_transformer.rs new file mode 100644 index 00000000..a57ea1dc --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_response_transformer.rs @@ -0,0 +1,620 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use llama_cpp_bindings_types::ParsedToolCall; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use parking_lot::Mutex; +use serde_json::json; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::compatibility::openai_service::arguments_to_tool_call_string::arguments_to_tool_call_string; +use crate::compatibility::openai_service::openai_non_streaming_state::OpenAINonStreamingState; +use crate::compatibility::openai_service::openai_usage_json::openai_usage_json; +use crate::compatibility::openai_service::try_universal_error_chunk::try_universal_error_chunk; + +#[derive(Clone)] +pub struct OpenAINonStreamingResponseTransformer { + pub created: u64, + pub model: String, + pub state: Arc>, +} + +impl OpenAINonStreamingResponseTransformer { + fn append_content(&self, text: &str) { + self.state.lock().content.push_str(text); + } + + fn append_tool_calls(&self, parsed_calls: Vec) { + self.state.lock().tool_calls.extend(parsed_calls); + } + + fn build_done_chunk(&self, request_id: &str, summary: &GenerationSummary) -> Result { + let snapshot = self.snapshot_state(); + + let has_tool_calls = !snapshot.tool_calls.is_empty(); + let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" }; + + let tool_calls_json = snapshot + .tool_calls + .iter() + .map(|call| { + arguments_to_tool_call_string(&call.arguments).map(|arguments| { + json!({ + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": arguments, + } + }) + }) + }) + .collect::>>(); + + tool_calls_json.and_then(|tool_calls_json| { + let mut message_obj = json!({ + "role": "assistant", + "content": if snapshot.content.is_empty() && has_tool_calls { + serde_json::Value::Null + } else { + json!(snapshot.content) + }, + "refusal": null, + "annotations": [] + }); + + if has_tool_calls && let Some(map) = message_obj.as_object_mut() { + map.insert("tool_calls".to_owned(), json!(tool_calls_json)); + } + + serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion", + "created": self.created, + "model": self.model, + "choices": [ + { + "index": 0, + "message": message_obj, + "logprobs": null, + "finish_reason": finish_reason + } + ], + "usage": openai_usage_json(&summary.usage), + "service_tier": "default" + })) + .context("serializing non-streaming completion") + }) + } + + fn snapshot_state(&self) -> OpenAINonStreamingState { + self.state.lock().clone() + } +} + +#[async_trait] +impl TransformsOutgoingMessage for OpenAINonStreamingResponseTransformer { + type Output = TransformResult; + + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let Some(error_chunk) = try_universal_error_chunk(&message) { + return Ok(vec![error_chunk]); + } + + match message { + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken( + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text), + ), + .. + }) => { + self.append_content(&text); + Ok(vec![]) + } + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken( + GeneratedTokenResult::ReasoningToken(_) + | GeneratedTokenResult::ToolCallToken(_), + ), + .. + }) => Ok(vec![]), + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), + .. + }) => { + self.append_tool_calls(parsed_calls); + Ok(vec![]) + } + OutgoingMessage::Response(ResponseEnvelope { + request_id, + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), + .. + }) => Ok(vec![TransformResult::Chunk( + self.build_done_chunk(&request_id, &summary)?, + )]), + other => Err(anyhow!( + "OpenAINonStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" + )), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use anyhow::Result; + use llama_cpp_bindings_types::ParsedToolCall; + use llama_cpp_bindings_types::TokenUsage; + use llama_cpp_bindings_types::ToolCallArguments; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::generation_summary::GenerationSummary; + use paddler_messaging::inference_client::message::Message as OutgoingMessage; + use paddler_messaging::inference_client::response::Response as OutgoingResponse; + use paddler_messaging::jsonrpc::error::Error as JsonRpcError; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + use parking_lot::Mutex; + use serde_json::json; + + use super::OpenAINonStreamingResponseTransformer; + use super::OpenAINonStreamingState; + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; + + #[must_use] + pub fn token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::GeneratedToken(token_result), + }) + } + + #[must_use] + pub fn error_message(code: i32, description: &str) -> OutgoingMessage { + OutgoingMessage::Error(ErrorEnvelope { + request_id: "test-request".to_owned(), + error: JsonRpcError { + code, + description: description.to_owned(), + }, + }) + } + + #[must_use] + pub fn response_message(response: OutgoingResponse) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response, + }) + } + + #[must_use] + pub fn summary_with_counts( + prompt_tokens: u64, + content_tokens: u64, + reasoning_tokens: u64, + ) -> GenerationSummary { + GenerationSummary { + usage: TokenUsage { + prompt_tokens, + content_tokens, + reasoning_tokens, + ..TokenUsage::default() + }, + } + } + + #[must_use] + pub fn weather_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(json!({ "location": "Paris" })), + ) + } + + #[must_use] + pub fn invalid_json_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_invalid".to_owned(), + "broken_tool".to_owned(), + ToolCallArguments::InvalidJson("{not valid json".to_owned()), + ) + } + + pub fn assert_chunk_contains(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Chunk(content) = result else { + anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); + }; + + assert!( + content.contains(expected), + "chunk does not contain '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_chunk_does_not_contain(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Chunk(content) = result else { + anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); + }; + + assert!( + !content.contains(expected), + "chunk unexpectedly contains '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_error_contains(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Error(content) = result else { + anyhow::bail!("expected TransformResult::Error, got TransformResult::Chunk"); + }; + + assert!( + content.contains(expected), + "error does not contain '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_chunk_body_contains(result: &TransformResult, expected: &str) { + let TransformResult::Chunk(content) = result else { + panic!("expected a chunk variant"); + }; + + assert!( + content.contains(expected), + "chunk does not contain '{expected}': {content}" + ); + } + + pub fn assert_error_body_contains(result: &TransformResult, expected: &str) { + let TransformResult::Error(content) = result else { + panic!("expected an error variant"); + }; + + assert!( + content.contains(expected), + "error does not contain '{expected}': {content}" + ); + } + + fn non_streaming_transformer() -> OpenAINonStreamingResponseTransformer { + OpenAINonStreamingResponseTransformer { + created: 0, + model: "test-model".to_owned(), + state: Arc::new(Mutex::new(OpenAINonStreamingState::default())), + } + } + + #[tokio::test] + async fn non_streaming_aggregates_content_only_when_no_reasoning() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hel".to_owned(), + ))) + .await?; + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "lo".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(4, 2, 0); + let final_chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"hello\"")?; + assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; + assert_chunk_contains(&final_chunks[0], "\"prompt_tokens\":4")?; + assert_chunk_contains(&final_chunks[0], "\"completion_tokens\":2")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_drops_reasoning_but_keeps_reasoning_token_count() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ReasoningToken( + "think".to_owned(), + ))) + .await?; + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "answer".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(3, 1, 1); + let final_chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"answer\"")?; + assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; + assert_chunk_contains(&final_chunks[0], "\"reasoning_tokens\":1")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_undeterminable_routes_to_content() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::UndeterminableToken( + "amb".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(2, 0, 0); + let final_chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"content\":\"amb\"")?; + assert_chunk_does_not_contain(&final_chunks[0], "reasoning_content")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_tool_call_parsed_populates_message_tool_calls() -> Result<()> { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await?; + + let summary = summary_with_counts(4, 0, 0); + let final_chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(final_chunks.len(), 1); + assert_chunk_contains(&final_chunks[0], "\"tool_calls\":")?; + assert_chunk_contains(&final_chunks[0], "\"name\":\"get_weather\"")?; + assert_chunk_contains( + &final_chunks[0], + "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", + )?; + assert_chunk_contains(&final_chunks[0], "\"finish_reason\":\"tool_calls\"")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_tool_call_parse_failed_emits_error() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParseFailed( + "bad payload".to_owned(), + ))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad payload")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_tool_call_validation_failed_emits_error() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::ToolCallValidationFailed(vec!["bad shape".to_owned()]), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad shape")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::UnrecognizedToolCallFormat( + paddler_messaging::raw_tool_call_tokens::RawToolCallTokens { + text: "blah".to_owned(), + ffi_error_message: "common_chat_parse failed: no parser".to_owned(), + }, + ), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; + assert_error_contains(&chunks[0], "blah")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_error_message_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(error_message(500, "internal server error")) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "internal server error")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_chat_template_error_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = token_message(GeneratedTokenResult::ChatTemplateError( + "bad template".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad template")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_image_decoding_failed_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = token_message(GeneratedTokenResult::ImageDecodingFailed( + "unsupported format".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "unsupported format")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = token_message(GeneratedTokenResult::MultimodalNotSupported( + "model does not support images".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "model does not support images")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = token_message(GeneratedTokenResult::ImageExceedsBatchSize( + paddler_messaging::oversized_image_details::OversizedImageDetails { + image_tokens: 368, + n_batch: 100, + }, + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "368")?; + assert_error_contains(&chunks[0], "100")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_timeout_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = response_message(OutgoingResponse::Timeout); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "request timed out")?; + assert_error_contains(&chunks[0], "timeout")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { + let transformer = non_streaming_transformer(); + + let message = response_message(OutgoingResponse::TooManyBufferedRequests); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "too many buffered requests")?; + assert_error_contains(&chunks[0], "rate_limit_error")?; + + Ok(()) + } + + #[tokio::test] + async fn non_streaming_tool_call_with_invalid_json_arguments_passes_raw_string_through() { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + invalid_json_call(), + ]))) + .await + .unwrap(); + + let final_chunks = transformer + .transform(token_message(GeneratedTokenResult::Done( + summary_with_counts(3, 0, 0), + ))) + .await + .unwrap(); + + assert_eq!(final_chunks.len(), 1); + assert_chunk_body_contains(&final_chunks[0], "{not valid json"); + assert_chunk_body_contains(&final_chunks[0], "\"name\":\"broken_tool\""); + } + + #[tokio::test] + async fn non_streaming_embedding_response_returns_invalid_request_error() { + let transformer = non_streaming_transformer(); + + let message = response_message(OutgoingResponse::Embedding( + paddler_messaging::embedding_result::EmbeddingResult::Done, + )); + let chunks = transformer.transform(message).await.unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "invalid_request_error"); + assert_error_body_contains( + &chunks[0], + "unexpected embedding response in chat completions", + ); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_state.rs b/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_state.rs new file mode 100644 index 00000000..7ee5416a --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_non_streaming_state.rs @@ -0,0 +1,7 @@ +use llama_cpp_bindings_types::ParsedToolCall; + +#[derive(Clone, Default)] +pub struct OpenAINonStreamingState { + pub content: String, + pub tool_calls: Vec, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_item.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_item.rs new file mode 100644 index 00000000..69c8004a --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_item.rs @@ -0,0 +1,8 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct OpenAIResponsesFunctionCallItem { + pub call_id: String, + pub name: String, + pub arguments: String, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_output_item.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_output_item.rs new file mode 100644 index 00000000..2b38ae88 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_call_output_item.rs @@ -0,0 +1,8 @@ +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_function_output::OpenAIResponsesFunctionOutput; + +#[derive(Deserialize)] +pub struct OpenAIResponsesFunctionCallOutputItem { + pub output: OpenAIResponsesFunctionOutput, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_function_output.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_output.rs new file mode 100644 index 00000000..d975ac4d --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_output.rs @@ -0,0 +1,27 @@ +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_input_content_part::OpenAIResponsesInputContentPart; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum OpenAIResponsesFunctionOutput { + Text(String), + Parts(Vec), +} + +impl OpenAIResponsesFunctionOutput { + #[must_use] + pub fn into_text(self) -> String { + match self { + Self::Text(text) => text, + Self::Parts(parts) => parts + .into_iter() + .filter_map(|part| match part { + OpenAIResponsesInputContentPart::InputText { text } => Some(text), + OpenAIResponsesInputContentPart::InputImage { .. } + | OpenAIResponsesInputContentPart::Unsupported => None, + }) + .collect::(), + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_function_tool.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_tool.rs new file mode 100644 index 00000000..42ebfb33 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_function_tool.rs @@ -0,0 +1,11 @@ +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct OpenAIResponsesFunctionTool { + pub name: String, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub parameters: Option, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_input.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_input.rs new file mode 100644 index 00000000..8762ee47 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_input.rs @@ -0,0 +1,33 @@ +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_input_item::OpenAIResponsesInputItem; + +pub enum OpenAIResponsesInput { + Text(String), + Items(Vec), +} + +impl Default for OpenAIResponsesInput { + fn default() -> Self { + Self::Items(Vec::new()) + } +} + +impl<'de> Deserialize<'de> for OpenAIResponsesInput { + fn deserialize(deserializer: TDeserializer) -> Result + where + TDeserializer: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum TextOrItems { + Text(String), + Items(Vec), + } + + Ok(match TextOrItems::deserialize(deserializer)? { + TextOrItems::Text(text) => Self::Text(text), + TextOrItems::Items(items) => Self::Items(items), + }) + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_input_content_part.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_input_content_part.rs new file mode 100644 index 00000000..66ce979b --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_input_content_part.rs @@ -0,0 +1,32 @@ +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::image_url::ImageUrl; +use serde::Deserialize; + +#[derive(Deserialize)] +#[serde(tag = "type")] +pub enum OpenAIResponsesInputContentPart { + #[serde(rename = "input_text")] + InputText { text: String }, + #[serde(rename = "input_image")] + InputImage { + #[serde(default)] + image_url: Option, + }, + #[serde(other)] + Unsupported, +} + +impl OpenAIResponsesInputContentPart { + #[must_use] + pub fn into_conversation_part(self) -> Option { + match self { + Self::InputText { text } => Some(ConversationMessageContentPart::Text { text }), + Self::InputImage { + image_url: Some(url), + } => Some(ConversationMessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + }), + Self::InputImage { image_url: None } | Self::Unsupported => None, + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_input_item.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_input_item.rs new file mode 100644 index 00000000..5c5f9cbd --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_input_item.rs @@ -0,0 +1,46 @@ +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use serde::Deserialize; +use serde_json::json; + +use crate::compatibility::openai_service::openai_responses_function_call_item::OpenAIResponsesFunctionCallItem; +use crate::compatibility::openai_service::openai_responses_function_call_output_item::OpenAIResponsesFunctionCallOutputItem; +use crate::compatibility::openai_service::openai_responses_message_item::OpenAIResponsesMessageItem; +use crate::compatibility::openai_service::openai_responses_tagged_item::OpenAIResponsesTaggedItem; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum OpenAIResponsesInputItem { + Tagged(OpenAIResponsesTaggedItem), + Message(OpenAIResponsesMessageItem), +} + +impl OpenAIResponsesInputItem { + #[must_use] + pub fn into_conversation_message(self) -> Option { + match self { + Self::Message(message) | Self::Tagged(OpenAIResponsesTaggedItem::Message(message)) => { + Some(message.into_conversation_message()) + } + Self::Tagged(OpenAIResponsesTaggedItem::FunctionCall( + OpenAIResponsesFunctionCallItem { + call_id, + name, + arguments, + }, + )) => Some(ConversationMessage { + content: ConversationMessageContent::Text( + json!({ "call_id": call_id, "name": name, "arguments": arguments }).to_string(), + ), + role: "assistant".to_owned(), + }), + Self::Tagged(OpenAIResponsesTaggedItem::FunctionCallOutput( + OpenAIResponsesFunctionCallOutputItem { output }, + )) => Some(ConversationMessage { + content: ConversationMessageContent::Text(output.into_text()), + role: "tool".to_owned(), + }), + Self::Tagged(OpenAIResponsesTaggedItem::Unsupported) => None, + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_message_content.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_message_content.rs new file mode 100644 index 00000000..70831c7b --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_message_content.rs @@ -0,0 +1,26 @@ +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_input_content_part::OpenAIResponsesInputContentPart; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum OpenAIResponsesMessageContent { + Text(String), + Parts(Vec), +} + +impl OpenAIResponsesMessageContent { + #[must_use] + pub fn into_conversation_content(self) -> ConversationMessageContent { + match self { + Self::Text(text) => ConversationMessageContent::Text(text), + Self::Parts(parts) => ConversationMessageContent::Parts( + parts + .into_iter() + .filter_map(OpenAIResponsesInputContentPart::into_conversation_part) + .collect(), + ), + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_message_item.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_message_item.rs new file mode 100644 index 00000000..8b366502 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_message_item.rs @@ -0,0 +1,28 @@ +use paddler_messaging::conversation_message::ConversationMessage; +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_message_content::OpenAIResponsesMessageContent; + +fn normalize_role(role: String) -> String { + if role == "developer" { + "system".to_owned() + } else { + role + } +} + +#[derive(Deserialize)] +pub struct OpenAIResponsesMessageItem { + pub role: String, + pub content: OpenAIResponsesMessageContent, +} + +impl OpenAIResponsesMessageItem { + #[must_use] + pub fn into_conversation_message(self) -> ConversationMessage { + ConversationMessage { + content: self.content.into_conversation_content(), + role: normalize_role(self.role), + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_reasoning.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_reasoning.rs new file mode 100644 index 00000000..aa9a77f5 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_reasoning.rs @@ -0,0 +1,14 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct OpenAIResponsesReasoning { + #[serde(default)] + pub effort: Option, +} + +impl OpenAIResponsesReasoning { + #[must_use] + pub fn enables_thinking(&self) -> bool { + self.effort.as_deref() != Some("none") + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_request_params.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_request_params.rs new file mode 100644 index 00000000..b6a121e8 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_request_params.rs @@ -0,0 +1,273 @@ +use anyhow::Result; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::validates::Validates; +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_function_tool::OpenAIResponsesFunctionTool; +use crate::compatibility::openai_service::openai_responses_input::OpenAIResponsesInput; +use crate::compatibility::openai_service::openai_responses_input_item::OpenAIResponsesInputItem; +use crate::compatibility::openai_service::openai_responses_reasoning::OpenAIResponsesReasoning; +use crate::compatibility::openai_service::openai_responses_text_param::OpenAIResponsesTextParam; +use crate::compatibility::openai_service::openai_responses_tool::OpenAIResponsesTool; +use crate::compatibility::openai_service::responses_prepared_request::ResponsesPreparedRequest; + +const DEFAULT_MAX_TOKENS: i32 = 2000; + +#[derive(Deserialize)] +pub struct OpenAIResponsesRequestParams { + /// Echoed back in the response object; not used for routing. + pub model: String, + #[serde(default)] + pub input: OpenAIResponsesInput, + #[serde(default)] + pub instructions: Option, + #[serde(default)] + pub stream: Option, + #[serde(default)] + pub max_output_tokens: Option, + #[serde(default)] + pub tools: Vec, + #[serde(default)] + pub text: Option, + #[serde(default)] + pub reasoning: Option, +} + +impl OpenAIResponsesRequestParams { + pub fn into_prepared(self) -> Result { + let Self { + model, + input, + instructions, + stream, + max_output_tokens, + tools, + text, + reasoning, + } = self; + + let mut messages: Vec = Vec::new(); + + if let Some(instructions) = &instructions + && !instructions.is_empty() + { + messages.push(ConversationMessage { + content: ConversationMessageContent::Text(instructions.clone()), + role: "system".to_owned(), + }); + } + + match input { + OpenAIResponsesInput::Text(text) => messages.push(ConversationMessage { + content: ConversationMessageContent::Text(text), + role: "user".to_owned(), + }), + OpenAIResponsesInput::Items(items) => { + messages.extend( + items + .into_iter() + .filter_map(OpenAIResponsesInputItem::into_conversation_message), + ); + } + } + + let validated_tools = tools + .into_iter() + .filter_map(|tool| match tool { + OpenAIResponsesTool::Function(function_tool) => { + let OpenAIResponsesFunctionTool { + name, + description, + parameters, + } = *function_tool; + + Some(Tool::Function(FunctionCall { + function: Function { + name, + description: description.unwrap_or_default(), + parameters: parameters.map_or(Parameters::Empty, Parameters::Schema), + }, + })) + } + OpenAIResponsesTool::Unsupported => None, + }) + .map(Validates::validate) + .collect::>>()?; + + let parse_tool_calls = !validated_tools.is_empty(); + + Ok(ResponsesPreparedRequest { + paddler_params: ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(messages), + enable_thinking: reasoning + .as_ref() + .is_none_or(OpenAIResponsesReasoning::enables_thinking), + grammar: match text { + Some(text_param) => text_param.into_grammar_constraint()?, + None => None, + }, + max_tokens: max_output_tokens.unwrap_or(DEFAULT_MAX_TOKENS), + parse_tool_calls, + tools: validated_tools, + }, + stream: stream.unwrap_or(false), + model, + instructions, + }) + } +} + +#[cfg(test)] +mod tests { + use paddler_messaging::grammar_constraint::GrammarConstraint; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; + use serde_json::json; + + use super::OpenAIResponsesRequestParams; + use crate::compatibility::openai_service::responses_prepared_request::ResponsesPreparedRequest; + + fn prepared_from(value: serde_json::Value) -> ResponsesPreparedRequest { + let params: OpenAIResponsesRequestParams = serde_json::from_value(value).unwrap(); + + params.into_prepared().unwrap() + } + + #[test] + fn string_input_becomes_a_single_user_message() { + let prepared = prepared_from(json!({ "model": "test", "input": "Say hello" })); + + let messages = &prepared.paddler_params.conversation_history.messages; + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "user"); + assert_eq!(messages[0].content.text_content(), "Say hello"); + } + + #[test] + fn instructions_are_prepended_as_a_system_message() { + let prepared = prepared_from(json!({ + "model": "test", + "instructions": "be terse", + "input": "hi" + })); + + let messages = &prepared.paddler_params.conversation_history.messages; + + assert_eq!(messages[0].role, "system"); + assert_eq!(messages[0].content.text_content(), "be terse"); + assert_eq!(messages[1].role, "user"); + } + + #[test] + fn function_call_output_item_becomes_a_tool_message() { + let prepared = prepared_from(json!({ + "model": "test", + "input": [ + { "type": "function_call_output", "call_id": "call_1", "output": "sunny" } + ] + })); + + let messages = &prepared.paddler_params.conversation_history.messages; + + assert_eq!(messages[0].role, "tool"); + assert_eq!(messages[0].content.text_content(), "sunny"); + } + + #[test] + fn developer_role_is_normalized_to_system() { + let prepared = prepared_from(json!({ + "model": "test", + "input": [ + { "type": "message", "role": "developer", "content": "rules" } + ] + })); + + assert_eq!( + prepared.paddler_params.conversation_history.messages[0].role, + "system" + ); + } + + #[test] + fn flat_function_tool_maps_to_an_internal_tool_with_default_description() { + let prepared = prepared_from(json!({ + "model": "test", + "input": "hi", + "tools": [ + { "type": "function", "name": "get_weather", "parameters": { "type": "object" } } + ] + })); + + assert!(prepared.paddler_params.parse_tool_calls); + + let Tool::Function(function_call) = &prepared.paddler_params.tools[0]; + + assert_eq!(function_call.function.name, "get_weather"); + assert_eq!(function_call.function.description, ""); + } + + #[test] + fn text_format_json_schema_becomes_a_grammar_constraint() { + let prepared = prepared_from(json!({ + "model": "test", + "input": "hi", + "text": { "format": { "type": "json_schema", "name": "out", "schema": { "type": "object" } } } + })); + + let Some(GrammarConstraint::JsonSchema { schema }) = &prepared.paddler_params.grammar + else { + panic!("expected a json schema grammar constraint"); + }; + + assert!(schema.contains("\"type\":\"object\"")); + } + + #[test] + fn reasoning_effort_none_disables_thinking() { + let prepared = prepared_from(json!({ + "model": "test", + "input": "hi", + "reasoning": { "effort": "none" } + })); + + assert!(!prepared.paddler_params.enable_thinking); + } + + #[test] + fn unsupported_tool_is_skipped_and_disables_tool_call_parsing() { + let prepared = prepared_from(json!({ + "model": "test", + "input": "hi", + "tools": [ { "type": "web_search" } ] + })); + + assert!(prepared.paddler_params.tools.is_empty()); + assert!(!prepared.paddler_params.parse_tool_calls); + } + + #[test] + fn unsupported_and_stateful_fields_are_ignored() { + let prepared = prepared_from(json!({ + "model": "test", + "input": "hi", + "store": true, + "previous_response_id": "resp_prev", + "conversation": "conv_1", + "temperature": 0.5, + "tool_choice": "required" + })); + + assert_eq!( + prepared.paddler_params.conversation_history.messages.len(), + 1 + ); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_tagged_item.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_tagged_item.rs new file mode 100644 index 00000000..86035f63 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_tagged_item.rs @@ -0,0 +1,18 @@ +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_function_call_item::OpenAIResponsesFunctionCallItem; +use crate::compatibility::openai_service::openai_responses_function_call_output_item::OpenAIResponsesFunctionCallOutputItem; +use crate::compatibility::openai_service::openai_responses_message_item::OpenAIResponsesMessageItem; + +#[derive(Deserialize)] +#[serde(tag = "type")] +pub enum OpenAIResponsesTaggedItem { + #[serde(rename = "message")] + Message(OpenAIResponsesMessageItem), + #[serde(rename = "function_call")] + FunctionCall(OpenAIResponsesFunctionCallItem), + #[serde(rename = "function_call_output")] + FunctionCallOutput(OpenAIResponsesFunctionCallOutputItem), + #[serde(other)] + Unsupported, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_text_format.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_text_format.rs new file mode 100644 index 00000000..3e28f3c6 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_text_format.rs @@ -0,0 +1,13 @@ +use serde::Deserialize; +use serde_json::Value; + +#[derive(Deserialize)] +#[serde(tag = "type")] +pub enum OpenAIResponsesTextFormat { + #[serde(rename = "text")] + Text, + #[serde(rename = "json_schema")] + JsonSchema { schema: Value }, + #[serde(other)] + Unsupported, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_text_param.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_text_param.rs new file mode 100644 index 00000000..2f9da515 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_text_param.rs @@ -0,0 +1,27 @@ +use anyhow::Context as _; +use anyhow::Result; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_text_format::OpenAIResponsesTextFormat; + +#[derive(Deserialize)] +pub struct OpenAIResponsesTextParam { + #[serde(default)] + pub format: Option, +} + +impl OpenAIResponsesTextParam { + pub fn into_grammar_constraint(self) -> Result> { + match self.format { + Some(OpenAIResponsesTextFormat::JsonSchema { schema }) => { + Ok(Some(GrammarConstraint::JsonSchema { + schema: serde_json::to_string(&schema) + .context("serializing responses text.format json schema")?, + })) + } + Some(OpenAIResponsesTextFormat::Text | OpenAIResponsesTextFormat::Unsupported) + | None => Ok(None), + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_responses_tool.rs b/paddler_balancer/src/compatibility/openai_service/openai_responses_tool.rs new file mode 100644 index 00000000..8cf49741 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_responses_tool.rs @@ -0,0 +1,13 @@ +use serde::Deserialize; + +use crate::compatibility::openai_service::openai_responses_function_tool::OpenAIResponsesFunctionTool; + +#[derive(Deserialize)] +#[serde(tag = "type")] +pub enum OpenAIResponsesTool { + // Boxed because the function-tool payload is far larger than the empty `Unsupported` variant. + #[serde(rename = "function")] + Function(Box), + #[serde(other)] + Unsupported, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_streaming_response_transformer.rs b/paddler_balancer/src/compatibility/openai_service/openai_streaming_response_transformer.rs new file mode 100644 index 00000000..2e40e732 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_streaming_response_transformer.rs @@ -0,0 +1,845 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use llama_cpp_bindings_types::ParsedToolCall; +use llama_cpp_bindings_types::TokenUsage; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use parking_lot::Mutex; +use serde_json::json; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::compatibility::openai_service::arguments_to_tool_call_string::arguments_to_tool_call_string; +use crate::compatibility::openai_service::openai_streaming_state::OpenAIStreamingState; +use crate::compatibility::openai_service::openai_usage_json::openai_usage_json; +use crate::compatibility::openai_service::try_universal_error_chunk::try_universal_error_chunk; + +#[derive(Clone)] +pub struct OpenAIStreamingResponseTransformer { + pub created: u64, + pub include_usage: bool, + pub model: String, + pub state: Arc>, + pub system_fingerprint: String, +} + +impl OpenAIStreamingResponseTransformer { + fn content_chunk(&self, request_id: &str, text: &str) -> Result { + serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": text, + }, + "logprobs": null, + "finish_reason": null + } + ] + })) + .context("serializing content chunk") + } + + fn tool_calls_chunk( + &self, + request_id: &str, + parsed_calls: &[ParsedToolCall], + ) -> Result { + parsed_calls + .iter() + .enumerate() + .map(|(index, call)| { + arguments_to_tool_call_string(&call.arguments).map(|arguments| { + json!({ + "index": index, + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": arguments, + } + }) + }) + }) + .collect::>>() + .and_then(|tool_calls| { + serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": tool_calls, + }, + "logprobs": null, + "finish_reason": null + } + ] + })) + .context("serializing tool-calls chunk") + }) + } + + fn finish_chunk(&self, request_id: &str, finish_reason: &str) -> Result { + serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": null, + "finish_reason": finish_reason + } + ] + })) + .context("serializing finish chunk") + } + + fn usage_chunk(&self, request_id: &str, usage: &TokenUsage) -> Result { + serde_json::to_string(&json!({ + "id": request_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.system_fingerprint, + "choices": [], + "usage": openai_usage_json(usage), + })) + .context("serializing usage chunk") + } + + fn handle_content(&self, request_id: &str, text: &str) -> Result> { + self.content_chunk(request_id, text) + .map(|chunk| vec![TransformResult::Chunk(chunk)]) + } + + fn handle_tool_call_parsed( + &self, + request_id: &str, + parsed_calls: &[ParsedToolCall], + ) -> Result> { + if parsed_calls.is_empty() { + return Ok(vec![]); + } + + self.state.lock().saw_tool_call = true; + + self.tool_calls_chunk(request_id, parsed_calls) + .map(|chunk| vec![TransformResult::Chunk(chunk)]) + } + + fn handle_done( + &self, + request_id: &str, + summary: &GenerationSummary, + ) -> Result> { + let saw_tool_call = self.state.lock().saw_tool_call; + let finish_reason = if saw_tool_call { "tool_calls" } else { "stop" }; + + self.finish_chunk(request_id, finish_reason) + .and_then(|finish_chunk| { + let finish = TransformResult::Chunk(finish_chunk); + + if self.include_usage { + self.usage_chunk(request_id, &summary.usage) + .map(|usage_chunk| vec![finish, TransformResult::Chunk(usage_chunk)]) + } else { + Ok(vec![finish]) + } + }) + } +} + +#[async_trait] +impl TransformsOutgoingMessage for OpenAIStreamingResponseTransformer { + type Output = TransformResult; + + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let Some(error_chunk) = try_universal_error_chunk(&message) { + return Ok(vec![error_chunk]); + } + + match message { + OutgoingMessage::Response(ResponseEnvelope { + request_id, + response: + OutgoingResponse::GeneratedToken( + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text), + ), + .. + }) => self.handle_content(&request_id, &text), + OutgoingMessage::Response(ResponseEnvelope { + response: + OutgoingResponse::GeneratedToken( + GeneratedTokenResult::ReasoningToken(_) + | GeneratedTokenResult::ToolCallToken(_), + ), + .. + }) => Ok(vec![]), + OutgoingMessage::Response(ResponseEnvelope { + request_id, + response: + OutgoingResponse::GeneratedToken(GeneratedTokenResult::ToolCallParsed(parsed_calls)), + .. + }) => self.handle_tool_call_parsed(&request_id, &parsed_calls), + OutgoingMessage::Response(ResponseEnvelope { + request_id, + response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done(summary)), + .. + }) => self.handle_done(&request_id, &summary), + other => Err(anyhow!( + "OpenAIStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" + )), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use anyhow::Result; + use llama_cpp_bindings_types::ParsedToolCall; + use llama_cpp_bindings_types::TokenUsage; + use llama_cpp_bindings_types::ToolCallArguments; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::generation_summary::GenerationSummary; + use paddler_messaging::inference_client::message::Message as OutgoingMessage; + use paddler_messaging::inference_client::response::Response as OutgoingResponse; + use paddler_messaging::jsonrpc::error::Error as JsonRpcError; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + use parking_lot::Mutex; + use serde_json::json; + + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; + + use super::OpenAIStreamingResponseTransformer; + use super::OpenAIStreamingState; + + #[must_use] + pub fn token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::GeneratedToken(token_result), + }) + } + + #[must_use] + pub fn error_message(code: i32, description: &str) -> OutgoingMessage { + OutgoingMessage::Error(ErrorEnvelope { + request_id: "test-request".to_owned(), + error: JsonRpcError { + code, + description: description.to_owned(), + }, + }) + } + + #[must_use] + pub fn response_message(response: OutgoingResponse) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response, + }) + } + + #[must_use] + pub fn summary_with_counts( + prompt_tokens: u64, + content_tokens: u64, + reasoning_tokens: u64, + ) -> GenerationSummary { + GenerationSummary { + usage: TokenUsage { + prompt_tokens, + content_tokens, + reasoning_tokens, + ..TokenUsage::default() + }, + } + } + + #[must_use] + pub fn weather_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(json!({ "location": "Paris" })), + ) + } + + #[must_use] + pub fn invalid_json_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_invalid".to_owned(), + "broken_tool".to_owned(), + ToolCallArguments::InvalidJson("{not valid json".to_owned()), + ) + } + + pub fn assert_chunk_contains(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Chunk(content) = result else { + anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); + }; + + assert!( + content.contains(expected), + "chunk does not contain '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_chunk_does_not_contain(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Chunk(content) = result else { + anyhow::bail!("expected TransformResult::Chunk, got TransformResult::Error"); + }; + + assert!( + !content.contains(expected), + "chunk unexpectedly contains '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_error_contains(result: &TransformResult, expected: &str) -> Result<()> { + let TransformResult::Error(content) = result else { + anyhow::bail!("expected TransformResult::Error, got TransformResult::Chunk"); + }; + + assert!( + content.contains(expected), + "error does not contain '{expected}': {content}" + ); + + Ok(()) + } + + pub fn assert_chunk_body_contains(result: &TransformResult, expected: &str) { + let TransformResult::Chunk(content) = result else { + panic!("expected a chunk variant"); + }; + + assert!( + content.contains(expected), + "chunk does not contain '{expected}': {content}" + ); + } + + pub fn assert_error_body_contains(result: &TransformResult, expected: &str) { + let TransformResult::Error(content) = result else { + panic!("expected an error variant"); + }; + + assert!( + content.contains(expected), + "error does not contain '{expected}': {content}" + ); + } + + fn streaming_transformer(include_usage: bool) -> OpenAIStreamingResponseTransformer { + OpenAIStreamingResponseTransformer { + created: 0, + include_usage, + model: "test-model".to_owned(), + state: Arc::new(Mutex::new(OpenAIStreamingState::default())), + system_fingerprint: "test-fingerprint".to_owned(), + } + } + + #[tokio::test] + async fn streaming_content_token_emits_content_delta() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::ContentToken("hello".to_owned())); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"content\":\"hello\"")?; + assert_chunk_contains(&chunks[0], "\"role\":\"assistant\"")?; + assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_reasoning_token_is_dropped() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::ReasoningToken("thought".to_owned())); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn streaming_undeterminable_token_emits_content_delta() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::UndeterminableToken( + "ambig".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"content\":\"ambig\"")?; + assert_chunk_does_not_contain(&chunks[0], "reasoning_content")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_tool_call_token_is_silently_dropped() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallToken( + "{".to_owned(), + ))) + .await?; + + assert_eq!(chunks.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn streaming_tool_call_parsed_emits_structured_tool_calls_chunk() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"tool_calls\"")?; + assert_chunk_contains(&chunks[0], "\"id\":\"call_x\"")?; + assert_chunk_contains(&chunks[0], "\"name\":\"get_weather\"")?; + assert_chunk_contains( + &chunks[0], + "\"arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"", + )?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_done_after_tool_call_uses_tool_calls_finish_reason() -> Result<()> { + let transformer = streaming_transformer(false); + + transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await?; + + let summary = summary_with_counts(2, 0, 0); + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"tool_calls\"")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_done_without_tool_call_uses_stop_finish_reason() -> Result<()> { + let transformer = streaming_transformer(false); + + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hi".to_owned(), + ))) + .await?; + + let summary = summary_with_counts(2, 1, 0); + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_done_with_include_usage_emits_finish_then_usage_chunk() -> Result<()> { + let transformer = streaming_transformer(true); + let summary = summary_with_counts(7, 4, 1); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 2); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + assert_chunk_does_not_contain(&chunks[0], "usage")?; + assert_chunk_contains(&chunks[1], "\"prompt_tokens\":7")?; + assert_chunk_contains(&chunks[1], "\"completion_tokens\":5")?; + assert_chunk_contains(&chunks[1], "\"total_tokens\":12")?; + assert_chunk_contains(&chunks[1], "\"choices\":[]")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_done_without_include_usage_emits_only_finish_chunk() -> Result<()> { + let transformer = streaming_transformer(false); + let summary = summary_with_counts(5, 3, 2); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done(summary))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_chunk_contains(&chunks[0], "\"finish_reason\":\"stop\"")?; + assert_chunk_does_not_contain(&chunks[0], "usage")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_tool_call_parse_failed_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParseFailed( + "bad payload".to_owned(), + ))) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad payload")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_tool_call_validation_failed_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::ToolCallValidationFailed(vec!["missing field x".to_owned()]), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "missing field x")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_unrecognized_tool_call_format_emits_server_error() -> Result<()> { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::UnrecognizedToolCallFormat( + paddler_messaging::raw_tool_call_tokens::RawToolCallTokens { + text: "blah".to_owned(), + ffi_error_message: "common_chat_parse failed: no parser".to_owned(), + }, + ), + )) + .await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "common_chat_parse failed: no parser")?; + assert_error_contains(&chunks[0], "blah")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_error_message_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = error_message(500, "internal server error"); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "internal server error")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_chat_template_error_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::ChatTemplateError( + "bad template".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "bad template")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_timeout_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = response_message(OutgoingResponse::Timeout); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "request timed out")?; + assert_error_contains(&chunks[0], "timeout")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_too_many_buffered_requests_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = response_message(OutgoingResponse::TooManyBufferedRequests); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "too many buffered requests")?; + assert_error_contains(&chunks[0], "rate_limit_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_image_decoding_failed_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::ImageDecodingFailed( + "unsupported format".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "unsupported format")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_multimodal_not_supported_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::MultimodalNotSupported( + "model does not support images".to_owned(), + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "model does not support images")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_image_exceeds_batch_size_returns_error_variant() -> Result<()> { + let transformer = streaming_transformer(false); + + let message = token_message(GeneratedTokenResult::ImageExceedsBatchSize( + paddler_messaging::oversized_image_details::OversizedImageDetails { + image_tokens: 368, + n_batch: 100, + }, + )); + let chunks = transformer.transform(message).await?; + + assert_eq!(chunks.len(), 1); + assert_error_contains(&chunks[0], "368")?; + assert_error_contains(&chunks[0], "100")?; + assert_error_contains(&chunks[0], "server_error")?; + + Ok(()) + } + + #[tokio::test] + async fn streaming_tool_call_with_invalid_json_arguments_passes_raw_string_through() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + invalid_json_call(), + ]))) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_chunk_body_contains(&chunks[0], "{not valid json"); + assert_chunk_body_contains(&chunks[0], "\"name\":\"broken_tool\""); + } + + #[tokio::test] + async fn streaming_empty_parsed_tool_calls_emit_no_chunks() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed( + Vec::new(), + ))) + .await + .unwrap(); + + assert_eq!(chunks.len(), 0); + } + + #[tokio::test] + async fn streaming_embedding_response_returns_invalid_request_error() { + let transformer = streaming_transformer(false); + + let message = response_message(OutgoingResponse::Embedding( + paddler_messaging::embedding_result::EmbeddingResult::Done, + )); + let chunks = transformer.transform(message).await.unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "invalid_request_error"); + assert_error_body_contains( + &chunks[0], + "unexpected embedding response in chat completions", + ); + } + + #[tokio::test] + async fn streaming_grammar_incompatible_with_thinking_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::GrammarIncompatibleWithThinking( + "grammar conflicts with thinking".to_owned(), + ), + )) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "grammar conflicts with thinking"); + assert_error_body_contains(&chunks[0], "server_error"); + } + + #[tokio::test] + async fn streaming_grammar_rejected_model_output_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::GrammarRejectedModelOutput( + "output rejected by grammar".to_owned(), + ), + )) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "output rejected by grammar"); + } + + #[tokio::test] + async fn streaming_grammar_initialization_failed_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message( + GeneratedTokenResult::GrammarInitializationFailed( + "could not build grammar".to_owned(), + ), + )) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "could not build grammar"); + } + + #[tokio::test] + async fn streaming_grammar_syntax_error_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::GrammarSyntaxError( + "bad grammar syntax".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "bad grammar syntax"); + } + + #[tokio::test] + async fn streaming_sampler_error_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::SamplerError( + "sampler blew up".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "sampler blew up"); + } + + #[tokio::test] + async fn streaming_tool_schema_invalid_returns_server_error() { + let transformer = streaming_transformer(false); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::ToolSchemaInvalid( + "schema is not valid".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!(chunks.len(), 1); + assert_error_body_contains(&chunks[0], "schema is not valid"); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_streaming_state.rs b/paddler_balancer/src/compatibility/openai_service/openai_streaming_state.rs new file mode 100644 index 00000000..b57f74f4 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_streaming_state.rs @@ -0,0 +1,4 @@ +#[derive(Default)] +pub struct OpenAIStreamingState { + pub saw_tool_call: bool, +} diff --git a/paddler_balancer/src/compatibility/openai_service/openai_usage_json.rs b/paddler_balancer/src/compatibility/openai_service/openai_usage_json.rs new file mode 100644 index 00000000..e5bc288e --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/openai_usage_json.rs @@ -0,0 +1,19 @@ +use llama_cpp_bindings_types::TokenUsage; +use serde_json::Value; +use serde_json::json; + +#[must_use] +pub fn openai_usage_json(usage: &TokenUsage) -> Value { + json!({ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens(), + "total_tokens": usage.total_tokens(), + "prompt_tokens_details": { + "cached_tokens": usage.cached_prompt_tokens, + "audio_tokens": usage.input_audio_tokens, + }, + "completion_tokens_details": { + "reasoning_tokens": usage.reasoning_tokens, + } + }) +} diff --git a/paddler_balancer/src/compatibility/openai_service/output_item_event.rs b/paddler_balancer/src/compatibility/openai_service/output_item_event.rs new file mode 100644 index 00000000..fcac11e4 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/output_item_event.rs @@ -0,0 +1,8 @@ +use serde_json::Value; + +#[derive(Clone, Debug)] +pub struct OutputItemEvent { + pub sequence_number: u64, + pub output_index: usize, + pub item: Value, +} diff --git a/paddler_balancer/src/compatibility/openai_service/output_text_part.rs b/paddler_balancer/src/compatibility/openai_service/output_text_part.rs new file mode 100644 index 00000000..95d6fa3d --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/output_text_part.rs @@ -0,0 +1,12 @@ +use serde_json::Value; +use serde_json::json; + +#[must_use] +pub fn output_text_part(text: &str) -> Value { + json!({ + "type": "output_text", + "text": text, + "annotations": [], + "logprobs": [] + }) +} diff --git a/paddler_balancer/src/compatibility/openai_service/reasoning_item_done.rs b/paddler_balancer/src/compatibility/openai_service/reasoning_item_done.rs new file mode 100644 index 00000000..d1e9520f --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/reasoning_item_done.rs @@ -0,0 +1,13 @@ +use serde_json::Value; +use serde_json::json; + +#[must_use] +pub fn reasoning_item_done(item_id: &str, text: &str) -> Value { + json!({ + "type": "reasoning", + "id": item_id, + "summary": [], + "content": [{ "type": "reasoning_text", "text": text }], + "status": "completed" + }) +} diff --git a/paddler_balancer/src/compatibility/openai_service/response_snapshot_event.rs b/paddler_balancer/src/compatibility/openai_service/response_snapshot_event.rs new file mode 100644 index 00000000..2365bf71 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/response_snapshot_event.rs @@ -0,0 +1,7 @@ +use serde_json::Value; + +#[derive(Clone, Debug)] +pub struct ResponseSnapshotEvent { + pub sequence_number: u64, + pub response: Value, +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_error.rs b/paddler_balancer/src/compatibility/openai_service/responses_error.rs new file mode 100644 index 00000000..5f602cf8 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_error.rs @@ -0,0 +1,21 @@ +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + +use crate::compatibility::openai_service::openai_error::OpenAIError; + +#[must_use] +pub fn responses_error(message: &OutgoingMessage) -> Option { + if let OutgoingMessage::Response(ResponseEnvelope { + response: OutgoingResponse::Embedding(_), + .. + }) = message + { + return Some(OpenAIError { + error_type: "invalid_request_error", + message: "unexpected embedding response in responses".to_owned(), + }); + } + + OpenAIError::classify(message) +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_response_transformer.rs b/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_response_transformer.rs new file mode 100644 index 00000000..b5447b38 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_response_transformer.rs @@ -0,0 +1,315 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use parking_lot::Mutex; +use serde_json::Value; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::compatibility::openai_service::arguments_to_tool_call_string::arguments_to_tool_call_string; +use crate::compatibility::openai_service::function_call_item::function_call_item; +use crate::compatibility::openai_service::message_item_done::message_item_done; +use crate::compatibility::openai_service::reasoning_item_done::reasoning_item_done; +use crate::compatibility::openai_service::responses_error::responses_error; +use crate::compatibility::openai_service::responses_non_streaming_state::ResponsesNonStreamingState; +use crate::compatibility::openai_service::responses_response_builder::ResponsesResponseBuilder; + +#[derive(Clone)] +pub struct ResponsesNonStreamingResponseTransformer { + pub builder: ResponsesResponseBuilder, + pub state: Arc>, +} + +impl ResponsesNonStreamingResponseTransformer { + fn build_completed(&self, summary: &GenerationSummary) -> Result { + let snapshot = self.state.lock().clone(); + + let mut output: Vec = Vec::new(); + + if !snapshot.reasoning.is_empty() { + output.push(reasoning_item_done( + &format!("rs_{}", output.len()), + &snapshot.reasoning, + )); + } + + let has_tool_calls = !snapshot.tool_calls.is_empty(); + + if !snapshot.content.is_empty() || !has_tool_calls { + output.push(message_item_done( + &format!("msg_{}", output.len()), + &snapshot.content, + )); + } + + for call in &snapshot.tool_calls { + let arguments = arguments_to_tool_call_string(&call.arguments)?; + + output.push(function_call_item( + &format!("fc_{}", output.len()), + &call.id, + &call.name, + &arguments, + "completed", + )); + } + + serde_json::to_string(&self.builder.completed(output, &summary.usage)) + .context("serializing non-streaming responses completion") + } +} + +#[async_trait] +impl TransformsOutgoingMessage for ResponsesNonStreamingResponseTransformer { + type Output = TransformResult; + + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let Some(error) = responses_error(&message) { + return Ok(vec![TransformResult::Error( + error.to_envelope().to_string(), + )]); + } + + match message { + OutgoingMessage::Response(ResponseEnvelope { + response: OutgoingResponse::GeneratedToken(token), + .. + }) => match token { + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text) => { + self.state.lock().content.push_str(&text); + Ok(vec![]) + } + GeneratedTokenResult::ReasoningToken(text) => { + self.state.lock().reasoning.push_str(&text); + Ok(vec![]) + } + GeneratedTokenResult::ToolCallToken(_) => Ok(vec![]), + GeneratedTokenResult::ToolCallParsed(parsed_calls) => { + self.state.lock().tool_calls.extend(parsed_calls); + Ok(vec![]) + } + GeneratedTokenResult::Done(summary) => Ok(vec![TransformResult::Chunk( + self.build_completed(&summary)?, + )]), + other => Err(anyhow!( + "ResponsesNonStreamingResponseTransformer received a token it does not know how to handle: {other:?}" + )), + }, + other => Err(anyhow!( + "ResponsesNonStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" + )), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use llama_cpp_bindings_types::ParsedToolCall; + use llama_cpp_bindings_types::TokenUsage; + use llama_cpp_bindings_types::ToolCallArguments; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::generation_summary::GenerationSummary; + use paddler_messaging::inference_client::message::Message as OutgoingMessage; + use paddler_messaging::inference_client::response::Response as OutgoingResponse; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; + use parking_lot::Mutex; + use serde_json::json; + + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; + use crate::compatibility::openai_service::responses_response_builder::ResponsesResponseBuilder; + + use super::ResponsesNonStreamingResponseTransformer; + use super::ResponsesNonStreamingState; + + #[must_use] + pub fn token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::GeneratedToken(token_result), + }) + } + + #[must_use] + pub fn summary_with_counts( + prompt_tokens: u64, + content_tokens: u64, + reasoning_tokens: u64, + ) -> GenerationSummary { + GenerationSummary { + usage: TokenUsage { + prompt_tokens, + content_tokens, + reasoning_tokens, + ..TokenUsage::default() + }, + } + } + + #[must_use] + pub fn weather_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(json!({ "location": "Paris" })), + ) + } + + #[must_use] + pub fn builder() -> ResponsesResponseBuilder { + ResponsesResponseBuilder { + id: "resp_test".to_owned(), + created_at: 0, + model: "test-model".to_owned(), + instructions: None, + } + } + + fn non_streaming_transformer() -> ResponsesNonStreamingResponseTransformer { + ResponsesNonStreamingResponseTransformer { + builder: builder(), + state: Arc::new(Mutex::new(ResponsesNonStreamingState::default())), + } + } + + #[tokio::test] + async fn non_streaming_aggregates_content_into_a_message_item() { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hel".to_owned(), + ))) + .await + .unwrap(); + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "lo".to_owned(), + ))) + .await + .unwrap(); + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done( + summary_with_counts(3, 2, 0), + ))) + .await + .unwrap(); + + let TransformResult::Chunk(body) = &chunks[0] else { + panic!("expected a chunk"); + }; + let response: serde_json::Value = serde_json::from_str(body).unwrap(); + + assert_eq!(response["object"], "response"); + assert_eq!(response["status"], "completed"); + assert_eq!(response["output"][0]["type"], "message"); + assert_eq!(response["output"][0]["content"][0]["text"], "hello"); + } + + #[tokio::test] + async fn non_streaming_surfaces_reasoning_and_tool_calls_in_output() { + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ReasoningToken( + "ponder".to_owned(), + ))) + .await + .unwrap(); + transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await + .unwrap(); + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done( + summary_with_counts(3, 0, 1), + ))) + .await + .unwrap(); + + let TransformResult::Chunk(body) = &chunks[0] else { + panic!("expected a chunk"); + }; + let response: serde_json::Value = serde_json::from_str(body).unwrap(); + + assert_eq!(response["output"][0]["type"], "reasoning"); + assert_eq!(response["output"][1]["type"], "function_call"); + assert_eq!(response["output"][1]["name"], "get_weather"); + assert_eq!( + response["usage"]["output_tokens_details"]["reasoning_tokens"], + 1 + ); + } + + #[tokio::test] + async fn non_streaming_error_returns_an_error_envelope() { + let transformer = non_streaming_transformer(); + + let chunks = transformer + .transform(token_message(GeneratedTokenResult::SamplerError( + "sampler blew up".to_owned(), + ))) + .await + .unwrap(); + + let TransformResult::Error(body) = &chunks[0] else { + panic!("expected an error"); + }; + + assert!(body.contains("sampler blew up")); + assert!(body.contains("server_error")); + } + + #[tokio::test] + async fn the_non_streaming_response_conforms_to_the_official_schema() { + let validator = OpenAIValidator::new().unwrap(); + let transformer = non_streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ReasoningToken( + "p".to_owned(), + ))) + .await + .unwrap(); + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hello".to_owned(), + ))) + .await + .unwrap(); + transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await + .unwrap(); + let chunks = transformer + .transform(token_message(GeneratedTokenResult::Done( + summary_with_counts(5, 3, 2), + ))) + .await + .unwrap(); + + let TransformResult::Chunk(body) = &chunks[0] else { + panic!("expected a chunk"); + }; + let response: serde_json::Value = serde_json::from_str(body).unwrap(); + + validator.validate_responses_response(&response).unwrap(); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_state.rs b/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_state.rs new file mode 100644 index 00000000..c7c8e4ac --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_non_streaming_state.rs @@ -0,0 +1,8 @@ +use llama_cpp_bindings_types::ParsedToolCall; + +#[derive(Clone, Default)] +pub struct ResponsesNonStreamingState { + pub content: String, + pub reasoning: String, + pub tool_calls: Vec, +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_prepared_request.rs b/paddler_balancer/src/compatibility/openai_service/responses_prepared_request.rs new file mode 100644 index 00000000..4175ef9c --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_prepared_request.rs @@ -0,0 +1,9 @@ +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + +pub struct ResponsesPreparedRequest { + pub paddler_params: ContinueFromConversationHistoryParams, + pub stream: bool, + pub model: String, + pub instructions: Option, +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_response_builder.rs b/paddler_balancer/src/compatibility/openai_service/responses_response_builder.rs new file mode 100644 index 00000000..431d2b2a --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_response_builder.rs @@ -0,0 +1,78 @@ +use llama_cpp_bindings_types::TokenUsage; +use serde_json::Value; +use serde_json::json; + +use crate::compatibility::openai_service::openai_error::OpenAIError; + +fn responses_usage_json(usage: &TokenUsage) -> Value { + json!({ + "input_tokens": usage.prompt_tokens, + "input_tokens_details": { "cached_tokens": usage.cached_prompt_tokens }, + "output_tokens": usage.completion_tokens(), + "output_tokens_details": { "reasoning_tokens": usage.reasoning_tokens }, + "total_tokens": usage.total_tokens(), + }) +} + +#[derive(Clone)] +pub struct ResponsesResponseBuilder { + pub id: String, + pub created_at: u64, + pub model: String, + pub instructions: Option, +} + +impl ResponsesResponseBuilder { + // `usage` is intentionally absent here: the official `ResponseUsage` reference is not nullable, so the + // in-progress and failed snapshots must omit it rather than emit `null`. Only `completed` adds it. + fn base(&self, status: &str, output: &Value, error: &Value) -> Value { + let instructions = self + .instructions + .as_ref() + .map_or(Value::Null, |instructions| json!(instructions)); + + json!({ + "id": self.id, + "object": "response", + "created_at": self.created_at, + "status": status, + "error": error, + "incomplete_details": null, + "instructions": instructions, + "model": self.model, + "tools": [], + "output": output, + "parallel_tool_calls": true, + "metadata": {}, + "tool_choice": "auto", + "temperature": 1, + "top_p": 1, + "text": { "format": { "type": "text" } } + }) + } + + #[must_use] + pub fn in_progress(&self) -> Value { + self.base("in_progress", &json!([]), &Value::Null) + } + + #[must_use] + pub fn completed(&self, output: Vec, usage: &TokenUsage) -> Value { + let mut response = self.base("completed", &Value::Array(output), &Value::Null); + + if let Some(object) = response.as_object_mut() { + object.insert("usage".to_owned(), responses_usage_json(usage)); + } + + response + } + + #[must_use] + pub fn failed(&self, error: &OpenAIError) -> Value { + self.base( + "failed", + &json!([]), + &json!({ "code": "server_error", "message": error.message }), + ) + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_stream_event.rs b/paddler_balancer/src/compatibility/openai_service/responses_stream_event.rs new file mode 100644 index 00000000..c7a82a74 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_stream_event.rs @@ -0,0 +1,190 @@ +use serde_json::Value; +use serde_json::json; + +use crate::compatibility::openai_service::content_part_event::ContentPartEvent; +use crate::compatibility::openai_service::function_call_arguments_delta_event::FunctionCallArgumentsDeltaEvent; +use crate::compatibility::openai_service::function_call_arguments_done_event::FunctionCallArgumentsDoneEvent; +use crate::compatibility::openai_service::output_item_event::OutputItemEvent; +use crate::compatibility::openai_service::response_snapshot_event::ResponseSnapshotEvent; +use crate::compatibility::openai_service::text_delta_event::TextDeltaEvent; +use crate::compatibility::openai_service::text_done_event::TextDoneEvent; + +#[derive(Clone, Debug)] +pub enum ResponsesStreamEvent { + Created(ResponseSnapshotEvent), + InProgress(ResponseSnapshotEvent), + OutputItemAdded(OutputItemEvent), + OutputItemDone(OutputItemEvent), + ContentPartAdded(ContentPartEvent), + ContentPartDone(ContentPartEvent), + OutputTextDelta(TextDeltaEvent), + OutputTextDone(TextDoneEvent), + ReasoningTextDelta(TextDeltaEvent), + ReasoningTextDone(TextDoneEvent), + FunctionCallArgumentsDelta(FunctionCallArgumentsDeltaEvent), + FunctionCallArgumentsDone(FunctionCallArgumentsDoneEvent), + Completed(ResponseSnapshotEvent), + Failed(ResponseSnapshotEvent), +} + +impl ResponsesStreamEvent { + #[must_use] + pub const fn event_name(&self) -> &'static str { + match self { + Self::Created(_) => "response.created", + Self::InProgress(_) => "response.in_progress", + Self::OutputItemAdded(_) => "response.output_item.added", + Self::OutputItemDone(_) => "response.output_item.done", + Self::ContentPartAdded(_) => "response.content_part.added", + Self::ContentPartDone(_) => "response.content_part.done", + Self::OutputTextDelta(_) => "response.output_text.delta", + Self::OutputTextDone(_) => "response.output_text.done", + Self::ReasoningTextDelta(_) => "response.reasoning_text.delta", + Self::ReasoningTextDone(_) => "response.reasoning_text.done", + Self::FunctionCallArgumentsDelta(_) => "response.function_call_arguments.delta", + Self::FunctionCallArgumentsDone(_) => "response.function_call_arguments.done", + Self::Completed(_) => "response.completed", + Self::Failed(_) => "response.failed", + } + } + + #[must_use] + pub fn to_json(&self) -> Value { + let event_type = self.event_name(); + + match self { + Self::Created(snapshot) + | Self::InProgress(snapshot) + | Self::Completed(snapshot) + | Self::Failed(snapshot) => json!({ + "type": event_type, + "sequence_number": snapshot.sequence_number, + "response": snapshot.response, + }), + Self::OutputItemAdded(item_event) | Self::OutputItemDone(item_event) => json!({ + "type": event_type, + "sequence_number": item_event.sequence_number, + "output_index": item_event.output_index, + "item": item_event.item, + }), + Self::ContentPartAdded(part_event) | Self::ContentPartDone(part_event) => json!({ + "type": event_type, + "sequence_number": part_event.sequence_number, + "item_id": part_event.item_id, + "output_index": part_event.output_index, + "content_index": part_event.content_index, + "part": part_event.part, + }), + Self::OutputTextDelta(delta_event) => json!({ + "type": event_type, + "sequence_number": delta_event.sequence_number, + "item_id": delta_event.item_id, + "output_index": delta_event.output_index, + "content_index": delta_event.content_index, + "delta": delta_event.delta, + "logprobs": [], + }), + // Reasoning text events, unlike output-text events, do not carry a `logprobs` field. + Self::ReasoningTextDelta(delta_event) => json!({ + "type": event_type, + "sequence_number": delta_event.sequence_number, + "item_id": delta_event.item_id, + "output_index": delta_event.output_index, + "content_index": delta_event.content_index, + "delta": delta_event.delta, + }), + Self::OutputTextDone(done_event) => json!({ + "type": event_type, + "sequence_number": done_event.sequence_number, + "item_id": done_event.item_id, + "output_index": done_event.output_index, + "content_index": done_event.content_index, + "text": done_event.text, + "logprobs": [], + }), + Self::ReasoningTextDone(done_event) => json!({ + "type": event_type, + "sequence_number": done_event.sequence_number, + "item_id": done_event.item_id, + "output_index": done_event.output_index, + "content_index": done_event.content_index, + "text": done_event.text, + }), + Self::FunctionCallArgumentsDelta(arguments_event) => json!({ + "type": event_type, + "sequence_number": arguments_event.sequence_number, + "item_id": arguments_event.item_id, + "output_index": arguments_event.output_index, + "delta": arguments_event.delta, + }), + Self::FunctionCallArgumentsDone(arguments_event) => json!({ + "type": event_type, + "sequence_number": arguments_event.sequence_number, + "item_id": arguments_event.item_id, + "output_index": arguments_event.output_index, + "name": arguments_event.name, + "arguments": arguments_event.arguments, + }), + } + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::ResponseSnapshotEvent; + use super::ResponsesStreamEvent; + use super::TextDeltaEvent; + + #[test] + fn reasoning_and_text_delta_carry_their_distinct_event_names_with_the_same_payload_shape() { + let text_delta = ResponsesStreamEvent::OutputTextDelta(TextDeltaEvent { + sequence_number: 4, + item_id: "msg_0".to_owned(), + output_index: 0, + content_index: 0, + delta: "hi".to_owned(), + }); + let reasoning_delta = ResponsesStreamEvent::ReasoningTextDelta(TextDeltaEvent { + sequence_number: 4, + item_id: "rs_0".to_owned(), + output_index: 0, + content_index: 0, + delta: "hmm".to_owned(), + }); + + assert_eq!(text_delta.event_name(), "response.output_text.delta"); + assert_eq!( + reasoning_delta.event_name(), + "response.reasoning_text.delta" + ); + } + + #[test] + fn to_json_type_field_matches_event_name() { + let event = ResponsesStreamEvent::Completed(ResponseSnapshotEvent { + sequence_number: 7, + response: json!({ "id": "resp_0" }), + }); + + let serialized = event.to_json(); + + assert_eq!(serialized["type"], event.event_name()); + assert_eq!(serialized["sequence_number"], 7); + assert_eq!(serialized["response"]["id"], "resp_0"); + } + + #[test] + fn text_delta_includes_the_required_logprobs_array() { + let event = ResponsesStreamEvent::OutputTextDelta(TextDeltaEvent { + sequence_number: 1, + item_id: "msg_0".to_owned(), + output_index: 0, + content_index: 0, + delta: "x".to_owned(), + }); + + assert_eq!(event.to_json()["logprobs"], json!([])); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_streaming_response_transformer.rs b/paddler_balancer/src/compatibility/openai_service/responses_streaming_response_transformer.rs new file mode 100644 index 00000000..f07b3453 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_streaming_response_transformer.rs @@ -0,0 +1,437 @@ +use std::sync::Arc; + +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use parking_lot::Mutex; + +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::compatibility::openai_service::response_snapshot_event::ResponseSnapshotEvent; +use crate::compatibility::openai_service::responses_error::responses_error; +use crate::compatibility::openai_service::responses_response_builder::ResponsesResponseBuilder; +use crate::compatibility::openai_service::responses_stream_event::ResponsesStreamEvent; +use crate::compatibility::openai_service::responses_streaming_state::ResponsesStreamingState; + +#[derive(Clone)] +pub struct ResponsesStreamingResponseTransformer { + pub builder: ResponsesResponseBuilder, + pub state: Arc>, +} + +impl ResponsesStreamingResponseTransformer { + fn ensure_preamble( + &self, + state: &mut ResponsesStreamingState, + events: &mut Vec, + ) { + if state.started { + return; + } + + state.started = true; + + let created_sequence_number = state.next_sequence_number(); + events.push(ResponsesStreamEvent::Created(ResponseSnapshotEvent { + sequence_number: created_sequence_number, + response: self.builder.in_progress(), + })); + + let in_progress_sequence_number = state.next_sequence_number(); + events.push(ResponsesStreamEvent::InProgress(ResponseSnapshotEvent { + sequence_number: in_progress_sequence_number, + response: self.builder.in_progress(), + })); + } + + fn handle_done( + &self, + state: &mut ResponsesStreamingState, + events: &mut Vec, + summary: &GenerationSummary, + ) { + state.close_open_item(events); + + let output = state.finalized_output.clone(); + let completed_sequence_number = state.next_sequence_number(); + events.push(ResponsesStreamEvent::Completed(ResponseSnapshotEvent { + sequence_number: completed_sequence_number, + response: self.builder.completed(output, &summary.usage), + })); + } +} + +#[async_trait] +impl TransformsOutgoingMessage for ResponsesStreamingResponseTransformer { + type Output = ResponsesStreamEvent; + + #[expect( + clippy::significant_drop_tightening, + reason = "one guard must span the whole per-message state transition; calls are serial so there is no contention" + )] + async fn transform(&self, message: OutgoingMessage) -> Result> { + let mut events: Vec = Vec::new(); + let mut state = self.state.lock(); + + if let Some(error) = responses_error(&message) { + self.ensure_preamble(&mut state, &mut events); + + let failed_sequence_number = state.next_sequence_number(); + events.push(ResponsesStreamEvent::Failed(ResponseSnapshotEvent { + sequence_number: failed_sequence_number, + response: self.builder.failed(&error), + })); + + return Ok(events); + } + + match message { + OutgoingMessage::Response(ResponseEnvelope { + response: OutgoingResponse::GeneratedToken(token), + .. + }) => match token { + GeneratedTokenResult::ContentToken(text) + | GeneratedTokenResult::UndeterminableToken(text) => { + self.ensure_preamble(&mut state, &mut events); + state.handle_content(&mut events, &text); + } + GeneratedTokenResult::ReasoningToken(text) => { + self.ensure_preamble(&mut state, &mut events); + state.handle_reasoning(&mut events, &text); + } + GeneratedTokenResult::ToolCallToken(_) => {} + GeneratedTokenResult::ToolCallParsed(parsed_calls) => { + self.ensure_preamble(&mut state, &mut events); + state.handle_tool_calls(&mut events, &parsed_calls)?; + } + GeneratedTokenResult::Done(summary) => { + self.ensure_preamble(&mut state, &mut events); + self.handle_done(&mut state, &mut events, &summary); + } + other => { + return Err(anyhow!( + "ResponsesStreamingResponseTransformer received a token it does not know how to handle: {other:?}" + )); + } + }, + other => { + return Err(anyhow!( + "ResponsesStreamingResponseTransformer received an outgoing message it does not know how to handle: {other:?}" + )); + } + } + + Ok(events) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use llama_cpp_bindings_types::ParsedToolCall; + use llama_cpp_bindings_types::TokenUsage; + use llama_cpp_bindings_types::ToolCallArguments; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::generation_summary::GenerationSummary; + use paddler_messaging::inference_client::message::Message as OutgoingMessage; + use paddler_messaging::inference_client::response::Response as OutgoingResponse; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; + use parking_lot::Mutex; + use serde_json::json; + + use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; + use crate::compatibility::openai_service::responses_response_builder::ResponsesResponseBuilder; + use crate::compatibility::openai_service::responses_stream_event::ResponsesStreamEvent; + + use super::ResponsesStreamingResponseTransformer; + use super::ResponsesStreamingState; + + #[must_use] + pub fn token_message(token_result: GeneratedTokenResult) -> OutgoingMessage { + OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "test-request".to_owned(), + response: OutgoingResponse::GeneratedToken(token_result), + }) + } + + #[must_use] + pub fn summary_with_counts( + prompt_tokens: u64, + content_tokens: u64, + reasoning_tokens: u64, + ) -> GenerationSummary { + GenerationSummary { + usage: TokenUsage { + prompt_tokens, + content_tokens, + reasoning_tokens, + ..TokenUsage::default() + }, + } + } + + #[must_use] + pub fn weather_call() -> ParsedToolCall { + ParsedToolCall::new( + "call_x".to_owned(), + "get_weather".to_owned(), + ToolCallArguments::ValidJson(json!({ "location": "Paris" })), + ) + } + + #[must_use] + pub fn builder() -> ResponsesResponseBuilder { + ResponsesResponseBuilder { + id: "resp_test".to_owned(), + created_at: 0, + model: "test-model".to_owned(), + instructions: None, + } + } + + fn streaming_transformer() -> ResponsesStreamingResponseTransformer { + ResponsesStreamingResponseTransformer { + builder: builder(), + state: Arc::new(Mutex::new(ResponsesStreamingState::default())), + } + } + + fn names(events: &[ResponsesStreamEvent]) -> Vec<&'static str> { + events + .iter() + .map(ResponsesStreamEvent::event_name) + .collect() + } + + #[tokio::test] + async fn streaming_first_content_token_emits_preamble_then_text_delta() { + let transformer = streaming_transformer(); + + let events = transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hi".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!( + names(&events), + vec![ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.content_part.added", + "response.output_text.delta", + ] + ); + assert_eq!(events[0].to_json()["response"]["status"], "in_progress"); + assert_eq!(events[4].to_json()["delta"], "hi"); + } + + #[tokio::test] + async fn streaming_preamble_is_emitted_only_once() { + let transformer = streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "a".to_owned(), + ))) + .await + .unwrap(); + let events = transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "b".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!(names(&events), vec!["response.output_text.delta"]); + } + + #[tokio::test] + async fn streaming_done_finalizes_message_and_emits_completed_with_usage() { + let transformer = streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "hello".to_owned(), + ))) + .await + .unwrap(); + let events = transformer + .transform(token_message(GeneratedTokenResult::Done( + summary_with_counts(7, 4, 1), + ))) + .await + .unwrap(); + + assert_eq!( + names(&events), + vec![ + "response.output_text.done", + "response.content_part.done", + "response.output_item.done", + "response.completed", + ] + ); + + let completed = events[3].to_json(); + + assert_eq!(completed["response"]["status"], "completed"); + assert_eq!(completed["response"]["usage"]["input_tokens"], 7); + assert_eq!(completed["response"]["usage"]["total_tokens"], 12); + assert_eq!( + completed["response"]["output"][0]["content"][0]["text"], + "hello" + ); + assert_eq!( + completed["response"]["output"][0]["content"][0]["logprobs"], + json!([]) + ); + } + + #[tokio::test] + async fn streaming_reasoning_then_content_closes_the_reasoning_item_first() { + let transformer = streaming_transformer(); + + transformer + .transform(token_message(GeneratedTokenResult::ReasoningToken( + "think".to_owned(), + ))) + .await + .unwrap(); + let events = transformer + .transform(token_message(GeneratedTokenResult::ContentToken( + "answer".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!( + names(&events), + vec![ + "response.reasoning_text.done", + "response.output_item.done", + "response.output_item.added", + "response.content_part.added", + "response.output_text.delta", + ] + ); + // reasoning item closed at output_index 0, message opened at output_index 1 + assert_eq!(events[1].to_json()["output_index"], 0); + assert_eq!(events[2].to_json()["output_index"], 1); + assert_eq!(events[1].to_json()["item"]["type"], "reasoning"); + } + + #[tokio::test] + async fn streaming_tool_call_emits_function_call_argument_events_without_content_index() { + let transformer = streaming_transformer(); + + let events = transformer + .transform(token_message(GeneratedTokenResult::ToolCallParsed(vec![ + weather_call(), + ]))) + .await + .unwrap(); + + assert_eq!( + names(&events), + vec![ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.delta", + "response.function_call_arguments.done", + "response.output_item.done", + ] + ); + + let delta_event = events[3].to_json(); + + assert_eq!(delta_event["delta"], "{\"location\":\"Paris\"}"); + assert!( + delta_event.get("content_index").is_none(), + "function_call_arguments events must not carry a content_index" + ); + assert_eq!(events[4].to_json()["name"], "get_weather"); + assert_eq!(events[5].to_json()["item"]["call_id"], "call_x"); + } + + #[tokio::test] + async fn streaming_error_emits_preamble_then_failed() { + let transformer = streaming_transformer(); + + let events = transformer + .transform(token_message(GeneratedTokenResult::ChatTemplateError( + "boom".to_owned(), + ))) + .await + .unwrap(); + + assert_eq!( + names(&events), + vec![ + "response.created", + "response.in_progress", + "response.failed" + ] + ); + + let failed = events[2].to_json(); + + assert_eq!(failed["response"]["status"], "failed"); + assert_eq!(failed["response"]["error"]["code"], "server_error"); + assert_eq!(failed["response"]["error"]["message"], "boom"); + } + + #[tokio::test] + async fn every_emitted_streaming_event_conforms_to_the_official_schema() { + let validator = OpenAIValidator::new().unwrap(); + let transformer = streaming_transformer(); + + let mut emitted: Vec = Vec::new(); + + for token in [ + GeneratedTokenResult::ReasoningToken("ponder".to_owned()), + GeneratedTokenResult::ContentToken("hello".to_owned()), + GeneratedTokenResult::ToolCallParsed(vec![weather_call()]), + GeneratedTokenResult::Done(summary_with_counts(5, 3, 2)), + ] { + emitted.extend(transformer.transform(token_message(token)).await.unwrap()); + } + + assert!(emitted.len() > 10); + + for event in &emitted { + validator + .validate_responses_stream_event(&event.to_json()) + .unwrap(); + } + } + + #[tokio::test] + async fn the_failed_streaming_event_conforms_to_the_official_schema() { + let validator = OpenAIValidator::new().unwrap(); + let transformer = streaming_transformer(); + + let events = transformer + .transform(token_message(GeneratedTokenResult::SamplerError( + "boom".to_owned(), + ))) + .await + .unwrap(); + + for event in &events { + validator + .validate_responses_stream_event(&event.to_json()) + .unwrap(); + } + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/responses_streaming_state.rs b/paddler_balancer/src/compatibility/openai_service/responses_streaming_state.rs new file mode 100644 index 00000000..c6fc70bd --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/responses_streaming_state.rs @@ -0,0 +1,256 @@ +use anyhow::Result; +use llama_cpp_bindings_types::ParsedToolCall; +use serde_json::Value; +use serde_json::json; + +use crate::compatibility::openai_service::arguments_to_tool_call_string::arguments_to_tool_call_string; +use crate::compatibility::openai_service::content_part_event::ContentPartEvent; +use crate::compatibility::openai_service::function_call_arguments_delta_event::FunctionCallArgumentsDeltaEvent; +use crate::compatibility::openai_service::function_call_arguments_done_event::FunctionCallArgumentsDoneEvent; +use crate::compatibility::openai_service::function_call_item::function_call_item; +use crate::compatibility::openai_service::message_item_done::message_item_done; +use crate::compatibility::openai_service::open_item::OpenItem; +use crate::compatibility::openai_service::output_item_event::OutputItemEvent; +use crate::compatibility::openai_service::output_text_part::output_text_part; +use crate::compatibility::openai_service::reasoning_item_done::reasoning_item_done; +use crate::compatibility::openai_service::responses_stream_event::ResponsesStreamEvent; +use crate::compatibility::openai_service::text_delta_event::TextDeltaEvent; +use crate::compatibility::openai_service::text_done_event::TextDoneEvent; + +fn message_item_open(item_id: &str) -> Value { + json!({ + "id": item_id, + "type": "message", + "role": "assistant", + "status": "in_progress", + "content": [] + }) +} + +fn reasoning_item_open(item_id: &str) -> Value { + json!({ + "type": "reasoning", + "id": item_id, + "summary": [], + "status": "in_progress" + }) +} + +#[derive(Default)] +pub struct ResponsesStreamingState { + pub started: bool, + sequence_number: u64, + output_index: usize, + open: OpenItem, + reasoning_id: String, + reasoning_text: String, + message_id: String, + message_text: String, + pub finalized_output: Vec, +} + +impl ResponsesStreamingState { + pub const fn next_sequence_number(&mut self) -> u64 { + let sequence_number = self.sequence_number; + self.sequence_number += 1; + + sequence_number + } + + pub fn close_open_item(&mut self, events: &mut Vec) { + match self.open { + OpenItem::None => {} + OpenItem::Reasoning => { + let item_id = self.reasoning_id.clone(); + let text = self.reasoning_text.clone(); + let output_index = self.output_index; + + let text_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::ReasoningTextDone(TextDoneEvent { + sequence_number: text_done_sequence_number, + item_id: item_id.clone(), + output_index, + content_index: 0, + text: text.clone(), + })); + + let item = reasoning_item_done(&item_id, &text); + let item_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemDone(OutputItemEvent { + sequence_number: item_done_sequence_number, + output_index, + item: item.clone(), + })); + + self.finalized_output.push(item); + self.output_index += 1; + self.reasoning_text.clear(); + self.open = OpenItem::None; + } + OpenItem::Message => { + let item_id = self.message_id.clone(); + let text = self.message_text.clone(); + let output_index = self.output_index; + + let text_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputTextDone(TextDoneEvent { + sequence_number: text_done_sequence_number, + item_id: item_id.clone(), + output_index, + content_index: 0, + text: text.clone(), + })); + + let part_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::ContentPartDone(ContentPartEvent { + sequence_number: part_done_sequence_number, + item_id: item_id.clone(), + output_index, + content_index: 0, + part: output_text_part(&text), + })); + + let item = message_item_done(&item_id, &text); + let item_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemDone(OutputItemEvent { + sequence_number: item_done_sequence_number, + output_index, + item: item.clone(), + })); + + self.finalized_output.push(item); + self.output_index += 1; + self.message_text.clear(); + self.open = OpenItem::None; + } + } + } + + pub fn handle_reasoning(&mut self, events: &mut Vec, text: &str) { + if self.open != OpenItem::Reasoning { + self.close_open_item(events); + + let output_index = self.output_index; + let item_id = format!("rs_{output_index}"); + self.reasoning_id.clone_from(&item_id); + + let added_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemAdded(OutputItemEvent { + sequence_number: added_sequence_number, + output_index, + item: reasoning_item_open(&item_id), + })); + + self.open = OpenItem::Reasoning; + } + + self.reasoning_text.push_str(text); + + let item_id = self.reasoning_id.clone(); + let output_index = self.output_index; + let delta_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::ReasoningTextDelta(TextDeltaEvent { + sequence_number: delta_sequence_number, + item_id, + output_index, + content_index: 0, + delta: text.to_owned(), + })); + } + + pub fn handle_content(&mut self, events: &mut Vec, text: &str) { + if self.open != OpenItem::Message { + self.close_open_item(events); + + let output_index = self.output_index; + let item_id = format!("msg_{output_index}"); + self.message_id.clone_from(&item_id); + + let added_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemAdded(OutputItemEvent { + sequence_number: added_sequence_number, + output_index, + item: message_item_open(&item_id), + })); + + let part_added_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::ContentPartAdded(ContentPartEvent { + sequence_number: part_added_sequence_number, + item_id, + output_index, + content_index: 0, + part: output_text_part(""), + })); + + self.open = OpenItem::Message; + } + + self.message_text.push_str(text); + + let item_id = self.message_id.clone(); + let output_index = self.output_index; + let delta_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputTextDelta(TextDeltaEvent { + sequence_number: delta_sequence_number, + item_id, + output_index, + content_index: 0, + delta: text.to_owned(), + })); + } + + pub fn handle_tool_calls( + &mut self, + events: &mut Vec, + parsed_calls: &[ParsedToolCall], + ) -> Result<()> { + self.close_open_item(events); + + for call in parsed_calls { + let output_index = self.output_index; + let item_id = format!("fc_{output_index}"); + let arguments = arguments_to_tool_call_string(&call.arguments)?; + + let added_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemAdded(OutputItemEvent { + sequence_number: added_sequence_number, + output_index, + item: function_call_item(&item_id, &call.id, &call.name, "", "in_progress"), + })); + + let delta_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::FunctionCallArgumentsDelta( + FunctionCallArgumentsDeltaEvent { + sequence_number: delta_sequence_number, + item_id: item_id.clone(), + output_index, + delta: arguments.clone(), + }, + )); + + let done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::FunctionCallArgumentsDone( + FunctionCallArgumentsDoneEvent { + sequence_number: done_sequence_number, + item_id: item_id.clone(), + output_index, + name: call.name.clone(), + arguments: arguments.clone(), + }, + )); + + let item = function_call_item(&item_id, &call.id, &call.name, &arguments, "completed"); + let item_done_sequence_number = self.next_sequence_number(); + events.push(ResponsesStreamEvent::OutputItemDone(OutputItemEvent { + sequence_number: item_done_sequence_number, + output_index, + item: item.clone(), + })); + + self.finalized_output.push(item); + self.output_index += 1; + } + + Ok(()) + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/sse_response_from_agent.rs b/paddler_balancer/src/compatibility/openai_service/sse_response_from_agent.rs new file mode 100644 index 00000000..8af3fa47 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/sse_response_from_agent.rs @@ -0,0 +1,53 @@ +use std::convert::Infallible; +use std::fmt::Debug; +use std::sync::Arc; + +use actix_web::HttpResponse; +use actix_web::http::header; +use actix_web_lab::sse; +use futures::stream::StreamExt as _; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; +use paddler_messaging::streamable_result::StreamableResult; +use tokio_util::sync::CancellationToken; + +use crate::agent_controller::AgentController; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::compatibility::openai_service::responses_stream_event::ResponsesStreamEvent; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::manages_senders::ManagesSenders; +use crate::unbounded_stream_from_agent::unbounded_stream_from_agent; + +fn event_to_sse_data(event: &ResponsesStreamEvent) -> sse::Data { + sse::Data::new(event.to_json().to_string()).event(event.event_name()) +} + +pub fn sse_response_from_agent( + buffered_request_manager: Arc, + inference_service_configuration: InferenceServiceConfiguration, + params: TParams, + transformer: TTransformsOutgoingMessage, + shutdown: CancellationToken, +) -> HttpResponse +where + TParams: Debug + Into + Send + 'static, + AgentController: HandlesAgentStreamingResponse, + <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, + TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, +{ + let event_stream = unbounded_stream_from_agent( + buffered_request_manager, + inference_service_configuration, + params, + transformer, + shutdown, + ) + .map(|event| Ok::(sse::Event::Data(event_to_sse_data(&event)))); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header((header::CACHE_CONTROL, "no-cache")) + .body(sse::Sse::from_stream(event_stream)) +} diff --git a/paddler_balancer/src/compatibility/openai_service/stream_options.rs b/paddler_balancer/src/compatibility/openai_service/stream_options.rs new file mode 100644 index 00000000..41de73fc --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/stream_options.rs @@ -0,0 +1,8 @@ +use serde::Deserialize; + +#[derive(Default, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct StreamOptions { + #[serde(default)] + pub include_usage: bool, +} diff --git a/paddler_balancer/src/compatibility/openai_service/text_delta_event.rs b/paddler_balancer/src/compatibility/openai_service/text_delta_event.rs new file mode 100644 index 00000000..02429806 --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/text_delta_event.rs @@ -0,0 +1,8 @@ +#[derive(Clone, Debug)] +pub struct TextDeltaEvent { + pub sequence_number: u64, + pub item_id: String, + pub output_index: usize, + pub content_index: usize, + pub delta: String, +} diff --git a/paddler_balancer/src/compatibility/openai_service/text_done_event.rs b/paddler_balancer/src/compatibility/openai_service/text_done_event.rs new file mode 100644 index 00000000..bce6be4b --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/text_done_event.rs @@ -0,0 +1,8 @@ +#[derive(Clone, Debug)] +pub struct TextDoneEvent { + pub sequence_number: u64, + pub item_id: String, + pub output_index: usize, + pub content_index: usize, + pub text: String, +} diff --git a/paddler_balancer/src/compatibility/openai_service/timestamp_from.rs b/paddler_balancer/src/compatibility/openai_service/timestamp_from.rs new file mode 100644 index 00000000..1826d1ed --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/timestamp_from.rs @@ -0,0 +1,35 @@ +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use anyhow::Context as _; +use anyhow::Result; + +pub fn timestamp_from(now: SystemTime) -> Result { + Ok(now + .duration_since(UNIX_EPOCH) + .context("system time is before the Unix epoch")? + .as_secs()) +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + use std::time::SystemTime; + use std::time::UNIX_EPOCH; + + use super::timestamp_from; + + #[test] + fn returns_seconds_since_epoch() { + let timestamp = timestamp_from(SystemTime::now()).unwrap(); + + assert!(timestamp > 0); + } + + #[test] + fn errors_before_the_unix_epoch() { + let before_epoch = UNIX_EPOCH - Duration::from_secs(1); + + assert!(timestamp_from(before_epoch).is_err()); + } +} diff --git a/paddler_balancer/src/compatibility/openai_service/try_universal_error_chunk.rs b/paddler_balancer/src/compatibility/openai_service/try_universal_error_chunk.rs new file mode 100644 index 00000000..9a350cea --- /dev/null +++ b/paddler_balancer/src/compatibility/openai_service/try_universal_error_chunk.rs @@ -0,0 +1,27 @@ +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::compatibility::openai_service::openai_error::OpenAIError; + +#[must_use] +pub fn try_universal_error_chunk(message: &OutgoingMessage) -> Option { + if let OutgoingMessage::Response(ResponseEnvelope { + response: OutgoingResponse::Embedding(_), + .. + }) = message + { + return Some(TransformResult::Error( + OpenAIError { + error_type: "invalid_request_error", + message: "unexpected embedding response in chat completions".to_owned(), + } + .to_envelope() + .to_string(), + )); + } + + OpenAIError::classify(message) + .map(|error| TransformResult::Error(error.to_envelope().to_string())) +} diff --git a/paddler/src/continuation_decision.rs b/paddler_balancer/src/continuation_decision.rs similarity index 100% rename from paddler/src/continuation_decision.rs rename to paddler_balancer/src/continuation_decision.rs diff --git a/paddler/src/continuation_stop_parameters.rs b/paddler_balancer/src/continuation_stop_parameters.rs similarity index 100% rename from paddler/src/continuation_stop_parameters.rs rename to paddler_balancer/src/continuation_stop_parameters.rs diff --git a/paddler_balancer/src/controls_manages_senders_endpoint.rs b/paddler_balancer/src/controls_manages_senders_endpoint.rs new file mode 100644 index 00000000..c4b6cefb --- /dev/null +++ b/paddler_balancer/src/controls_manages_senders_endpoint.rs @@ -0,0 +1,169 @@ +use std::sync::Arc; + +use actix_web::Error; +use actix_web::HttpResponse; +use async_trait::async_trait; +use tokio::time::Duration; +use tokio::time::sleep; + +use crate::agent_controller::AgentController; +use crate::agent_controller_pool::AgentControllerPool; +use crate::manages_senders::ManagesSenders; +use crate::manages_senders_controller::ManagesSendersController; + +const TIMEOUT: Duration = Duration::from_secs(3); + +#[async_trait] +pub trait ControlsManagesSendersEndpoint { + type SenderCollection: ManagesSenders + Send + Sync + 'static; + + fn get_agent_controller_pool(&self) -> Arc; + + fn get_agent_id(&self) -> String; + + async fn get_manages_senders_controller( + &self, + agent_controller: Arc, + ) -> anyhow::Result>; + + async fn respond(&self) -> Result { + let agent_controller_pool = self.get_agent_controller_pool(); + let agent_id = self.get_agent_id(); + let Some(agent_controller) = agent_controller_pool.get_agent_controller(&agent_id) else { + return Ok(HttpResponse::NotFound().finish()); + }; + + let connection_close = agent_controller.connection_close.clone(); + + match self.get_manages_senders_controller(agent_controller).await { + Ok(mut receive_response_controller) => { + tokio::select! { + () = connection_close.cancelled() => Ok(HttpResponse::BadGateway().finish()), + () = sleep(TIMEOUT) => Ok(HttpResponse::GatewayTimeout().finish()), + response = receive_response_controller.response_rx.recv() => response.map_or_else( + || Ok(HttpResponse::NotFound().finish()), + |existing_response| Ok(HttpResponse::Ok().json(existing_response)), + ), + } + } + Err(err) => Ok(HttpResponse::InternalServerError().body(format!("{err}"))), + } + } +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + + use actix_web::http::StatusCode; + use async_trait::async_trait; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::ControlsManagesSendersEndpoint; + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::manages_senders_controller::ManagesSendersController; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + + fn registered_agent_id(pool: &AgentControllerPool) -> String { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + let agent_id = "agent-test".to_owned(); + + pool.register_agent_controller( + agent_id.clone(), + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new( + GenerateTokensSenderCollection::default(), + ), + id: agent_id.clone(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new( + ModelMetadataSenderCollection::default(), + ), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(0), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }), + ) + .unwrap(); + + agent_id + } + + struct EndpointWithClosedResponseChannel { + agent_controller_pool: Arc, + agent_id: String, + } + + #[async_trait] + impl ControlsManagesSendersEndpoint for EndpointWithClosedResponseChannel { + type SenderCollection = ModelMetadataSenderCollection; + + fn get_agent_controller_pool(&self) -> Arc { + self.agent_controller_pool.clone() + } + + fn get_agent_id(&self) -> String { + self.agent_id.clone() + } + + async fn get_manages_senders_controller( + &self, + _agent_controller: Arc, + ) -> anyhow::Result> { + let response_sender_collection = Arc::new(ModelMetadataSenderCollection::default()); + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + drop(response_tx); + + Ok(ManagesSendersController { + request_id: "closed-request".to_owned(), + response_rx, + response_sender_collection, + }) + } + } + + #[actix_web::test] + async fn responds_not_found_when_response_channel_yields_no_value() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let agent_id = registered_agent_id(&agent_controller_pool); + + let endpoint = EndpointWithClosedResponseChannel { + agent_controller_pool, + agent_id, + }; + + let response = endpoint.respond().await.unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } +} diff --git a/paddler/src/controls_session.rs b/paddler_balancer/src/controls_session.rs similarity index 90% rename from paddler/src/controls_session.rs rename to paddler_balancer/src/controls_session.rs index 58d19dec..8debba8a 100644 --- a/paddler/src/controls_session.rs +++ b/paddler_balancer/src/controls_session.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; use log::error; -use paddler_types::rpc_message::RpcMessage; +use paddler_messaging::rpc_message::RpcMessage; #[async_trait] pub trait ControlsSession: Send + Sync diff --git a/paddler_balancer/src/controls_websocket_endpoint.rs b/paddler_balancer/src/controls_websocket_endpoint.rs new file mode 100644 index 00000000..d0ec7e96 --- /dev/null +++ b/paddler_balancer/src/controls_websocket_endpoint.rs @@ -0,0 +1,674 @@ +use std::sync::Arc; + +use actix_web::Error; +use actix_web::HttpRequest; +use actix_web::HttpResponse; +use actix_web::rt; +use actix_web::web::Payload; +use actix_ws::AggregatedMessage; +use actix_ws::CloseCode; +use actix_ws::CloseReason; +use actix_ws::ProtocolError; +use actix_ws::Session; +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use futures_util::StreamExt as _; +use log::debug; +use log::error; +use log::warn; +use paddler_messaging::rpc_message::RpcMessage; +use serde::de::DeserializeOwned; +use tokio::time::Duration; +use tokio::time::MissedTickBehavior; +use tokio::time::interval; +use tokio_util::sync::CancellationToken; + +use crate::continuation_decision::ContinuationDecision; +use crate::continuation_stop_parameters::ContinuationStopParameters; +use crate::websocket_session_controller::WebSocketSessionController; + +const MAX_FRAME_SIZE: usize = 50 * 1024 * 1024; +const MAX_CONTINUATION_SIZE: usize = 50 * 1024 * 1024; +const PING_INTERVAL: Duration = Duration::from_secs(3); + +#[async_trait] +pub trait ControlsWebSocketEndpoint: Send + Sync + 'static { + type Context: Send + Sync + 'static; + type IncomingMessage: DeserializeOwned + RpcMessage + Sync + 'static; + type OutgoingMessage: RpcMessage + Sync + 'static; + + fn create_context(&self) -> Self::Context; + + async fn handle_deserialized_message( + connection_close: CancellationToken, + context: Arc, + deserialized_message: Self::IncomingMessage, + websocket_session_controller: WebSocketSessionController, + ) -> Result; + + async fn handle_aggregated_message( + connection_close: CancellationToken, + context: Arc, + msg: Option>, + session: &mut Session, + ) -> Result { + match msg { + Some(Ok(AggregatedMessage::Binary(_))) => { + debug!("Received binary message, but only text messages are supported"); + + Ok(ContinuationDecision::Continue) + } + Some(Ok(AggregatedMessage::Close(_))) | None => { + return Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: None, + })); + } + Some(Ok(AggregatedMessage::Ping(msg))) => { + if session.pong(&msg).await.is_err() { + return Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: None, + })); + } + + Ok(ContinuationDecision::Continue) + } + Some(Ok(AggregatedMessage::Pong(_))) => { + // ignore pong messages + Ok(ContinuationDecision::Continue) + } + Some(Ok(AggregatedMessage::Text(text))) => { + match Self::handle_text_message( + connection_close, + context.clone(), + &text, + WebSocketSessionController::::new(session.clone()), + ) + .await + .context(format!("Text message: {text}")) + { + Ok(continuation_decision) => return Ok(continuation_decision), + Err(err) => { + error!("Error handling text message: {err:?}"); + + Ok(ContinuationDecision::Continue) + } + } + } + Some(Err(ProtocolError::Overflow)) => { + error!("Message exceeded the maximum allowed frame size of {MAX_FRAME_SIZE} bytes"); + + return Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: Some(CloseReason { + code: CloseCode::Size, + description: Some(format!( + "Message exceeded the maximum allowed frame size of {MAX_FRAME_SIZE} bytes" + )), + }), + })); + } + Some(Err(ProtocolError::Io(ref io_err))) + if io_err + .to_string() + .contains("Exceeded maximum continuation size") => + { + error!( + "Message exceeded the maximum allowed continuation size of {MAX_CONTINUATION_SIZE} bytes" + ); + + return Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: Some(CloseReason { + code: CloseCode::Size, + description: Some(format!( + "Message exceeded the maximum allowed continuation size of {MAX_CONTINUATION_SIZE} bytes" + )), + }), + })); + } + Some(Err(err)) => { + error!("Error receiving message: {err:?}"); + + return Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: None, + })); + } + } + } + + async fn handle_serialization_error( + _connection_close: CancellationToken, + _context: Arc, + error: serde_json::Error, + _websocket_session_controller: WebSocketSessionController, + ) -> Result { + error!("Paddler-RPC serialization error: {error}"); + + Ok(ContinuationDecision::Continue) + } + + async fn handle_text_message( + connection_close: CancellationToken, + context: Arc, + text: &str, + websocket_session_controller: WebSocketSessionController, + ) -> Result { + match serde_json::from_str::(text) { + Ok(deserialized_message) => { + rt::spawn(async move { + match Self::handle_deserialized_message( + connection_close.clone(), + context, + deserialized_message, + websocket_session_controller, + ) + .await + { + Ok(ContinuationDecision::Continue) => { + // Continue processing messages + } + Ok(ContinuationDecision::Stop(_)) => connection_close.cancel(), + Err(err) => { + error!("Error handling deserialized message: {err:?}"); + + connection_close.cancel(); + } + } + }); + + Ok(ContinuationDecision::Continue) + } + Err(err @ serde_json::Error { .. }) if err.is_data() || err.is_syntax() => { + error!("JSON-RPC syntax error: {err:?}"); + + Self::handle_serialization_error( + connection_close, + context, + err, + websocket_session_controller, + ) + .await + } + Err(err) => { + error!("Error handling JSON-RPC request: {err:?}"); + + Self::handle_serialization_error( + connection_close, + context, + err, + websocket_session_controller, + ) + .await + } + } + } + + async fn on_connection_start( + _context: Arc, + _session: &mut Session, + ) -> Result { + Ok(ContinuationDecision::Continue) + } + + fn respond( + &self, + payload: Payload, + req: HttpRequest, + shutdown: CancellationToken, + ) -> Result { + let connection_close = CancellationToken::new(); + let context = Arc::new(self.create_context()); + let (res, mut session, msg_stream) = actix_ws::handle(&req, payload)?; + + let mut aggregated_msg_stream = msg_stream + .max_frame_size(MAX_FRAME_SIZE) + .aggregate_continuations() + .max_continuation_size(MAX_CONTINUATION_SIZE); + + rt::spawn(async move { + let mut close_reason: Option = None; + + match Self::on_connection_start(context.clone(), &mut session).await { + Ok(ContinuationDecision::Continue) => {} + Ok(ContinuationDecision::Stop(stop_parameters)) => { + close_reason = stop_parameters.close_reason; + + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed after Stop decision (peer likely already disconnected): {close_err:?}" + ); + } + + return; + } + Err(err) => { + error!("Error in connection start handler: {err:?}"); + + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed after start-handler error (peer likely already disconnected): {close_err:?}" + ); + } + + return; + } + } + let mut ping_ticker = interval(PING_INTERVAL); + + ping_ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + msg = aggregated_msg_stream.next() => { + match Self::handle_aggregated_message( + connection_close.clone(), + context.clone(), + msg, + &mut session, + ).await { + Ok(ContinuationDecision::Continue) => { + // continue processing messages + } + Ok(ContinuationDecision::Stop(stop_parameters)) => { + close_reason = stop_parameters.close_reason; + + break; + } + Err(err) => { + error!("Error handling aggregated message: {err:?}"); + + break; + }, + } + } + _ = ping_ticker.tick() => { + if session.ping(b"").await.is_err() { + break; + } + } + () = connection_close.cancelled() => { + break; + } + () = shutdown.cancelled() => { + close_reason = Some(CloseReason { + code: CloseCode::Away, + description: Some("Server shutting down".to_owned()), + }); + break; + } + } + } + + connection_close.cancel(); + + if let Err(close_err) = session.close(close_reason).await { + warn!( + "WebSocket session close failed at end of message loop (peer likely already disconnected): {close_err:?}" + ); + } + }); + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use actix_web::FromRequest as _; + use actix_web::body::to_bytes; + use actix_web::http::header; + use actix_web::test::TestRequest; + use actix_web::web::Bytes; + use actix_web::web::Payload; + use serde::Deserialize; + use serde::Serialize; + use std::mem::discriminant; + + use super::ContinuationDecision; + use super::ContinuationStopParameters; + use super::ControlsWebSocketEndpoint; + use super::WebSocketSessionController; + use actix_ws::AggregatedMessage; + use actix_ws::CloseCode; + use actix_ws::CloseReason; + use actix_ws::Session; + use anyhow::Result; + use anyhow::anyhow; + use async_trait::async_trait; + use paddler_messaging::rpc_message::RpcMessage; + use std::sync::Arc; + use tokio_util::sync::CancellationToken; + + #[derive(Deserialize, Serialize)] + struct ProbeIncomingMessage {} + + impl RpcMessage for ProbeIncomingMessage {} + + #[derive(Serialize)] + struct ProbeOutgoingMessage; + + impl RpcMessage for ProbeOutgoingMessage {} + + #[derive(Clone, Copy)] + enum DeserializedMessageOutcome { + Continue, + Stop, + Err, + } + + struct ProbeEndpoint { + deserialized_message_outcome: DeserializedMessageOutcome, + } + + #[async_trait] + impl ControlsWebSocketEndpoint for ProbeEndpoint { + type Context = DeserializedMessageOutcome; + type IncomingMessage = ProbeIncomingMessage; + type OutgoingMessage = ProbeOutgoingMessage; + + fn create_context(&self) -> Self::Context { + self.deserialized_message_outcome + } + + async fn handle_deserialized_message( + _connection_close: CancellationToken, + context: Arc, + _deserialized_message: Self::IncomingMessage, + _websocket_session_controller: WebSocketSessionController, + ) -> Result { + match *context { + DeserializedMessageOutcome::Continue => Ok(ContinuationDecision::Continue), + DeserializedMessageOutcome::Stop => { + Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: None, + })) + } + DeserializedMessageOutcome::Err => { + Err(anyhow!("deserialized message handler failed")) + } + } + } + } + + #[derive(Clone, Copy)] + enum ConnectionStartOutcome { + Stop, + Err, + } + + struct StartOverridingEndpoint { + connection_start_outcome: ConnectionStartOutcome, + } + + #[async_trait] + impl ControlsWebSocketEndpoint for StartOverridingEndpoint { + type Context = ConnectionStartOutcome; + type IncomingMessage = ProbeIncomingMessage; + type OutgoingMessage = ProbeOutgoingMessage; + + fn create_context(&self) -> Self::Context { + self.connection_start_outcome + } + + async fn handle_deserialized_message( + _connection_close: CancellationToken, + _context: Arc, + _deserialized_message: Self::IncomingMessage, + _websocket_session_controller: WebSocketSessionController, + ) -> Result { + Ok(ContinuationDecision::Continue) + } + + async fn on_connection_start( + context: Arc, + _session: &mut Session, + ) -> Result { + match *context { + ConnectionStartOutcome::Stop => { + Ok(ContinuationDecision::Stop(ContinuationStopParameters { + close_reason: Some(CloseReason { + code: CloseCode::Normal, + description: Some("stop on start".to_owned()), + }), + })) + } + ConnectionStartOutcome::Err => Err(anyhow!("connection start handler failed")), + } + } + } + + struct AggregatedErroringEndpoint; + + #[async_trait] + impl ControlsWebSocketEndpoint for AggregatedErroringEndpoint { + type Context = (); + type IncomingMessage = ProbeIncomingMessage; + type OutgoingMessage = ProbeOutgoingMessage; + + fn create_context(&self) -> Self::Context {} + + async fn handle_deserialized_message( + _connection_close: CancellationToken, + _context: Arc, + _deserialized_message: Self::IncomingMessage, + _websocket_session_controller: WebSocketSessionController, + ) -> Result { + Ok(ContinuationDecision::Continue) + } + + async fn handle_aggregated_message( + _connection_close: CancellationToken, + _context: Arc, + _msg: Option>, + _session: &mut Session, + ) -> Result { + Err(anyhow!("aggregated message handler failed")) + } + } + + fn handshake_request() -> TestRequest { + TestRequest::get() + .insert_header((header::CONNECTION, "upgrade")) + .insert_header((header::UPGRADE, "websocket")) + .insert_header((header::SEC_WEBSOCKET_VERSION, "13")) + .insert_header((header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")) + } + + #[expect( + clippy::future_not_send, + reason = "test-only helper; the future is awaited in place, never sent across threads" + )] + async fn open_session() -> Session { + let (request, mut raw_payload) = handshake_request().to_http_parts(); + let payload = Payload::from_request(&request, &mut raw_payload) + .await + .unwrap(); + let (_response, session, _msg_stream) = actix_ws::handle(&request, payload).unwrap(); + + session + } + + #[actix_web::test] + async fn handle_serialization_error_continues() { + let session = open_session().await; + let serialization_error = serde_json::from_str::("not-a-number").err().unwrap(); + let continuation_decision = ProbeEndpoint::handle_serialization_error( + CancellationToken::new(), + Arc::new(DeserializedMessageOutcome::Continue), + serialization_error, + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert!(matches!( + continuation_decision, + ContinuationDecision::Continue + )); + } + + #[actix_web::test] + async fn default_on_connection_start_continues() { + let mut session = open_session().await; + let continuation_decision = ProbeEndpoint::on_connection_start( + Arc::new(DeserializedMessageOutcome::Continue), + &mut session, + ) + .await + .unwrap(); + + assert!(matches!( + continuation_decision, + ContinuationDecision::Continue + )); + } + + #[actix_web::test] + async fn deserialized_message_stop_cancels_connection() { + let session = open_session().await; + let connection_close = CancellationToken::new(); + let continuation_decision = ProbeEndpoint::handle_text_message( + connection_close.clone(), + Arc::new(DeserializedMessageOutcome::Stop), + "{}", + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert!(matches!( + continuation_decision, + ContinuationDecision::Continue + )); + + connection_close.cancelled().await; + + assert!(connection_close.is_cancelled()); + } + + #[actix_web::test] + async fn deserialized_message_error_cancels_connection() { + let session = open_session().await; + let connection_close = CancellationToken::new(); + let continuation_decision = ProbeEndpoint::handle_text_message( + connection_close.clone(), + Arc::new(DeserializedMessageOutcome::Err), + "{}", + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert!(matches!( + continuation_decision, + ContinuationDecision::Continue + )); + + connection_close.cancelled().await; + + assert!(connection_close.is_cancelled()); + } + + #[expect( + clippy::future_not_send, + reason = "test-only helper; the future is awaited in place, never sent across threads" + )] + async fn drain_close_frame(endpoint: &impl ControlsWebSocketEndpoint) -> Bytes { + let (request, mut raw_payload) = handshake_request().to_http_parts(); + let payload = Payload::from_request(&request, &mut raw_payload) + .await + .unwrap(); + let response = endpoint + .respond(payload, request, CancellationToken::new()) + .unwrap(); + + assert_eq!(response.status().as_u16(), 101); + + to_bytes(response.into_body()).await.unwrap() + } + + #[actix_web::test] + async fn respond_runs_message_loop_until_stream_closes() { + let close_frame = drain_close_frame(&ProbeEndpoint { + deserialized_message_outcome: DeserializedMessageOutcome::Continue, + }) + .await; + + assert!(!close_frame.is_empty()); + } + + #[actix_web::test] + async fn respond_closes_when_start_handler_stops() { + let close_frame = drain_close_frame(&StartOverridingEndpoint { + connection_start_outcome: ConnectionStartOutcome::Stop, + }) + .await; + + assert!(!close_frame.is_empty()); + } + + #[actix_web::test] + async fn respond_closes_when_start_handler_errors() { + let close_frame = drain_close_frame(&StartOverridingEndpoint { + connection_start_outcome: ConnectionStartOutcome::Err, + }) + .await; + + assert!(!close_frame.is_empty()); + } + + #[actix_web::test] + async fn respond_breaks_when_aggregated_handler_errors() { + let close_frame = drain_close_frame(&AggregatedErroringEndpoint).await; + + assert!(!close_frame.is_empty()); + } + + #[actix_web::test] + async fn start_overriding_endpoint_deserialized_message_continues() { + let session = open_session().await; + let continuation_decision = StartOverridingEndpoint::handle_deserialized_message( + CancellationToken::new(), + Arc::new(ConnectionStartOutcome::Stop), + ProbeIncomingMessage {}, + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert_eq!( + discriminant(&continuation_decision), + discriminant(&ContinuationDecision::Continue) + ); + } + + #[actix_web::test] + async fn aggregated_erroring_endpoint_deserialized_message_continues() { + let session = open_session().await; + let continuation_decision = AggregatedErroringEndpoint::handle_deserialized_message( + CancellationToken::new(), + Arc::new(()), + ProbeIncomingMessage {}, + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert_eq!( + discriminant(&continuation_decision), + discriminant(&ContinuationDecision::Continue) + ); + } + + #[actix_web::test] + async fn respond_propagates_handshake_error_on_non_websocket_request() { + let (request, mut raw_payload) = TestRequest::get().to_http_parts(); + let payload = Payload::from_request(&request, &mut raw_payload) + .await + .unwrap(); + let respond_result = + AggregatedErroringEndpoint.respond(payload, request, CancellationToken::new()); + let handshake_error = respond_result.err().unwrap(); + + assert_eq!(handshake_error.error_response().status().as_u16(), 400); + } +} diff --git a/paddler/src/create_cors_middleware.rs b/paddler_balancer/src/create_cors_middleware.rs similarity index 100% rename from paddler/src/create_cors_middleware.rs rename to paddler_balancer/src/create_cors_middleware.rs diff --git a/paddler/src/balancer/dispatch_candidate.rs b/paddler_balancer/src/dispatch_candidate.rs similarity index 69% rename from paddler/src/balancer/dispatch_candidate.rs rename to paddler_balancer/src/dispatch_candidate.rs index 7d8785cc..bfd37479 100644 --- a/paddler/src/balancer/dispatch_candidate.rs +++ b/paddler_balancer/src/dispatch_candidate.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::balancer::agent_controller::AgentController; +use crate::agent_controller::AgentController; pub struct DispatchCandidate { pub agent_controller: Arc, diff --git a/paddler/src/balancer/dispatched_agent.rs b/paddler_balancer/src/dispatched_agent.rs similarity index 75% rename from paddler/src/balancer/dispatched_agent.rs rename to paddler_balancer/src/dispatched_agent.rs index 087b811c..42cdd632 100644 --- a/paddler/src/balancer/dispatched_agent.rs +++ b/paddler_balancer/src/dispatched_agent.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::agent_controller_slot_guard::AgentControllerSlotGuard; +use crate::agent_controller::AgentController; +use crate::agent_controller_slot_guard::AgentControllerSlotGuard; pub struct DispatchedAgent { pub agent_controller: Arc, diff --git a/paddler/src/balancer/embedding_sender_collection.rs b/paddler_balancer/src/embedding_sender_collection.rs similarity index 83% rename from paddler/src/balancer/embedding_sender_collection.rs rename to paddler_balancer/src/embedding_sender_collection.rs index fc7ad4cc..0e09a844 100644 --- a/paddler/src/balancer/embedding_sender_collection.rs +++ b/paddler_balancer/src/embedding_sender_collection.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use dashmap::DashMap; -use paddler_types::embedding_result::EmbeddingResult; +use paddler_messaging::embedding_result::EmbeddingResult; use tokio::sync::mpsc; -use crate::balancer::manages_senders::ManagesSenders; +use crate::manages_senders::ManagesSenders; pub struct EmbeddingSenderCollection { senders: DashMap>, diff --git a/paddler/src/balancer/generate_tokens_sender_collection.rs b/paddler_balancer/src/generate_tokens_sender_collection.rs similarity index 83% rename from paddler/src/balancer/generate_tokens_sender_collection.rs rename to paddler_balancer/src/generate_tokens_sender_collection.rs index 291ec2ce..3b20c299 100644 --- a/paddler/src/balancer/generate_tokens_sender_collection.rs +++ b/paddler_balancer/src/generate_tokens_sender_collection.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use dashmap::DashMap; -use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generated_token_result::GeneratedTokenResult; use tokio::sync::mpsc; -use crate::balancer::manages_senders::ManagesSenders; +use crate::manages_senders::ManagesSenders; pub struct GenerateTokensSenderCollection { senders: DashMap>, diff --git a/paddler/src/balancer/handles_agent_streaming_response.rs b/paddler_balancer/src/handles_agent_streaming_response.rs similarity index 66% rename from paddler/src/balancer/handles_agent_streaming_response.rs rename to paddler_balancer/src/handles_agent_streaming_response.rs index 6cd7ac2c..00547270 100644 --- a/paddler/src/balancer/handles_agent_streaming_response.rs +++ b/paddler_balancer/src/handles_agent_streaming_response.rs @@ -1,9 +1,9 @@ use anyhow::Result; use async_trait::async_trait; -use crate::agent::jsonrpc::Request as AgentJsonRpcRequest; -use crate::balancer::manages_senders::ManagesSenders; -use crate::balancer::manages_senders_controller::ManagesSendersController; +use crate::manages_senders::ManagesSenders; +use crate::manages_senders_controller::ManagesSendersController; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; #[async_trait] pub trait HandlesAgentStreamingResponse diff --git a/paddler/src/balancer/http_route/get_health.rs b/paddler_balancer/src/http_route/get_health.rs similarity index 100% rename from paddler/src/balancer/http_route/get_health.rs rename to paddler_balancer/src/http_route/get_health.rs diff --git a/paddler/src/balancer/http_route/mod.rs b/paddler_balancer/src/http_route/mod.rs similarity index 100% rename from paddler/src/balancer/http_route/mod.rs rename to paddler_balancer/src/http_route/mod.rs diff --git a/paddler_balancer/src/http_stream_from_agent.rs b/paddler_balancer/src/http_stream_from_agent.rs new file mode 100644 index 00000000..1f557a30 --- /dev/null +++ b/paddler_balancer/src/http_stream_from_agent.rs @@ -0,0 +1,165 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::http::header; +use bytes::Bytes; +use futures::stream::StreamExt; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::streamable_result::StreamableResult; +use tokio_util::sync::CancellationToken; + +use crate::agent_controller::AgentController; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::manages_senders::ManagesSenders; +use crate::unbounded_stream_from_agent::unbounded_stream_from_agent; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; + +pub fn http_stream_from_agent( + buffered_request_manager: Arc, + inference_service_configuration: InferenceServiceConfiguration, + params: TParams, + transformer: TTransformsOutgoingMessage, + shutdown: CancellationToken, +) -> HttpResponse +where + TParams: Debug + Into + Send + 'static, + AgentController: HandlesAgentStreamingResponse, + <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, + TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, +{ + let stream = unbounded_stream_from_agent( + buffered_request_manager, + inference_service_configuration, + params, + transformer, + shutdown, + ) + .filter_map(|transform_result| async move { + match transform_result { + TransformResult::Chunk(chunk) => { + Some(Ok::<_, Error>(Bytes::from(format!("{chunk}\n")))) + } + TransformResult::Error(error) => { + Some(Ok::<_, Error>(Bytes::from(format!("{error}\n")))) + } + TransformResult::Discard => None, + } + }); + + HttpResponse::Ok() + .insert_header(header::ContentType::json()) + .insert_header((header::CACHE_CONTROL, "no-cache")) + .streaming(stream) +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + + use actix_web::body; + use anyhow::Result; + use async_trait::async_trait; + use tokio_util::sync::CancellationToken; + + use super::http_stream_from_agent; + use crate::agent_controller_pool::AgentControllerPool; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; + use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + use paddler_messaging::inference_client::message::Message as OutgoingMessage; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + #[derive(Clone)] + struct ErrorTransformer; + + #[async_trait] + impl TransformsOutgoingMessage for ErrorTransformer { + type Output = TransformResult; + + async fn transform(&self, _message: OutgoingMessage) -> Result> { + Ok(vec![TransformResult::Error("boom".to_owned())]) + } + } + + fn empty_pool_manager() -> Arc { + Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )) + } + + fn inference_service_configuration() -> InferenceServiceConfiguration { + InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: Duration::from_secs(1), + } + } + + fn raw_prompt_params() -> ContinueFromRawPromptParams { + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "hello".to_owned(), + } + } + + #[actix_web::test] + async fn forwards_transformed_chunk_to_streaming_body() { + let shutdown = CancellationToken::new(); + shutdown.cancel(); + + let response = http_stream_from_agent( + empty_pool_manager(), + inference_service_configuration(), + raw_prompt_params(), + IdentityTransformer::new(), + shutdown, + ); + + let body_bytes = body::to_bytes(response.into_body()).await.unwrap(); + let body_text = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert!( + body_text.contains("balancer is shutting down"), + "chunk arm must serialize the shutdown error envelope into the streaming body: {body_text}" + ); + assert!( + body_text.ends_with('\n'), + "chunk arm must append a trailing newline: {body_text:?}" + ); + } + + #[actix_web::test] + async fn forwards_transformed_error_to_streaming_body() { + let shutdown = CancellationToken::new(); + shutdown.cancel(); + + let response = http_stream_from_agent( + empty_pool_manager(), + inference_service_configuration(), + raw_prompt_params(), + ErrorTransformer, + shutdown, + ); + + let body_bytes = body::to_bytes(response.into_body()).await.unwrap(); + let body_text = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert_eq!( + body_text, "boom\n", + "error arm must forward the transformer error string with a trailing newline" + ); + } +} diff --git a/paddler/src/balancer/inference_service/app_data.rs b/paddler_balancer/src/inference_service/app_data.rs similarity index 68% rename from paddler/src/balancer/inference_service/app_data.rs rename to paddler_balancer/src/inference_service/app_data.rs index 26e2dee3..442e7cfb 100644 --- a/paddler/src/balancer/inference_service/app_data.rs +++ b/paddler_balancer/src/inference_service/app_data.rs @@ -2,10 +2,10 @@ use std::sync::Arc; use tokio_util::sync::CancellationToken; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::inference_service::configuration::Configuration; +use crate::agent_controller_pool::AgentControllerPool; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::inference_service::configuration::Configuration; pub struct AppData { pub agent_controller_pool: Arc, diff --git a/paddler/src/balancer/inference_service/configuration.rs b/paddler_balancer/src/inference_service/configuration.rs similarity index 100% rename from paddler/src/balancer/inference_service/configuration.rs rename to paddler_balancer/src/inference_service/configuration.rs diff --git a/paddler/src/balancer/inference_service/http_route/api/mod.rs b/paddler_balancer/src/inference_service/http_route/api/mod.rs similarity index 100% rename from paddler/src/balancer/inference_service/http_route/api/mod.rs rename to paddler_balancer/src/inference_service/http_route/api/mod.rs diff --git a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs b/paddler_balancer/src/inference_service/http_route/api/post_continue_from_conversation_history.rs similarity index 61% rename from paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs rename to paddler_balancer/src/inference_service/http_route/api/post_continue_from_conversation_history.rs index ffe33394..8d0440e9 100644 --- a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_conversation_history.rs +++ b/paddler_balancer/src/inference_service/http_route/api/post_continue_from_conversation_history.rs @@ -4,13 +4,13 @@ use actix_web::error::ErrorBadRequest; use actix_web::post; use actix_web::web; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; -use paddler_types::validates::Validates as _; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use paddler_messaging::validates::Validates as _; -use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; -use crate::balancer::http_stream_from_agent::http_stream_from_agent; -use crate::balancer::inference_service::app_data::AppData; +use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use crate::http_stream_from_agent::http_stream_from_agent; +use crate::inference_service::app_data::AppData; pub fn register(cfg: &mut web::ServiceConfig) { cfg.service(respond); @@ -33,5 +33,6 @@ async fn respond( } }, IdentityTransformer::new(), + app_data.shutdown.clone(), )) } diff --git a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_raw_prompt.rs b/paddler_balancer/src/inference_service/http_route/api/post_continue_from_raw_prompt.rs similarity index 62% rename from paddler/src/balancer/inference_service/http_route/api/post_continue_from_raw_prompt.rs rename to paddler_balancer/src/inference_service/http_route/api/post_continue_from_raw_prompt.rs index db9f9e8e..5f4cb6ee 100644 --- a/paddler/src/balancer/inference_service/http_route/api/post_continue_from_raw_prompt.rs +++ b/paddler_balancer/src/inference_service/http_route/api/post_continue_from_raw_prompt.rs @@ -2,11 +2,11 @@ use actix_web::Error; use actix_web::Responder; use actix_web::post; use actix_web::web; -use paddler_types::request_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; -use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; -use crate::balancer::http_stream_from_agent::http_stream_from_agent; -use crate::balancer::inference_service::app_data::AppData; +use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use crate::http_stream_from_agent::http_stream_from_agent; +use crate::inference_service::app_data::AppData; pub fn register(cfg: &mut web::ServiceConfig) { cfg.service(respond); @@ -22,5 +22,6 @@ async fn respond( app_data.inference_service_configuration.clone(), params.into_inner(), IdentityTransformer::new(), + app_data.shutdown.clone(), )) } diff --git a/paddler_balancer/src/inference_service/http_route/api/post_generate_embedding_batch.rs b/paddler_balancer/src/inference_service/http_route/api/post_generate_embedding_batch.rs new file mode 100644 index 00000000..64576b88 --- /dev/null +++ b/paddler_balancer/src/inference_service/http_route/api/post_generate_embedding_batch.rs @@ -0,0 +1,415 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::error::ErrorInternalServerError; +use actix_web::error::ErrorNotImplemented; +use actix_web::error::ErrorServiceUnavailable; +use actix_web::http::header; +use actix_web::post; +use actix_web::rt; +use actix_web::web; +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::StreamExt; +use nanoid::nanoid; +use paddler_messaging::embedding_result::EmbeddingResult; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use paddler_messaging::request_params::generate_embedding_batch_params::chunk_evenly_with_cap_error::ChunkEvenlyWithCapError; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use tokio::sync::mpsc; +use tokio::task::JoinSet; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_util::sync::CancellationToken; + +use crate::cancellation_token_stream_guard::CancellationTokenStreamGuard; +use crate::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use crate::chunk_forwarding_session_controller::transform_result::TransformResult; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::controls_session::ControlsSession as _; +use crate::inference_service::app_data::AppData; +use crate::request_from_agent::request_from_agent; + +#[derive(Clone)] +struct EmbeddingChunkBodyTransformer; + +#[async_trait] +impl TransformsOutgoingMessage for EmbeddingChunkBodyTransformer { + type Output = TransformResult; + + async fn transform(&self, message: OutgoingMessage) -> Result> { + if let OutgoingMessage::Response(ResponseEnvelope { + response: OutgoingResponse::Embedding(EmbeddingResult::Done), + .. + }) = &message + { + return Ok(vec![TransformResult::Discard]); + } + + let serialized = serde_json::to_string(&message)?; + + Ok(vec![TransformResult::Chunk(serialized)]) + } +} + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[post("/api/v1/generate_embedding_batch")] +async fn respond( + app_data: web::Data, + params: web::Json, +) -> Result { + let balancer_applicable_state_holder = app_data.balancer_applicable_state_holder.clone(); + let Some(agent_desired_state) = balancer_applicable_state_holder.get_agent_desired_state() + else { + return Err(ErrorServiceUnavailable( + "Balancer applicable state is not yet set", + )); + }; + + if !agent_desired_state.inference_parameters.enable_embeddings { + return Err(ErrorNotImplemented( + "Embedding generation is not enabled in the inference parameters", + )); + } + + let agent_count = app_data.agent_controller_pool.agents.len(); + let embedding_batch_size = agent_desired_state + .inference_parameters + .embedding_batch_size; + + let connection_close = CancellationToken::new(); + let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); + + let mut chunk_tasks: JoinSet<()> = JoinSet::new(); + + let batches = match params + .into_inner() + .chunk_evenly_with_cap(agent_count, embedding_batch_size) + { + Ok(batches) => batches, + Err(ChunkEvenlyWithCapError::ZeroAgentCount) => { + return Err(ErrorServiceUnavailable("No agents are currently connected")); + } + Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) => { + return Err(ErrorInternalServerError( + "embedding_batch_size is zero despite validation", + )); + } + }; + + for batch in batches { + let buffered_request_manager_clone = app_data.buffered_request_manager.clone(); + let chunk_tx_clone = chunk_tx.clone(); + let connection_close_clone = connection_close.clone(); + let inference_service_configuration_clone = + app_data.inference_service_configuration.clone(); + let shutdown_clone = app_data.shutdown.clone(); + + chunk_tasks.spawn(async move { + let request_id: String = nanoid!(); + let session_controller = ChunkForwardingSessionController::new( + chunk_tx_clone, + EmbeddingChunkBodyTransformer, + ); + + request_from_agent( + buffered_request_manager_clone, + connection_close_clone, + inference_service_configuration_clone, + batch, + request_id, + session_controller, + shutdown_clone, + ) + .await; + }); + } + + let final_done_chunk_tx = chunk_tx.clone(); + + rt::spawn(async move { + while chunk_tasks.join_next().await.is_some() {} + + let final_request_id: String = nanoid!(); + let mut final_session = + ChunkForwardingSessionController::new(final_done_chunk_tx, IdentityTransformer::new()); + + final_session + .send_response_safe(OutgoingMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: final_request_id, + response: OutgoingResponse::Embedding(EmbeddingResult::Done), + })) + .await; + }); + + drop(chunk_tx); + + let stream = + CancellationTokenStreamGuard::new(UnboundedReceiverStream::new(chunk_rx), connection_close) + .filter_map(|transform_result| async move { + match transform_result { + TransformResult::Chunk(content) | TransformResult::Error(content) => { + Some(Ok::<_, Error>(Bytes::from(format!("{content}\n")))) + } + TransformResult::Discard => None, + } + }); + + Ok(HttpResponse::Ok() + .insert_header(header::ContentType::json()) + .insert_header((header::CACHE_CONTROL, "no-cache")) + .streaming(stream)) +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::net::SocketAddr; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state::BalancerApplicableState; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::inference_service::app_data::AppData; + use crate::inference_service::configuration::Configuration; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::agent_desired_state::AgentDesiredState; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::embedding_input_document::EmbeddingInputDocument; + use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; + use paddler_messaging::inference_parameters::InferenceParameters; + use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; + + fn agent_with_dropped_receiver(agent_id: &str) -> Arc { + let (agent_message_tx, agent_message_rx) = mpsc::unbounded_channel(); + + drop(agent_message_rx); + + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: agent_id.to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }) + } + + fn inference_parameters_with_embeddings( + enable_embeddings: bool, + embedding_batch_size: usize, + ) -> InferenceParameters { + InferenceParameters { + embedding_batch_size, + enable_embeddings, + ..InferenceParameters::default() + } + } + + fn applicable_state(inference_parameters: InferenceParameters) -> BalancerApplicableState { + BalancerApplicableState { + agent_desired_state: AgentDesiredState { + chat_template_override: None, + inference_parameters, + model: AgentDesiredModel::LocalToAgent("model.gguf".to_owned()), + multimodal_projection: AgentDesiredModel::None, + }, + } + } + + fn app_data( + agent_controller_pool: Arc, + balancer_applicable_state: Option, + ) -> AppData { + let balancer_applicable_state_holder = Arc::new(BalancerApplicableStateHolder::default()); + + balancer_applicable_state_holder.set_balancer_applicable_state(balancer_applicable_state); + + AppData { + buffered_request_manager: Arc::new(BufferedRequestManager::new( + agent_controller_pool.clone(), + Duration::from_secs(1), + 10, + )), + agent_controller_pool, + balancer_applicable_state_holder, + inference_service_configuration: Configuration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: Duration::from_secs(1), + }, + shutdown: CancellationToken::new(), + } + } + + fn single_document_params() -> GenerateEmbeddingBatchParams { + GenerateEmbeddingBatchParams { + input_batch: vec![EmbeddingInputDocument { + content: "the quick brown fox".to_owned(), + id: "doc-1".to_owned(), + }], + normalization_method: EmbeddingNormalizationMethod::None, + } + } + + #[actix_web::test] + async fn responds_service_unavailable_when_balancer_state_is_not_set() { + let app = test::init_service( + App::new() + .app_data(web::Data::new(app_data( + Arc::new(AgentControllerPool::default()), + None, + ))) + .configure(register), + ) + .await; + + let request = test::TestRequest::post() + .uri("/api/v1/generate_embedding_batch") + .set_json(single_document_params()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[actix_web::test] + async fn responds_service_unavailable_when_no_agents_are_connected() { + let app = test::init_service( + App::new() + .app_data(web::Data::new(app_data( + Arc::new(AgentControllerPool::default()), + Some(applicable_state(inference_parameters_with_embeddings( + true, 256, + ))), + ))) + .configure(register), + ) + .await; + + let request = test::TestRequest::post() + .uri("/api/v1/generate_embedding_batch") + .set_json(single_document_params()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[actix_web::test] + async fn responds_internal_server_error_when_embedding_batch_size_is_zero() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + agent_controller_pool + .register_agent_controller( + "agent-zero".to_owned(), + agent_with_dropped_receiver("agent-zero"), + ) + .unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(app_data( + agent_controller_pool, + Some(applicable_state(inference_parameters_with_embeddings( + true, 0, + ))), + ))) + .configure(register), + ) + .await; + + let request = test::TestRequest::post() + .uri("/api/v1/generate_embedding_batch") + .set_json(single_document_params()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[actix_web::test] + async fn streams_error_chunk_when_agent_request_fails() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + agent_controller_pool + .register_agent_controller( + "agent-closed".to_owned(), + agent_with_dropped_receiver("agent-closed"), + ) + .unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(app_data( + agent_controller_pool, + Some(applicable_state(inference_parameters_with_embeddings( + true, 256, + ))), + ))) + .configure(register), + ) + .await; + + let request = test::TestRequest::post() + .uri("/api/v1/generate_embedding_batch") + .set_json(single_document_params()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let body = test::read_body(response).await; + let body_text = String::from_utf8(body.to_vec()).unwrap(); + + assert!( + body_text.contains("Failed to generate response"), + "streamed body must carry the forwarded agent error chunk, got: {body_text}" + ); + } +} diff --git a/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs b/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs new file mode 100644 index 00000000..e921fa6e --- /dev/null +++ b/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/inference_socket_controller_context.rs @@ -0,0 +1,12 @@ +use std::sync::Arc; + +use tokio_util::sync::CancellationToken; + +use crate::buffered_request_manager::BufferedRequestManager; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + +pub struct InferenceSocketControllerContext { + pub buffered_request_manager: Arc, + pub inference_service_configuration: InferenceServiceConfiguration, + pub shutdown: CancellationToken, +} diff --git a/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/mod.rs b/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/mod.rs new file mode 100644 index 00000000..fec33e78 --- /dev/null +++ b/paddler_balancer/src/inference_service/http_route/api/ws_inference_socket/mod.rs @@ -0,0 +1,503 @@ +mod inference_socket_controller_context; + +use std::sync::Arc; + +use actix_web::rt; +use actix_web::Error; +use actix_web::HttpRequest; +use actix_web::HttpResponse; +use actix_web::get; +use actix_web::web::Data; +use actix_web::web::Payload; +use actix_web::web::ServiceConfig; +use anyhow::Result; +use async_trait::async_trait; +use log::error; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_server::message::Message as InferenceServerMessage; +use paddler_messaging::inference_server::request::Request as InferenceServerRequest; +use paddler_messaging::jsonrpc::error::Error as JsonRpcError; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; +use paddler_messaging::jsonrpc::request_envelope::RequestEnvelope; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use paddler_messaging::validates::Validates as _; +use tokio_util::sync::CancellationToken; + +use self::inference_socket_controller_context::InferenceSocketControllerContext; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::continuation_decision::ContinuationDecision; +use crate::controls_websocket_endpoint::ControlsWebSocketEndpoint; +use crate::inference_service::app_data::AppData; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::request_from_agent::request_from_agent; +use crate::websocket_session_controller::WebSocketSessionController; + +type InferenceJsonRpcMessage = InferenceServerMessage; +type InferenceJsonRpcRequest = InferenceServerRequest; + +struct InferenceSocketController { + buffered_request_manager: Arc, + inference_service_configuration: InferenceServiceConfiguration, + shutdown: CancellationToken, +} + +#[async_trait] +impl ControlsWebSocketEndpoint for InferenceSocketController { + type Context = InferenceSocketControllerContext; + type IncomingMessage = InferenceJsonRpcMessage; + type OutgoingMessage = OutgoingMessage; + + fn create_context(&self) -> Self::Context { + InferenceSocketControllerContext { + buffered_request_manager: self.buffered_request_manager.clone(), + inference_service_configuration: self.inference_service_configuration.clone(), + shutdown: self.shutdown.clone(), + } + } + + async fn handle_deserialized_message( + connection_close: CancellationToken, + context: Arc, + deserialized_message: Self::IncomingMessage, + websocket_session_controller: WebSocketSessionController, + ) -> Result { + match deserialized_message { + InferenceJsonRpcMessage::Error(ErrorEnvelope { + request_id, + error: JsonRpcError { code, description }, + }) => { + error!( + "Received error from client: code: {code}, description: {description:?}, request_id: {request_id:?}" + ); + + return Ok(ContinuationDecision::Continue); + } + InferenceJsonRpcMessage::Request(RequestEnvelope { + id: request_id, + request: + InferenceJsonRpcRequest::ContinueFromConversationHistory( + conversation_history_params, + ), + }) => { + let validated_params = conversation_history_params.validate()?; + + rt::spawn(async move { + request_from_agent( + context.buffered_request_manager.clone(), + connection_close, + context.inference_service_configuration.clone(), + validated_params, + request_id, + websocket_session_controller, + context.shutdown.clone(), + ) + .await; + }); + + Ok(ContinuationDecision::Continue) + } + InferenceJsonRpcMessage::Request(RequestEnvelope { + id: request_id, + request: InferenceJsonRpcRequest::ContinueFromRawPrompt(raw_prompt_params), + }) => { + rt::spawn(async move { + request_from_agent( + context.buffered_request_manager.clone(), + connection_close, + context.inference_service_configuration.clone(), + raw_prompt_params, + request_id, + websocket_session_controller, + context.shutdown.clone(), + ) + .await; + }); + + Ok(ContinuationDecision::Continue) + } + } + } +} + +#[get("/api/v1/inference_socket")] +#[expect( + clippy::future_not_send, + reason = "actix-web handler futures are inherently !Send: each worker runs them on its own single-threaded runtime and never moves them across threads" +)] +async fn respond( + app_data: Data, + payload: Payload, + http_request: HttpRequest, +) -> Result { + let inference_socket_controller = InferenceSocketController { + buffered_request_manager: app_data.buffered_request_manager.clone(), + inference_service_configuration: app_data.inference_service_configuration.clone(), + shutdown: app_data.shutdown.clone(), + }; + + inference_socket_controller.respond(payload, http_request, app_data.shutdown.clone()) +} + +pub fn register(service_config: &mut ServiceConfig) { + service_config.service(respond); +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::mem::discriminant; + use std::net::SocketAddr; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use actix_web::App; + use actix_web::FromRequest as _; + use actix_web::http::StatusCode; + use actix_web::http::header; + use actix_web::test; + use actix_web::test::TestRequest; + use actix_web::web::Data; + use actix_web::web::Payload; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::continuation_decision::ContinuationDecision; + use crate::controls_websocket_endpoint::ControlsWebSocketEndpoint as _; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::websocket_session_controller::WebSocketSessionController; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::conversation_history::ConversationHistory; + use paddler_messaging::jsonrpc::error::Error as JsonRpcError; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + use paddler_messaging::jsonrpc::request_envelope::RequestEnvelope; + use paddler_messaging::management_socket::agent::message::Message as AgentJsonRpcMessage; + use paddler_messaging::management_socket::agent::notification::Notification as AgentJsonRpcNotification; + use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; + use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + use super::AppData; + use super::InferenceJsonRpcMessage; + use super::InferenceJsonRpcRequest; + use super::InferenceServiceConfiguration; + use super::InferenceSocketController; + use super::InferenceSocketControllerContext; + use super::OutgoingMessage; + use super::register; + + struct RegisteredAgent { + pool: Arc, + agent_message_rx: mpsc::UnboundedReceiver, + } + + fn pool_with_one_free_slot(agent_id: &str) -> RegisteredAgent { + let pool = Arc::new(AgentControllerPool::default()); + let (agent_message_tx, agent_message_rx) = mpsc::unbounded_channel(); + let agent_controller = Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: agent_id.to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }); + + pool.register_agent_controller(agent_id.to_owned(), agent_controller) + .unwrap(); + + RegisteredAgent { + pool, + agent_message_rx, + } + } + + fn context_with_pool(pool: Arc) -> Arc { + Arc::new(InferenceSocketControllerContext { + buffered_request_manager: Arc::new(BufferedRequestManager::new( + pool, + Duration::from_mins(1), + 10, + )), + inference_service_configuration: inference_service_configuration(), + shutdown: CancellationToken::new(), + }) + } + + #[expect( + clippy::future_not_send, + reason = "actix_ws::Session is !Send and the future is awaited in place" + )] + async fn open_session_controller() -> WebSocketSessionController { + let (request, mut raw_payload) = TestRequest::get() + .insert_header((header::CONNECTION, "upgrade")) + .insert_header((header::UPGRADE, "websocket")) + .insert_header((header::SEC_WEBSOCKET_VERSION, "13")) + .insert_header((header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")) + .to_http_parts(); + let payload = Payload::from_request(&request, &mut raw_payload) + .await + .unwrap(); + let (_response, session, _msg_stream) = actix_ws::handle(&request, payload).unwrap(); + + WebSocketSessionController::new(session) + } + + fn inference_service_configuration() -> InferenceServiceConfiguration { + InferenceServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 0)), + cors_allowed_hosts: vec!["http://localhost".to_owned()], + inference_item_timeout: Duration::from_secs(30), + } + } + + #[actix_web::test] + async fn create_context_copies_controller_state() { + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_mins(1), + 10, + )); + let shutdown = CancellationToken::new(); + let controller = InferenceSocketController { + buffered_request_manager: buffered_request_manager.clone(), + inference_service_configuration: inference_service_configuration(), + shutdown: shutdown.clone(), + }; + + let context = controller.create_context(); + + assert!(Arc::ptr_eq( + &context.buffered_request_manager, + &buffered_request_manager + )); + assert_eq!( + context + .inference_service_configuration + .inference_item_timeout, + Duration::from_secs(30) + ); + assert_eq!( + context.inference_service_configuration.cors_allowed_hosts, + vec!["http://localhost".to_owned()] + ); + + shutdown.cancel(); + + assert!(context.shutdown.is_cancelled()); + } + + #[actix_web::test] + async fn respond_upgrades_websocket_handshake() { + let app_data = Data::new(AppData { + agent_controller_pool: Arc::new(AgentControllerPool::default()), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_mins(1), + 10, + )), + inference_service_configuration: inference_service_configuration(), + shutdown: CancellationToken::new(), + }); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + + let request = test::TestRequest::get() + .uri("/api/v1/inference_socket") + .insert_header((header::UPGRADE, "websocket")) + .insert_header((header::CONNECTION, "Upgrade")) + .insert_header((header::SEC_WEBSOCKET_VERSION, "13")) + .insert_header((header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS); + } + + #[actix_web::test] + async fn handle_error_message_continues_without_dispatch() { + let RegisteredAgent { + pool, + mut agent_message_rx, + } = pool_with_one_free_slot("agent-error-arm"); + let session_controller = open_session_controller().await; + + let continuation_decision = InferenceSocketController::handle_deserialized_message( + CancellationToken::new(), + context_with_pool(pool), + InferenceJsonRpcMessage::Error(ErrorEnvelope { + request_id: "request-error".to_owned(), + error: JsonRpcError { + code: -32_600, + description: "client reported error".to_owned(), + }, + }), + session_controller, + ) + .await + .unwrap(); + + assert_eq!( + discriminant(&continuation_decision), + discriminant(&ContinuationDecision::Continue) + ); + assert!(agent_message_rx.try_recv().is_err()); + } + + #[actix_web::test] + async fn handle_raw_prompt_request_dispatches_to_agent() { + let RegisteredAgent { + pool, + mut agent_message_rx, + } = pool_with_one_free_slot("agent-raw-prompt"); + let session_controller = open_session_controller().await; + let connection_close = CancellationToken::new(); + + let continuation_decision = InferenceSocketController::handle_deserialized_message( + connection_close.clone(), + context_with_pool(pool), + InferenceJsonRpcMessage::Request(RequestEnvelope { + id: "request-raw-prompt".to_owned(), + request: InferenceJsonRpcRequest::ContinueFromRawPrompt( + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "fixture prompt".to_owned(), + }, + ), + }), + session_controller, + ) + .await + .unwrap(); + + assert_eq!( + discriminant(&continuation_decision), + discriminant(&ContinuationDecision::Continue) + ); + + let dispatched_message = agent_message_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&dispatched_message), + discriminant(&AgentJsonRpcMessage::Request(RequestEnvelope { + id: "request-raw-prompt".to_owned(), + request: AgentJsonRpcRequest::ContinueFromRawPrompt(ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "fixture prompt".to_owned(), + }), + })), + ); + + connection_close.cancel(); + + let stop_message = agent_message_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&stop_message), + discriminant(&AgentJsonRpcMessage::Notification( + AgentJsonRpcNotification::StopRespondingTo("request-raw-prompt".to_owned()) + )), + ); + } + + #[actix_web::test] + async fn handle_conversation_history_request_validates_and_dispatches_to_agent() { + let RegisteredAgent { + pool, + mut agent_message_rx, + } = pool_with_one_free_slot("agent-conversation-history"); + let session_controller = open_session_controller().await; + let connection_close = CancellationToken::new(); + + let continuation_decision = InferenceSocketController::handle_deserialized_message( + connection_close.clone(), + context_with_pool(pool), + InferenceJsonRpcMessage::Request(RequestEnvelope { + id: "request-conversation-history".to_owned(), + request: InferenceJsonRpcRequest::ContinueFromConversationHistory( + ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(Vec::new()), + enable_thinking: false, + grammar: None, + max_tokens: 1, + parse_tool_calls: false, + tools: Vec::new(), + }, + ), + }), + session_controller, + ) + .await + .unwrap(); + + assert_eq!( + discriminant(&continuation_decision), + discriminant(&ContinuationDecision::Continue) + ); + + let dispatched_message = agent_message_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&dispatched_message), + discriminant(&AgentJsonRpcMessage::Request(RequestEnvelope { + id: "request-conversation-history".to_owned(), + request: AgentJsonRpcRequest::ContinueFromConversationHistory( + ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(Vec::new()), + enable_thinking: false, + grammar: None, + max_tokens: 1, + parse_tool_calls: false, + tools: Vec::new(), + }, + ), + })), + ); + + connection_close.cancel(); + + let stop_message = agent_message_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&stop_message), + discriminant(&AgentJsonRpcMessage::Notification( + AgentJsonRpcNotification::StopRespondingTo( + "request-conversation-history".to_owned() + ) + )), + ); + } +} diff --git a/paddler/src/balancer/inference_service/http_route/mod.rs b/paddler_balancer/src/inference_service/http_route/mod.rs similarity index 100% rename from paddler/src/balancer/inference_service/http_route/mod.rs rename to paddler_balancer/src/inference_service/http_route/mod.rs diff --git a/paddler_balancer/src/inference_service/mod.rs b/paddler_balancer/src/inference_service/mod.rs new file mode 100644 index 00000000..abaa2df5 --- /dev/null +++ b/paddler_balancer/src/inference_service/mod.rs @@ -0,0 +1,178 @@ +pub mod app_data; +pub mod configuration; +pub mod http_route; + +use std::sync::Arc; + +use actix_web::App; +use actix_web::HttpServer; +use actix_web::web::Data; +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; + +use crate::agent_controller_pool::AgentControllerPool; +use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::create_cors_middleware::create_cors_middleware; +use crate::http_route as common_http_route; +use crate::inference_service::app_data::AppData; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +#[cfg(feature = "web_admin_panel")] +use crate::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; + +pub struct InferenceService { + pub agent_controller_pool: Arc, + pub balancer_applicable_state_holder: Arc, + pub buffered_request_manager: Arc, + pub configuration: InferenceServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, + #[cfg(feature = "web_admin_panel")] + pub web_admin_panel_service_configuration: Option, +} + +#[async_trait] +impl Service for InferenceService { + fn name(&self) -> &'static str { + "balancer::inference_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let web_admin_panel_cors_allowed_hosts: Vec = { + #[cfg(feature = "web_admin_panel")] + { + self.web_admin_panel_service_configuration + .as_ref() + .map(|web_admin_panel_config| format!("http://{}", web_admin_panel_config.addr)) + .into_iter() + .collect() + } + #[cfg(not(feature = "web_admin_panel"))] + { + Vec::new() + } + }; + + let cors_allowed_hosts_arc = Arc::new( + self.configuration + .cors_allowed_hosts + .iter() + .cloned() + .chain(web_admin_panel_cors_allowed_hosts) + .collect::>(), + ); + + let app_data = Data::new(AppData { + agent_controller_pool: self.agent_controller_pool.clone(), + balancer_applicable_state_holder: self.balancer_applicable_state_holder.clone(), + buffered_request_manager: self.buffered_request_manager.clone(), + inference_service_configuration: self.configuration.clone(), + shutdown: shutdown.clone(), + }); + + let bind_addr = self.configuration.addr; + + let server = HttpServer::new(move || { + App::new() + .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) + .app_data(app_data.clone()) + .configure(common_http_route::get_health::register) + .configure(http_route::api::post_continue_from_conversation_history::register) + .configure(http_route::api::post_continue_from_raw_prompt::register) + .configure(http_route::api::post_generate_embedding_batch::register) + .configure(http_route::api::ws_inference_socket::register) + }) + .shutdown_signal(async move { + shutdown.cancelled().await; + }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) + .disable_signals() + .bind(bind_addr) + .with_context(|| format!("Unable to bind balancer inference service to {bind_addr}"))?; + + server.run().await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::net::TcpListener; + use std::sync::Arc; + use std::time::Duration; + + use tokio_util::sync::CancellationToken; + use trzcina::Service as _; + use trzcina::ServiceShutdownOptions; + + use super::InferenceService; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; + #[cfg(feature = "web_admin_panel")] + use crate::resolved_socket_addr::ResolvedSocketAddr; + #[cfg(feature = "web_admin_panel")] + use crate::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; + #[cfg(feature = "web_admin_panel")] + use crate::web_admin_panel_service::template_data::TemplateData; + + fn build_service(addr: SocketAddr) -> InferenceService { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + InferenceService { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + agent_controller_pool, + Duration::from_secs(30), + 32, + )), + configuration: InferenceServiceConfiguration { + addr, + cors_allowed_hosts: vec!["http://127.0.0.1:8080".to_owned()], + inference_item_timeout: Duration::from_secs(30), + }, + shutdown_options: ServiceShutdownOptions::default(), + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration: Some(WebAdminPanelServiceConfiguration { + addr: SocketAddr::from(([127, 0, 0, 1], 8081)), + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(30), + compat_openai_addr: None, + inference_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:0".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 0)), + }, + management_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:0".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 0)), + }, + max_buffered_requests: 32, + statsd_addr: None, + statsd_prefix: "paddler".to_owned(), + statsd_reporting_interval: Duration::from_secs(10), + }, + }), + } + } + + #[actix_web::test] + async fn run_returns_error_when_address_is_already_in_use() { + let occupied_listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap(); + let occupied_addr = occupied_listener.local_addr().unwrap(); + + let service = Box::new(build_service(occupied_addr)); + let result = service.run(CancellationToken::new()).await; + + let error_message = result.unwrap_err().to_string(); + let expected_addr_fragment = occupied_addr.to_string(); + + assert!(error_message.contains(&expected_addr_fragment)); + } +} diff --git a/paddler/src/balancer/mod.rs b/paddler_balancer/src/lib.rs similarity index 68% rename from paddler/src/balancer/mod.rs rename to paddler_balancer/src/lib.rs index f0011cbf..3a4f0a03 100644 --- a/paddler/src/balancer/mod.rs +++ b/paddler_balancer/src/lib.rs @@ -3,14 +3,23 @@ pub mod agent_controller_pool; mod agent_controller_pool_total_slots; pub mod agent_controller_slot_guard; pub mod agent_controller_update_result; +pub mod balancer_applicable_state; +pub mod balancer_applicable_state_holder; +pub mod balancer_desired_state_converter; mod buffered_request_agent_wait_result; mod buffered_request_count_guard; mod buffered_request_counter; pub mod buffered_request_manager; +pub mod cancellation_token_stream_guard; pub mod chat_template_override_sender_collection; pub mod chunk_forwarding_session_controller; pub mod compatibility; +pub mod continuation_decision; +pub mod continuation_stop_parameters; mod controls_manages_senders_endpoint; +pub mod controls_session; +pub mod controls_websocket_endpoint; +pub mod create_cors_middleware; pub mod dispatch_candidate; pub mod dispatched_agent; pub mod embedding_sender_collection; @@ -25,11 +34,18 @@ pub mod manages_senders_controller; pub mod model_metadata_sender_collection; pub mod reconciliation_service; pub mod request_from_agent; +pub mod resolved_socket_addr; #[cfg(feature = "web_admin_panel")] mod response; +pub mod sends_rpc_message; +pub mod sets_desired_state; +pub mod snapshots_stream; pub mod state_database; pub mod state_database_type; +#[cfg(feature = "web_admin_panel")] +pub mod static_files; pub mod statsd_service; mod unbounded_stream_from_agent; #[cfg(feature = "web_admin_panel")] pub mod web_admin_panel_service; +pub mod websocket_session_controller; diff --git a/paddler/src/balancer/management_service/app_data.rs b/paddler_balancer/src/management_service/app_data.rs similarity index 59% rename from paddler/src/balancer/management_service/app_data.rs rename to paddler_balancer/src/management_service/app_data.rs index a5e8b276..cce3869e 100644 --- a/paddler/src/balancer/management_service/app_data.rs +++ b/paddler_balancer/src/management_service/app_data.rs @@ -2,14 +2,14 @@ use std::sync::Arc; use tokio_util::sync::CancellationToken; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::buffered_request_manager::BufferedRequestManager; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use crate::balancer::state_database::StateDatabase; +use crate::agent_controller_pool::AgentControllerPool; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use crate::embedding_sender_collection::EmbeddingSenderCollection; +use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; +use crate::state_database::StateDatabase; pub struct AppData { pub agent_controller_pool: Arc, diff --git a/paddler/src/balancer/management_service/configuration.rs b/paddler_balancer/src/management_service/configuration.rs similarity index 100% rename from paddler/src/balancer/management_service/configuration.rs rename to paddler_balancer/src/management_service/configuration.rs diff --git a/paddler_balancer/src/management_service/http_route/api/get_agents.rs b/paddler_balancer/src/management_service/http_route/api/get_agents.rs new file mode 100644 index 00000000..53d83e71 --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/get_agents.rs @@ -0,0 +1,164 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::error::ErrorInternalServerError; +use actix_web::get; +use actix_web::web; + +use crate::management_service::app_data::AppData; +use paddler_messaging::produces_snapshot::ProducesSnapshot as _; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/api/v1/agents")] +async fn respond(app_data: web::Data) -> Result { + Ok(HttpResponse::Ok().json( + app_data + .agent_controller_pool + .make_snapshot() + .map_err(ErrorInternalServerError)?, + )) +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use tokio::sync::broadcast; + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::app_data::AppData; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::state_database::memory::Memory; + use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + + fn agent_controller_with_status_code(status_code: i32) -> Arc { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-test".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(0), + state_application_status_code: AtomicValue::::new(status_code), + uses_chat_template_override: AtomicValue::::new(false), + }) + } + + fn app_data_with_pool(agent_controller_pool: Arc) -> Data { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + + Data::new(AppData { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + agent_controller_pool, + Duration::from_secs(1), + 10, + )), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown: CancellationToken::new(), + state_database: Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )), + statsd_prefix: "paddler".to_owned(), + }) + } + + #[actix_web::test] + async fn responds_with_registered_agents_snapshot() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + agent_controller_pool + .register_agent_controller( + "agent-test".to_owned(), + agent_controller_with_status_code(AgentStateApplicationStatus::Fresh as i32), + ) + .unwrap(); + + let app = test::init_service( + App::new() + .app_data(app_data_with_pool(agent_controller_pool)) + .configure(register), + ) + .await; + let request = test::TestRequest::get().uri("/api/v1/agents").to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let snapshot: AgentControllerPoolSnapshot = test::read_body_json(response).await; + + assert_eq!(snapshot.agents.len(), 1); + assert_eq!(snapshot.agents[0].id, "agent-test"); + } + + #[actix_web::test] + async fn responds_with_internal_server_error_when_snapshot_fails() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + + agent_controller_pool + .register_agent_controller( + "agent-test".to_owned(), + agent_controller_with_status_code(99), + ) + .unwrap(); + + let app = test::init_service( + App::new() + .app_data(app_data_with_pool(agent_controller_pool)) + .configure(register), + ) + .await; + let request = test::TestRequest::get().uri("/api/v1/agents").to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/paddler_balancer/src/management_service/http_route/api/get_agents_stream.rs b/paddler_balancer/src/management_service/http_route/api/get_agents_stream.rs new file mode 100644 index 00000000..7531d132 --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/get_agents_stream.rs @@ -0,0 +1,82 @@ +use std::convert::Infallible; +use std::time::Duration; + +use actix_web::Error; +use actix_web::Responder; +use actix_web::get; +use actix_web::web; +use actix_web_lab::sse; +use futures::StreamExt as _; +use log::error; +use serde::Serialize; + +use crate::management_service::app_data::AppData; +use crate::snapshots_stream::snapshots_stream; + +fn serialize_snapshot_event( + snapshot: &TSnapshot, +) -> Option> +where + TSnapshot: Serialize, +{ + match serde_json::to_string(snapshot) { + Ok(json) => Some(Ok(sse::Event::Data(sse::Data::new(json)))), + Err(err) => { + error!("Failed to serialize agent controller pool snapshot: {err}"); + None + } + } +} + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/api/v1/agents/stream")] +async fn respond(app_data: web::Data) -> Result { + let event_stream = snapshots_stream( + app_data.agent_controller_pool.clone(), + app_data.shutdown.clone(), + ) + .filter_map(|snapshot| async move { serialize_snapshot_event(&snapshot) }); + + Ok(sse::Sse::from_stream(event_stream).with_keep_alive(Duration::from_secs(10))) +} + +#[cfg(test)] +mod tests { + use serde::Serializer; + use serde::ser::Error as _; + + use super::*; + + struct FailingSnapshot; + + impl Serialize for FailingSnapshot { + fn serialize( + &self, + _serializer: TSerializer, + ) -> Result + where + TSerializer: Serializer, + { + Err(TSerializer::Error::custom("snapshot cannot be serialized")) + } + } + + #[test] + fn serialize_snapshot_event_returns_event_for_serializable_snapshot() { + let event = serialize_snapshot_event(&"snapshot"); + + assert!(event.is_some()); + } + + #[test] + fn serialize_snapshot_event_skips_unserializable_snapshot() { + log::set_max_level(log::LevelFilter::Trace); + + let event = serialize_snapshot_event(&FailingSnapshot); + + assert!(event.is_none()); + } +} diff --git a/paddler_balancer/src/management_service/http_route/api/get_balancer_applicable_state.rs b/paddler_balancer/src/management_service/http_route/api/get_balancer_applicable_state.rs new file mode 100644 index 00000000..6140eeb7 --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/get_balancer_applicable_state.rs @@ -0,0 +1,126 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::get; +use actix_web::web; + +use crate::management_service::app_data::AppData; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/api/v1/balancer_applicable_state")] +async fn respond(app_data: web::Data) -> Result { + let applicable_state = app_data + .balancer_applicable_state_holder + .get_agent_desired_state(); + + Ok(HttpResponse::Ok().json(applicable_state)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use tokio::sync::broadcast; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state::BalancerApplicableState; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::app_data::AppData; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::state_database::memory::Memory; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::agent_desired_state::AgentDesiredState; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + use paddler_messaging::inference_parameters::InferenceParameters; + + fn build_app_data( + balancer_applicable_state_holder: Arc, + ) -> Data { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + + Data::new(AppData { + agent_controller_pool: Arc::new(AgentControllerPool::default()), + balancer_applicable_state_holder, + buffered_request_manager: Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown: CancellationToken::new(), + state_database: Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )), + statsd_prefix: "paddler".to_owned(), + }) + } + + #[actix_web::test] + async fn responds_with_stored_agent_desired_state() { + let balancer_applicable_state_holder = Arc::new(BalancerApplicableStateHolder::default()); + + balancer_applicable_state_holder.set_balancer_applicable_state(Some( + BalancerApplicableState { + agent_desired_state: AgentDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::LocalToAgent("model.gguf".to_owned()), + multimodal_projection: AgentDesiredModel::None, + }, + }, + )); + + let app_data = build_app_data(balancer_applicable_state_holder); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get() + .uri("/api/v1/balancer_applicable_state") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let agent_desired_state: AgentDesiredState = test::read_body_json(response).await; + + assert_eq!( + agent_desired_state.model, + AgentDesiredModel::LocalToAgent("model.gguf".to_owned()) + ); + } + + #[actix_web::test] + async fn responds_with_json_null_when_no_state_is_set() { + let app_data = build_app_data(Arc::new(BalancerApplicableStateHolder::default())); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get() + .uri("/api/v1/balancer_applicable_state") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let body = test::read_body(response).await; + + assert_eq!(body.as_ref(), b"null"); + } +} diff --git a/paddler_balancer/src/management_service/http_route/api/get_balancer_desired_state.rs b/paddler_balancer/src/management_service/http_route/api/get_balancer_desired_state.rs new file mode 100644 index 00000000..13c3470e --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/get_balancer_desired_state.rs @@ -0,0 +1,122 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::error::ErrorInternalServerError; +use actix_web::get; +use actix_web::web; + +use crate::management_service::app_data::AppData; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/api/v1/balancer_desired_state")] +async fn respond(app_data: web::Data) -> Result { + let desired_state = app_data + .state_database + .read_balancer_desired_state() + .await + .map_err(ErrorInternalServerError)?; + + Ok(HttpResponse::Ok().json(desired_state)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use tempfile::TempDir; + use tokio::sync::broadcast; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::app_data::AppData; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::state_database::StateDatabase; + use crate::state_database::file::File; + use crate::state_database::memory::Memory; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + use paddler_messaging::inference_parameters::InferenceParameters; + + fn build_app_data(state_database: Arc) -> Data { + Data::new(AppData { + agent_controller_pool: Arc::new(AgentControllerPool::default()), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown: CancellationToken::new(), + state_database, + statsd_prefix: "paddler".to_owned(), + }) + } + + #[actix_web::test] + async fn responds_with_stored_desired_state() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + let stored_state = BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::LocalToAgent("model.gguf".to_owned()), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }; + let state_database = Arc::new(Memory::new( + balancer_desired_state_notify_tx, + stored_state.clone(), + )); + let app_data = build_app_data(state_database); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get() + .uri("/api/v1/balancer_desired_state") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let desired_state: BalancerDesiredState = test::read_body_json(response).await; + + assert_eq!(desired_state, stored_state); + } + + #[actix_web::test] + async fn responds_with_internal_server_error_when_reading_state_fails() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + let temp_dir = TempDir::new().unwrap(); + let state_database = Arc::new(File::new( + balancer_desired_state_notify_tx, + temp_dir.path().to_path_buf(), + )); + let app_data = build_app_data(state_database); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get() + .uri("/api/v1/balancer_desired_state") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/paddler_balancer/src/management_service/http_route/api/get_buffered_requests.rs b/paddler_balancer/src/management_service/http_route/api/get_buffered_requests.rs new file mode 100644 index 00000000..1fd3c55d --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/get_buffered_requests.rs @@ -0,0 +1,97 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::error::ErrorInternalServerError; +use actix_web::get; +use actix_web::web; + +use crate::management_service::app_data::AppData; +use paddler_messaging::produces_snapshot::ProducesSnapshot as _; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/api/v1/buffered_requests")] +async fn respond(app_data: web::Data) -> Result { + Ok(HttpResponse::Ok().json( + app_data + .buffered_request_manager + .make_snapshot() + .map_err(ErrorInternalServerError)?, + )) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use tokio::sync::broadcast; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::app_data::AppData; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::state_database::memory::Memory; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; + + #[actix_web::test] + async fn responds_with_current_buffered_request_count() { + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )); + + buffered_request_manager + .buffered_request_counter + .increment(); + buffered_request_manager + .buffered_request_counter + .increment(); + + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + + let app_data = Data::new(AppData { + agent_controller_pool: Arc::new(AgentControllerPool::default()), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown: CancellationToken::new(), + state_database: Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )), + statsd_prefix: "paddler".to_owned(), + }); + + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get() + .uri("/api/v1/buffered_requests") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let snapshot: BufferedRequestManagerSnapshot = test::read_body_json(response).await; + + assert_eq!(snapshot.buffered_requests_current, 2); + } +} diff --git a/paddler/src/balancer/management_service/http_route/api/get_buffered_requests_stream.rs b/paddler_balancer/src/management_service/http_route/api/get_buffered_requests_stream.rs similarity index 94% rename from paddler/src/balancer/management_service/http_route/api/get_buffered_requests_stream.rs rename to paddler_balancer/src/management_service/http_route/api/get_buffered_requests_stream.rs index 8f9e4f6a..4619d56a 100644 --- a/paddler/src/balancer/management_service/http_route/api/get_buffered_requests_stream.rs +++ b/paddler_balancer/src/management_service/http_route/api/get_buffered_requests_stream.rs @@ -9,7 +9,7 @@ use actix_web_lab::sse; use futures::StreamExt as _; use log::error; -use crate::balancer::management_service::app_data::AppData; +use crate::management_service::app_data::AppData; use crate::snapshots_stream::snapshots_stream; pub fn register(cfg: &mut web::ServiceConfig) { diff --git a/paddler/src/balancer/management_service/http_route/api/get_chat_template_override.rs b/paddler_balancer/src/management_service/http_route/api/get_chat_template_override.rs similarity index 76% rename from paddler/src/balancer/management_service/http_route/api/get_chat_template_override.rs rename to paddler_balancer/src/management_service/http_route/api/get_chat_template_override.rs index 96c6bba9..1654445c 100644 --- a/paddler/src/balancer/management_service/http_route/api/get_chat_template_override.rs +++ b/paddler_balancer/src/management_service/http_route/api/get_chat_template_override.rs @@ -7,12 +7,12 @@ use actix_web::web; use async_trait::async_trait; use serde::Deserialize; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::controls_manages_senders_endpoint::ControlsManagesSendersEndpoint; -use crate::balancer::management_service::app_data::AppData; -use crate::balancer::manages_senders_controller::ManagesSendersController; +use crate::agent_controller::AgentController; +use crate::agent_controller_pool::AgentControllerPool; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use crate::controls_manages_senders_endpoint::ControlsManagesSendersEndpoint; +use crate::management_service::app_data::AppData; +use crate::manages_senders_controller::ManagesSendersController; pub fn register(cfg: &mut web::ServiceConfig) { cfg.service(respond); diff --git a/paddler/src/balancer/management_service/http_route/api/get_model_metadata.rs b/paddler_balancer/src/management_service/http_route/api/get_model_metadata.rs similarity index 76% rename from paddler/src/balancer/management_service/http_route/api/get_model_metadata.rs rename to paddler_balancer/src/management_service/http_route/api/get_model_metadata.rs index ec74c5e9..a91e7bc5 100644 --- a/paddler/src/balancer/management_service/http_route/api/get_model_metadata.rs +++ b/paddler_balancer/src/management_service/http_route/api/get_model_metadata.rs @@ -7,12 +7,12 @@ use actix_web::web; use async_trait::async_trait; use serde::Deserialize; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::controls_manages_senders_endpoint::ControlsManagesSendersEndpoint; -use crate::balancer::management_service::app_data::AppData; -use crate::balancer::manages_senders_controller::ManagesSendersController; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; +use crate::agent_controller::AgentController; +use crate::agent_controller_pool::AgentControllerPool; +use crate::controls_manages_senders_endpoint::ControlsManagesSendersEndpoint; +use crate::management_service::app_data::AppData; +use crate::manages_senders_controller::ManagesSendersController; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; pub fn register(cfg: &mut web::ServiceConfig) { cfg.service(respond); diff --git a/paddler/src/balancer/management_service/http_route/api/mod.rs b/paddler_balancer/src/management_service/http_route/api/mod.rs similarity index 100% rename from paddler/src/balancer/management_service/http_route/api/mod.rs rename to paddler_balancer/src/management_service/http_route/api/mod.rs diff --git a/paddler_balancer/src/management_service/http_route/api/put_balancer_desired_state.rs b/paddler_balancer/src/management_service/http_route/api/put_balancer_desired_state.rs new file mode 100644 index 00000000..0fcc429c --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/put_balancer_desired_state.rs @@ -0,0 +1,155 @@ +use actix_web::Error; +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::error::ErrorBadRequest; +use actix_web::error::ErrorInternalServerError; +use actix_web::put; +use actix_web::web; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::validates::Validates; + +use crate::management_service::app_data::AppData; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[put("/api/v1/balancer_desired_state")] +async fn respond( + app_data: web::Data, + balancer_desired_state: web::Json, +) -> Result { + let balancer_desired_state_inner = balancer_desired_state.into_inner(); + + balancer_desired_state_inner + .inference_parameters + .clone() + .validate() + .map_err(ErrorBadRequest)?; + + app_data + .state_database + .store_balancer_desired_state(&balancer_desired_state_inner) + .await + .map_err(ErrorInternalServerError)?; + + Ok(HttpResponse::NoContent().finish()) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use tokio::sync::broadcast; + use tokio_util::sync::CancellationToken; + + use super::register; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::app_data::AppData; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::state_database::memory::Memory; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + use paddler_messaging::inference_parameters::InferenceParameters; + + fn build_app_data(state_database: Arc) -> Data { + Data::new(AppData { + agent_controller_pool: Arc::new(AgentControllerPool::default()), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + Arc::new(AgentControllerPool::default()), + Duration::from_secs(1), + 10, + )), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown: CancellationToken::new(), + state_database, + statsd_prefix: "paddler".to_owned(), + }) + } + + #[actix_web::test] + async fn stores_desired_state_and_responds_with_no_content() { + let (balancer_desired_state_notify_tx, balancer_desired_state_notify_rx) = + broadcast::channel(1); + let state_database = Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )); + let app_data = build_app_data(state_database.clone()); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::put() + .uri("/api/v1/balancer_desired_state") + .set_json(BalancerDesiredState::default()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + + drop(balancer_desired_state_notify_rx); + } + + #[actix_web::test] + async fn responds_with_bad_request_when_inference_parameters_are_invalid() { + let (balancer_desired_state_notify_tx, balancer_desired_state_notify_rx) = + broadcast::channel(1); + let state_database = Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )); + let app_data = build_app_data(state_database); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let invalid_desired_state = BalancerDesiredState { + inference_parameters: InferenceParameters { + image_resize_to_fit: 0, + ..InferenceParameters::default() + }, + ..BalancerDesiredState::default() + }; + let request = test::TestRequest::put() + .uri("/api/v1/balancer_desired_state") + .set_json(invalid_desired_state) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + drop(balancer_desired_state_notify_rx); + } + + #[actix_web::test] + async fn responds_with_internal_server_error_when_store_fails() { + let (balancer_desired_state_notify_tx, balancer_desired_state_notify_rx) = + broadcast::channel(1); + + drop(balancer_desired_state_notify_rx); + + let state_database = Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )); + let app_data = build_app_data(state_database); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::put() + .uri("/api/v1/balancer_desired_state") + .set_json(BalancerDesiredState::default()) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/paddler_balancer/src/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs b/paddler_balancer/src/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs new file mode 100644 index 00000000..9631b28b --- /dev/null +++ b/paddler_balancer/src/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs @@ -0,0 +1,126 @@ +use std::sync::Arc; + +use log::error; +use log::info; + +use crate::agent_controller_pool::AgentControllerPool; +use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use crate::embedding_sender_collection::EmbeddingSenderCollection; +use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + +pub struct AgentSocketControllerContext { + pub agent_controller_pool: Arc, + pub agent_id: String, + pub balancer_applicable_state_holder: Arc, + pub chat_template_override_sender_collection: Arc, + pub embedding_sender_collection: Arc, + pub generate_tokens_sender_collection: Arc, + pub model_metadata_sender_collection: Arc, +} + +impl Drop for AgentSocketControllerContext { + fn drop(&mut self) { + if let Err(err) = self + .agent_controller_pool + .remove_agent_controller(&self.agent_id) + { + error!("Failed to remove agent: {err}"); + } + + info!("Removed agent: {}", self.agent_id); + } +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::AgentSocketControllerContext; + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + + #[test] + fn drop_removes_registered_agent_from_pool() { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + + agent_controller_pool + .register_agent_controller( + "agent-under-drop".to_owned(), + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new( + GenerateTokensSenderCollection::default(), + ), + id: "agent-under-drop".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new( + ModelMetadataSenderCollection::default(), + ), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }), + ) + .unwrap(); + + let context = AgentSocketControllerContext { + agent_controller_pool: agent_controller_pool.clone(), + agent_id: "agent-under-drop".to_owned(), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + }; + + assert!( + agent_controller_pool + .get_agent_controller("agent-under-drop") + .is_some() + ); + + drop(context); + + assert!( + agent_controller_pool + .get_agent_controller("agent-under-drop") + .is_none() + ); + } +} diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs b/paddler_balancer/src/management_service/http_route/api/ws_agent_socket/mod.rs similarity index 67% rename from paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs rename to paddler_balancer/src/management_service/http_route/api/ws_agent_socket/mod.rs index cd6c09d5..cda3eb7a 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs +++ b/paddler_balancer/src/management_service/http_route/api/ws_agent_socket/mod.rs @@ -1,8 +1,7 @@ mod agent_socket_controller_context; -pub mod jsonrpc; +use parking_lot::RwLock; use std::sync::Arc; -use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; @@ -22,38 +21,38 @@ use anyhow::Result; use async_trait::async_trait; use log::error; use log::info; -use paddler_types::jsonrpc::ResponseEnvelope; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use serde::Deserialize; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use self::agent_socket_controller_context::AgentSocketControllerContext; -use self::jsonrpc::Message as ManagementJsonRpcMessage; -use self::jsonrpc::Notification as ManagementJsonRpcNotification; -use self::jsonrpc::notification_params::RegisterAgentParams; -use self::jsonrpc::notification_params::UpdateAgentStatusParams; -use crate::agent::jsonrpc::Message as AgentJsonRpcMessage; -use crate::agent::jsonrpc::Notification as AgentJsonRpcNotification; -use crate::agent::jsonrpc::Response as AgentJsonRpcResponse; -use crate::agent::jsonrpc::notification_params::VersionParams; -use crate::atomic_value::AtomicValue; -use crate::balancer::agent_controller::AgentController; -use crate::balancer::agent_controller_pool::AgentControllerPool; -use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult; -use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use crate::balancer::management_service::app_data::AppData; -use crate::balancer::manages_senders::ManagesSenders as _; -use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; +use crate::agent_controller::AgentController; +use crate::agent_controller_pool::AgentControllerPool; +use crate::agent_controller_update_result::AgentControllerUpdateResult; use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; use crate::continuation_decision::ContinuationDecision; use crate::continuation_stop_parameters::ContinuationStopParameters; use crate::controls_session::ControlsSession as _; use crate::controls_websocket_endpoint::ControlsWebSocketEndpoint; +use crate::embedding_sender_collection::EmbeddingSenderCollection; +use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use crate::management_service::app_data::AppData; +use crate::manages_senders::ManagesSenders as _; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; use crate::sets_desired_state::SetsDesiredState as _; use crate::websocket_session_controller::WebSocketSessionController; +use paddler_messaging::atomic_value::AtomicValue; +use paddler_messaging::management_socket::agent::message::Message as AgentJsonRpcMessage; +use paddler_messaging::management_socket::agent::notification::Notification as AgentJsonRpcNotification; +use paddler_messaging::management_socket::agent::response::Response as AgentJsonRpcResponse; +use paddler_messaging::management_socket::agent::notification_params::version_params::VersionParams; +use paddler_messaging::management_socket::balancer::message::Message as ManagementJsonRpcMessage; +use paddler_messaging::management_socket::balancer::notification::Notification as ManagementJsonRpcNotification; +use paddler_messaging::management_socket::balancer::notification_params::register_agent_params::RegisterAgentParams; +use paddler_messaging::management_socket::balancer::notification_params::update_agent_status_params::UpdateAgentStatusParams; pub fn register(cfg: &mut ServiceConfig) { cfg.service(respond); @@ -337,3 +336,111 @@ async fn respond( agent_socket_controller.respond(payload, req, app_data.shutdown.clone()) } + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + use std::sync::Arc; + + use actix_web::FromRequest as _; + use actix_web::body::to_bytes; + use actix_web::http::header; + use actix_web::test::TestRequest; + use actix_web::web::Payload; + use tokio_util::sync::CancellationToken; + + use super::AgentSocketController; + use super::AgentSocketControllerContext; + use super::ManagementJsonRpcMessage; + use super::ManagementJsonRpcNotification; + use super::RegisterAgentParams; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::continuation_decision::ContinuationDecision; + use crate::controls_websocket_endpoint::ControlsWebSocketEndpoint as _; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::websocket_session_controller::WebSocketSessionController; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; + + #[actix_web::test] + async fn forwarder_task_stops_when_agent_message_channel_closes() { + log::set_max_level(log::LevelFilter::Trace); + + let agent_id = "agent-forwarder-close".to_owned(); + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let context = Arc::new(AgentSocketControllerContext { + agent_controller_pool: agent_controller_pool.clone(), + agent_id: agent_id.clone(), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + }); + + let (request, mut raw_payload) = TestRequest::get() + .insert_header((header::CONNECTION, "upgrade")) + .insert_header((header::UPGRADE, "websocket")) + .insert_header((header::SEC_WEBSOCKET_VERSION, "13")) + .insert_header((header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")) + .to_http_parts(); + let payload = Payload::from_request(&request, &mut raw_payload) + .await + .unwrap(); + let (response, session, _msg_stream) = actix_ws::handle(&request, payload).unwrap(); + + let continuation_decision = AgentSocketController::handle_deserialized_message( + CancellationToken::new(), + context, + ManagementJsonRpcMessage::Notification(ManagementJsonRpcNotification::RegisterAgent( + RegisterAgentParams { + name: None, + slot_aggregated_status_snapshot: SlotAggregatedStatusSnapshot { + desired_slots_total: 0, + download_current: 0, + download_filename: None, + download_indeterminate: false, + download_total: 0, + issues: BTreeSet::new(), + model_path: None, + slots_processing: 0, + slots_total: 1, + state_application_status: AgentStateApplicationStatus::Fresh, + uses_chat_template_override: false, + version: 0, + }, + }, + )), + WebSocketSessionController::new(session), + ) + .await + .unwrap(); + + assert_eq!( + std::mem::discriminant(&continuation_decision), + std::mem::discriminant(&ContinuationDecision::Continue), + ); + + assert!( + agent_controller_pool + .get_agent_controller(&agent_id) + .is_some() + ); + + assert!( + agent_controller_pool + .remove_agent_controller(&agent_id) + .unwrap() + ); + + let close_frame = to_bytes(response.into_body()).await.unwrap(); + + assert!(close_frame.is_empty()); + } +} diff --git a/paddler/src/balancer/management_service/http_route/get_metrics.rs b/paddler_balancer/src/management_service/http_route/get_metrics.rs similarity index 90% rename from paddler/src/balancer/management_service/http_route/get_metrics.rs rename to paddler_balancer/src/management_service/http_route/get_metrics.rs index 2c50c684..276c0199 100644 --- a/paddler/src/balancer/management_service/http_route/get_metrics.rs +++ b/paddler_balancer/src/management_service/http_route/get_metrics.rs @@ -7,8 +7,8 @@ use actix_web::web::Data; use actix_web::web::ServiceConfig; use indoc::formatdoc; -use crate::balancer::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; -use crate::balancer::management_service::app_data::AppData; +use crate::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; +use crate::management_service::app_data::AppData; pub fn register(cfg: &mut ServiceConfig) { cfg.service(respond); diff --git a/paddler/src/balancer/management_service/http_route/mod.rs b/paddler_balancer/src/management_service/http_route/mod.rs similarity index 100% rename from paddler/src/balancer/management_service/http_route/mod.rs rename to paddler_balancer/src/management_service/http_route/mod.rs diff --git a/paddler_balancer/src/management_service/mod.rs b/paddler_balancer/src/management_service/mod.rs new file mode 100644 index 00000000..7316f005 --- /dev/null +++ b/paddler_balancer/src/management_service/mod.rs @@ -0,0 +1,253 @@ +pub mod app_data; +pub mod configuration; +pub mod http_route; + +use std::sync::Arc; + +use actix_web::App; +use actix_web::HttpServer; +use actix_web::web::Data; +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; + +use crate::agent_controller_pool::AgentControllerPool; +use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use crate::create_cors_middleware::create_cors_middleware; +use crate::embedding_sender_collection::EmbeddingSenderCollection; +use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use crate::http_route as common_http_route; +use crate::management_service::app_data::AppData; +use crate::management_service::configuration::Configuration as ManagementServiceConfiguration; +use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; +use crate::state_database::StateDatabase; +#[cfg(feature = "web_admin_panel")] +use crate::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; + +#[cfg(feature = "web_admin_panel")] +fn collect_web_admin_panel_cors_allowed_hosts( + web_admin_panel_service_configuration: Option<&WebAdminPanelServiceConfiguration>, +) -> Vec { + web_admin_panel_service_configuration + .map(|web_admin_panel_config| format!("http://{}", web_admin_panel_config.addr)) + .into_iter() + .collect() +} + +pub struct ManagementService { + pub agent_controller_pool: Arc, + pub balancer_applicable_state_holder: Arc, + pub buffered_request_manager: Arc, + pub chat_template_override_sender_collection: Arc, + pub configuration: ManagementServiceConfiguration, + pub embedding_sender_collection: Arc, + pub generate_tokens_sender_collection: Arc, + pub model_metadata_sender_collection: Arc, + pub shutdown_options: ServiceShutdownOptions, + pub state_database: Arc, + pub statsd_prefix: String, + #[cfg(feature = "web_admin_panel")] + pub web_admin_panel_service_configuration: Option, +} + +#[async_trait] +impl Service for ManagementService { + fn name(&self) -> &'static str { + "balancer::management_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let web_admin_panel_cors_allowed_hosts: Vec = { + #[cfg(feature = "web_admin_panel")] + { + collect_web_admin_panel_cors_allowed_hosts( + self.web_admin_panel_service_configuration.as_ref(), + ) + } + #[cfg(not(feature = "web_admin_panel"))] + { + Vec::new() + } + }; + + let cors_allowed_hosts_arc = Arc::new( + self.configuration + .cors_allowed_hosts + .iter() + .cloned() + .chain(web_admin_panel_cors_allowed_hosts) + .collect::>(), + ); + + let app_data = Data::new(AppData { + agent_controller_pool: self.agent_controller_pool.clone(), + balancer_applicable_state_holder: self.balancer_applicable_state_holder.clone(), + buffered_request_manager: self.buffered_request_manager.clone(), + chat_template_override_sender_collection: self + .chat_template_override_sender_collection + .clone(), + embedding_sender_collection: self.embedding_sender_collection.clone(), + generate_tokens_sender_collection: self.generate_tokens_sender_collection.clone(), + model_metadata_sender_collection: self.model_metadata_sender_collection.clone(), + shutdown: shutdown.clone(), + state_database: self.state_database.clone(), + statsd_prefix: self.statsd_prefix.clone(), + }); + + let bind_addr = self.configuration.addr; + + let server = HttpServer::new(move || { + App::new() + .wrap(create_cors_middleware(&cors_allowed_hosts_arc)) + .app_data(app_data.clone()) + .configure(common_http_route::get_health::register) + .configure(http_route::api::get_agents::register) + .configure(http_route::api::get_agents_stream::register) + .configure(http_route::api::get_balancer_applicable_state::register) + .configure(http_route::api::get_balancer_desired_state::register) + .configure(http_route::api::get_buffered_requests::register) + .configure(http_route::api::get_buffered_requests_stream::register) + .configure(http_route::api::get_chat_template_override::register) + .configure(http_route::api::get_model_metadata::register) + .configure(http_route::api::put_balancer_desired_state::register) + .configure(http_route::api::ws_agent_socket::register) + .configure(http_route::get_metrics::register) + }) + .shutdown_signal(async move { + shutdown.cancelled().await; + }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) + .disable_signals() + .bind(bind_addr) + .with_context(|| format!("Unable to bind balancer management service to {bind_addr}"))?; + + server.run().await?; + + Ok(()) + } +} + +#[cfg(all(test, feature = "web_admin_panel"))] +mod tests { + use std::net::SocketAddr; + use std::net::TcpListener; + use std::sync::Arc; + use std::time::Duration; + + use tokio::sync::broadcast; + use tokio_util::sync::CancellationToken; + use trzcina::Service as _; + use trzcina::ServiceShutdownOptions; + + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::buffered_request_manager::BufferedRequestManager; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::management_service::configuration::Configuration as ManagementServiceConfiguration; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use crate::resolved_socket_addr::ResolvedSocketAddr; + use crate::state_database::memory::Memory; + use crate::web_admin_panel_service::template_data::TemplateData; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + + use super::ManagementService; + use super::WebAdminPanelServiceConfiguration; + use super::collect_web_admin_panel_cors_allowed_hosts; + + fn build_service(addr: SocketAddr) -> ManagementService { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(1); + + ManagementService { + agent_controller_pool: agent_controller_pool.clone(), + balancer_applicable_state_holder: Arc::new(BalancerApplicableStateHolder::default()), + buffered_request_manager: Arc::new(BufferedRequestManager::new( + agent_controller_pool, + Duration::from_secs(30), + 32, + )), + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + configuration: ManagementServiceConfiguration { + addr, + cors_allowed_hosts: vec!["http://127.0.0.1:8080".to_owned()], + }, + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + shutdown_options: ServiceShutdownOptions::default(), + state_database: Arc::new(Memory::new( + balancer_desired_state_notify_tx, + BalancerDesiredState::default(), + )), + statsd_prefix: "paddler".to_owned(), + web_admin_panel_service_configuration: None, + } + } + + #[expect( + clippy::unwrap_used, + reason = "test fixture helper; allow-unwrap-in-tests is not applied to non-#[test] helpers inside cfg(all(test, feature)) modules" + )] + fn make_resolved_socket_addr(input_addr: &str) -> ResolvedSocketAddr { + ResolvedSocketAddr { + input_addr: input_addr.to_owned(), + socket_addr: input_addr.parse().unwrap(), + } + } + + fn make_web_admin_panel_configuration(addr: SocketAddr) -> WebAdminPanelServiceConfiguration { + WebAdminPanelServiceConfiguration { + addr, + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(1), + compat_openai_addr: None, + inference_addr: make_resolved_socket_addr("127.0.0.1:8081"), + management_addr: make_resolved_socket_addr("127.0.0.1:8082"), + max_buffered_requests: 1, + statsd_addr: None, + statsd_prefix: "paddler".to_owned(), + statsd_reporting_interval: Duration::from_secs(1), + }, + } + } + + #[test] + fn builds_http_origin_from_web_admin_panel_addr() { + let configuration = make_web_admin_panel_configuration("127.0.0.1:9000".parse().unwrap()); + + let allowed_hosts = collect_web_admin_panel_cors_allowed_hosts(Some(&configuration)); + + assert_eq!(allowed_hosts, vec!["http://127.0.0.1:9000".to_owned()]); + } + + #[test] + fn yields_no_hosts_when_web_admin_panel_is_absent() { + let allowed_hosts = collect_web_admin_panel_cors_allowed_hosts(None); + + assert!(allowed_hosts.is_empty()); + } + + #[actix_web::test] + async fn run_returns_error_when_address_is_already_in_use() { + let occupied_listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap(); + let occupied_addr = occupied_listener.local_addr().unwrap(); + + let service = Box::new(build_service(occupied_addr)); + let result = service.run(CancellationToken::new()).await; + + let error_message = result.unwrap_err().to_string(); + let expected_addr_fragment = occupied_addr.to_string(); + + assert!(error_message.contains(&expected_addr_fragment)); + } +} diff --git a/paddler/src/balancer/manages_senders.rs b/paddler_balancer/src/manages_senders.rs similarity index 66% rename from paddler/src/balancer/manages_senders.rs rename to paddler_balancer/src/manages_senders.rs index 8d64a724..029ff567 100644 --- a/paddler/src/balancer/manages_senders.rs +++ b/paddler_balancer/src/manages_senders.rs @@ -60,3 +60,33 @@ pub trait ManagesSenders: Send + Sync { Ok(()) } } + +#[cfg(test)] +mod tests { + use tokio::sync::mpsc; + + use super::ManagesSenders; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + + #[test] + fn register_sender_rejects_duplicate_request_id() { + let sender_collection = EmbeddingSenderCollection::default(); + let request_id = "duplicate-request".to_owned(); + let (first_sender, _first_receiver) = mpsc::unbounded_channel(); + let (second_sender, _second_receiver) = mpsc::unbounded_channel(); + + sender_collection + .register_sender(request_id.clone(), first_sender) + .unwrap(); + + let duplicate_error = sender_collection + .register_sender(request_id, second_sender) + .err() + .unwrap(); + + assert_eq!( + duplicate_error.to_string(), + "Sender for request_id duplicate-request already exists" + ); + } +} diff --git a/paddler_balancer/src/manages_senders_controller.rs b/paddler_balancer/src/manages_senders_controller.rs new file mode 100644 index 00000000..f571b767 --- /dev/null +++ b/paddler_balancer/src/manages_senders_controller.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use anyhow::Result; +use log::error; +use tokio::sync::mpsc; + +use crate::manages_senders::ManagesSenders; + +pub struct ManagesSendersController +where + TSenderCollection: ManagesSenders, +{ + pub request_id: String, + pub response_rx: mpsc::UnboundedReceiver, + pub response_sender_collection: Arc, +} + +impl ManagesSendersController +where + TSenderCollection: ManagesSenders, +{ + pub fn from_request_id( + request_id: String, + response_sender_collection: Arc, + ) -> Result { + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + response_sender_collection.register_sender(request_id.clone(), response_tx)?; + + Ok(Self { + request_id, + response_rx, + response_sender_collection, + }) + } +} + +impl Drop for ManagesSendersController +where + TSenderCollection: ManagesSenders, +{ + fn drop(&mut self) { + self.response_sender_collection + .deregister_sender(self.request_id.clone()) + .unwrap_or_else(|err| { + error!( + "Failed to deregister sender for request_id {}: {err}", + self.request_id + ); + }); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::ManagesSendersController; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::manages_senders::ManagesSenders; + + #[test] + fn registers_sender_on_construction() { + let response_sender_collection = Arc::new(EmbeddingSenderCollection::default()); + + let controller = ManagesSendersController::from_request_id( + "request-1".to_owned(), + response_sender_collection.clone(), + ) + .unwrap(); + + assert!( + response_sender_collection + .get_sender_collection() + .contains_key("request-1") + ); + + drop(controller); + } + + #[test] + fn returns_error_when_sender_already_registered() { + let response_sender_collection = Arc::new(EmbeddingSenderCollection::default()); + + let _first_controller = ManagesSendersController::from_request_id( + "request-1".to_owned(), + response_sender_collection.clone(), + ) + .unwrap(); + + let result = ManagesSendersController::from_request_id( + "request-1".to_owned(), + response_sender_collection, + ); + + let error = result.err().unwrap(); + + assert_eq!( + error.to_string(), + "Sender for request_id request-1 already exists" + ); + } +} diff --git a/paddler/src/balancer/model_metadata_sender_collection.rs b/paddler_balancer/src/model_metadata_sender_collection.rs similarity index 84% rename from paddler/src/balancer/model_metadata_sender_collection.rs rename to paddler_balancer/src/model_metadata_sender_collection.rs index 67a33399..b130cd60 100644 --- a/paddler/src/balancer/model_metadata_sender_collection.rs +++ b/paddler_balancer/src/model_metadata_sender_collection.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use dashmap::DashMap; -use paddler_types::model_metadata::ModelMetadata; +use paddler_messaging::model_metadata::ModelMetadata; use tokio::sync::mpsc; -use crate::balancer::manages_senders::ManagesSenders; +use crate::manages_senders::ManagesSenders; pub struct ModelMetadataSenderCollection { senders: DashMap>>, diff --git a/paddler_balancer/src/reconciliation_service.rs b/paddler_balancer/src/reconciliation_service.rs new file mode 100644 index 00000000..41bff7fb --- /dev/null +++ b/paddler_balancer/src/reconciliation_service.rs @@ -0,0 +1,248 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use log::error; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use tokio::sync::broadcast; +use tokio::time::Duration; +use tokio::time::MissedTickBehavior; +use tokio::time::interval; +use tokio_util::sync::CancellationToken; +use trzcina::Service; + +use crate::agent_controller_pool::AgentControllerPool; +use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use crate::balancer_desired_state_converter::BalancerDesiredStateConverter; +use crate::sets_desired_state::SetsDesiredState as _; +use paddler_state_conversion::converts_to_applicable_state::ConvertsToApplicableState as _; + +async fn convert_to_applicable_state( + balancer_desired_state: &BalancerDesiredState, + agent_controller_pool: &AgentControllerPool, + balancer_applicable_state_holder: &BalancerApplicableStateHolder, + is_converted_to_applicable_state: &mut bool, +) -> Result<()> { + let balancer_applicable_state = BalancerDesiredStateConverter + .to_applicable_state(balancer_desired_state.clone()) + .await?; + + agent_controller_pool + .set_desired_state(balancer_applicable_state.agent_desired_state.clone()) + .await?; + balancer_applicable_state_holder.set_balancer_applicable_state(Some(balancer_applicable_state)); + + *is_converted_to_applicable_state = true; + + Ok(()) +} + +async fn try_convert_to_applicable_state( + balancer_desired_state: &BalancerDesiredState, + agent_controller_pool: &AgentControllerPool, + balancer_applicable_state_holder: &BalancerApplicableStateHolder, + is_converted_to_applicable_state: &mut bool, +) { + if let Err(err) = convert_to_applicable_state( + balancer_desired_state, + agent_controller_pool, + balancer_applicable_state_holder, + is_converted_to_applicable_state, + ) + .await + { + error!("Failed to convert to applicable state: {err}"); + } +} + +pub struct ReconciliationService { + pub agent_controller_pool: Arc, + pub balancer_applicable_state_holder: Arc, + pub balancer_desired_state: BalancerDesiredState, + pub balancer_desired_state_rx: broadcast::Receiver, + pub is_converted_to_applicable_state: bool, +} + +#[async_trait] +impl Service for ReconciliationService { + fn name(&self) -> &'static str { + "balancer::reconciliation_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let Self { + agent_controller_pool, + balancer_applicable_state_holder, + mut balancer_desired_state, + mut balancer_desired_state_rx, + mut is_converted_to_applicable_state, + } = *self; + + let mut ticker = interval(Duration::from_secs(1)); + + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + () = shutdown.cancelled() => break Ok(()), + _ = ticker.tick() => { + if !is_converted_to_applicable_state { + try_convert_to_applicable_state( + &balancer_desired_state, + &agent_controller_pool, + &balancer_applicable_state_holder, + &mut is_converted_to_applicable_state, + ).await; + } + }, + received_balancer_desired_state = balancer_desired_state_rx.recv() => { + is_converted_to_applicable_state = false; + balancer_desired_state = received_balancer_desired_state?; + try_convert_to_applicable_state( + &balancer_desired_state, + &agent_controller_pool, + &balancer_applicable_state_holder, + &mut is_converted_to_applicable_state, + ).await; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::convert_to_applicable_state; + use super::try_convert_to_applicable_state; + use crate::agent_controller::AgentController; + use crate::agent_controller_pool::AgentControllerPool; + use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use crate::balancer_desired_state_converter::BalancerDesiredStateConverter; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + use paddler_state_conversion::converts_to_desired_state::ConvertsToDesiredState as _; + + fn agent_controller_with_dropped_receiver() -> Arc { + let (agent_message_tx, agent_message_rx) = mpsc::unbounded_channel(); + + drop(agent_message_rx); + + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: "agent-test".to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(0), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }) + } + + #[tokio::test] + async fn convert_to_applicable_state_sets_flag_and_stores_state_for_empty_pool() { + let balancer_desired_state = BalancerDesiredState::default(); + let agent_controller_pool = AgentControllerPool::default(); + let balancer_applicable_state_holder = BalancerApplicableStateHolder::default(); + let mut is_converted_to_applicable_state = false; + + convert_to_applicable_state( + &balancer_desired_state, + &agent_controller_pool, + &balancer_applicable_state_holder, + &mut is_converted_to_applicable_state, + ) + .await + .unwrap(); + + assert!(is_converted_to_applicable_state); + assert_eq!( + balancer_applicable_state_holder.get_agent_desired_state(), + Some(BalancerDesiredStateConverter.to_desired_state(balancer_desired_state.clone())) + ); + } + + #[tokio::test] + async fn convert_to_applicable_state_errors_when_agent_message_receiver_dropped() { + let balancer_desired_state = BalancerDesiredState::default(); + let agent_controller_pool = AgentControllerPool::default(); + + agent_controller_pool + .register_agent_controller( + "agent-test".to_owned(), + agent_controller_with_dropped_receiver(), + ) + .unwrap(); + + let balancer_applicable_state_holder = BalancerApplicableStateHolder::default(); + let mut is_converted_to_applicable_state = false; + + let result = convert_to_applicable_state( + &balancer_desired_state, + &agent_controller_pool, + &balancer_applicable_state_holder, + &mut is_converted_to_applicable_state, + ) + .await; + + assert!(result.err().is_some()); + assert!(!is_converted_to_applicable_state); + } + + #[tokio::test] + async fn try_convert_to_applicable_state_keeps_flag_false_when_agent_send_fails() { + let balancer_desired_state = BalancerDesiredState::default(); + let agent_controller_pool = AgentControllerPool::default(); + + agent_controller_pool + .register_agent_controller( + "agent-test".to_owned(), + agent_controller_with_dropped_receiver(), + ) + .unwrap(); + + let balancer_applicable_state_holder = BalancerApplicableStateHolder::default(); + let mut is_converted_to_applicable_state = false; + + try_convert_to_applicable_state( + &balancer_desired_state, + &agent_controller_pool, + &balancer_applicable_state_holder, + &mut is_converted_to_applicable_state, + ) + .await; + + assert!(!is_converted_to_applicable_state); + } +} diff --git a/paddler_balancer/src/request_from_agent.rs b/paddler_balancer/src/request_from_agent.rs new file mode 100644 index 00000000..f836d780 --- /dev/null +++ b/paddler_balancer/src/request_from_agent.rs @@ -0,0 +1,793 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use log::debug; +use log::error; +use log::warn; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::jsonrpc::error::Error as JsonRpcError; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; +use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; +use paddler_messaging::streamable_result::StreamableResult; +use tokio::time::sleep; +use tokio_util::sync::CancellationToken; + +use crate::agent_controller::AgentController; +use crate::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::controls_session::ControlsSession; +use crate::dispatched_agent::DispatchedAgent; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::manages_senders::ManagesSenders; +use crate::manages_senders_controller::ManagesSendersController; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; + +pub async fn request_from_agent( + buffered_request_manager: Arc, + connection_close: CancellationToken, + inference_service_configuration: InferenceServiceConfiguration, + params: TParams, + request_id: String, + mut session_controller: TControlsSession, + shutdown: CancellationToken, +) +where + TControlsSession: ControlsSession, + TParams: Debug + Into + Send, + AgentController: HandlesAgentStreamingResponse, + <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, +{ + let Some(dispatched_agent) = wait_for_agent_controller( + buffered_request_manager.clone(), + connection_close.clone(), + request_id.clone(), + &mut session_controller, + shutdown.clone(), + ) + .await + else { + return; + }; + + let receive_response_controller = match dispatched_agent + .agent_controller + .handle_streaming_response(request_id.clone(), params) + .await + { + Ok(receive_response_controller) => receive_response_controller, + Err(err) => { + error!("Failed to handle request {request_id:?}: {err}"); + + respond_with_error( + JsonRpcError { + code: 500, + description: "Failed to generate response".to_owned(), + }, + request_id.clone(), + &mut session_controller, + ) + .await; + + return; + } + }; + + forward_responses_stream( + dispatched_agent.agent_controller.clone(), + connection_close, + inference_service_configuration, + receive_response_controller, + request_id, + session_controller, + shutdown, + ) + .await; +} + +pub async fn forward_responses_stream( + agent_controller: Arc, + connection_close: CancellationToken, + inference_service_configuration: InferenceServiceConfiguration, + mut receive_response_controller: ManagesSendersController, + request_id: String, + mut session_controller: TControlsSession, + shutdown: CancellationToken, +) where + TControlsSession: ControlsSession, + TManagesSenders: ManagesSenders + Send + Sync, + TManagesSenders::Value: Debug + Into + Send + StreamableResult, +{ + debug!("Found available agent controller for request: {request_id:?}"); + + let agent_connection_close = agent_controller.connection_close.clone(); + + loop { + tokio::select! { + () = agent_connection_close.cancelled() => { + error!("Agent controller connection closed"); + + respond_with_error( + JsonRpcError { + code: 502, + description: "Agent controller connection closed".to_owned(), + }, + request_id, + &mut session_controller, + ).await; + + break; + } + () = connection_close.cancelled() => { + agent_controller.stop_responding_to(request_id.clone()).await.unwrap_or_else(|err| { + error!("Failed to stop request {request_id:?}: {err}"); + }); + + break; + } + () = shutdown.cancelled() => { + respond_with_error( + JsonRpcError { + code: 503, + description: "balancer is shutting down".to_owned(), + }, + request_id.clone(), + &mut session_controller, + ).await; + + agent_controller.stop_responding_to(request_id.clone()).await.unwrap_or_else(|err| { + error!("Failed to stop request {request_id:?}: {err}"); + }); + + break; + } + () = sleep(inference_service_configuration.inference_item_timeout) => { + let timeout_ms = inference_service_configuration.inference_item_timeout.as_millis(); + + warn!( + "Timed out after {timeout_ms}ms waiting for next token for request {request_id:?}. \ + Consider increasing --inference-item-timeout if the model needs more time to process the prompt." + ); + + respond_with_error( + JsonRpcError { + code: 504, + description: format!( + "Inference timed out after {timeout_ms}ms waiting for next token. \ + Increase --inference-item-timeout if the prompt requires longer processing." + ), + }, + request_id.clone(), + &mut session_controller, + ).await; + + agent_controller.stop_responding_to(request_id.clone()).await.unwrap_or_else(|err| { + error!("Failed to stop responding to request {request_id:?}: {err}"); + }); + + break; + } + response = receive_response_controller.response_rx.recv() => { + if let Some(response) = response { + let is_done = response.is_done(); + + let send_succeeded = send_response_to_client( + agent_controller.clone(), + response, + request_id.clone(), + &mut session_controller, + ).await; + + if is_done || !send_succeeded { + break; + } + } else { + error!( + "Response channel closed before terminator for request {request_id:?}" + ); + + respond_with_error( + JsonRpcError { + code: 502, + description: + "Response channel closed before terminator".to_owned(), + }, + request_id, + &mut session_controller, + ).await; + + break; + } + } + } + } +} + +pub async fn respond_with_error( + error: JsonRpcError, + request_id: String, + session_controller: &mut TControlsSession, +) where + TControlsSession: ControlsSession, +{ + session_controller + .send_response(OutgoingMessage::Error(ErrorEnvelope { + request_id: request_id.clone(), + error, + })) + .await + .unwrap_or_else(|err| { + error!("Failed to send response for request {request_id:?}: {err}"); + }); +} + +async fn send_response_to_client( + agent_controller: Arc, + response: TResponse, + request_id: String, + session_controller: &mut TControlsSession, +) -> bool +where + TControlsSession: ControlsSession, + TResponse: Into + Send, +{ + if let Err(err) = session_controller + .send_response(OutgoingMessage::Response(ResponseEnvelope { + generated_by: agent_controller.name.clone(), + request_id: request_id.clone(), + response: response.into(), + })) + .await + { + error!("Failed to send response for request {request_id:?}: {err}"); + + agent_controller + .stop_responding_to(request_id.clone()) + .await + .unwrap_or_else(|err| { + error!("Failed to stop responding to request {request_id:?}: {err}"); + }); + + return false; + } + + true +} + +async fn wait_for_agent_controller( + buffered_request_manager: Arc, + connection_close: CancellationToken, + request_id: String, + session_controller: &mut TControlsSession, + shutdown: CancellationToken, +) -> Option +where + TControlsSession: ControlsSession, +{ + let buffered_request_manager = buffered_request_manager.clone(); + + tokio::select! { + () = connection_close.cancelled() => { + debug!("Connection close signal received, stopping GenerateTokens loop."); + + None + }, + () = shutdown.cancelled() => { + respond_with_error( + JsonRpcError { + code: 503, + description: "balancer is shutting down".to_owned(), + }, + request_id.clone(), + session_controller, + ).await; + + None + }, + buffered_request_agent_wait_result = buffered_request_manager.wait_for_available_agent() => { + match buffered_request_agent_wait_result { + Ok(BufferedRequestAgentWaitResult::Found(dispatched_agent)) => Some(dispatched_agent), + Ok(BufferedRequestAgentWaitResult::BufferOverflow) => { + warn!("Too many buffered requests, dropping request: {request_id:?}"); + + respond_with_error( + JsonRpcError { + code: 503, + description: "Buffered requests overflow".to_owned(), + }, + request_id.clone(), + session_controller, + ).await; + + None + } + Ok(BufferedRequestAgentWaitResult::Timeout(err)) => { + warn!("Buffered request {request_id:?} timed out: {err:?}"); + + respond_with_error( + JsonRpcError { + code: 504, + description: "Waiting for available slot timed out".to_owned(), + }, + request_id.clone(), + session_controller, + ).await; + + None + } + Err(err) => { + error!("Error while waiting for available agent controller for GenerateTokens request: {err}"); + + respond_with_error( + JsonRpcError { + code: 500, + description: "Internal server error".to_owned(), + }, + request_id.clone(), + session_controller, + ).await; + + None + } + } + } + } +} + +#[cfg(test)] +mod tests { + use parking_lot::RwLock; + use std::collections::BTreeSet; + use std::mem::discriminant; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use tokio::sync::mpsc; + + use super::*; + use crate::agent_controller_pool::AgentControllerPool; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::chunk_forwarding_session_controller::ChunkForwardingSessionController; + use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::management_socket::agent::message::Message as AgentJsonRpcMessage; + use paddler_messaging::management_socket::agent::notification::Notification as AgentJsonRpcNotification; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + struct AgentControllerWithIncomingChannel { + agent_controller: Arc, + agent_message_rx: mpsc::UnboundedReceiver, + } + + fn agent_controller_with_one_free_slot(id: &str) -> AgentControllerWithIncomingChannel { + let (agent_message_tx, agent_message_rx) = mpsc::unbounded_channel(); + + let agent_controller = Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(1), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), + id: id.to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new(ModelMetadataSenderCollection::default()), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(0), + slots_total: AtomicValue::::new(1), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }); + + AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx, + } + } + + fn raw_prompt_params() -> ContinueFromRawPromptParams { + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "fixture prompt".to_owned(), + } + } + + fn inference_service_configuration_with_long_timeout() -> InferenceServiceConfiguration { + const TIMEOUT_LONGER_THAN_ANY_TEST_RUN: Duration = Duration::from_hours(1); + + InferenceServiceConfiguration { + addr: "127.0.0.1:0".parse().unwrap(), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: TIMEOUT_LONGER_THAN_ANY_TEST_RUN, + } + } + + #[tokio::test] + async fn request_from_agent_forwards_error_when_agent_connection_closes() { + let pool = Arc::new(AgentControllerPool::default()); + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx: _agent_message_rx, + } = agent_controller_with_one_free_slot("agent-close"); + + agent_controller.connection_close.cancel(); + + pool.register_agent_controller("agent-close".to_owned(), agent_controller) + .unwrap(); + + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(1), + 10, + )); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + request_from_agent( + buffered_request_manager, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + raw_prompt_params(), + "request-close".to_owned(), + session_controller, + CancellationToken::new(), + ) + .await; + + let forwarded = chunk_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&forwarded), + discriminant(&TransformResult::Chunk(String::new())) + ); + } + + #[tokio::test] + async fn request_from_agent_responds_with_error_when_streaming_setup_fails() { + let pool = Arc::new(AgentControllerPool::default()); + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx, + } = agent_controller_with_one_free_slot("agent-setup-fail"); + + drop(agent_message_rx); + + pool.register_agent_controller("agent-setup-fail".to_owned(), agent_controller) + .unwrap(); + + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(1), + 10, + )); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + request_from_agent( + buffered_request_manager, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + raw_prompt_params(), + "request-setup-fail".to_owned(), + session_controller, + CancellationToken::new(), + ) + .await; + + let forwarded = chunk_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&forwarded), + discriminant(&TransformResult::Chunk(String::new())) + ); + } + + #[tokio::test] + async fn forward_responses_stream_stops_responding_when_client_send_fails() { + let AgentControllerWithIncomingChannel { + agent_controller, + mut agent_message_rx, + } = agent_controller_with_one_free_slot("agent-send-fail"); + + let request_id = "request-send-fail".to_owned(); + let sender_collection = agent_controller.generate_tokens_sender_collection.clone(); + let receive_response_controller = ManagesSendersController::from_request_id( + request_id.clone(), + sender_collection.clone(), + ) + .unwrap(); + + sender_collection + .forward_response( + request_id.clone(), + GeneratedTokenResult::ContentToken("token".to_owned()), + ) + .await + .unwrap(); + + let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); + + drop(chunk_rx); + + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + forward_responses_stream( + agent_controller, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + receive_response_controller, + request_id, + session_controller, + CancellationToken::new(), + ) + .await; + + let stop_message = agent_message_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&stop_message), + discriminant(&AgentJsonRpcMessage::Notification( + AgentJsonRpcNotification::StopRespondingTo(String::new()) + )) + ); + } + + #[tokio::test] + async fn request_from_agent_responds_with_error_on_shutdown_while_waiting() { + let pool = Arc::new(AgentControllerPool::default()); + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(1), + 10, + )); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let shutdown = CancellationToken::new(); + + shutdown.cancel(); + + request_from_agent( + buffered_request_manager, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + raw_prompt_params(), + "request-shutdown".to_owned(), + session_controller, + shutdown, + ) + .await; + + let forwarded = chunk_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&forwarded), + discriminant(&TransformResult::Chunk(String::new())) + ); + } + + #[tokio::test] + async fn request_from_agent_responds_with_error_on_buffer_overflow() { + let pool = Arc::new(AgentControllerPool::default()); + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx: _agent_message_rx, + } = agent_controller_with_one_free_slot("agent-overflow"); + + agent_controller.slots_processing.set(1); + + pool.register_agent_controller("agent-overflow".to_owned(), agent_controller) + .unwrap(); + + let buffered_request_manager = + Arc::new(BufferedRequestManager::new(pool, Duration::from_secs(1), 0)); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + request_from_agent( + buffered_request_manager, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + raw_prompt_params(), + "request-overflow".to_owned(), + session_controller, + CancellationToken::new(), + ) + .await; + + let forwarded = chunk_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&forwarded), + discriminant(&TransformResult::Chunk(String::new())) + ); + } + + #[tokio::test] + async fn request_from_agent_stops_waiting_when_connection_closes() { + let pool = Arc::new(AgentControllerPool::default()); + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(1), + 10, + )); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let connection_close = CancellationToken::new(); + + connection_close.cancel(); + + request_from_agent( + buffered_request_manager, + connection_close, + inference_service_configuration_with_long_timeout(), + raw_prompt_params(), + "request-connection-close".to_owned(), + session_controller, + CancellationToken::new(), + ) + .await; + + assert!(chunk_rx.recv().await.is_none()); + } + + #[tokio::test] + async fn forward_responses_stream_stops_agent_when_client_connection_closes() { + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx, + } = agent_controller_with_one_free_slot("agent-client-close"); + + drop(agent_message_rx); + + let request_id = "request-client-close".to_owned(); + let receive_response_controller = ManagesSendersController::from_request_id( + request_id.clone(), + agent_controller.generate_tokens_sender_collection.clone(), + ) + .unwrap(); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let connection_close = CancellationToken::new(); + + connection_close.cancel(); + + forward_responses_stream( + agent_controller, + connection_close, + inference_service_configuration_with_long_timeout(), + receive_response_controller, + request_id, + session_controller, + CancellationToken::new(), + ) + .await; + + assert!(chunk_rx.recv().await.is_none()); + } + + #[tokio::test] + async fn forward_responses_stream_responds_with_error_on_shutdown() { + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx, + } = agent_controller_with_one_free_slot("agent-shutdown"); + + drop(agent_message_rx); + + let request_id = "request-stream-shutdown".to_owned(); + let receive_response_controller = ManagesSendersController::from_request_id( + request_id.clone(), + agent_controller.generate_tokens_sender_collection.clone(), + ) + .unwrap(); + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + let session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let shutdown = CancellationToken::new(); + + shutdown.cancel(); + + forward_responses_stream( + agent_controller, + CancellationToken::new(), + inference_service_configuration_with_long_timeout(), + receive_response_controller, + request_id, + session_controller, + shutdown, + ) + .await; + + let forwarded = chunk_rx.recv().await.unwrap(); + + assert_eq!( + discriminant(&forwarded), + discriminant(&TransformResult::Chunk(String::new())) + ); + } + + #[tokio::test] + async fn respond_with_error_logs_when_client_send_fails() { + let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); + + drop(chunk_rx); + + let mut session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + respond_with_error( + JsonRpcError { + code: 500, + description: "send fails".to_owned(), + }, + "request-send-error".to_owned(), + &mut session_controller, + ) + .await; + } + + #[tokio::test] + async fn send_response_to_client_returns_false_and_logs_when_stop_fails() { + let AgentControllerWithIncomingChannel { + agent_controller, + agent_message_rx, + } = agent_controller_with_one_free_slot("agent-stop-fails"); + + drop(agent_message_rx); + + let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); + + drop(chunk_rx); + + let mut session_controller = + ChunkForwardingSessionController::new(chunk_tx, IdentityTransformer::new()); + + let send_succeeded = send_response_to_client( + agent_controller, + GeneratedTokenResult::ContentToken("token".to_owned()), + "request-stop-fails".to_owned(), + &mut session_controller, + ) + .await; + + assert!(!send_succeeded); + } +} diff --git a/paddler/src/resolved_socket_addr.rs b/paddler_balancer/src/resolved_socket_addr.rs similarity index 100% rename from paddler/src/resolved_socket_addr.rs rename to paddler_balancer/src/resolved_socket_addr.rs diff --git a/paddler/src/balancer/response/mod.rs b/paddler_balancer/src/response/mod.rs similarity index 50% rename from paddler/src/balancer/response/mod.rs rename to paddler_balancer/src/response/mod.rs index b1d22db1..62904c7b 100644 --- a/paddler/src/balancer/response/mod.rs +++ b/paddler_balancer/src/response/mod.rs @@ -1,4 +1,2 @@ -mod view; +pub mod view; mod view_from_http_response_builder; - -pub use self::view::view; diff --git a/paddler_balancer/src/response/view.rs b/paddler_balancer/src/response/view.rs new file mode 100644 index 00000000..a2670e13 --- /dev/null +++ b/paddler_balancer/src/response/view.rs @@ -0,0 +1,51 @@ +use actix_web::HttpResponse; +use actix_web::Result; +use askama::Template; + +use super::view_from_http_response_builder::view_from_http_response_builder; + +pub fn view(template: TTemplate) -> Result { + view_from_http_response_builder(HttpResponse::Ok(), template) +} + +#[cfg(test)] +mod tests { + use actix_web::http::StatusCode; + use actix_web::http::header::CONTENT_TYPE; + use askama::Template; + + use super::view; + + #[derive(Template)] + #[template(ext = "html", source = "

{{ greeting }}

")] + struct GreetingTemplate { + greeting: String, + } + + #[test] + fn responds_with_ok_status() { + let response = view(GreetingTemplate { + greeting: "hello".to_owned(), + }) + .unwrap(); + + assert_eq!(StatusCode::OK, response.status()); + } + + #[test] + fn responds_with_html_content_type() { + let response = view(GreetingTemplate { + greeting: "hello".to_owned(), + }) + .unwrap(); + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!("text/html; charset=utf-8", content_type); + } +} diff --git a/paddler_balancer/src/response/view_from_http_response_builder.rs b/paddler_balancer/src/response/view_from_http_response_builder.rs new file mode 100644 index 00000000..6b49a598 --- /dev/null +++ b/paddler_balancer/src/response/view_from_http_response_builder.rs @@ -0,0 +1,186 @@ +use actix_web::HttpResponse; +use actix_web::HttpResponseBuilder; +use actix_web::Result; +use actix_web::error::ErrorInternalServerError; +use askama::Template; + +pub fn view_from_http_response_builder( + mut http_response_builder: HttpResponseBuilder, + template: TTemplate, +) -> Result { + let rendered = template.render().map_err(ErrorInternalServerError)?; + + Ok(http_response_builder + .content_type("text/html; charset=utf-8") + .body(rendered)) +} + +#[cfg(test)] +mod tests { + use std::fmt; + use std::mem; + + use actix_web::HttpResponse; + use actix_web::http::StatusCode; + use actix_web::http::header::CONTENT_TYPE; + use askama::Error as AskamaError; + use askama::FastWritable; + use askama::Template; + use askama::Values; + + use super::view_from_http_response_builder; + + struct FailingWriter; + + impl fmt::Write for FailingWriter { + fn write_str(&mut self, _content: &str) -> fmt::Result { + Err(fmt::Error) + } + } + + struct RenderingTemplate; + + impl fmt::Display for RenderingTemplate { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + self.render_into(formatter) + .map_err(|_askama_error| fmt::Error) + } + } + + impl FastWritable for RenderingTemplate { + fn write_into( + &self, + destination: &mut TWriter, + values: &dyn Values, + ) -> askama::Result<()> { + self.render_into_with_values(destination, values) + } + } + + impl Template for RenderingTemplate { + const SIZE_HINT: usize = 0; + + fn render_into_with_values( + &self, + writer: &mut TWriter, + _values: &dyn Values, + ) -> askama::Result<()> { + writer.write_str("

rendered

")?; + + Ok(()) + } + } + + struct FailingTemplate; + + impl fmt::Display for FailingTemplate { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + self.render_into(formatter) + .map_err(|_askama_error| fmt::Error) + } + } + + impl FastWritable for FailingTemplate { + fn write_into( + &self, + destination: &mut TWriter, + values: &dyn Values, + ) -> askama::Result<()> { + self.render_into_with_values(destination, values) + } + } + + impl Template for FailingTemplate { + const SIZE_HINT: usize = 0; + + fn render_into_with_values( + &self, + _writer: &mut TWriter, + _values: &dyn Values, + ) -> askama::Result<()> { + Err(AskamaError::ValueMissing) + } + } + + #[test] + fn renders_template_into_ok_html_response() { + let http_response = + view_from_http_response_builder(HttpResponse::Ok(), RenderingTemplate).unwrap(); + + assert_eq!(http_response.status(), StatusCode::OK); + assert_eq!( + http_response + .headers() + .get(CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "text/html; charset=utf-8", + ); + } + + #[test] + fn maps_render_failure_to_internal_server_error() { + let render_error = + view_from_http_response_builder(HttpResponse::Ok(), FailingTemplate).unwrap_err(); + + assert_eq!( + render_error.as_response_error().status_code(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + } + + #[test] + fn displays_rendering_template_markup() { + assert_eq!(RenderingTemplate.to_string(), "

rendered

"); + } + + #[test] + fn displays_failing_template_as_fmt_error() { + use std::fmt::Write as _; + + let mut destination = String::new(); + let display_error = write!(destination, "{FailingTemplate}").err().unwrap(); + + assert_eq!(display_error, fmt::Error); + } + + #[test] + fn rendering_template_fast_writable_writes_markup() { + let mut destination = String::new(); + + FastWritable::write_into(&RenderingTemplate, &mut destination, askama::NO_VALUES).unwrap(); + + assert_eq!(destination, "

rendered

"); + } + + #[test] + fn failing_template_fast_writable_propagates_value_missing() { + let mut destination = String::new(); + + let write_error = + FastWritable::write_into(&FailingTemplate, &mut destination, askama::NO_VALUES) + .err() + .unwrap(); + + assert_eq!( + mem::discriminant(&write_error), + mem::discriminant(&AskamaError::ValueMissing), + ); + } + + #[test] + fn rendering_template_maps_writer_failure_to_fmt_error() { + let mut failing_writer = FailingWriter; + + let write_error = RenderingTemplate + .render_into_with_values(&mut failing_writer, askama::NO_VALUES) + .err() + .unwrap(); + + assert_eq!( + mem::discriminant(&write_error), + mem::discriminant(&AskamaError::Fmt), + ); + } +} diff --git a/paddler/src/sends_rpc_message.rs b/paddler_balancer/src/sends_rpc_message.rs similarity index 80% rename from paddler/src/sends_rpc_message.rs rename to paddler_balancer/src/sends_rpc_message.rs index c6e0cfc4..f461e8d4 100644 --- a/paddler/src/sends_rpc_message.rs +++ b/paddler_balancer/src/sends_rpc_message.rs @@ -1,6 +1,6 @@ use anyhow::Result; use async_trait::async_trait; -use paddler_types::rpc_message::RpcMessage; +use paddler_messaging::rpc_message::RpcMessage; #[async_trait] pub trait SendsRpcMessage { diff --git a/paddler/src/sets_desired_state.rs b/paddler_balancer/src/sets_desired_state.rs similarity index 74% rename from paddler/src/sets_desired_state.rs rename to paddler_balancer/src/sets_desired_state.rs index 8d182a39..c6a2a94f 100644 --- a/paddler/src/sets_desired_state.rs +++ b/paddler_balancer/src/sets_desired_state.rs @@ -1,6 +1,6 @@ use anyhow::Result; use async_trait::async_trait; -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_messaging::agent_desired_state::AgentDesiredState; #[async_trait] pub trait SetsDesiredState { diff --git a/paddler/src/snapshots_stream.rs b/paddler_balancer/src/snapshots_stream.rs similarity index 68% rename from paddler/src/snapshots_stream.rs rename to paddler_balancer/src/snapshots_stream.rs index 13704fd6..aee5a80d 100644 --- a/paddler/src/snapshots_stream.rs +++ b/paddler_balancer/src/snapshots_stream.rs @@ -5,8 +5,8 @@ use futures::Stream; use log::error; use tokio_util::sync::CancellationToken; -use crate::produces_snapshot::ProducesSnapshot; -use crate::subscribes_to_updates::SubscribesToUpdates; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates; pub fn snapshots_stream( producer: Arc, @@ -50,6 +50,8 @@ mod tests { use super::*; + const SNAPSHOT_TIMEOUT: Duration = Duration::from_secs(1); + struct CounterProducer { update_tx: watch::Sender<()>, value: AtomicI32, @@ -86,63 +88,49 @@ mod tests { } #[tokio::test] - async fn snapshots_stream_emits_initial_snapshot() -> Result<()> { + async fn snapshots_stream_emits_initial_snapshot() { let producer = Arc::new(CounterProducer::new()); let shutdown = CancellationToken::new(); let mut stream = Box::pin(snapshots_stream(producer.clone(), shutdown.clone())); - let first = timeout(Duration::from_secs(1), stream.next()) + let first = timeout(SNAPSHOT_TIMEOUT, stream.next()) .await - .map_err(|err| anyhow::anyhow!("initial snapshot did not arrive: {err}"))? - .ok_or_else(|| anyhow::anyhow!("stream ended before yielding initial snapshot"))?; + .unwrap() + .unwrap(); assert_eq!(first, 0); - - Ok(()) } #[tokio::test] - async fn snapshots_stream_emits_after_subscribed_signal() -> Result<()> { + async fn snapshots_stream_emits_after_subscribed_signal() { let producer = Arc::new(CounterProducer::new()); let shutdown = CancellationToken::new(); let mut stream = Box::pin(snapshots_stream(producer.clone(), shutdown.clone())); - stream - .next() - .await - .ok_or_else(|| anyhow::anyhow!("stream ended before initial snapshot"))?; + stream.next().await.unwrap(); producer.bump(); - let next = timeout(Duration::from_secs(1), stream.next()) + let next = timeout(SNAPSHOT_TIMEOUT, stream.next()) .await - .map_err(|err| anyhow::anyhow!("snapshot after signal did not arrive: {err}"))? - .ok_or_else(|| anyhow::anyhow!("stream ended before yielding bumped snapshot"))?; + .unwrap() + .unwrap(); assert_eq!(next, 1); - - Ok(()) } #[tokio::test] - async fn snapshots_stream_terminates_on_shutdown() -> Result<()> { + async fn snapshots_stream_terminates_on_shutdown() { let producer = Arc::new(CounterProducer::new()); let shutdown = CancellationToken::new(); let mut stream = Box::pin(snapshots_stream(producer.clone(), shutdown.clone())); - stream - .next() - .await - .ok_or_else(|| anyhow::anyhow!("stream ended before initial snapshot"))?; + stream.next().await.unwrap(); shutdown.cancel(); - let terminated = timeout(Duration::from_secs(1), stream.next()) - .await - .map_err(|err| anyhow::anyhow!("stream did not close after shutdown: {err}"))?; + let terminated = timeout(SNAPSHOT_TIMEOUT, stream.next()).await.unwrap(); assert!(terminated.is_none()); - - Ok(()) } } diff --git a/paddler_balancer/src/state_database/file/mod.rs b/paddler_balancer/src/state_database/file/mod.rs new file mode 100644 index 00000000..132d46a8 --- /dev/null +++ b/paddler_balancer/src/state_database/file/mod.rs @@ -0,0 +1,303 @@ +mod schema; + +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; +use async_trait::async_trait; +use log::warn; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use tokio::fs; +use tokio::io::AsyncWriteExt; +use tokio::sync::RwLock; +use tokio::sync::broadcast; + +use self::schema::Schema; +use super::StateDatabase; + +pub struct File { + balancer_desired_state_notify_tx: broadcast::Sender, + path: PathBuf, + write_lock: RwLock<()>, +} + +impl File { + #[must_use] + pub fn new( + balancer_desired_state_notify_tx: broadcast::Sender, + path: PathBuf, + ) -> Self { + Self { + balancer_desired_state_notify_tx, + path, + write_lock: RwLock::new(()), + } + } + + async fn read_schema_from_file(&self) -> Result { + match fs::read_to_string(&self.path).await { + Ok(content) => { + if content.is_empty() { + return self.store_default_schema().await; + } + + let schema: Schema = serde_json::from_str(&content).context(format!("Unable to parse database file contents: '{}'. Either that is not a valid database file, or this version of Paddler is incompatible with it.", self.path.display()))?; + + Ok(schema) + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + warn!( + "State database file not found; trying to store the default state: '{}'", + self.path.display() + ); + + self.store_default_schema().await + } + Err(err) => Err(err.into()), + } + } + + async fn store_default_schema(&self) -> Result { + let schema = Schema::default(); + + self.store_schema(&schema) + .await + .context("Failed to store default state")?; + + Ok(schema) + } + + async fn store_schema(&self, schema: &Schema) -> Result<()> { + let balancer_desired_state = schema.balancer_desired_state.clone(); + let _lock = self.write_lock.write().await; + + let serialized_schema = serde_json::to_string_pretty(schema) + .context("Failed to serialize the state database schema")?; + let mut file = fs::File::create(&self.path).await?; + + file.write_all(serialized_schema.as_bytes()).await?; + file.sync_all().await?; + + self.balancer_desired_state_notify_tx + .send(balancer_desired_state)?; + + Ok(()) + } + + async fn update_schema(&self, modifier: TModifier) -> Result<()> + where + TModifier: FnOnce(&mut Schema), + { + let mut schema = self + .read_schema_from_file() + .await + .context("Unable to read current state from file")?; + + modifier(&mut schema); + + self.store_schema(&schema).await + } +} + +#[async_trait] +impl StateDatabase for File { + async fn read_balancer_desired_state(&self) -> Result { + Ok(self + .read_schema_from_file() + .await + .context("Unable to read state from file")? + .balancer_desired_state) + } + + async fn store_balancer_desired_state( + &self, + balancer_desired_state: &BalancerDesiredState, + ) -> Result<()> { + self.update_schema(|schema| { + schema.balancer_desired_state = balancer_desired_state.clone(); + }) + .await + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use log::LevelFilter; + use tempfile::NamedTempFile; + use tempfile::TempDir; + use tokio::fs; + use tokio::sync::broadcast; + + use super::File; + use super::schema::Schema; + use crate::state_database::StateDatabase; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::balancer_desired_state::BalancerDesiredState; + + #[tokio::test] + async fn store_then_read_round_trips_through_real_file() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().join("state.json"); + let database = File::new(balancer_desired_state_notify_tx, path.clone()); + + let desired_state = BalancerDesiredState { + chat_template_override: None, + inference_parameters: BalancerDesiredState::default().inference_parameters, + model: AgentDesiredModel::LocalToAgent("stored_model_path".to_owned()), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }; + + database + .store_balancer_desired_state(&desired_state) + .await + .unwrap(); + + let read_back = database.read_balancer_desired_state().await.unwrap(); + + assert_eq!(read_back.model, desired_state.model); + assert!(fs::metadata(&path).await.unwrap().is_file()); + } + + #[tokio::test] + async fn reading_missing_file_stores_and_returns_default_state() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().join("not_yet_created.json"); + let database = File::new(balancer_desired_state_notify_tx, path.clone()); + + let read_state = database.read_balancer_desired_state().await.unwrap(); + + assert_eq!(read_state, BalancerDesiredState::default()); + assert!(fs::metadata(&path).await.unwrap().is_file()); + } + + #[tokio::test] + async fn reading_invalid_json_returns_parse_error() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_file = NamedTempFile::new().unwrap(); + let path = temp_file.path().to_path_buf(); + fs::write(&path, b"this is not valid json").await.unwrap(); + let database = File::new(balancer_desired_state_notify_tx, path); + + let read_result = database.read_balancer_desired_state().await; + + assert!(read_result.is_err()); + } + + #[tokio::test] + async fn reading_a_directory_path_returns_non_not_found_error() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let database = File::new( + balancer_desired_state_notify_tx, + temp_dir.path().to_path_buf(), + ); + + let read_result = database.read_balancer_desired_state().await; + + assert!(read_result.is_err()); + } + + #[tokio::test] + async fn storing_default_state_fails_when_parent_directory_is_missing() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let path: PathBuf = temp_dir.path().join("missing_directory").join("state.json"); + let database = File::new(balancer_desired_state_notify_tx, path); + + let read_result = database.read_balancer_desired_state().await; + + assert!(read_result.is_err()); + } + + #[tokio::test] + async fn updating_schema_fails_when_path_is_a_directory() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let database = File::new( + balancer_desired_state_notify_tx, + temp_dir.path().to_path_buf(), + ); + + let store_result = database + .store_balancer_desired_state(&BalancerDesiredState::default()) + .await; + + assert!(store_result.is_err()); + } + + #[tokio::test] + async fn storing_fails_when_no_receivers_are_listening() { + let (balancer_desired_state_notify_tx, balancer_desired_state_notify_rx) = + broadcast::channel(8); + drop(balancer_desired_state_notify_rx); + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().join("state.json"); + let database = File::new(balancer_desired_state_notify_tx, path); + + let store_result = database + .store_balancer_desired_state(&BalancerDesiredState::default()) + .await; + + assert!(store_result.is_err()); + } + + #[tokio::test] + async fn reading_missing_file_logs_path_when_warnings_are_enabled() { + log::set_max_level(LevelFilter::Warn); + + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().join("warned_missing.json"); + let database = File::new(balancer_desired_state_notify_tx, path.clone()); + + let read_state = database.read_balancer_desired_state().await.unwrap(); + + assert_eq!(read_state, BalancerDesiredState::default()); + assert!(fs::metadata(&path).await.unwrap().is_file()); + } + + #[tokio::test] + async fn storing_schema_fails_when_target_is_unwritable() { + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let database = File::new(balancer_desired_state_notify_tx, PathBuf::from("/dev/full")); + + let store_result = database.store_schema(&Schema::default()).await; + + let store_error = store_result.err().unwrap(); + + assert!(store_error.downcast_ref::().is_some()); + } + + #[tokio::test] + async fn storing_a_large_schema_surfaces_the_write_error_during_write_all() { + const TOKIO_FILE_BUFFER_BYTES: usize = 2 * 1024 * 1024; + + let (balancer_desired_state_notify_tx, _balancer_desired_state_notify_rx) = + broadcast::channel(8); + let database = File::new(balancer_desired_state_notify_tx, PathBuf::from("/dev/full")); + + let mut schema = Schema::default(); + schema.balancer_desired_state.model = + AgentDesiredModel::LocalToAgent("x".repeat(TOKIO_FILE_BUFFER_BYTES * 2)); + + let store_result = database.store_schema(&schema).await; + + let store_error = store_result.err().unwrap(); + let io_error = store_error.downcast_ref::().unwrap(); + + assert_eq!(io_error.kind(), std::io::ErrorKind::StorageFull); + } +} diff --git a/paddler/src/balancer/state_database/file/schema.rs b/paddler_balancer/src/state_database/file/schema.rs similarity index 59% rename from paddler/src/balancer/state_database/file/schema.rs rename to paddler_balancer/src/state_database/file/schema.rs index 3f18fe7a..11b64f46 100644 --- a/paddler/src/balancer/state_database/file/schema.rs +++ b/paddler_balancer/src/state_database/file/schema.rs @@ -1,4 +1,4 @@ -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use serde::Deserialize; use serde::Serialize; @@ -13,3 +13,13 @@ pub struct Schema { #[serde(default = "default_version")] pub version: String, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_version_is_one() { + assert_eq!(default_version(), "1"); + } +} diff --git a/paddler/src/balancer/state_database/memory.rs b/paddler_balancer/src/state_database/memory.rs similarity index 64% rename from paddler/src/balancer/state_database/memory.rs rename to paddler_balancer/src/state_database/memory.rs index 014f774d..9b89ae27 100644 --- a/paddler/src/balancer/state_database/memory.rs +++ b/paddler_balancer/src/state_database/memory.rs @@ -1,8 +1,7 @@ -use std::sync::RwLock; - use anyhow::Result; use async_trait::async_trait; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use parking_lot::RwLock; use tokio::sync::broadcast; use super::StateDatabase; @@ -27,22 +26,13 @@ impl Memory { #[async_trait] impl StateDatabase for Memory { - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] async fn read_balancer_desired_state(&self) -> Result { - Ok(self - .balancer_desired_state - .read() - .expect("Failed to acquire read lock") - .clone()) + Ok(self.balancer_desired_state.read().clone()) } - #[expect(clippy::expect_used, reason = "mutex lock poison is unrecoverable")] async fn store_balancer_desired_state(&self, state: &BalancerDesiredState) -> Result<()> { { - let mut balancer_desired_state = self - .balancer_desired_state - .write() - .expect("Failed to acquire write lock"); + let mut balancer_desired_state = self.balancer_desired_state.write(); *balancer_desired_state = state.clone(); } diff --git a/paddler/src/balancer/state_database/mod.rs b/paddler_balancer/src/state_database/mod.rs similarity index 55% rename from paddler/src/balancer/state_database/mod.rs rename to paddler_balancer/src/state_database/mod.rs index 74c977ea..69ea70db 100644 --- a/paddler/src/balancer/state_database/mod.rs +++ b/paddler_balancer/src/state_database/mod.rs @@ -1,12 +1,9 @@ -mod file; -mod memory; +pub mod file; +pub mod memory; use anyhow::Result; use async_trait::async_trait; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -pub use self::file::File; -pub use self::memory::Memory; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; #[async_trait] pub trait StateDatabase: Send + Sync { @@ -17,16 +14,17 @@ pub trait StateDatabase: Send + Sync { #[cfg(test)] mod tests { - use anyhow::Result; - use paddler_types::agent_desired_model::AgentDesiredModel; - use paddler_types::chat_template::ChatTemplate; - use paddler_types::inference_parameters::InferenceParameters; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::chat_template::ChatTemplate; + use paddler_messaging::inference_parameters::InferenceParameters; use tempfile::NamedTempFile; use tokio::sync::broadcast; + use super::file::File; + use super::memory::Memory; use super::*; - async fn subtest_store_desired_state(db: &TDatabase) -> Result<()> { + async fn subtest_store_desired_state(database: &TDatabase) { let desired_state = BalancerDesiredState { chat_template_override: None, inference_parameters: InferenceParameters::default(), @@ -35,40 +33,36 @@ mod tests { use_chat_template_override: false, }; - db.store_balancer_desired_state(&desired_state).await?; + database + .store_balancer_desired_state(&desired_state) + .await + .unwrap(); - let read_state = db.read_balancer_desired_state().await?; + let read_state = database.read_balancer_desired_state().await.unwrap(); assert_eq!(read_state.model, desired_state.model); - - Ok(()) } #[tokio::test] - async fn test_file_database() -> Result<()> { + async fn test_file_database() { let (balancer_desired_state_tx, _balancer_desired_state_rx) = broadcast::channel(100); - let tempfile = NamedTempFile::new()?; - let db = File::new(balancer_desired_state_tx, tempfile.path().to_path_buf()); - - subtest_store_desired_state(&db).await?; + let tempfile = NamedTempFile::new().unwrap(); + let database = File::new(balancer_desired_state_tx, tempfile.path().to_path_buf()); - Ok(()) + subtest_store_desired_state(&database).await; } #[tokio::test] - async fn test_memory_database() -> Result<()> { + async fn test_memory_database() { let (balancer_desired_state_tx, _balancer_desired_state_rx) = broadcast::channel(100); - let db = Memory::new(balancer_desired_state_tx, BalancerDesiredState::default()); + let database = Memory::new(balancer_desired_state_tx, BalancerDesiredState::default()); - subtest_store_desired_state(&db).await?; - - Ok(()) + subtest_store_desired_state(&database).await; } #[tokio::test] - async fn test_file_database_persists_chat_template_override_across_fresh_instance() -> Result<()> - { - let tempfile = NamedTempFile::new()?; + async fn test_file_database_persists_chat_template_override_across_fresh_instance() { + let tempfile = NamedTempFile::new().unwrap(); let path = tempfile.path().to_path_buf(); let chat_template = ChatTemplate { @@ -83,20 +77,21 @@ mod tests { }; { - let (tx, _rx) = broadcast::channel(100); - let db = File::new(tx, path.clone()); + let (balancer_desired_state_tx, _balancer_desired_state_rx) = broadcast::channel(100); + let database = File::new(balancer_desired_state_tx, path.clone()); - db.store_balancer_desired_state(&desired_state).await?; + database + .store_balancer_desired_state(&desired_state) + .await + .unwrap(); } - let (tx, _rx) = broadcast::channel(100); - let db = File::new(tx, path); - let read_back = db.read_balancer_desired_state().await?; + let (balancer_desired_state_tx, _balancer_desired_state_rx) = broadcast::channel(100); + let database = File::new(balancer_desired_state_tx, path); + let read_back = database.read_balancer_desired_state().await.unwrap(); assert_eq!(read_back.chat_template_override, Some(chat_template)); assert!(read_back.use_chat_template_override); assert_eq!(read_back.model, desired_state.model); - - Ok(()) } } diff --git a/paddler/src/balancer/state_database_type.rs b/paddler_balancer/src/state_database_type.rs similarity index 71% rename from paddler/src/balancer/state_database_type.rs rename to paddler_balancer/src/state_database_type.rs index 0f22655d..1286912e 100644 --- a/paddler/src/balancer/state_database_type.rs +++ b/paddler_balancer/src/state_database_type.rs @@ -7,7 +7,7 @@ use anyhow::Error; use anyhow::Result; use anyhow::anyhow; use indoc::formatdoc; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use url::Url; #[derive(Clone)] @@ -53,18 +53,19 @@ impl FromStr for StateDatabaseType { #[cfg(test)] mod tests { + use std::mem::discriminant; use std::str::FromStr; - use anyhow::Result; - use super::*; #[test] - fn test_memory_basic() -> Result<()> { - let result = StateDatabaseType::from_str("memory://")?; - assert!(matches!(result, StateDatabaseType::Memory(_))); + fn test_memory_basic() { + let result = StateDatabaseType::from_str("memory://").unwrap(); - Ok(()) + assert_eq!( + discriminant(&result), + discriminant(&StateDatabaseType::Memory(Box::default())), + ); } #[test] @@ -75,24 +76,28 @@ mod tests { } #[test] - fn test_file_absolute_path() -> Result<()> { + fn test_file_scheme_without_authority_is_error() { + let result = StateDatabaseType::from_str("file:relative"); + + assert!(result.is_err()); + } + + #[test] + fn test_file_absolute_path() { #[cfg(unix)] - let (url, expected_path) = ("file:///absolute/path", "/absolute/path"); + let expected_path = "/absolute/path"; + #[cfg(unix)] + let url = "file:///absolute/path"; #[cfg(windows)] - let (url, expected_path) = ("file://C:/absolute/path", "C:/absolute/path"); - - let result = StateDatabaseType::from_str(url)?; + let expected_path = "C:/absolute/path"; + #[cfg(windows)] + let url = "file://C:/absolute/path"; - match result { - StateDatabaseType::File(path) => { - assert_eq!(path, PathBuf::from(expected_path)); - } - StateDatabaseType::Memory(_) => { - return Err(anyhow!("Expected File variant")); - } - } + let result = StateDatabaseType::from_str(url).unwrap(); - Ok(()) + assert!( + matches!(&result, StateDatabaseType::File(path) if path == &PathBuf::from(expected_path)) + ); } #[test] diff --git a/paddler_balancer/src/static_files.rs b/paddler_balancer/src/static_files.rs new file mode 100644 index 00000000..bbae0ede --- /dev/null +++ b/paddler_balancer/src/static_files.rs @@ -0,0 +1,51 @@ +use rust_embed::Embed; + +#[derive(Embed)] +#[folder = "../static"] +pub struct StaticFiles; + +#[cfg(test)] +mod tests { + use std::hint::black_box; + + use rust_embed::EmbeddedFile; + + use super::StaticFiles; + + fn any_embedded_file_name() -> String { + StaticFiles::iter() + .next() + .map(|file_name| file_name.as_ref().to_owned()) + .unwrap() + } + + #[test] + fn returns_embedded_file_for_existing_path() { + let embedded_file = StaticFiles::get(&any_embedded_file_name()).unwrap(); + + assert!(!embedded_file.data.is_empty()); + } + + #[test] + fn returns_none_for_missing_path() { + assert!(StaticFiles::get("this_file_does_not_exist.txt").is_none()); + } + + #[test] + fn returns_none_for_path_traversal_outside_embedded_folder() { + assert!(StaticFiles::get("../Cargo.toml").is_none()); + } + + #[test] + fn iterates_over_embedded_file_names() { + assert!(StaticFiles::iter().next().is_some()); + } + + #[test] + fn returns_embedded_file_when_called_through_indirect_call() { + let lookup: fn(&str) -> Option = black_box(StaticFiles::get); + let embedded_file = lookup(&any_embedded_file_name()).unwrap(); + + assert!(!embedded_file.data.is_empty()); + } +} diff --git a/paddler/src/balancer/statsd_service/configuration.rs b/paddler_balancer/src/statsd_service/configuration.rs similarity index 100% rename from paddler/src/balancer/statsd_service/configuration.rs rename to paddler_balancer/src/statsd_service/configuration.rs diff --git a/paddler_balancer/src/statsd_service/mod.rs b/paddler_balancer/src/statsd_service/mod.rs new file mode 100644 index 00000000..a683152e --- /dev/null +++ b/paddler_balancer/src/statsd_service/mod.rs @@ -0,0 +1,351 @@ +pub mod configuration; + +use std::net::UdpSocket; +use std::sync::Arc; + +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use cadence::Gauged; +use cadence::MetricError; +use cadence::StatsdClient; +use cadence::UdpMetricSink; +use log::error; +use tokio::time::MissedTickBehavior; +use tokio::time::interval; +use tokio_util::sync::CancellationToken; +use trzcina::Service; + +use crate::agent_controller_pool::AgentControllerPool; +use crate::agent_controller_pool_total_slots::AgentControllerPoolTotalSlots; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::statsd_service::configuration::Configuration as StatsdServiceConfiguration; + +#[expect( + clippy::needless_pass_by_value, + reason = "cadence StatsdClient::with_error_handler requires an owned Fn(MetricError) handler" +)] +fn log_statsd_error(error: MetricError) { + error!("Statsd error: {error}"); +} + +pub struct StatsdService { + pub agent_controller_pool: Arc, + pub buffered_request_manager: Arc, + pub configuration: StatsdServiceConfiguration, +} + +impl StatsdService { + fn report_metrics(&self, client: &StatsdClient) -> Result<()> { + let AgentControllerPoolTotalSlots { + slots_processing, + slots_total, + } = self.agent_controller_pool.total_slots(); + let requests_buffered = self.buffered_request_manager.buffered_request_counter.get(); + + let slots_processing = + u64::try_from(slots_processing).context("slots_processing count is negative")?; + let slots_total = u64::try_from(slots_total).context("slots_total count is negative")?; + let requests_buffered = + u64::try_from(requests_buffered).context("requests_buffered count is negative")?; + + client.gauge("slots_processing", slots_processing)?; + client.gauge("slots_total", slots_total)?; + client.gauge("requests_buffered", requests_buffered)?; + client.flush()?; + + Ok(()) + } +} + +#[async_trait] +impl Service for StatsdService { + fn name(&self) -> &'static str { + "balancer::statsd_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let statsd_sink_socket = UdpSocket::bind("0.0.0.0:0")?; + let statsd_sink = UdpMetricSink::from(self.configuration.statsd_addr, statsd_sink_socket)?; + + let client = StatsdClient::builder(&self.configuration.statsd_prefix.clone(), statsd_sink) + .with_error_handler(log_statsd_error) + .build(); + + let mut ticker = interval(self.configuration.statsd_reporting_interval); + + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + () = shutdown.cancelled() => break Ok(()), + _ = ticker.tick() => { + if let Err(err) = self.report_metrics(&client) { + error!("Failed to report metrics: {err}"); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + use std::net::SocketAddr; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicI32; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use cadence::BufferedSpyMetricSink; + use cadence::ErrorKind; + use cadence::MetricError; + use cadence::SpyMetricSink; + use parking_lot::RwLock; + use tokio::net::UdpSocket as TokioUdpSocket; + use tokio::sync::mpsc; + + use super::*; + use crate::agent_controller::AgentController; + use crate::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use crate::embedding_sender_collection::EmbeddingSenderCollection; + use crate::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use crate::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + + const REPORTING_INTERVAL: Duration = Duration::from_secs(1); + const STATSD_PREFIX: &str = "paddler"; + + fn register_agent_controller_with_slots( + pool: &AgentControllerPool, + agent_id: &str, + slots_processing: i32, + slots_total: i32, + ) { + let (agent_message_tx, _agent_message_rx) = mpsc::unbounded_channel(); + + pool.register_agent_controller( + agent_id.to_owned(), + Arc::new(AgentController { + agent_message_tx, + chat_template_override_sender_collection: Arc::new( + ChatTemplateOverrideSenderCollection::default(), + ), + connection_close: CancellationToken::new(), + desired_slots_total: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), + download_filename: RwLock::new(None), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), + embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), + generate_tokens_sender_collection: Arc::new( + GenerateTokensSenderCollection::default(), + ), + id: agent_id.to_owned(), + issues: RwLock::new(BTreeSet::new()), + model_metadata_sender_collection: Arc::new( + ModelMetadataSenderCollection::default(), + ), + model_path: RwLock::new(None), + name: None, + newest_update_version: AtomicValue::::new(0), + slots_processing: AtomicValue::::new(slots_processing), + slots_total: AtomicValue::::new(slots_total), + state_application_status_code: AtomicValue::::new( + AgentStateApplicationStatus::Fresh as i32, + ), + uses_chat_template_override: AtomicValue::::new(false), + }), + ) + .unwrap(); + } + + fn build_service(statsd_addr: SocketAddr) -> StatsdService { + let agent_controller_pool = Arc::new(AgentControllerPool::default()); + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + agent_controller_pool.clone(), + REPORTING_INTERVAL, + 10, + )); + + StatsdService { + agent_controller_pool, + buffered_request_manager, + configuration: StatsdServiceConfiguration { + statsd_addr, + statsd_prefix: STATSD_PREFIX.to_owned(), + statsd_reporting_interval: REPORTING_INTERVAL, + }, + } + } + + #[test] + fn name_identifies_the_statsd_service() { + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + assert_eq!(service.name(), "balancer::statsd_service"); + } + + #[tokio::test] + async fn report_metrics_emits_a_gauge_datagram_for_each_pool_metric() { + let receiver = TokioUdpSocket::bind("127.0.0.1:0").await.unwrap(); + let receiver_addr = receiver.local_addr().unwrap(); + let service = build_service(receiver_addr); + + let sender_socket = UdpSocket::bind("0.0.0.0:0").unwrap(); + let sink = UdpMetricSink::from(receiver_addr, sender_socket).unwrap(); + let client = StatsdClient::builder(STATSD_PREFIX, sink).build(); + + service.report_metrics(&client).unwrap(); + + let mut received_lines: Vec = Vec::new(); + let mut datagram = [0_u8; 1024]; + + for _ in 0..3 { + let byte_count = receiver.recv(&mut datagram).await.unwrap(); + + received_lines.push(String::from_utf8(datagram[..byte_count].to_vec()).unwrap()); + } + + assert!(received_lines.contains(&"paddler.slots_processing:0|g".to_owned())); + assert!(received_lines.contains(&"paddler.slots_total:0|g".to_owned())); + assert!(received_lines.contains(&"paddler.requests_buffered:0|g".to_owned())); + } + + #[tokio::test] + async fn run_reports_on_first_tick_then_stops_on_cancellation() { + let receiver = TokioUdpSocket::bind("127.0.0.1:0").await.unwrap(); + let receiver_addr = receiver.local_addr().unwrap(); + let service = Box::new(build_service(receiver_addr)); + + let shutdown = CancellationToken::new(); + let run_handle = tokio::spawn(service.run(shutdown.clone())); + + let mut datagram = [0_u8; 1024]; + let byte_count = receiver.recv(&mut datagram).await.unwrap(); + let first_line = String::from_utf8(datagram[..byte_count].to_vec()).unwrap(); + + let expected_first_tick_lines = [ + "paddler.slots_processing:0|g".to_owned(), + "paddler.slots_total:0|g".to_owned(), + "paddler.requests_buffered:0|g".to_owned(), + ]; + + assert!(expected_first_tick_lines.contains(&first_line)); + + shutdown.cancel(); + + assert!(run_handle.await.unwrap().is_ok()); + } + + #[test] + fn report_metrics_propagates_error_from_the_first_gauge_emit() { + let (receiver, sink) = SpyMetricSink::new(); + + drop(receiver); + + let client = StatsdClient::builder(STATSD_PREFIX, sink).build(); + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + let result = service.report_metrics(&client); + + assert!(result.err().unwrap().is::()); + } + + #[test] + fn report_metrics_propagates_error_from_the_second_gauge_emit() { + let (receiver, sink) = SpyMetricSink::with_capacity(1); + let client = StatsdClient::builder(STATSD_PREFIX, sink).build(); + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + let result = service.report_metrics(&client); + + assert!(result.err().unwrap().is::()); + assert_eq!(receiver.len(), 1); + } + + #[test] + fn report_metrics_propagates_error_from_the_third_gauge_emit() { + let (receiver, sink) = SpyMetricSink::with_capacity(2); + let client = StatsdClient::builder(STATSD_PREFIX, sink).build(); + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + let result = service.report_metrics(&client); + + assert!(result.err().unwrap().is::()); + assert_eq!(receiver.len(), 2); + } + + #[test] + fn report_metrics_propagates_error_from_the_flush() { + let (receiver, sink) = BufferedSpyMetricSink::new(); + let client = StatsdClient::builder(STATSD_PREFIX, sink).build(); + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + drop(receiver); + + let result = service.report_metrics(&client); + + assert!(result.err().unwrap().is::()); + } + + #[test] + fn report_metrics_rejects_negative_slots_processing() { + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + register_agent_controller_with_slots(&service.agent_controller_pool, "agent", -1, 0); + + let client = StatsdClient::builder(STATSD_PREFIX, SpyMetricSink::new().1).build(); + let result = service.report_metrics(&client); + + assert_eq!( + result.err().unwrap().to_string(), + "slots_processing count is negative" + ); + } + + #[test] + fn report_metrics_rejects_negative_slots_total() { + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + register_agent_controller_with_slots(&service.agent_controller_pool, "agent", 0, -1); + + let client = StatsdClient::builder(STATSD_PREFIX, SpyMetricSink::new().1).build(); + let result = service.report_metrics(&client); + + assert_eq!( + result.err().unwrap().to_string(), + "slots_total count is negative" + ); + } + + #[test] + fn report_metrics_rejects_negative_requests_buffered() { + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + service + .buffered_request_manager + .buffered_request_counter + .decrement(); + + let client = StatsdClient::builder(STATSD_PREFIX, SpyMetricSink::new().1).build(); + let result = service.report_metrics(&client); + + assert_eq!( + result.err().unwrap().to_string(), + "requests_buffered count is negative" + ); + } + + #[test] + fn log_statsd_error_logs_the_metric_error() { + log::set_max_level(log::LevelFilter::Trace); + + log_statsd_error(MetricError::from(( + ErrorKind::InvalidInput, + "statsd error fixture", + ))); + } +} diff --git a/paddler_balancer/src/unbounded_stream_from_agent.rs b/paddler_balancer/src/unbounded_stream_from_agent.rs new file mode 100644 index 00000000..71935980 --- /dev/null +++ b/paddler_balancer/src/unbounded_stream_from_agent.rs @@ -0,0 +1,126 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use actix_web::rt; +use futures_util::Stream; +use nanoid::nanoid; +use paddler_messaging::inference_client::response::Response as OutgoingResponse; +use paddler_messaging::streamable_result::StreamableResult; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_util::sync::CancellationToken; + +use crate::agent_controller::AgentController; +use crate::buffered_request_manager::BufferedRequestManager; +use crate::cancellation_token_stream_guard::CancellationTokenStreamGuard; +use crate::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use crate::chunk_forwarding_session_controller::transforms_outgoing_message::TransformsOutgoingMessage; +use crate::handles_agent_streaming_response::HandlesAgentStreamingResponse; +use crate::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use crate::manages_senders::ManagesSenders; +use crate::request_from_agent::request_from_agent; +use paddler_messaging::management_socket::agent::request::Request as AgentJsonRpcRequest; + +pub fn unbounded_stream_from_agent( + buffered_request_manager: Arc, + inference_service_configuration: InferenceServiceConfiguration, + params: TParams, + transformer: TTransformsOutgoingMessage, + shutdown: CancellationToken, +) -> impl Stream +where + TParams: Debug + Into + Send + 'static, + AgentController: HandlesAgentStreamingResponse, + <>::SenderCollection as ManagesSenders>::Value: Debug + Into + StreamableResult, + TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync + 'static, +{ + let request_id: String = nanoid!(); + let connection_close = CancellationToken::new(); + let (chunk_tx, chunk_rx) = mpsc::unbounded_channel(); + + rt::spawn({ + let connection_close = connection_close.clone(); + + async move { + let session_controller = ChunkForwardingSessionController::new(chunk_tx, transformer); + + request_from_agent( + buffered_request_manager, + connection_close, + inference_service_configuration, + params, + request_id, + session_controller, + shutdown, + ) + .await; + } + }); + + CancellationTokenStreamGuard::new(UnboundedReceiverStream::new(chunk_rx), connection_close) +} + +#[cfg(test)] +mod tests { + use std::mem::discriminant; + use std::time::Duration; + + use futures_util::StreamExt as _; + + use super::*; + use crate::agent_controller_pool::AgentControllerPool; + use crate::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; + use crate::chunk_forwarding_session_controller::transform_result::TransformResult; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + + fn inference_service_configuration() -> InferenceServiceConfiguration { + const TIMEOUT_LONGER_THAN_ANY_TEST_RUN: Duration = Duration::from_hours(1); + + InferenceServiceConfiguration { + addr: "127.0.0.1:0".parse().unwrap(), + cors_allowed_hosts: Vec::new(), + inference_item_timeout: TIMEOUT_LONGER_THAN_ANY_TEST_RUN, + } + } + + #[actix_web::test] + async fn spawned_task_runs_request_from_agent_and_closes_stream_on_shutdown() { + let pool = Arc::new(AgentControllerPool::default()); + let buffered_request_manager = Arc::new(BufferedRequestManager::new( + pool, + Duration::from_secs(1), + 10, + )); + + let shutdown = CancellationToken::new(); + + shutdown.cancel(); + + let mut stream = Box::pin(unbounded_stream_from_agent( + buffered_request_manager, + inference_service_configuration(), + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 1, + raw_prompt: "fixture prompt".to_owned(), + }, + IdentityTransformer::new(), + shutdown, + )); + + let shutdown_chunk = stream.next().await.unwrap(); + + assert_eq!( + discriminant(&TransformResult::Chunk(String::new())), + discriminant(&shutdown_chunk), + ); + + let chunk_text = match shutdown_chunk { + TransformResult::Chunk(chunk_text) | TransformResult::Error(chunk_text) => chunk_text, + TransformResult::Discard => String::new(), + }; + + assert!(chunk_text.contains("shutting down")); + assert!(stream.next().await.is_none()); + } +} diff --git a/paddler_balancer/src/web_admin_panel_service/app_data.rs b/paddler_balancer/src/web_admin_panel_service/app_data.rs new file mode 100644 index 00000000..2b6f52f9 --- /dev/null +++ b/paddler_balancer/src/web_admin_panel_service/app_data.rs @@ -0,0 +1,5 @@ +use crate::web_admin_panel_service::template_data::TemplateData; + +pub struct AppData { + pub template_data: TemplateData, +} diff --git a/paddler/src/balancer/web_admin_panel_service/configuration.rs b/paddler_balancer/src/web_admin_panel_service/configuration.rs similarity index 100% rename from paddler/src/balancer/web_admin_panel_service/configuration.rs rename to paddler_balancer/src/web_admin_panel_service/configuration.rs diff --git a/paddler_balancer/src/web_admin_panel_service/http_route/favicon.rs b/paddler_balancer/src/web_admin_panel_service/http_route/favicon.rs new file mode 100644 index 00000000..116fe59e --- /dev/null +++ b/paddler_balancer/src/web_admin_panel_service/http_route/favicon.rs @@ -0,0 +1,50 @@ +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::get; +use actix_web::web; + +const FAVICON: &[u8] = include_bytes!("../../../../resources/images/favicon.svg"); + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/favicon.ico")] +async fn respond() -> impl Responder { + HttpResponse::Ok() + .content_type("image/svg+xml") + .body(FAVICON) +} + +#[cfg(test)] +mod tests { + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::http::header; + use actix_web::test; + + use super::FAVICON; + use super::register; + + #[actix_web::test] + async fn serves_embedded_favicon_as_svg() { + let app = test::init_service(App::new().configure(register)).await; + let request = test::TestRequest::get().uri("/favicon.ico").to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response + .headers() + .get(header::CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "image/svg+xml" + ); + + let body = test::read_body(response).await; + + assert_eq!(body.as_ref(), FAVICON); + } +} diff --git a/paddler_balancer/src/web_admin_panel_service/http_route/home.rs b/paddler_balancer/src/web_admin_panel_service/http_route/home.rs new file mode 100644 index 00000000..cd98ab99 --- /dev/null +++ b/paddler_balancer/src/web_admin_panel_service/http_route/home.rs @@ -0,0 +1,163 @@ +use actix_web::Responder; +use actix_web::get; +use actix_web::web; +use askama::Template; +use esbuild_metafile::HttpPreloader; +use esbuild_metafile::filters; + +use crate::response::view::view; +use crate::web_admin_panel_service::app_data::AppData; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[derive(Template)] +#[template(path = "web_admin_panel.html")] +struct WebAdminPanelTemplate { + buffered_request_timeout_millis: u128, + compat_openai_addr: String, + inference_addr: String, + management_addr: String, + max_buffered_requests: i32, + preloads: HttpPreloader, + statsd_addr: String, + statsd_prefix: String, + statsd_reporting_interval_millis: u128, +} + +#[get("/{_:.*}")] +async fn respond(preloads: HttpPreloader, app_data: web::Data) -> impl Responder { + view(WebAdminPanelTemplate { + buffered_request_timeout_millis: app_data + .template_data + .buffered_request_timeout + .as_millis(), + compat_openai_addr: match app_data.template_data.compat_openai_addr.clone() { + Some(addr) => addr.input_addr, + None => String::new(), + }, + inference_addr: app_data.template_data.inference_addr.input_addr.clone(), + management_addr: app_data.template_data.management_addr.input_addr.clone(), + max_buffered_requests: app_data.template_data.max_buffered_requests, + preloads, + statsd_addr: match app_data.template_data.statsd_addr.clone() { + Some(addr) => addr.input_addr, + None => String::new(), + }, + statsd_prefix: app_data.template_data.statsd_prefix.clone(), + statsd_reporting_interval_millis: app_data + .template_data + .statsd_reporting_interval + .as_millis(), + }) +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::time::Duration; + + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::test; + use actix_web::web::Data; + use parking_lot::Once; + + use super::register; + use crate::resolved_socket_addr::ResolvedSocketAddr; + use crate::web_admin_panel_service::app_data::AppData; + use crate::web_admin_panel_service::template_data::TemplateData; + + static INIT_ESBUILD_METAFILE: Once = Once::new(); + + fn ensure_esbuild_metafile_initialized() { + INIT_ESBUILD_METAFILE.call_once(|| { + esbuild_metafile::instance::initialize_instance(include_str!( + "../../../../esbuild-meta.json" + )); + }); + } + + #[actix_web::test] + async fn renders_optional_addresses_when_present() { + ensure_esbuild_metafile_initialized(); + + let app_data = Data::new(AppData { + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(1), + compat_openai_addr: Some(ResolvedSocketAddr { + input_addr: "127.0.0.1:8081".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8081)), + }), + inference_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:8082".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8082)), + }, + management_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:8083".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8083)), + }, + max_buffered_requests: 32, + statsd_addr: Some(ResolvedSocketAddr { + input_addr: "127.0.0.1:8125".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8125)), + }), + statsd_prefix: "paddler".to_owned(), + statsd_reporting_interval: Duration::from_millis(500), + }, + }); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get().uri("/").to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let body = test::read_body(response).await; + let body_text = std::str::from_utf8(body.as_ref()).unwrap(); + + assert!(body_text.contains("data-compat-openai-addr=\"127.0.0.1:8081\"")); + assert!(body_text.contains("data-statsd-addr=\"127.0.0.1:8125\"")); + assert!(body_text.contains("data-inference-addr=\"127.0.0.1:8082\"")); + assert!(body_text.contains("data-management-addr=\"127.0.0.1:8083\"")); + assert!(body_text.contains("data-buffered-request-timeout-millis=\"1000\"")); + assert!(body_text.contains("data-max-buffered-requests=\"32\"")); + assert!(body_text.contains("data-statsd-prefix=\"paddler\"")); + assert!(body_text.contains("data-statsd-reporting-interval-millis=\"500\"")); + } + + #[actix_web::test] + async fn renders_empty_addresses_when_absent() { + ensure_esbuild_metafile_initialized(); + + let app_data = Data::new(AppData { + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(1), + compat_openai_addr: None, + inference_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:8082".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8082)), + }, + management_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:8083".to_owned(), + socket_addr: SocketAddr::from(([127, 0, 0, 1], 8083)), + }, + max_buffered_requests: 32, + statsd_addr: None, + statsd_prefix: "paddler".to_owned(), + statsd_reporting_interval: Duration::from_millis(500), + }, + }); + let app = test::init_service(App::new().app_data(app_data).configure(register)).await; + let request = test::TestRequest::get().uri("/").to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let body = test::read_body(response).await; + let body_text = std::str::from_utf8(body.as_ref()).unwrap(); + + assert!(body_text.contains("data-compat-openai-addr=\"\"")); + assert!(body_text.contains("data-statsd-addr=\"\"")); + } +} diff --git a/paddler/src/balancer/web_admin_panel_service/http_route/mod.rs b/paddler_balancer/src/web_admin_panel_service/http_route/mod.rs similarity index 100% rename from paddler/src/balancer/web_admin_panel_service/http_route/mod.rs rename to paddler_balancer/src/web_admin_panel_service/http_route/mod.rs diff --git a/paddler_balancer/src/web_admin_panel_service/http_route/static_files.rs b/paddler_balancer/src/web_admin_panel_service/http_route/static_files.rs new file mode 100644 index 00000000..554b8faa --- /dev/null +++ b/paddler_balancer/src/web_admin_panel_service/http_route/static_files.rs @@ -0,0 +1,83 @@ +use actix_web::HttpResponse; +use actix_web::Responder; +use actix_web::get; +use actix_web::web; +use mime_guess::from_path; + +use crate::static_files::StaticFiles; + +pub fn register(cfg: &mut web::ServiceConfig) { + cfg.service(respond); +} + +#[get("/static/{path:.*}")] +async fn respond(path: web::Path) -> impl Responder { + let path = path.into_inner(); + + match StaticFiles::get(path.as_str()) { + Some(content) => HttpResponse::Ok() + .content_type(from_path(path).first_or_octet_stream().as_ref()) + .body(content.data.into_owned()), + None => HttpResponse::NotFound().body("File not found"), + } +} + +#[cfg(test)] +mod tests { + use actix_web::App; + use actix_web::http::StatusCode; + use actix_web::http::header::CONTENT_TYPE; + use actix_web::test; + use mime_guess::from_path; + + use super::register; + use crate::static_files::StaticFiles; + + fn any_embedded_file_name() -> String { + StaticFiles::iter() + .next() + .map(|file_name| file_name.as_ref().to_owned()) + .unwrap() + } + + #[actix_web::test] + async fn serves_embedded_file_with_guessed_content_type() { + let existing_file_path = any_embedded_file_name(); + + let app = test::init_service(App::new().configure(register)).await; + let request = test::TestRequest::get() + .uri(&format!("/static/{existing_file_path}")) + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::OK); + + let content_type = response.headers().get(CONTENT_TYPE).unwrap(); + let expected_content_type = from_path(&existing_file_path).first_or_octet_stream(); + + assert_eq!(content_type, expected_content_type.as_ref()); + + let expected_body = StaticFiles::get(&existing_file_path) + .unwrap() + .data + .into_owned(); + let body = test::read_body(response).await; + + assert_eq!(body.as_ref(), expected_body.as_slice()); + } + + #[actix_web::test] + async fn responds_with_not_found_for_missing_file() { + let app = test::init_service(App::new().configure(register)).await; + let request = test::TestRequest::get() + .uri("/static/this_file_does_not_exist.txt") + .to_request(); + let response = test::call_service(&app, request).await; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let body = test::read_body(response).await; + + assert_eq!(body.as_ref(), b"File not found"); + } +} diff --git a/paddler_balancer/src/web_admin_panel_service/mod.rs b/paddler_balancer/src/web_admin_panel_service/mod.rs new file mode 100644 index 00000000..25a9592d --- /dev/null +++ b/paddler_balancer/src/web_admin_panel_service/mod.rs @@ -0,0 +1,119 @@ +pub mod app_data; +pub mod configuration; +pub mod http_route; +pub mod template_data; + +use actix_web::App; +use actix_web::HttpServer; +use actix_web::web::Data; +use anyhow::Context as _; +use anyhow::Result; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; +use trzcina::Service; +use trzcina::ServiceShutdownOptions; + +use crate::web_admin_panel_service::app_data::AppData; +use crate::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; + +pub struct WebAdminPanelService { + pub configuration: WebAdminPanelServiceConfiguration, + pub shutdown_options: ServiceShutdownOptions, +} + +#[async_trait] +impl Service for WebAdminPanelService { + fn name(&self) -> &'static str { + "balancer::web_admin_panel_service" + } + + async fn run(self: Box, shutdown: CancellationToken) -> Result<()> { + let app_data: Data = Data::new(AppData { + template_data: self.configuration.template_data.clone(), + }); + + let bind_addr = self.configuration.addr; + + let server = HttpServer::new(move || { + App::new() + .app_data(app_data.clone()) + .configure(http_route::favicon::register) + .configure(http_route::static_files::register) + .configure(http_route::home::register) + }) + .shutdown_signal(async move { + shutdown.cancelled().await; + }) + .shutdown_timeout(self.shutdown_options.cooperative_deadline.as_secs()) + .disable_signals() + .bind(bind_addr) + .with_context(|| { + format!("Unable to bind balancer web admin panel service to {bind_addr}") + })?; + + server.run().await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use std::net::TcpListener; + use std::time::Duration; + + use tokio_util::sync::CancellationToken; + use trzcina::Service as _; + use trzcina::ServiceShutdownOptions; + + use super::WebAdminPanelService; + use crate::resolved_socket_addr::ResolvedSocketAddr; + use crate::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; + use crate::web_admin_panel_service::template_data::TemplateData; + + fn build_service(addr: SocketAddr) -> WebAdminPanelService { + let loopback_addr = ResolvedSocketAddr { + input_addr: "127.0.0.1:0".to_owned(), + socket_addr: addr, + }; + + WebAdminPanelService { + configuration: WebAdminPanelServiceConfiguration { + addr, + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(30), + compat_openai_addr: None, + inference_addr: loopback_addr.clone(), + management_addr: loopback_addr, + max_buffered_requests: 32, + statsd_addr: None, + statsd_prefix: "paddler".to_owned(), + statsd_reporting_interval: Duration::from_secs(10), + }, + }, + shutdown_options: ServiceShutdownOptions::default(), + } + } + + #[test] + fn name_identifies_the_web_admin_panel_service() { + let service = build_service(SocketAddr::from(([127, 0, 0, 1], 0))); + + assert_eq!(service.name(), "balancer::web_admin_panel_service"); + } + + #[actix_web::test] + async fn run_returns_error_when_address_is_already_in_use() { + let occupied_listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).unwrap(); + let occupied_addr = occupied_listener.local_addr().unwrap(); + + let service = Box::new(build_service(occupied_addr)); + let result = service.run(CancellationToken::new()).await; + + let error_message = result.unwrap_err().to_string(); + let expected_addr_fragment = occupied_addr.to_string(); + + assert!(error_message.contains(&expected_addr_fragment)); + } +} diff --git a/paddler/src/balancer/web_admin_panel_service/template_data.rs b/paddler_balancer/src/web_admin_panel_service/template_data.rs similarity index 100% rename from paddler/src/balancer/web_admin_panel_service/template_data.rs rename to paddler_balancer/src/web_admin_panel_service/template_data.rs diff --git a/paddler/src/websocket_session_controller.rs b/paddler_balancer/src/websocket_session_controller.rs similarity index 95% rename from paddler/src/websocket_session_controller.rs rename to paddler_balancer/src/websocket_session_controller.rs index 98dcf7cd..ecab4818 100644 --- a/paddler/src/websocket_session_controller.rs +++ b/paddler_balancer/src/websocket_session_controller.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use actix_ws::Session; use anyhow::Result; use async_trait::async_trait; -use paddler_types::rpc_message::RpcMessage; +use paddler_messaging::rpc_message::RpcMessage; use crate::controls_session::ControlsSession; diff --git a/paddler/templates/web_admin_panel.html b/paddler_balancer/templates/web_admin_panel.html similarity index 100% rename from paddler/templates/web_admin_panel.html rename to paddler_balancer/templates/web_admin_panel.html diff --git a/paddler_bootstrap/Cargo.toml b/paddler_bootstrap/Cargo.toml index 32d86412..6024fb3c 100644 --- a/paddler_bootstrap/Cargo.toml +++ b/paddler_bootstrap/Cargo.toml @@ -12,18 +12,26 @@ anyhow = { workspace = true } async-trait = { workspace = true } log = { workspace = true } nanoid = { workspace = true } -paddler = { workspace = true } -paddler_types = { workspace = true } +paddler_agent = { workspace = true } +paddler_balancer = { workspace = true } +paddler_messaging = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } trzcina = { workspace = true } [dev-dependencies] +reqwest = { workspace = true } tempfile = { workspace = true } +[target.'cfg(unix)'.dev-dependencies] +nix = { workspace = true } + [lints] workspace = true [features] default = [] -web_admin_panel = ["paddler/web_admin_panel"] +cuda = ["paddler_agent/cuda"] +metal = ["paddler_agent/metal"] +vulkan = ["paddler_agent/vulkan"] +web_admin_panel = ["paddler_balancer/web_admin_panel"] diff --git a/paddler_bootstrap/src/agent_runner.rs b/paddler_bootstrap/src/agent_runner.rs index 758e4cde..aa5aa2de 100644 --- a/paddler_bootstrap/src/agent_runner.rs +++ b/paddler_bootstrap/src/agent_runner.rs @@ -2,12 +2,12 @@ use std::future::Future; use std::sync::Arc; use anyhow::Result; -use paddler::slot_aggregated_status::SlotAggregatedStatus; +use paddler_agent::slot_aggregated_status::SlotAggregatedStatus; use tokio_util::sync::CancellationToken; -use trzcina::ServiceManager; use trzcina::ServiceShutdownOptions; use crate::agent_service_bundle::AgentServiceBundle; +use crate::run_service_manager::run_service_manager; use crate::service_thread::ServiceThread; pub struct AgentRunnerParams { @@ -35,15 +35,8 @@ impl AgentRunner { let bundle = AgentServiceBundle::new(agent_name, &management_address, slots); let slot_aggregated_status = bundle.slot_aggregated_status.clone(); - let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| async move { - let mut service_manager = ServiceManager::default(); - service_manager.register_bundle(bundle).await?; - service_manager - .start(task_shutdown) - .run_to_completion(ServiceShutdownOptions::default()) - .await - .into_result() - .map_err(anyhow::Error::from) + let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| { + run_service_manager(bundle, task_shutdown, ServiceShutdownOptions::default()) }); Self { diff --git a/paddler_bootstrap/src/agent_service_bundle.rs b/paddler_bootstrap/src/agent_service_bundle.rs index ab1573bc..3e7019d9 100644 --- a/paddler_bootstrap/src/agent_service_bundle.rs +++ b/paddler_bootstrap/src/agent_service_bundle.rs @@ -3,17 +3,17 @@ use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use nanoid::nanoid; -use paddler::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; -use paddler::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; -use paddler::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; -use paddler::agent::llamacpp_arbiter_service::LlamaCppArbiterService; -use paddler::agent::management_socket_client_service::ManagementSocketClientService; -use paddler::agent::model_metadata_holder::ModelMetadataHolder; -use paddler::agent::reconciliation_service::ReconciliationService; -use paddler::agent_applicable_state_holder::AgentApplicableStateHolder; -use paddler::slot_aggregated_status::SlotAggregatedStatus; -use paddler::slot_aggregated_status_manager::SlotAggregatedStatusManager; -use paddler_types::agent_desired_state::AgentDesiredState; +use paddler_agent::agent_applicable_state_holder::AgentApplicableStateHolder; +use paddler_agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest; +use paddler_agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest; +use paddler_agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest; +use paddler_agent::llamacpp_arbiter_service::LlamaCppArbiterService; +use paddler_agent::management_socket_client_service::ManagementSocketClientService; +use paddler_agent::model_metadata_holder::ModelMetadataHolder; +use paddler_agent::reconciliation_service::ReconciliationService; +use paddler_agent::slot_aggregated_status::SlotAggregatedStatus; +use paddler_agent::slot_aggregated_status_manager::SlotAggregatedStatusManager; +use paddler_messaging::agent_desired_state::AgentDesiredState; use tokio::sync::mpsc; use trzcina::Service; use trzcina::ServiceBundle; @@ -42,7 +42,9 @@ impl AgentServiceBundle { let agent_applicable_state_holder = Arc::new(AgentApplicableStateHolder::default()); let model_metadata_holder = Arc::new(ModelMetadataHolder::default()); let slot_aggregated_status_manager = Arc::new(SlotAggregatedStatusManager::new(slots)); - let slot_aggregated_status = slot_aggregated_status_manager.slot_aggregated_status.clone(); + let slot_aggregated_status = slot_aggregated_status_manager + .slot_aggregated_status + .clone(); let llamacpp_arbiter_service = LlamaCppArbiterService { agent_applicable_state: None, diff --git a/paddler_bootstrap/src/balancer_runner.rs b/paddler_bootstrap/src/balancer_runner.rs index 76bdbb34..b1a857e6 100644 --- a/paddler_bootstrap/src/balancer_runner.rs +++ b/paddler_bootstrap/src/balancer_runner.rs @@ -3,23 +3,23 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use paddler::balancer::agent_controller_pool::AgentControllerPool; -use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::state_database_type::StateDatabaseType; -use paddler::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; +use paddler_balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use paddler_balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler_balancer::state_database_type::StateDatabaseType; +use paddler_balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; -use trzcina::ServiceManager; use trzcina::ServiceShutdownOptions; use crate::balancer_service_bundle::BalancerBootstrapConfig; use crate::balancer_service_bundle::BalancerServiceBundle; +use crate::run_service_manager::run_service_manager; use crate::service_thread::ServiceThread; pub struct BalancerRunnerParams { @@ -29,6 +29,7 @@ pub struct BalancerRunnerParams { pub max_buffered_requests: i32, pub openai_service_configuration: Option, pub cancellation_token: CancellationToken, + pub shutdown_options: ServiceShutdownOptions, pub state_database_type: StateDatabaseType, pub statsd_prefix: String, pub statsd_service_configuration: Option, @@ -53,6 +54,7 @@ impl BalancerRunner { max_buffered_requests, openai_service_configuration, cancellation_token, + shutdown_options, state_database_type, statsd_prefix, statsd_service_configuration, @@ -60,8 +62,6 @@ impl BalancerRunner { web_admin_panel_service_configuration, }: BalancerRunnerParams, ) -> Result { - let shutdown_options = ServiceShutdownOptions::default(); - let bundle = BalancerServiceBundle::new(BalancerBootstrapConfig { buffered_request_timeout, inference_service_configuration, @@ -82,15 +82,8 @@ impl BalancerRunner { let balancer_desired_state_tx = bundle.balancer_desired_state_tx.clone(); let initial_desired_state = bundle.initial_desired_state.clone(); - let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| async move { - let mut service_manager = ServiceManager::default(); - service_manager.register_bundle(bundle).await?; - service_manager - .start(task_shutdown) - .run_to_completion(shutdown_options) - .await - .into_result() - .map_err(anyhow::Error::from) + let thread = ServiceThread::spawn(cancellation_token, move |task_shutdown| { + run_service_manager(bundle, task_shutdown, shutdown_options) }); Ok(Self { diff --git a/paddler_bootstrap/src/balancer_service_bundle.rs b/paddler_bootstrap/src/balancer_service_bundle.rs index 4b852e10..a7399a15 100644 --- a/paddler_bootstrap/src/balancer_service_bundle.rs +++ b/paddler_bootstrap/src/balancer_service_bundle.rs @@ -3,31 +3,31 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; -use paddler::balancer::agent_controller_pool::AgentControllerPool; -use paddler::balancer::buffered_request_manager::BufferedRequestManager; -use paddler::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use paddler::balancer::compatibility::openai_service::OpenAIService; -use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use paddler::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use paddler::balancer::inference_service::InferenceService; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::ManagementService; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use paddler::balancer::reconciliation_service::ReconciliationService; -use paddler::balancer::state_database::File; -use paddler::balancer::state_database::Memory; -use paddler::balancer::state_database::StateDatabase; -use paddler::balancer::state_database_type::StateDatabaseType; -use paddler::balancer::statsd_service::StatsdService; -use paddler::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; +use paddler_balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use paddler_balancer::buffered_request_manager::BufferedRequestManager; +use paddler_balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use paddler_balancer::compatibility::openai_service::OpenAIService; +use paddler_balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use paddler_balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler_balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use paddler_balancer::inference_service::InferenceService; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::ManagementService; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler_balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; +use paddler_balancer::reconciliation_service::ReconciliationService; +use paddler_balancer::state_database::StateDatabase; +use paddler_balancer::state_database::file::File; +use paddler_balancer::state_database::memory::Memory; +use paddler_balancer::state_database_type::StateDatabaseType; +use paddler_balancer::statsd_service::StatsdService; +use paddler_balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::WebAdminPanelService; +use paddler_balancer::web_admin_panel_service::WebAdminPanelService; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; -use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use tokio::sync::broadcast; use trzcina::Service; use trzcina::ServiceBundle; @@ -138,14 +138,13 @@ impl BalancerServiceBundle { is_converted_to_applicable_state: false, }; - let openai_service = openai_service_configuration.map(|openai_service_configuration| { - OpenAIService { + let openai_service = + openai_service_configuration.map(|openai_service_configuration| OpenAIService { buffered_request_manager: buffered_request_manager.clone(), inference_service_configuration, openai_service_configuration, shutdown_options: shutdown_options.clone(), - } - }); + }); let statsd_service = statsd_service_configuration.map(|configuration| StatsdService { agent_controller_pool: agent_controller_pool.clone(), @@ -202,3 +201,78 @@ impl ServiceBundle for BalancerServiceBundle { Ok(services) } } + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + + #[cfg(feature = "web_admin_panel")] + use paddler_balancer::resolved_socket_addr::ResolvedSocketAddr; + #[cfg(feature = "web_admin_panel")] + use paddler_balancer::web_admin_panel_service::template_data::TemplateData; + + use super::*; + + #[cfg(feature = "web_admin_panel")] + const EXPECTED_SERVICE_COUNT: usize = 6; + #[cfg(not(feature = "web_admin_panel"))] + const EXPECTED_SERVICE_COUNT: usize = 5; + + fn loopback_addr() -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], 0)) + } + + #[tokio::test] + async fn services_includes_every_optional_service_when_configured() { + let bundle = BalancerServiceBundle::new(BalancerBootstrapConfig { + buffered_request_timeout: Duration::from_secs(10), + inference_service_configuration: InferenceServiceConfiguration { + addr: loopback_addr(), + cors_allowed_hosts: vec![], + inference_item_timeout: Duration::from_secs(30), + }, + management_service_configuration: ManagementServiceConfiguration { + addr: loopback_addr(), + cors_allowed_hosts: vec![], + }, + max_buffered_requests: 30, + openai_service_configuration: Some(OpenAIServiceConfiguration { + addr: loopback_addr(), + }), + shutdown_options: ServiceShutdownOptions::default(), + state_database_type: StateDatabaseType::Memory(Box::default()), + statsd_prefix: "paddler_bootstrap_test_".to_owned(), + statsd_service_configuration: Some(StatsdServiceConfiguration { + statsd_addr: loopback_addr(), + statsd_prefix: "paddler_bootstrap_test_".to_owned(), + statsd_reporting_interval: Duration::from_secs(10), + }), + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration: Some(WebAdminPanelServiceConfiguration { + addr: loopback_addr(), + template_data: TemplateData { + buffered_request_timeout: Duration::from_secs(10), + compat_openai_addr: None, + inference_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:0".to_owned(), + socket_addr: loopback_addr(), + }, + management_addr: ResolvedSocketAddr { + input_addr: "127.0.0.1:0".to_owned(), + socket_addr: loopback_addr(), + }, + max_buffered_requests: 30, + statsd_addr: None, + statsd_prefix: "paddler_bootstrap_test_".to_owned(), + statsd_reporting_interval: Duration::from_secs(10), + }, + }), + }) + .await + .unwrap(); + + let services = bundle.services().await.unwrap(); + + assert_eq!(services.len(), EXPECTED_SERVICE_COUNT); + } +} diff --git a/paddler_bootstrap/src/lib.rs b/paddler_bootstrap/src/lib.rs index f98af427..2bd68012 100644 --- a/paddler_bootstrap/src/lib.rs +++ b/paddler_bootstrap/src/lib.rs @@ -2,5 +2,6 @@ pub mod agent_runner; pub mod agent_service_bundle; pub mod balancer_runner; pub mod balancer_service_bundle; +pub mod run_service_manager; pub mod service_thread; pub mod shutdown_signal; diff --git a/paddler_bootstrap/src/run_service_manager.rs b/paddler_bootstrap/src/run_service_manager.rs new file mode 100644 index 00000000..711fda90 --- /dev/null +++ b/paddler_bootstrap/src/run_service_manager.rs @@ -0,0 +1,89 @@ +use anyhow::Result; +use tokio_util::sync::CancellationToken; +use trzcina::ServiceBundle; +use trzcina::ServiceManager; +use trzcina::ServiceShutdownOptions; + +pub async fn run_service_manager( + bundle: TServiceBundle, + task_shutdown: CancellationToken, + shutdown_options: ServiceShutdownOptions, +) -> Result<()> { + let mut service_manager = ServiceManager::default(); + + service_manager.register_bundle(bundle).await?; + service_manager + .start(task_shutdown) + .run_to_completion(shutdown_options) + .await + .into_result() + .map_err(anyhow::Error::from) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use anyhow::anyhow; + use async_trait::async_trait; + use tokio_util::sync::CancellationToken; + use trzcina::Service; + use trzcina::ServiceBundle; + use trzcina::ServiceShutdownOptions; + + use super::run_service_manager; + + struct FailingServiceBundle; + + #[async_trait] + impl ServiceBundle for FailingServiceBundle { + async fn services(self) -> Result>> { + Err(anyhow!("service bundle failed to produce services")) + } + } + + struct FailingService; + + #[async_trait] + impl Service for FailingService { + fn name(&self) -> &'static str { + "failing_service" + } + + async fn run(self: Box, _shutdown: CancellationToken) -> Result<()> { + Err(anyhow!("service run failed")) + } + } + + struct BundleWithFailingService; + + #[async_trait] + impl ServiceBundle for BundleWithFailingService { + async fn services(self) -> Result>> { + Ok(vec![Box::new(FailingService)]) + } + } + + #[tokio::test] + async fn propagates_bundle_registration_error() { + let result = run_service_manager( + FailingServiceBundle, + CancellationToken::new(), + ServiceShutdownOptions::default(), + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn propagates_service_run_error() { + let result = run_service_manager( + BundleWithFailingService, + CancellationToken::new(), + ServiceShutdownOptions::default(), + ) + .await; + + assert!(result.is_err()); + } +} diff --git a/paddler_bootstrap/src/service_thread.rs b/paddler_bootstrap/src/service_thread.rs index f3256a65..9abbe9dd 100644 --- a/paddler_bootstrap/src/service_thread.rs +++ b/paddler_bootstrap/src/service_thread.rs @@ -73,3 +73,97 @@ impl Drop for ServiceThread { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn service_thread_finishes_cleanly_when_completion_receiver_dropped_after_success() { + let cancellation_token = CancellationToken::new(); + let mut service_thread = ServiceThread::spawn( + cancellation_token.clone(), + |task_cancellation_token| async move { + task_cancellation_token.cancelled().await; + + Ok(()) + }, + ); + + drop(service_thread.wait_for_completion()); + + cancellation_token.cancel(); + + drop(service_thread); + } + + #[tokio::test] + async fn service_thread_finishes_cleanly_when_completion_receiver_dropped_after_failure() { + let cancellation_token = CancellationToken::new(); + let mut service_thread = ServiceThread::spawn( + cancellation_token.clone(), + |task_cancellation_token| async move { + task_cancellation_token.cancelled().await; + + Err(anyhow!("service run failed")) + }, + ); + + drop(service_thread.wait_for_completion()); + + cancellation_token.cancel(); + + drop(service_thread); + } + + #[tokio::test] + async fn wait_for_completion_errors_when_service_thread_panics() { + let mut service_thread = + ServiceThread::spawn(CancellationToken::new(), |_task_cancellation_token| async { + panic!("service thread crashed") + }); + + let completion_result = service_thread.wait_for_completion().await; + + assert!(completion_result.is_err()); + } + + #[tokio::test] + async fn wait_for_completion_errors_when_called_twice() { + let cancellation_token = CancellationToken::new(); + let mut service_thread = ServiceThread::spawn( + cancellation_token.clone(), + |task_cancellation_token| async move { + task_cancellation_token.cancelled().await; + + Ok(()) + }, + ); + + cancellation_token.cancel(); + + let first_completion = service_thread.wait_for_completion().await; + let second_completion = service_thread.wait_for_completion().await; + + assert!(first_completion.is_ok()); + assert!(second_completion.is_err()); + } + + #[tokio::test] + async fn cancel_stops_the_running_service_thread() { + let mut service_thread = ServiceThread::spawn( + CancellationToken::new(), + |task_cancellation_token| async move { + task_cancellation_token.cancelled().await; + + Ok(()) + }, + ); + + service_thread.cancel(); + + let completion_result = service_thread.wait_for_completion().await; + + assert!(completion_result.is_ok()); + } +} diff --git a/paddler_bootstrap/src/shutdown_signal/unix.rs b/paddler_bootstrap/src/shutdown_signal/unix.rs index 411d5531..05a4a86b 100644 --- a/paddler_bootstrap/src/shutdown_signal/unix.rs +++ b/paddler_bootstrap/src/shutdown_signal/unix.rs @@ -23,14 +23,77 @@ impl ShutdownSignals { } } -pub fn register_shutdown_signals() -> Result { - let sigterm = signal(SignalKind::terminate()).context("failed to listen for SIGTERM")?; - let sigint = signal(SignalKind::interrupt()).context("failed to listen for SIGINT")?; - let sighup = signal(SignalKind::hangup()).context("failed to listen for SIGHUP")?; +fn listen_for_signal(kind: SignalKind, description: &str) -> Result { + signal(kind).with_context(|| format!("failed to listen for {description}")) +} +fn register_signals( + terminate_kind: SignalKind, + interrupt_kind: SignalKind, + hangup_kind: SignalKind, +) -> Result { Ok(ShutdownSignals { - sigterm, - sigint, - sighup, + sigterm: listen_for_signal(terminate_kind, "SIGTERM")?, + sigint: listen_for_signal(interrupt_kind, "SIGINT")?, + sighup: listen_for_signal(hangup_kind, "SIGHUP")?, }) } + +pub fn register_shutdown_signals() -> Result { + register_signals( + SignalKind::terminate(), + SignalKind::interrupt(), + SignalKind::hangup(), + ) +} + +#[cfg(test)] +mod tests { + use nix::sys::signal::Signal as UnixSignal; + use nix::sys::signal::raise; + use tokio::signal::unix::SignalKind; + + use super::register_shutdown_signals; + use super::register_signals; + + #[tokio::test] + async fn wait_returns_on_each_shutdown_signal() { + for shutdown_signal in [UnixSignal::SIGTERM, UnixSignal::SIGINT, UnixSignal::SIGHUP] { + let shutdown_signals = register_shutdown_signals().unwrap(); + + raise(shutdown_signal).unwrap(); + + shutdown_signals.wait().await.unwrap(); + } + } + + #[tokio::test] + async fn register_signals_errors_for_unregisterable_signal() { + let unregisterable = SignalKind::from_raw(UnixSignal::SIGKILL as i32); + + assert!( + register_signals( + unregisterable, + SignalKind::interrupt(), + SignalKind::hangup() + ) + .is_err() + ); + assert!( + register_signals( + SignalKind::terminate(), + unregisterable, + SignalKind::hangup() + ) + .is_err() + ); + assert!( + register_signals( + SignalKind::terminate(), + SignalKind::interrupt(), + unregisterable + ) + .is_err() + ); + } +} diff --git a/paddler_bootstrap/tests/runners.rs b/paddler_bootstrap/tests/runners.rs index 3567cb91..bb5de94e 100644 --- a/paddler_bootstrap/tests/runners.rs +++ b/paddler_bootstrap/tests/runners.rs @@ -1,26 +1,29 @@ +use std::fs; use std::net::SocketAddr; use std::net::TcpListener; use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::state_database::File as StateDatabaseFile; -use paddler::balancer::state_database::StateDatabase; -use paddler::balancer::state_database_type::StateDatabaseType; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler_balancer::state_database::StateDatabase; +use paddler_balancer::state_database::file::File as StateDatabaseFile; +use paddler_balancer::state_database_type::StateDatabaseType; use paddler_bootstrap::agent_runner::AgentRunner; use paddler_bootstrap::agent_runner::AgentRunnerParams; use paddler_bootstrap::balancer_runner::BalancerRunner; use paddler_bootstrap::balancer_runner::BalancerRunnerParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; use tempfile::NamedTempFile; use tokio::net::TcpStream; use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; +use trzcina::ServiceShutdownOptions; fn pick_free_loopback_addr() -> Result { let probe = @@ -62,6 +65,7 @@ fn make_balancer_runner_params( max_buffered_requests: 30, openai_service_configuration: None, cancellation_token, + shutdown_options: ServiceShutdownOptions::default(), state_database_type: StateDatabaseType::Memory(Box::default()), statsd_prefix: "paddler_bootstrap_test_".to_owned(), statsd_service_configuration: None, @@ -234,3 +238,100 @@ async fn agent_runner_cancels_from_parent_token() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn in_flight_request_is_released_when_balancer_shuts_down() -> Result<()> { + let management_addr = pick_free_loopback_addr()?; + let inference_addr = pick_free_loopback_addr()?; + + let mut params = + make_balancer_runner_params(management_addr, inference_addr, CancellationToken::new()); + + // The request never finds an agent, so it stays buffered (in flight) until this long + // timeout — which is far longer than the short shutdown deadline below. + params.buffered_request_timeout = Duration::from_mins(1); + + // A short, non-zero deadline: if the in-flight request fails to observe shutdown, actix + // runs its graceful drain to this deadline and trzcina aborts the service — the bug. + params.shutdown_options = ServiceShutdownOptions { + cooperative_deadline: Duration::from_secs(2), + abort_deadline: Duration::from_secs(2), + }; + + let mut runner = BalancerRunner::start(params).await?; + + wait_until_bound(inference_addr).await?; + + // No agents are registered, so this request buffers and holds its streaming response open. + let held_response = reqwest::Client::new() + .post(format!( + "http://{inference_addr}/api/v1/continue_from_raw_prompt" + )) + .json(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 10, + raw_prompt: "hold the connection open during shutdown".to_owned(), + }) + .send() + .await + .context("inference request headers should be received")?; + + runner.cancel(); + + let body = held_response + .text() + .await + .context("held in-flight response body should be readable")?; + + let shutdown_result = runner.wait_for_completion().await; + + assert!( + body.contains("shutting down"), + "in-flight request must be released with a shutdown error, got body: {body:?}" + ); + assert!( + shutdown_result.is_ok(), + "balancer shutdown must complete cleanly below the deadline, got: {shutdown_result:?}" + ); + + Ok(()) +} + +#[tokio::test] +async fn agent_runner_completes_after_explicit_cancel() -> Result<()> { + let management_addr = pick_free_loopback_addr()?; + + let mut runner = AgentRunner::start(make_agent_runner_params( + management_addr, + CancellationToken::new(), + )); + + runner.cancel(); + runner.wait_for_completion().await?; + + Ok(()) +} + +#[tokio::test] +async fn balancer_runner_fails_to_start_when_state_database_file_is_corrupt() -> Result<()> { + let corrupt_state_database = NamedTempFile::new()?; + fs::write( + corrupt_state_database.path(), + b"this is not a valid state database", + )?; + + let management_addr = pick_free_loopback_addr()?; + let inference_addr = pick_free_loopback_addr()?; + + let mut params = + make_balancer_runner_params(management_addr, inference_addr, CancellationToken::new()); + + params.state_database_type = + StateDatabaseType::File(corrupt_state_database.path().to_path_buf()); + + let start_result = BalancerRunner::start(params).await; + + assert!(start_result.is_err()); + + Ok(()) +} diff --git a/paddler_cache_dir/src/cached_downloaded_model.rs b/paddler_cache_dir/src/cached_downloaded_model.rs index 8c6bdcf2..2fbd586f 100644 --- a/paddler_cache_dir/src/cached_downloaded_model.rs +++ b/paddler_cache_dir/src/cached_downloaded_model.rs @@ -227,11 +227,7 @@ mod tests { let result = cached.try_acquire_download_lock(); - assert!( - result - .unwrap_err() - .is_another_process_downloading() - ); + assert!(result.unwrap_err().is_another_process_downloading()); } #[test] diff --git a/paddler_cache_dir/src/download_lock_acquisition_error.rs b/paddler_cache_dir/src/download_lock_acquisition_error.rs index edada94c..e89647d6 100644 --- a/paddler_cache_dir/src/download_lock_acquisition_error.rs +++ b/paddler_cache_dir/src/download_lock_acquisition_error.rs @@ -29,8 +29,7 @@ mod tests { #[test] fn is_another_process_downloading_returns_true_only_for_that_variant() { let another_process = DownloadLockAcquisitionError::AnotherProcessIsDownloading; - let io_error = - DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); + let io_error = DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); assert!(another_process.is_another_process_downloading()); assert!(!io_error.is_another_process_downloading()); @@ -38,8 +37,7 @@ mod tests { #[test] fn is_io_returns_true_only_for_io_variant() { - let io_error = - DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); + let io_error = DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); let another_process = DownloadLockAcquisitionError::AnotherProcessIsDownloading; assert!(io_error.is_io()); diff --git a/paddler_cache_dir/src/lib.rs b/paddler_cache_dir/src/lib.rs index d891ebc8..45a09025 100644 --- a/paddler_cache_dir/src/lib.rs +++ b/paddler_cache_dir/src/lib.rs @@ -1,9 +1,4 @@ -mod cache_dir; -mod cached_downloaded_model; -mod cached_downloaded_model_lock; -mod download_lock_acquisition_error; - -pub use crate::cache_dir::CacheDir; -pub use crate::cached_downloaded_model::CachedDownloadedModel; -pub use crate::cached_downloaded_model_lock::CachedDownloadedModelLock; -pub use crate::download_lock_acquisition_error::DownloadLockAcquisitionError; +pub mod cache_dir; +pub mod cached_downloaded_model; +pub mod cached_downloaded_model_lock; +pub mod download_lock_acquisition_error; diff --git a/paddler_cli/Cargo.toml b/paddler_cli/Cargo.toml index 8bec93a2..046f538a 100644 --- a/paddler_cli/Cargo.toml +++ b/paddler_cli/Cargo.toml @@ -19,9 +19,8 @@ clap = { workspace = true } env_logger = { workspace = true } log = { workspace = true } nanoid = { workspace = true } -paddler = { workspace = true } +paddler_balancer = { workspace = true } paddler_bootstrap = { workspace = true } -paddler_types = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } trzcina = { workspace = true } @@ -34,11 +33,11 @@ workspace = true [features] default = [] -cuda = ["paddler/cuda"] -metal = ["paddler/metal"] -vulkan = ["paddler/vulkan"] +cuda = ["paddler_bootstrap/cuda"] +metal = ["paddler_bootstrap/metal"] +vulkan = ["paddler_bootstrap/vulkan"] web_admin_panel = [ "dep:esbuild-metafile", - "paddler/web_admin_panel", + "paddler_balancer/web_admin_panel", "paddler_bootstrap/web_admin_panel", ] diff --git a/paddler_cli/src/cmd/agent.rs b/paddler_cli/src/cmd/agent.rs index 198268ab..8ae21863 100644 --- a/paddler_cli/src/cmd/agent.rs +++ b/paddler_cli/src/cmd/agent.rs @@ -1,14 +1,14 @@ use anyhow::Result; use async_trait::async_trait; use clap::Parser; -use paddler::resolved_socket_addr::ResolvedSocketAddr; +use paddler_balancer::resolved_socket_addr::ResolvedSocketAddr; use paddler_bootstrap::agent_service_bundle::AgentServiceBundle; use tokio_util::sync::CancellationToken; use trzcina::ServiceManager; use trzcina::ServiceShutdownOptions; use super::handler::Handler; -use super::value_parser::parse_socket_addr; +use super::value_parser::parse_socket_addr::parse_socket_addr; #[derive(Parser)] pub struct Agent { diff --git a/paddler_cli/src/cmd/balancer.rs b/paddler_cli/src/cmd/balancer.rs index 94d2248b..b33128fe 100644 --- a/paddler_cli/src/cmd/balancer.rs +++ b/paddler_cli/src/cmd/balancer.rs @@ -3,16 +3,16 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; use clap::Parser; -use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::state_database_type::StateDatabaseType; -use paddler::balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; +use paddler_balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler_balancer::resolved_socket_addr::ResolvedSocketAddr; +use paddler_balancer::state_database_type::StateDatabaseType; +use paddler_balancer::statsd_service::configuration::Configuration as StatsdServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; +use paddler_balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::template_data::TemplateData; -use paddler::resolved_socket_addr::ResolvedSocketAddr; +use paddler_balancer::web_admin_panel_service::template_data::TemplateData; use paddler_bootstrap::balancer_service_bundle::BalancerBootstrapConfig; use paddler_bootstrap::balancer_service_bundle::BalancerServiceBundle; use tokio_util::sync::CancellationToken; @@ -20,8 +20,8 @@ use trzcina::ServiceManager; use trzcina::ServiceShutdownOptions; use super::handler::Handler; -use super::value_parser::parse_duration; -use super::value_parser::parse_socket_addr; +use super::value_parser::parse_duration::parse_duration; +use super::value_parser::parse_socket_addr::parse_socket_addr; #[derive(Parser)] pub struct Balancer { @@ -158,3 +158,28 @@ impl Handler for Balancer { .map_err(anyhow::Error::from) } } + +#[cfg(all(test, feature = "web_admin_panel"))] +mod tests { + use clap::Parser as _; + + use super::Balancer; + + #[test] + fn web_admin_panel_configuration_is_built_from_the_provided_address() { + let balancer = Balancer::parse_from([ + "balancer", + "--web-admin-panel-addr", + "127.0.0.1:8062", + "--max-buffered-requests", + "7", + ]); + + let configuration = balancer + .get_web_admin_panel_service_configuration() + .unwrap(); + + assert_eq!(configuration.addr.port(), 8062); + assert_eq!(configuration.template_data.max_buffered_requests, 7); + } +} diff --git a/paddler_cli/src/cmd/value_parser/mod.rs b/paddler_cli/src/cmd/value_parser/mod.rs index 6133236b..17797fe7 100644 --- a/paddler_cli/src/cmd/value_parser/mod.rs +++ b/paddler_cli/src/cmd/value_parser/mod.rs @@ -1,5 +1,2 @@ -mod parse_duration; -mod parse_socket_addr; - -pub use self::parse_duration::parse_duration; -pub use self::parse_socket_addr::parse_socket_addr; +pub mod parse_duration; +pub mod parse_socket_addr; diff --git a/paddler_cli/src/cmd/value_parser/parse_duration.rs b/paddler_cli/src/cmd/value_parser/parse_duration.rs index 87a35bc0..182244fa 100644 --- a/paddler_cli/src/cmd/value_parser/parse_duration.rs +++ b/paddler_cli/src/cmd/value_parser/parse_duration.rs @@ -7,3 +7,20 @@ pub fn parse_duration(arg: &str) -> Result { Ok(std::time::Duration::from_millis(milliseconds)) } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::parse_duration; + + #[test] + fn parses_milliseconds() { + assert_eq!(parse_duration("1500").unwrap(), Duration::from_millis(1500)); + } + + #[test] + fn rejects_a_non_numeric_value() { + assert!(parse_duration("not-a-number").is_err()); + } +} diff --git a/paddler_cli/src/cmd/value_parser/parse_socket_addr.rs b/paddler_cli/src/cmd/value_parser/parse_socket_addr.rs index f8cad109..152ecf51 100644 --- a/paddler_cli/src/cmd/value_parser/parse_socket_addr.rs +++ b/paddler_cli/src/cmd/value_parser/parse_socket_addr.rs @@ -4,7 +4,7 @@ use std::net::ToSocketAddrs; use anyhow::Result; use anyhow::anyhow; use log::warn; -use paddler::resolved_socket_addr::ResolvedSocketAddr; +use paddler_balancer::resolved_socket_addr::ResolvedSocketAddr; fn resolve_socket_addr(input_addr: &str) -> Result { let addrs: Vec = input_addr.to_socket_addrs()?.collect(); @@ -46,46 +46,38 @@ pub fn parse_socket_addr(input_addr: &str) -> Result { #[cfg(test)] mod tests { - use anyhow::Result; - use crate::cmd::value_parser::parse_socket_addr::parse_socket_addr; #[test] - fn test_parses_ip_and_port_directly() -> Result<()> { - let result = parse_socket_addr("127.0.0.1:8080")?; + fn parses_ip_and_port_directly() { + let result = parse_socket_addr("127.0.0.1:8080").unwrap(); assert_eq!(result.input_addr, "127.0.0.1:8080"); assert_eq!(result.socket_addr.port(), 8080); assert!(result.socket_addr.is_ipv4()); - - Ok(()) } #[test] - fn test_resolves_localhost_via_dns() -> Result<()> { - let result = parse_socket_addr("localhost:9090")?; + fn resolves_localhost_via_dns() { + let result = parse_socket_addr("localhost:9090").unwrap(); assert_eq!(result.input_addr, "localhost:9090"); assert_eq!(result.socket_addr.port(), 9090); - - Ok(()) } #[test] - fn test_rejects_invalid_address() { + fn rejects_invalid_address() { let result = parse_socket_addr("not-a-valid-host-that-does-not-exist.invalid:1234"); assert!(result.is_err()); } #[test] - fn test_parses_ipv6_address() -> Result<()> { - let result = parse_socket_addr("[::1]:8080")?; + fn parses_ipv6_address() { + let result = parse_socket_addr("[::1]:8080").unwrap(); assert_eq!(result.input_addr, "[::1]:8080"); assert_eq!(result.socket_addr.port(), 8080); assert!(result.socket_addr.is_ipv6()); - - Ok(()) } } diff --git a/paddler_cli/src/lib.rs b/paddler_cli/src/lib.rs new file mode 100644 index 00000000..c6bd3849 --- /dev/null +++ b/paddler_cli/src/lib.rs @@ -0,0 +1,72 @@ +mod cmd; + +use anyhow::Result; +use clap::Parser; +use clap::Subcommand; +use cmd::agent::Agent; +use cmd::balancer::Balancer; +use cmd::handler::Handler as _; +#[cfg(feature = "web_admin_panel")] +use esbuild_metafile::instance::initialize_instance; +use paddler_bootstrap::shutdown_signal::register_shutdown_signals; +use tokio_util::sync::CancellationToken; + +#[cfg(feature = "web_admin_panel")] +pub const ESBUILD_META_CONTENTS: &str = include_str!("../../esbuild-meta.json"); + +pub const CUDA_DISCLAIMER_DOCS: &str = " +This software includes NVIDIA CUDA runtime components, subject to the NVIDIA CUDA Toolkit End User License Agreement: https://docs.nvidia.com/cuda/eula/index.html +This software contains source code provided by NVIDIA Corporation. +Paddler is not affiliated with, endorsed by, or sponsored by NVIDIA Corporation."; + +#[derive(Parser)] +#[command(arg_required_else_help(true), version, about, long_about = None)] +#[cfg_attr(feature = "cuda", command(before_help = CUDA_DISCLAIMER_DOCS))] +/// `LLMOps` platform for hosting and scaling open-source LLMs in your own infrastructure +struct Cli { + #[command(subcommand)] + command: Option, +} + +#[expect( + clippy::large_enum_variant, + reason = "clap's #[derive(Subcommand)] requires unboxed `Args` payloads (Box is unsupported by the derive); the command is parsed once at startup, so the variant size difference is immaterial" +)] +#[derive(Subcommand)] +enum Commands { + /// Generates tokens and embeddings; connects to the balancer + Agent(Agent), + /// Distributes incoming requests among agents + Balancer(Balancer), +} + +async fn run_async() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let shutdown_signals = register_shutdown_signals()?; + let shutdown = CancellationToken::new(); + let signal_shutdown = shutdown.clone(); + + tokio::spawn(async move { + if let Err(error) = shutdown_signals.wait().await { + log::error!("shutdown signal listener failed: {error}"); + return; + } + signal_shutdown.cancel(); + }); + + match Cli::parse().command { + Some(Commands::Agent(handler)) => handler.handle(shutdown).await, + Some(Commands::Balancer(handler)) => { + #[cfg(feature = "web_admin_panel")] + initialize_instance(ESBUILD_META_CONTENTS); + + handler.handle(shutdown).await + } + None => Ok(()), + } +} + +pub fn run() -> Result<()> { + actix_web::rt::System::new().block_on(run_async()) +} diff --git a/paddler_cli/src/main.rs b/paddler_cli/src/main.rs index 56fb8e46..ccd58c6a 100644 --- a/paddler_cli/src/main.rs +++ b/paddler_cli/src/main.rs @@ -1,67 +1,3 @@ -mod cmd; - -use anyhow::Result; -use clap::Parser; -use clap::Subcommand; -use cmd::agent::Agent; -use cmd::balancer::Balancer; -use cmd::handler::Handler as _; -#[cfg(feature = "web_admin_panel")] -use esbuild_metafile::instance::initialize_instance; -use paddler_bootstrap::shutdown_signal::register_shutdown_signals; -use tokio_util::sync::CancellationToken; - -#[cfg(feature = "web_admin_panel")] -pub const ESBUILD_META_CONTENTS: &str = include_str!("../../esbuild-meta.json"); - -pub const CUDA_DISCLAIMER_DOCS: &str = " -This software includes NVIDIA CUDA runtime components, -subject to the NVIDIA CUDA Toolkit End User License Agreement: https://docs.nvidia.com/cuda/eula/index.html -This software contains source code provided by NVIDIA Corporation. -Paddler is not affiliated with, endorsed by, or sponsored by NVIDIA Corporation."; - -#[derive(Parser)] -#[command(arg_required_else_help(true), version, about, long_about = None)] -#[cfg_attr(feature = "cuda", command(before_help = CUDA_DISCLAIMER_DOCS))] -/// `LLMOps` platform for hosting and scaling open-source LLMs in your own infrastructure -struct Cli { - #[command(subcommand)] - command: Option, -} - -#[expect(clippy::large_enum_variant)] -#[derive(Subcommand)] -enum Commands { - /// Generates tokens and embeddings; connects to the balancer - Agent(Agent), - /// Distributes incoming requests among agents - Balancer(Balancer), -} - -#[actix_web::main] -async fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - - let shutdown_signals = register_shutdown_signals()?; - let shutdown = CancellationToken::new(); - let signal_shutdown = shutdown.clone(); - - tokio::spawn(async move { - if let Err(error) = shutdown_signals.wait().await { - log::error!("shutdown signal listener failed: {error}"); - return; - } - signal_shutdown.cancel(); - }); - - match Cli::parse().command { - Some(Commands::Agent(handler)) => handler.handle(shutdown).await, - Some(Commands::Balancer(handler)) => { - #[cfg(feature = "web_admin_panel")] - initialize_instance(ESBUILD_META_CONTENTS); - - handler.handle(shutdown).await - } - None => Ok(()), - } +fn main() -> anyhow::Result<()> { + paddler_cli::run() } diff --git a/paddler_cli_tests/Cargo.toml b/paddler_cli_tests/Cargo.toml new file mode 100644 index 00000000..6ce65003 --- /dev/null +++ b/paddler_cli_tests/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "paddler_cli_tests" +authors.workspace = true +description = "Subprocess-isolated multi-agent integration tests for Paddler" +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[[bin]] +name = "paddler_cluster_node" +path = "src/bin/paddler_cluster_node.rs" + +[features] +default = [] +cuda = ["paddler_cli/cuda"] +metal = ["paddler_cli/metal"] +tests_that_use_llms = [] +web_admin_panel = ["paddler_cli/web_admin_panel"] + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +log = { workspace = true } +nix = { workspace = true } +paddler_cli = { workspace = true } +paddler_client = { workspace = true } +paddler_messaging = { workspace = true } +paddler_test_cluster_harness = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +base64 = { workspace = true } +futures-util = { workspace = true } +reqwest = { workspace = true } +serde_json = { workspace = true } +serial_test = { workspace = true } +url = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_cli_tests/src/bin/paddler_cluster_node.rs b/paddler_cli_tests/src/bin/paddler_cluster_node.rs new file mode 100644 index 00000000..ccd58c6a --- /dev/null +++ b/paddler_cli_tests/src/bin/paddler_cluster_node.rs @@ -0,0 +1,3 @@ +fn main() -> anyhow::Result<()> { + paddler_cli::run() +} diff --git a/paddler_cli_tests/src/lib.rs b/paddler_cli_tests/src/lib.rs new file mode 100644 index 00000000..fda7b016 --- /dev/null +++ b/paddler_cli_tests/src/lib.rs @@ -0,0 +1,11 @@ +pub mod model_card; +pub mod paddler_command; +pub mod qwen3_embedding_cluster_params; +pub mod spawn_agent_subprocess; +pub mod spawn_agent_subprocess_params; +pub mod start_subprocess_cluster; +pub mod start_subprocess_cluster_with_qwen3; +pub mod start_subprocess_embedding_cluster; +pub mod subprocess_agent_spawner; +pub mod subprocess_process; +pub mod terminate_child; diff --git a/paddler_cli_tests/src/model_card/mod.rs b/paddler_cli_tests/src/model_card/mod.rs new file mode 100644 index 00000000..cfba79c6 --- /dev/null +++ b/paddler_cli_tests/src/model_card/mod.rs @@ -0,0 +1,9 @@ +pub mod qwen3_0_6b; +pub mod qwen3_embedding_0_6b; + +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; + +pub struct ModelCard { + pub gpu_layer_count: u32, + pub reference: HuggingFaceModelReference, +} diff --git a/paddler_cli_tests/src/model_card/qwen3_0_6b.rs b/paddler_cli_tests/src/model_card/qwen3_0_6b.rs new file mode 100644 index 00000000..4a421502 --- /dev/null +++ b/paddler_cli_tests/src/model_card/qwen3_0_6b.rs @@ -0,0 +1,15 @@ +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn qwen3_0_6b() -> ModelCard { + ModelCard { + gpu_layer_count: 28, + reference: HuggingFaceModelReference { + filename: "Qwen3-0.6B-Q8_0.gguf".to_owned(), + repo_id: "Qwen/Qwen3-0.6B-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_cli_tests/src/model_card/qwen3_embedding_0_6b.rs b/paddler_cli_tests/src/model_card/qwen3_embedding_0_6b.rs new file mode 100644 index 00000000..fcdf63b9 --- /dev/null +++ b/paddler_cli_tests/src/model_card/qwen3_embedding_0_6b.rs @@ -0,0 +1,15 @@ +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; + +use crate::model_card::ModelCard; + +#[must_use] +pub fn qwen3_embedding_0_6b() -> ModelCard { + ModelCard { + gpu_layer_count: 28, + reference: HuggingFaceModelReference { + filename: "Qwen3-Embedding-0.6B-Q8_0.gguf".to_owned(), + repo_id: "Qwen/Qwen3-Embedding-0.6B-GGUF".to_owned(), + revision: "main".to_owned(), + }, + } +} diff --git a/paddler_cli_tests/src/paddler_command.rs b/paddler_cli_tests/src/paddler_command.rs new file mode 100644 index 00000000..fc1b3eb6 --- /dev/null +++ b/paddler_cli_tests/src/paddler_command.rs @@ -0,0 +1,16 @@ +use std::env; + +use tokio::process::Command; + +#[must_use] +pub fn paddler_command(binary_path: &str) -> Command { + let mut command = Command::new(binary_path); + + command.kill_on_drop(true); + + if let Ok(profile_file) = env::var("LLVM_PROFILE_FILE") { + command.env("LLVM_PROFILE_FILE", profile_file); + } + + command +} diff --git a/paddler_cli_tests/src/qwen3_embedding_cluster_params.rs b/paddler_cli_tests/src/qwen3_embedding_cluster_params.rs new file mode 100644 index 00000000..49ee2bc1 --- /dev/null +++ b/paddler_cli_tests/src/qwen3_embedding_cluster_params.rs @@ -0,0 +1,23 @@ +use std::time::Duration; + +use paddler_messaging::inference_parameters::InferenceParameters; + +use paddler_test_cluster_harness::agent_config::AgentConfig; + +pub struct Qwen3EmbeddingClusterParams { + pub agents: Vec, + pub buffered_request_timeout: Duration, + pub inference_parameters: InferenceParameters, + pub max_buffered_requests: i32, +} + +impl Default for Qwen3EmbeddingClusterParams { + fn default() -> Self { + Self { + agents: AgentConfig::uniform(1, 4), + buffered_request_timeout: Duration::from_secs(10), + inference_parameters: InferenceParameters::default(), + max_buffered_requests: 10, + } + } +} diff --git a/paddler_tests/src/spawn_agent_subprocess.rs b/paddler_cli_tests/src/spawn_agent_subprocess.rs similarity index 77% rename from paddler_tests/src/spawn_agent_subprocess.rs rename to paddler_cli_tests/src/spawn_agent_subprocess.rs index af0012fb..daeeac0d 100644 --- a/paddler_tests/src/spawn_agent_subprocess.rs +++ b/paddler_cli_tests/src/spawn_agent_subprocess.rs @@ -9,27 +9,22 @@ use crate::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; pub fn spawn_agent_subprocess( SpawnAgentSubprocessParams { + binary_path, management_addr, name, slots, }: SpawnAgentSubprocessParams, ) -> Result { - let mut command = paddler_command(); - - command + paddler_command(&binary_path) .arg("agent") .arg("--management-addr") .arg(management_addr.to_string()) + .arg("--name") + .arg(name) .arg("--slots") .arg(slots.to_string()) .stdout(Stdio::null()) - .stderr(Stdio::null()); - - if let Some(agent_name) = name { - command.arg("--name").arg(agent_name); - } - - command + .stderr(Stdio::null()) .spawn() .context("failed to spawn paddler agent subprocess") } diff --git a/paddler_tests/src/spawn_agent_subprocess_params.rs b/paddler_cli_tests/src/spawn_agent_subprocess_params.rs similarity index 71% rename from paddler_tests/src/spawn_agent_subprocess_params.rs rename to paddler_cli_tests/src/spawn_agent_subprocess_params.rs index e4084be6..a5a2438f 100644 --- a/paddler_tests/src/spawn_agent_subprocess_params.rs +++ b/paddler_cli_tests/src/spawn_agent_subprocess_params.rs @@ -1,7 +1,8 @@ use std::net::SocketAddr; pub struct SpawnAgentSubprocessParams { + pub binary_path: String, pub management_addr: SocketAddr, - pub name: Option, + pub name: String, pub slots: i32, } diff --git a/paddler_cli_tests/src/start_subprocess_cluster.rs b/paddler_cli_tests/src/start_subprocess_cluster.rs new file mode 100644 index 00000000..ab67e9ed --- /dev/null +++ b/paddler_cli_tests/src/start_subprocess_cluster.rs @@ -0,0 +1,114 @@ +use std::process::Stdio; + +use anyhow::Context as _; +use anyhow::Result; + +use paddler_test_cluster_harness::balancer_addresses::BalancerAddresses; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::running_balancer::RunningBalancer; + +use crate::paddler_command::paddler_command; +use crate::subprocess_agent_spawner::SubprocessAgentSpawner; +use crate::subprocess_process::SubprocessProcess; + +pub async fn start_subprocess_cluster( + binary_path: &str, + ClusterParams { + agents, + buffered_request_timeout, + desired_state, + inference_cors_allowed_hosts, + inference_item_timeout, + management_cors_allowed_hosts, + max_buffered_requests, + state_database_url, + wait_for_slots_ready, + }: ClusterParams, +) -> Result { + let addresses = BalancerAddresses::pick()?; + let management_addr = addresses.management; + + let mut balancer_command = paddler_command(binary_path); + + balancer_command + .arg("balancer") + .arg("--inference-addr") + .arg(addresses.inference.to_string()) + .arg("--management-addr") + .arg(addresses.management.to_string()) + .arg("--compat-openai-addr") + .arg(addresses.compat_openai.to_string()) + .arg("--state-database") + .arg(&state_database_url) + .arg("--max-buffered-requests") + .arg(max_buffered_requests.to_string()) + .arg("--buffered-request-timeout") + .arg(buffered_request_timeout.as_millis().to_string()) + .arg("--inference-item-timeout") + .arg(inference_item_timeout.as_millis().to_string()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + for allowed_host in &inference_cors_allowed_hosts { + balancer_command + .arg("--inference-cors-allowed-host") + .arg(allowed_host); + } + + for allowed_host in &management_cors_allowed_hosts { + balancer_command + .arg("--management-cors-allowed-host") + .arg(allowed_host); + } + + let balancer_subprocess = balancer_command + .spawn() + .context("failed to spawn paddler balancer subprocess")?; + + let running_balancer = RunningBalancer::new( + addresses, + Box::new(SubprocessProcess::new(balancer_subprocess)), + ); + + let mut cluster = Cluster::connect( + running_balancer, + Box::new(SubprocessAgentSpawner::new( + binary_path.to_owned(), + management_addr, + )), + desired_state.as_ref(), + ) + .await?; + + let expected_agent_count = agents.len(); + let mut last_ready_snapshot = None; + + for agent in &agents { + cluster.spawn_additional_agent(agent)?; + + if wait_for_slots_ready { + last_ready_snapshot = Some( + cluster + .wait_for_agent_ready(&agent.name, agent.slot_count) + .await?, + ); + } + } + + let registered_snapshot = match last_ready_snapshot { + Some(snapshot) => snapshot, + None => cluster + .wait_for_agent_count(expected_agent_count) + .await + .context("not all subprocess agents registered")?, + }; + + cluster.agent_ids = registered_snapshot + .agents + .iter() + .map(|registered_agent| registered_agent.id.clone()) + .collect(); + + Ok(cluster) +} diff --git a/paddler_cli_tests/src/start_subprocess_cluster_with_qwen3.rs b/paddler_cli_tests/src/start_subprocess_cluster_with_qwen3.rs new file mode 100644 index 00000000..34e6220b --- /dev/null +++ b/paddler_cli_tests/src/start_subprocess_cluster_with_qwen3.rs @@ -0,0 +1,41 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_0_6b::qwen3_0_6b; +use crate::start_subprocess_cluster::start_subprocess_cluster; + +pub async fn start_subprocess_cluster_with_qwen3( + binary_path: &str, + agents: Vec, +) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + start_subprocess_cluster( + binary_path, + ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }, + ) + .await +} diff --git a/paddler_cli_tests/src/start_subprocess_embedding_cluster.rs b/paddler_cli_tests/src/start_subprocess_embedding_cluster.rs new file mode 100644 index 00000000..51dd5842 --- /dev/null +++ b/paddler_cli_tests/src/start_subprocess_embedding_cluster.rs @@ -0,0 +1,50 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_embedding_0_6b::qwen3_embedding_0_6b; +use crate::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use crate::start_subprocess_cluster::start_subprocess_cluster; + +pub async fn start_subprocess_embedding_cluster( + binary_path: &str, + Qwen3EmbeddingClusterParams { + agents, + buffered_request_timeout, + inference_parameters, + max_buffered_requests, + }: Qwen3EmbeddingClusterParams, +) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_embedding_0_6b(); + + let inference_parameters_with_offload = InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..inference_parameters + }; + + start_subprocess_cluster( + binary_path, + ClusterParams { + agents, + buffered_request_timeout, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: inference_parameters_with_offload, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + max_buffered_requests, + wait_for_slots_ready: true, + ..ClusterParams::default() + }, + ) + .await +} diff --git a/paddler_cli_tests/src/subprocess_agent_spawner.rs b/paddler_cli_tests/src/subprocess_agent_spawner.rs new file mode 100644 index 00000000..0093c242 --- /dev/null +++ b/paddler_cli_tests/src/subprocess_agent_spawner.rs @@ -0,0 +1,39 @@ +use std::net::SocketAddr; + +use anyhow::Result; + +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::agent_spawner::AgentSpawner; +use paddler_test_cluster_harness::managed_process::ManagedProcess; + +use crate::spawn_agent_subprocess::spawn_agent_subprocess; +use crate::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; +use crate::subprocess_process::SubprocessProcess; + +pub struct SubprocessAgentSpawner { + binary_path: String, + management_addr: SocketAddr, +} + +impl SubprocessAgentSpawner { + #[must_use] + pub const fn new(binary_path: String, management_addr: SocketAddr) -> Self { + Self { + binary_path, + management_addr, + } + } +} + +impl AgentSpawner for SubprocessAgentSpawner { + fn spawn(&self, config: &AgentConfig) -> Result> { + let child = spawn_agent_subprocess(SpawnAgentSubprocessParams { + binary_path: self.binary_path.clone(), + management_addr: self.management_addr, + name: config.name.clone(), + slots: config.slot_count, + })?; + + Ok(Box::new(SubprocessProcess::new(child))) + } +} diff --git a/paddler_cli_tests/src/subprocess_process.rs b/paddler_cli_tests/src/subprocess_process.rs new file mode 100644 index 00000000..ea76d83b --- /dev/null +++ b/paddler_cli_tests/src/subprocess_process.rs @@ -0,0 +1,37 @@ +use anyhow::Result; +use async_trait::async_trait; +use log::warn; +use tokio::process::Child; + +use paddler_test_cluster_harness::managed_process::ManagedProcess; + +use crate::terminate_child::terminate_child; + +pub struct SubprocessProcess { + child: Child, +} + +impl SubprocessProcess { + #[must_use] + pub const fn new(child: Child) -> Self { + Self { child } + } +} + +#[async_trait] +impl ManagedProcess for SubprocessProcess { + async fn shutdown(&mut self) -> Result<()> { + terminate_child(&mut self.child)?; + self.child.wait().await?; + + Ok(()) + } +} + +impl Drop for SubprocessProcess { + fn drop(&mut self) { + if let Err(error) = terminate_child(&mut self.child) { + warn!("SubprocessProcess drop: failed to terminate subprocess: {error:#}"); + } + } +} diff --git a/paddler_tests/src/terminate_child.rs b/paddler_cli_tests/src/terminate_child.rs similarity index 100% rename from paddler_tests/src/terminate_child.rs rename to paddler_cli_tests/src/terminate_child.rs diff --git a/paddler_cli_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs b/paddler_cli_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs new file mode 100644 index 00000000..250ceb6b --- /dev/null +++ b/paddler_cli_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::time::Duration; + +use anyhow::Result; +use futures_util::StreamExt as _; +use paddler_cli_tests::model_card::ModelCard; +use paddler_cli_tests::model_card::qwen3_0_6b::qwen3_0_6b; +use paddler_cli_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_buffered_requests_across_two_agents() -> Result<()> { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + let cluster = start_subprocess_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + ClusterParams { + agents: vec![ + AgentConfig { + name: "distributed-agent-0".to_owned(), + slot_count: 2, + }, + AgentConfig { + name: "distributed-agent-1".to_owned(), + slot_count: 2, + }, + ], + wait_for_slots_ready: true, + buffered_request_timeout: Duration::from_mins(2), + max_buffered_requests: 10, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..ClusterParams::default() + }, + ) + .await?; + + let mut streams = Vec::new(); + + for _ in 0..5 { + let stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 10, + raw_prompt: "Hello".to_owned(), + }) + .await?; + + streams.push(stream); + } + + let mut successful_responses = 0; + + for mut stream in streams { + if let Some(item) = stream.next().await { + match item? { + Message::Response(_) => successful_responses += 1, + Message::Error(envelope) => { + anyhow::bail!( + "expected success, got error {}: {}", + envelope.error.code, + envelope.error.description + ); + } + } + } + } + + assert_eq!(successful_responses, 5); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents.rs b/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents.rs new file mode 100644 index 00000000..bce06beb --- /dev/null +++ b/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents.rs @@ -0,0 +1,62 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::collections::BTreeSet; + +use anyhow::Result; +use paddler_cli_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_cli_tests::start_subprocess_embedding_cluster::start_subprocess_embedding_cluster; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_embedding_batch_across_agents() -> Result<()> { + let cluster = start_subprocess_embedding_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(2, 4), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }, + ) + .await?; + + let filler = "x".repeat(380); + let input_batch: Vec = (0..12) + .map(|index| EmbeddingInputDocument { + content: format!("Document number {index:02}: {filler}"), + id: format!("doc-{index}"), + }) + .collect(); + let params = GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }; + + let collected = cluster.generate_embedding_batch(¶ms).await?; + + assert_eq!(collected.embeddings.len(), 12); + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert!( + producers.len() >= 2, + "expected the embedding batch to be distributed across at least two agents, but only saw producers: {producers:?}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs b/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs new file mode 100644 index 00000000..939ef3f6 --- /dev/null +++ b/paddler_cli_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs @@ -0,0 +1,89 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::collections::BTreeSet; + +use anyhow::Result; +use paddler_cli_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_cli_tests::start_subprocess_embedding_cluster::start_subprocess_embedding_cluster; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_embedding_batch_across_agents_with_uneven_slots() -> Result<()> { + let cluster = start_subprocess_embedding_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + Qwen3EmbeddingClusterParams { + agents: vec![ + AgentConfig { + name: "agent-fat".to_owned(), + slot_count: 4, + }, + AgentConfig { + name: "agent-thin-a".to_owned(), + slot_count: 1, + }, + AgentConfig { + name: "agent-medium".to_owned(), + slot_count: 2, + }, + AgentConfig { + name: "agent-thin-b".to_owned(), + slot_count: 1, + }, + ], + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }, + ) + .await?; + + let input_batch: Vec = (0..8) + .map(|index| EmbeddingInputDocument { + content: format!("Uneven-slot document number {index}."), + id: format!("doc-{index}"), + }) + .collect(); + + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) + .await?; + + assert_eq!(collected.embeddings.len(), 8); + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + + let returned_document_ids: BTreeSet = collected + .embeddings + .iter() + .map(|produced| produced.embedding.source_document_id.clone()) + .collect(); + let expected_document_ids: BTreeSet = + (0..8).map(|index| format!("doc-{index}")).collect(); + assert_eq!(returned_document_ids, expected_document_ids); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers.len(), + 4, + "embedding batch must fan out across all agents even when slot counts are uneven, but only saw producers: {producers:?}", + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_cli_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs b/paddler_cli_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs new file mode 100644 index 00000000..1148b432 --- /dev/null +++ b/paddler_cli_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs @@ -0,0 +1,78 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::collections::BTreeSet; + +use std::time::Duration; + +use anyhow::Result; +use futures_util::future; +use paddler_cli_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_cli_tests::start_subprocess_embedding_cluster::start_subprocess_embedding_cluster; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_distributes_embedding_burst_evenly_across_agents() -> Result<()> { + const AGENT_COUNT: usize = 4; + const SLOTS_PER_AGENT: i32 = 2; + const CONCURRENT_REQUESTS: usize = 8; + + let cluster = start_subprocess_embedding_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT), + buffered_request_timeout: Duration::from_mins(1), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + max_buffered_requests: 32, + }, + ) + .await?; + + let collection_futures = (0..CONCURRENT_REQUESTS).map(|request_index| { + let input_batch: Vec = (0..4) + .map(|document_index| EmbeddingInputDocument { + content: format!( + "Burst request {request_index}, document {document_index}: \ + provide an embedding for evaluation." + ), + id: format!("req-{request_index}-doc-{document_index}"), + }) + .collect(); + + cluster.generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) + }); + + let collected_streams = future::try_join_all(collection_futures).await?; + + let producers_across_streams: BTreeSet<&str> = collected_streams + .iter() + .flat_map(|collected| collected.embeddings.iter()) + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers_across_streams.len(), + AGENT_COUNT, + "burst of {CONCURRENT_REQUESTS} embedding batches across {AGENT_COUNT} agents must reach every agent, but saw producers: {producers_across_streams:?}", + ); + + for collected in &collected_streams { + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + assert_eq!(collected.embeddings.len(), 4); + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs b/paddler_cli_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs similarity index 61% rename from paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs rename to paddler_cli_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs index cf5cded4..8345fe6f 100644 --- a/paddler_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs +++ b/paddler_cli_tests/tests/balancer_distributes_token_burst_evenly_across_agents.rs @@ -1,19 +1,13 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::collections::BTreeSet; use anyhow::Result; use anyhow::anyhow; use futures_util::future; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_cli_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -21,31 +15,22 @@ async fn balancer_distributes_token_burst_evenly_across_agents() -> Result<()> { const AGENT_COUNT: usize = 4; const SLOTS_PER_AGENT: i32 = 1; - let cluster = - start_subprocess_cluster_with_qwen3(AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT)) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_subprocess_cluster_with_qwen3( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT), + ) + .await?; let prompts: Vec = (0..AGENT_COUNT) .map(|index| format!("Burst request number {index}: Count from one to five.")) .collect(); let collection_futures = prompts.iter().map(|prompt| { - let inference_client = inference_client.clone(); - let raw_prompt = prompt.clone(); - async move { - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 16, - raw_prompt, - }) - .await?; - - collect_generated_tokens(stream).await - } + cluster.continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: prompt.clone(), + }) }); let collected_streams = future::try_join_all(collection_futures).await?; diff --git a/paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs b/paddler_cli_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs similarity index 54% rename from paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs rename to paddler_cli_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs index a8b42453..ba7f4531 100644 --- a/paddler_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs +++ b/paddler_cli_tests/tests/balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffered_requests.rs @@ -1,22 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Result; use anyhow::anyhow; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_cli_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_cli_tests::start_subprocess_embedding_cluster::start_subprocess_embedding_cluster; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; use tokio::time::timeout; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] @@ -25,21 +19,21 @@ async fn balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffere -> Result<()> { const TOTAL_DOCUMENTS: usize = 16; - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(4, 1), - buffered_request_timeout: Duration::from_secs(2), - inference_parameters: InferenceParameters { - embedding_batch_size: 1, - enable_embeddings: true, - ..InferenceParameters::default() + let cluster = start_subprocess_embedding_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(4, 1), + buffered_request_timeout: Duration::from_secs(2), + inference_parameters: InferenceParameters { + embedding_batch_size: 1, + enable_embeddings: true, + ..InferenceParameters::default() + }, + max_buffered_requests: 4, }, - max_buffered_requests: 4, - }) + ) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let input_batch: Vec = (0..TOTAL_DOCUMENTS) .map(|index| EmbeddingInputDocument { content: format!("Overflow probe document {index}."), @@ -47,16 +41,15 @@ async fn balancer_emits_overflow_errors_when_embedding_burst_exceeds_max_buffere }) .collect(); - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = timeout( + Duration::from_secs(15), + cluster.generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch, normalization_method: EmbeddingNormalizationMethod::None, - }) - .await?; - - let collected = timeout(Duration::from_secs(15), collect_embedding_results(stream)) - .await - .map_err(|_| anyhow!("burst-overflow embedding stream did not finish within 15s"))??; + }), + ) + .await + .map_err(|_| anyhow!("burst-overflow embedding stream did not finish within 15s"))??; let overflow_errors: Vec<_> = collected .wire_errors diff --git a/paddler_cli_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs b/paddler_cli_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs new file mode 100644 index 00000000..38ba3a4a --- /dev/null +++ b/paddler_cli_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs @@ -0,0 +1,65 @@ +#![cfg(feature = "tests_that_use_llms")] + +use std::collections::BTreeSet; + +use anyhow::Result; +use paddler_cli_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_cli_tests::start_subprocess_embedding_cluster::start_subprocess_embedding_cluster; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_fans_out_embedding_batch_to_all_agents() -> Result<()> { + let agent_count: usize = 4; + + let cluster = start_subprocess_embedding_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + Qwen3EmbeddingClusterParams { + agents: AgentConfig::uniform(agent_count, 2), + inference_parameters: InferenceParameters { + enable_embeddings: true, + ..InferenceParameters::default() + }, + ..Qwen3EmbeddingClusterParams::default() + }, + ) + .await?; + + let filler = "x".repeat(380); + let input_batch: Vec = (0..16) + .map(|index| EmbeddingInputDocument { + content: format!("Document number {index:02}: {filler}"), + id: format!("doc-{index}"), + }) + .collect(); + let params = GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }; + + let collected = cluster.generate_embedding_batch(¶ms).await?; + + assert_eq!(collected.embeddings.len(), 16); + assert!(collected.saw_done); + assert!(collected.errors.is_empty()); + + let producers: BTreeSet<&str> = collected + .embeddings + .iter() + .filter_map(|produced| produced.generated_by.as_deref()) + .collect(); + + assert_eq!( + producers.len(), + agent_count, + "expected the embedding batch to fan out across every agent, but only saw producers: {producers:?}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_cli_tests/tests/balancer_registers_multiple_agents_over_time.rs b/paddler_cli_tests/tests/balancer_registers_multiple_agents_over_time.rs new file mode 100644 index 00000000..6b70c49d --- /dev/null +++ b/paddler_cli_tests/tests/balancer_registers_multiple_agents_over_time.rs @@ -0,0 +1,42 @@ +use anyhow::Context as _; +use anyhow::Result; +use paddler_cli_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +#[tokio::test(flavor = "multi_thread")] +async fn balancer_registers_multiple_agents_over_time() -> Result<()> { + let mut cluster = start_subprocess_cluster( + env!("CARGO_BIN_EXE_paddler_cluster_node"), + ClusterParams { + agents: Vec::new(), + wait_for_slots_ready: false, + ..ClusterParams::default() + }, + ) + .await?; + + cluster.spawn_additional_agent(&AgentConfig { + name: "test-agent-1".to_owned(), + slot_count: 1, + })?; + + cluster + .wait_for_agent_count(1) + .await + .context("first agent should register")?; + + cluster.spawn_additional_agent(&AgentConfig { + name: "test-agent-2".to_owned(), + slot_count: 1, + })?; + + cluster + .wait_for_agent_count(2) + .await + .context("second agent should register")?; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_client/Cargo.toml b/paddler_client/Cargo.toml index 033f16aa..46fe0e72 100644 --- a/paddler_client/Cargo.toml +++ b/paddler_client/Cargo.toml @@ -14,7 +14,7 @@ dashmap = { workspace = true } futures-util = { workspace = true } log = { workspace = true } nanoid = { workspace = true } -paddler_types = { workspace = true } +paddler_messaging = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -24,6 +24,9 @@ tokio-stream = { workspace = true } tokio-tungstenite = { workspace = true } url = { workspace = true } +[dev-dependencies] +http = { workspace = true } + [lints] workspace = true diff --git a/paddler_client/src/agents_stream.rs b/paddler_client/src/agents_stream.rs index 5f2dfa26..7db15fdd 100644 --- a/paddler_client/src/agents_stream.rs +++ b/paddler_client/src/agents_stream.rs @@ -1,8 +1,8 @@ use std::pin::Pin; use futures_util::Stream; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; -use crate::Result; +use crate::error::Result; pub type AgentsStream = Pin> + Send>>; diff --git a/paddler_client/src/buffered_requests_stream.rs b/paddler_client/src/buffered_requests_stream.rs index 383177bb..f6bef507 100644 --- a/paddler_client/src/buffered_requests_stream.rs +++ b/paddler_client/src/buffered_requests_stream.rs @@ -1,9 +1,9 @@ use std::pin::Pin; use futures_util::Stream; -use paddler_types::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; -use crate::Result; +use crate::error::Result; pub type BufferedRequestsStream = Pin> + Send>>; diff --git a/paddler_client/src/client_inference.rs b/paddler_client/src/client_inference.rs index 86fe6ddf..3df78b80 100644 --- a/paddler_client/src/client_inference.rs +++ b/paddler_client/src/client_inference.rs @@ -1,19 +1,19 @@ use std::sync::OnceLock; use nanoid::nanoid; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::inference_server::Message as InferenceServerMessage; -use paddler_types::inference_server::Request as InferenceServerRequest; -use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_server::message::Message as InferenceServerMessage; +use paddler_messaging::inference_server::request::Request as InferenceServerRequest; +use paddler_messaging::jsonrpc::request_envelope::RequestEnvelope; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use reqwest::Client; use tokio_stream::wrappers::UnboundedReceiverStream; use url::Url; -use crate::Result; +use crate::error::Result; use crate::format_api_url::format_api_url; use crate::inference_message_stream::InferenceMessageStream; use crate::inference_socket::pool::Pool; @@ -44,7 +44,7 @@ impl<'client> ClientInference<'client> { pub async fn get_health(&self) -> Result { let response = self .http_client - .get(format_api_url(self.url, "/health")?) + .get(format_api_url(self.url, "/health")) .send() .await? .error_for_status()?; @@ -102,7 +102,7 @@ impl<'client> ClientInference<'client> { .post(format_api_url( self.url, "/api/v1/continue_from_conversation_history", - )?) + )) .json(params) .send() .await? @@ -119,10 +119,7 @@ impl<'client> ClientInference<'client> { ) -> Result { let response = self .http_client - .post(format_api_url( - self.url, - "/api/v1/continue_from_raw_prompt", - )?) + .post(format_api_url(self.url, "/api/v1/continue_from_raw_prompt")) .json(params) .send() .await? @@ -139,10 +136,7 @@ impl<'client> ClientInference<'client> { ) -> Result { let response = self .http_client - .post(format_api_url( - self.url, - "/api/v1/generate_embedding_batch", - )?) + .post(format_api_url(self.url, "/api/v1/generate_embedding_batch")) .json(params) .send() .await? @@ -153,3 +147,64 @@ impl<'client> ClientInference<'client> { Ok(Box::pin(stream)) } } + +#[cfg(test)] +mod tests { + use paddler_messaging::conversation_history::ConversationHistory; + use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; + use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; + use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + use reqwest::Client; + use url::Url; + + use super::ClientInference; + + fn raw_prompt_params() -> ContinueFromRawPromptParams { + ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: "hello".to_owned(), + } + } + + fn conversation_history_params() + -> ContinueFromConversationHistoryParams { + ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(Vec::new()), + enable_thinking: false, + grammar: None, + max_tokens: 16, + parse_tool_calls: false, + tools: Vec::new(), + } + } + + #[tokio::test] + async fn continue_from_raw_prompt_errors_for_an_unreachable_server() { + let url = Url::parse("http://127.0.0.1:1").unwrap(); + let http_client = Client::new(); + let inference = ClientInference::new(&url, &http_client, 1); + + assert!( + inference + .continue_from_raw_prompt(raw_prompt_params()) + .await + .is_err() + ); + } + + #[tokio::test] + async fn continue_from_conversation_history_errors_for_an_unreachable_server() { + let url = Url::parse("http://127.0.0.1:1").unwrap(); + let http_client = Client::new(); + let inference = ClientInference::new(&url, &http_client, 1); + + assert!( + inference + .continue_from_conversation_history(conversation_history_params()) + .await + .is_err() + ); + } +} diff --git a/paddler_client/src/client_management.rs b/paddler_client/src/client_management.rs index a84d4a71..adfa4be9 100644 --- a/paddler_client/src/client_management.rs +++ b/paddler_client/src/client_management.rs @@ -1,17 +1,19 @@ use futures_util::StreamExt; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; -use paddler_types::agent_desired_state::AgentDesiredState; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::model_metadata::ModelMetadata; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_desired_state::AgentDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::model_metadata::ModelMetadata; use reqwest::Client; +use reqwest::Response; +use serde::de::DeserializeOwned; use serde_json::from_str; use url::Url; -use crate::Result; use crate::agents_stream::AgentsStream; use crate::buffered_requests_stream::BufferedRequestsStream; +use crate::error::Result; use crate::format_api_url::format_api_url; use crate::stream::sse::Sse; @@ -26,56 +28,42 @@ impl<'client> ClientManagement<'client> { Self { url, http_client } } - pub async fn get_health(&self) -> Result { - let response = self + async fn get(&self, path: &str) -> Result { + Ok(self .http_client - .get(format_api_url(self.url, "/health")?) + .get(format_api_url(self.url, path)) .send() .await? - .error_for_status()?; + .error_for_status()?) + } - Ok(response.text().await?) + async fn get_text(&self, path: &str) -> Result { + Ok(self.get(path).await?.text().await?) } - pub async fn get_agents(&self) -> Result { - let response = self - .http_client - .get(format_api_url(self.url, "/api/v1/agents")?) - .send() - .await? - .error_for_status()?; + async fn get_json(&self, path: &str) -> Result { + Ok(self.get(path).await?.json().await?) + } - Ok(response.json().await?) + pub async fn get_health(&self) -> Result { + self.get_text("/health").await } - pub async fn get_balancer_desired_state(&self) -> Result { - let response = self - .http_client - .get(format_api_url(self.url, "/api/v1/balancer_desired_state")?) - .send() - .await? - .error_for_status()?; + pub async fn get_agents(&self) -> Result { + self.get_json("/api/v1/agents").await + } - Ok(response.json().await?) + pub async fn get_balancer_desired_state(&self) -> Result { + self.get_json("/api/v1/balancer_desired_state").await } pub async fn get_balancer_applicable_state(&self) -> Result> { - let response = self - .http_client - .get(format_api_url( - self.url, - "/api/v1/balancer_applicable_state", - )?) - .send() - .await? - .error_for_status()?; - - Ok(response.json().await?) + self.get_json("/api/v1/balancer_applicable_state").await } pub async fn put_balancer_desired_state(&self, state: &BalancerDesiredState) -> Result<()> { self.http_client - .put(format_api_url(self.url, "/api/v1/balancer_desired_state")?) + .put(format_api_url(self.url, "/api/v1/balancer_desired_state")) .json(state) .send() .await? @@ -85,23 +73,11 @@ impl<'client> ClientManagement<'client> { } pub async fn get_buffered_requests(&self) -> Result { - let response = self - .http_client - .get(format_api_url(self.url, "/api/v1/buffered_requests")?) - .send() - .await? - .error_for_status()?; - - Ok(response.json().await?) + self.get_json("/api/v1/buffered_requests").await } pub async fn get_agents_stream(&self) -> Result { - let response = self - .http_client - .get(format_api_url(self.url, "/api/v1/agents/stream")?) - .send() - .await? - .error_for_status()?; + let response = self.get("/api/v1/agents/stream").await?; let stream = Sse::from_response(response) .map(|result| result.and_then(|data| from_str(&data).map_err(Into::into))); @@ -110,15 +86,7 @@ impl<'client> ClientManagement<'client> { } pub async fn get_buffered_requests_stream(&self) -> Result { - let response = self - .http_client - .get(format_api_url( - self.url, - "/api/v1/buffered_requests/stream", - )?) - .send() - .await? - .error_for_status()?; + let response = self.get("/api/v1/buffered_requests/stream").await?; let stream = Sse::from_response(response) .map(|result| result.and_then(|data| from_str(&data).map_err(Into::into))); @@ -127,41 +95,16 @@ impl<'client> ClientManagement<'client> { } pub async fn get_chat_template_override(&self, agent_id: &str) -> Result> { - let response = self - .http_client - .get(format_api_url( - self.url, - &format!("/api/v1/agent/{agent_id}/chat_template_override"), - )?) - .send() - .await? - .error_for_status()?; - - Ok(response.json().await?) + self.get_json(&format!("/api/v1/agent/{agent_id}/chat_template_override")) + .await } pub async fn get_model_metadata(&self, agent_id: &str) -> Result> { - let response = self - .http_client - .get(format_api_url( - self.url, - &format!("/api/v1/agent/{agent_id}/model_metadata"), - )?) - .send() - .await? - .error_for_status()?; - - Ok(response.json().await?) + self.get_json(&format!("/api/v1/agent/{agent_id}/model_metadata")) + .await } pub async fn get_metrics(&self) -> Result { - let response = self - .http_client - .get(format_api_url(self.url, "/metrics")?) - .send() - .await? - .error_for_status()?; - - Ok(response.text().await?) + self.get_text("/metrics").await } } diff --git a/paddler_client/src/error.rs b/paddler_client/src/error.rs index fd84c39a..de282c30 100644 --- a/paddler_client/src/error.rs +++ b/paddler_client/src/error.rs @@ -1,5 +1,8 @@ #[derive(Debug, thiserror::Error)] -#[expect(clippy::error_impl_error)] +#[expect( + clippy::error_impl_error, + reason = "the crate's single public error type is idiomatically named `Error` and correctly implements std::error::Error; renaming would be smurf naming and break the module-named-after-its-item convention" +)] pub enum Error { #[error("HTTP request failed: {0}")] Http(#[from] reqwest::Error), @@ -26,10 +29,4 @@ pub enum Error { Other(String), } -impl From for Error { - fn from(err: anyhow::Error) -> Self { - Self::Other(err.to_string()) - } -} - pub type Result = std::result::Result; diff --git a/paddler_client/src/format_api_url.rs b/paddler_client/src/format_api_url.rs index 6907d76f..9d69ac89 100644 --- a/paddler_client/src/format_api_url.rs +++ b/paddler_client/src/format_api_url.rs @@ -1,18 +1,7 @@ use url::Url; -use crate::error::Error; -use crate::error::Result; - -pub fn format_api_url(base_url: &Url, path: &str) -> Result { - if !path.starts_with('/') { - return Err(Error::Other(format!("path must start with '/': {path}"))); - } - - Ok(format!( - "{}{}", - base_url.as_str().trim_end_matches('/'), - path, - )) +pub fn format_api_url(base_url: &Url, path: &str) -> String { + format!("{}{}", base_url.as_str().trim_end_matches('/'), path) } #[cfg(test)] @@ -20,27 +9,14 @@ mod tests { use url::Url; use super::format_api_url; - use crate::error::Error; #[test] - fn test_formats_valid_url() -> std::result::Result<(), Error> { - let base_url = Url::parse("http://localhost:8080")?; + fn joins_the_path_onto_the_trimmed_base() { + let base_url = Url::parse("http://localhost:8080/").unwrap(); assert_eq!( - format_api_url(&base_url, "/api/v1/health")?, + format_api_url(&base_url, "/api/v1/health"), "http://localhost:8080/api/v1/health" ); - - Ok(()) - } - - #[test] - fn test_rejects_path_without_leading_slash() -> std::result::Result<(), Error> { - let base_url = Url::parse("http://localhost:8080")?; - let result = format_api_url(&base_url, "api/v1/health"); - - assert!(result.is_err()); - - Ok(()) } } diff --git a/paddler_client/src/inference_message_stream.rs b/paddler_client/src/inference_message_stream.rs index ac33bc73..f505bc58 100644 --- a/paddler_client/src/inference_message_stream.rs +++ b/paddler_client/src/inference_message_stream.rs @@ -1,9 +1,9 @@ use std::pin::Pin; use futures_util::Stream; -use paddler_types::inference_client::Message as InferenceMessage; +use paddler_messaging::inference_client::message::Message as InferenceMessage; -use crate::Result; +use crate::error::Result; pub type InferenceMessageStream = Pin> + Send + 'static>>; diff --git a/paddler_client/src/inference_socket/connection.rs b/paddler_client/src/inference_socket/connection.rs index 32500a98..3414ea08 100644 --- a/paddler_client/src/inference_socket/connection.rs +++ b/paddler_client/src/inference_socket/connection.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use dashmap::DashMap; use futures_util::StreamExt; -use paddler_types::inference_client::Message as InferenceMessage; +use paddler_messaging::inference_client::message::Message as InferenceMessage; use tokio::sync::mpsc; use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedSender; @@ -65,3 +65,17 @@ impl Connection { Ok(response_rx) } } + +#[cfg(test)] +mod tests { + use url::Url; + + use super::Connection; + + #[tokio::test] + async fn connect_fails_for_an_unreachable_server() { + let url = Url::parse("http://127.0.0.1:1").unwrap(); + + assert!(Connection::connect(url).await.is_err()); + } +} diff --git a/paddler_client/src/inference_socket/pending_requests.rs b/paddler_client/src/inference_socket/pending_requests.rs index a73661d8..76934a80 100644 --- a/paddler_client/src/inference_socket/pending_requests.rs +++ b/paddler_client/src/inference_socket/pending_requests.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use dashmap::DashMap; -use paddler_types::inference_client::Message as InferenceMessage; +use paddler_messaging::inference_client::message::Message as InferenceMessage; use tokio::sync::mpsc::UnboundedSender; use crate::error::Result; diff --git a/paddler_client/src/inference_socket/pool.rs b/paddler_client/src/inference_socket/pool.rs index cb701c70..bbea0b98 100644 --- a/paddler_client/src/inference_socket/pool.rs +++ b/paddler_client/src/inference_socket/pool.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use paddler_types::inference_client::Message as InferenceMessage; +use paddler_messaging::inference_client::message::Message as InferenceMessage; use serde::Serialize; use serde_json::to_string; use tokio::sync::Mutex; @@ -88,3 +88,18 @@ impl Pool { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::Pool; + + #[tokio::test] + async fn round_robins_across_connection_slots() { + let pool = Pool::new(url::Url::parse("http://127.0.0.1:1").unwrap(), 3); + + assert_eq!(pool.next_connection_index().await, 0); + assert_eq!(pool.next_connection_index().await, 1); + assert_eq!(pool.next_connection_index().await, 2); + assert_eq!(pool.next_connection_index().await, 0); + } +} diff --git a/paddler_client/src/inference_socket/spawn_read_task.rs b/paddler_client/src/inference_socket/spawn_read_task.rs index 7262c258..3bf33be4 100644 --- a/paddler_client/src/inference_socket/spawn_read_task.rs +++ b/paddler_client/src/inference_socket/spawn_read_task.rs @@ -2,9 +2,9 @@ use futures_util::StreamExt; use futures_util::stream::SplitStream; use log::error; use log::warn; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::inference_client::Response; -use paddler_types::streamable_result::StreamableResult; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_client::response::Response; +use paddler_messaging::streamable_result::StreamableResult; use serde_json::from_str; use tokio::task::JoinHandle; use tokio_tungstenite::MaybeTlsStream; diff --git a/paddler_client/src/inference_socket/url.rs b/paddler_client/src/inference_socket/url.rs index 6bdbd461..31f2669e 100644 --- a/paddler_client/src/inference_socket/url.rs +++ b/paddler_client/src/inference_socket/url.rs @@ -26,50 +26,37 @@ mod tests { use url::Url; use super::url; - use crate::error::Result; #[test] - fn test_http_becomes_ws() -> Result<()> { - let input = Url::parse("http://localhost:8080/some/path")?; - let result = url(input)?; + fn http_becomes_ws() { + let result = url(Url::parse("http://localhost:8080/some/path").unwrap()).unwrap(); assert_eq!(result.scheme(), "ws"); assert_eq!(result.path(), "/api/v1/inference_socket"); assert_eq!(result.host_str(), Some("localhost")); assert_eq!(result.port(), Some(8080)); - - Ok(()) } #[test] - fn test_https_becomes_wss() -> Result<()> { - let input = Url::parse("https://example.com/ignored")?; - let result = url(input)?; + fn https_becomes_wss() { + let result = url(Url::parse("https://example.com/ignored").unwrap()).unwrap(); assert_eq!(result.scheme(), "wss"); assert_eq!(result.path(), "/api/v1/inference_socket"); - - Ok(()) } #[test] - fn test_ws_scheme_preserved() -> Result<()> { - let input = Url::parse("ws://localhost:9090")?; - let result = url(input)?; + fn ws_scheme_preserved() { + let result = url(Url::parse("ws://localhost:9090").unwrap()).unwrap(); assert_eq!(result.scheme(), "ws"); assert_eq!(result.path(), "/api/v1/inference_socket"); - - Ok(()) } #[test] - fn test_original_path_replaced() -> Result<()> { - let input = Url::parse("http://host/deeply/nested/path?query=1")?; - let result = url(input)?; + fn original_path_replaced() { + let result = url(Url::parse("http://host/deeply/nested/path?query=1").unwrap()).unwrap(); assert_eq!(result.path(), "/api/v1/inference_socket"); - - Ok(()) } } diff --git a/paddler_client/src/lib.rs b/paddler_client/src/lib.rs index 0939d664..6ceb5c03 100644 --- a/paddler_client/src/lib.rs +++ b/paddler_client/src/lib.rs @@ -1,23 +1,18 @@ -mod agents_stream; -mod buffered_requests_stream; -mod client_inference; -mod client_management; -mod error; +pub mod agents_stream; +pub mod buffered_requests_stream; +pub mod client_inference; +pub mod client_management; +pub mod error; mod format_api_url; -mod inference_message_stream; +pub mod inference_message_stream; mod inference_socket; mod stream; use reqwest::Client; use url::Url; -pub use agents_stream::AgentsStream; -pub use buffered_requests_stream::BufferedRequestsStream; -pub use client_inference::ClientInference; -pub use client_management::ClientManagement; -pub use error::Error; -pub use error::Result; -pub use inference_message_stream::InferenceMessageStream; +use crate::client_inference::ClientInference; +use crate::client_management::ClientManagement; pub struct PaddlerClient { inference_url: Url, diff --git a/paddler_client/src/stream/ndjson.rs b/paddler_client/src/stream/ndjson.rs index a60aa164..80e392da 100644 --- a/paddler_client/src/stream/ndjson.rs +++ b/paddler_client/src/stream/ndjson.rs @@ -9,7 +9,7 @@ use reqwest::Response; use serde::de::DeserializeOwned; use serde_json::from_str; -use crate::Result; +use crate::error::Result; fn make_stream( response: Response, @@ -77,3 +77,95 @@ impl Stream for Ndjson { self.inner.as_mut().poll_next(cx) } } + +#[cfg(test)] +mod tests { + use std::io::Error as IoError; + use std::io::ErrorKind; + + use futures_util::StreamExt as _; + use serde_json::Value; + use serde_json::json; + + use super::Ndjson; + use crate::error::Result; + + fn response_from_chunks( + chunks: Vec>, + ) -> reqwest::Response { + let stream = futures_util::stream::iter( + chunks + .into_iter() + .map(|chunk| chunk.map(|text| text.as_bytes().to_vec())), + ); + + reqwest::Response::from(http::Response::new(reqwest::Body::wrap_stream(stream))) + } + + async fn collect_items( + chunks: Vec>, + ) -> Vec> { + Ndjson::::from_response(response_from_chunks(chunks)) + .collect() + .await + } + + #[tokio::test] + async fn parses_multiple_lines_in_one_chunk() { + let items = collect_items(vec![Ok("{\"a\":1}\n{\"a\":2}\n")]).await; + + assert_eq!(items.len(), 2); + assert_eq!(*items[0].as_ref().unwrap(), json!({ "a": 1 })); + assert_eq!(*items[1].as_ref().unwrap(), json!({ "a": 2 })); + } + + #[tokio::test] + async fn reassembles_a_line_split_across_chunks() { + let items = collect_items(vec![Ok("{\"a\""), Ok(":1}\n")]).await; + + assert_eq!(items.len(), 1); + assert_eq!(*items[0].as_ref().unwrap(), json!({ "a": 1 })); + } + + #[tokio::test] + async fn skips_blank_lines() { + let items = collect_items(vec![Ok("\n \n{\"a\":1}\n")]).await; + + assert_eq!(items.len(), 1); + assert_eq!(*items[0].as_ref().unwrap(), json!({ "a": 1 })); + } + + #[tokio::test] + async fn parses_trailing_remainder_without_newline() { + let items = collect_items(vec![Ok("{\"a\":1}")]).await; + + assert_eq!(items.len(), 1); + assert_eq!(*items[0].as_ref().unwrap(), json!({ "a": 1 })); + } + + #[tokio::test] + async fn empty_response_yields_no_items() { + let items = collect_items(vec![]).await; + + assert!(items.is_empty()); + } + + #[tokio::test] + async fn malformed_line_yields_error() { + let items = collect_items(vec![Ok("not json\n")]).await; + + assert_eq!(items.len(), 1); + assert!(items[0].is_err()); + } + + #[tokio::test] + async fn stream_error_yields_error() { + let items = collect_items(vec![ + Ok("{\"a\""), + Err(IoError::new(ErrorKind::ConnectionReset, "boom")), + ]) + .await; + + assert!(items.iter().any(Result::is_err)); + } +} diff --git a/paddler_client/src/stream/sse.rs b/paddler_client/src/stream/sse.rs index ea73b31a..0e513642 100644 --- a/paddler_client/src/stream/sse.rs +++ b/paddler_client/src/stream/sse.rs @@ -6,7 +6,7 @@ use futures_util::Stream; use futures_util::stream::unfold; use reqwest::Response; -use crate::Result; +use crate::error::Result; fn make_stream(response: Response) -> impl Stream> + Send { unfold( @@ -60,3 +60,85 @@ impl Stream for Sse { self.lines.as_mut().poll_next(cx) } } + +#[cfg(test)] +mod tests { + use std::io::Error as IoError; + use std::io::ErrorKind; + + use futures_util::StreamExt as _; + + use super::Sse; + use crate::error::Result; + + fn response_from_chunks( + chunks: Vec>, + ) -> reqwest::Response { + let stream = futures_util::stream::iter( + chunks + .into_iter() + .map(|chunk| chunk.map(|text| text.as_bytes().to_vec())), + ); + + reqwest::Response::from(http::Response::new(reqwest::Body::wrap_stream(stream))) + } + + async fn collect_lines( + chunks: Vec>, + ) -> Vec> { + Sse::from_response(response_from_chunks(chunks)) + .collect() + .await + } + + #[tokio::test] + async fn yields_data_payloads() { + let lines = collect_lines(vec![Ok("data: hello\ndata: world\n")]).await; + + assert_eq!(lines.len(), 2); + assert_eq!(lines[0].as_ref().unwrap(), "hello"); + assert_eq!(lines[1].as_ref().unwrap(), "world"); + } + + #[tokio::test] + async fn strips_trailing_carriage_return() { + let lines = collect_lines(vec![Ok("data: hello\r\n")]).await; + + assert_eq!(lines.len(), 1); + assert_eq!(lines[0].as_ref().unwrap(), "hello"); + } + + #[tokio::test] + async fn skips_non_data_lines() { + let lines = collect_lines(vec![Ok("event: ping\ndata: kept\n")]).await; + + assert_eq!(lines.len(), 1); + assert_eq!(lines[0].as_ref().unwrap(), "kept"); + } + + #[tokio::test] + async fn reassembles_a_payload_split_across_chunks() { + let lines = collect_lines(vec![Ok("data: hel"), Ok("lo\n")]).await; + + assert_eq!(lines.len(), 1); + assert_eq!(lines[0].as_ref().unwrap(), "hello"); + } + + #[tokio::test] + async fn empty_response_yields_no_lines() { + let lines = collect_lines(vec![]).await; + + assert!(lines.is_empty()); + } + + #[tokio::test] + async fn stream_error_yields_error() { + let lines = collect_lines(vec![ + Ok("data: partial"), + Err(IoError::new(ErrorKind::ConnectionReset, "boom")), + ]) + .await; + + assert!(lines.iter().any(Result::is_err)); + } +} diff --git a/paddler_client_cli/Cargo.toml b/paddler_client_cli/Cargo.toml deleted file mode 100644 index 9b56ade1..00000000 --- a/paddler_client_cli/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "paddler_client_cli" -version.workspace = true -edition.workspace = true -authors.workspace = true -description = "Client CLI/TUI binary for Paddler" -license.workspace = true - -[[bin]] -name = "paddler_client_cli" -path = "src/main.rs" - -[dependencies] -anyhow = { workspace = true } -async-trait = { workspace = true } -clap = { workspace = true } -crossterm = { workspace = true } -env_logger = { workspace = true } -futures-util = { workspace = true } -llama-cpp-bindings-types = { workspace = true } -log = { workspace = true } -paddler_bootstrap = { workspace = true } -paddler_client = { workspace = true } -paddler_types = { workspace = true } -ratatui = { workspace = true } -reqwest = { workspace = true } -serde_json = { workspace = true } -tokio = { workspace = true } -tokio-util = { workspace = true } -url = { workspace = true } - -[lints] -workspace = true diff --git a/paddler_client_cli/examples/calculator.json b/paddler_client_cli/examples/calculator.json deleted file mode 100644 index a857e795..00000000 --- a/paddler_client_cli/examples/calculator.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "type": "function", - "function": { - "name": "calculator", - "description": "Perform a basic arithmetic operation on two numbers", - "parameters": { - "type": "object", - "properties": { - "operation": { - "type": "string", - "enum": ["add", "subtract", "multiply", "divide"], - "description": "The operation to perform" - }, - "a": { - "type": "number", - "description": "First operand" - }, - "b": { - "type": "number", - "description": "Second operand" - } - }, - "required": ["operation", "a", "b"], - "additionalProperties": false - } - } -} diff --git a/paddler_client_cli/examples/get_weather.json b/paddler_client_cli/examples/get_weather.json deleted file mode 100644 index d3948491..00000000 --- a/paddler_client_cli/examples/get_weather.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather for a city", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name, e.g. 'Paris' or 'Tokyo'" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature unit" - } - }, - "required": ["location"], - "additionalProperties": false - } - } -} diff --git a/paddler_client_cli/examples/negotiate_with_cat.json b/paddler_client_cli/examples/negotiate_with_cat.json deleted file mode 100644 index 6b74c15f..00000000 --- a/paddler_client_cli/examples/negotiate_with_cat.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "type": "function", - "function": { - "name": "negotiate_with_cat", - "description": "Attempt to negotiate with a cat. Outcomes are not guaranteed and may include the silent treatment.", - "parameters": { - "type": "object", - "properties": { - "topic": { - "type": "string", - "description": "What you are trying to negotiate, e.g. 'get off the keyboard' or 'stop knocking things off the table'" - }, - "bribe": { - "type": "string", - "enum": ["tuna", "salmon", "treats", "ear_scritches", "cardboard_box", "none"], - "description": "What you are offering in exchange" - }, - "desperation_level": { - "type": "integer", - "description": "How desperate you are, on a scale from 1 (mildly annoyed human) to 10 (it is 3am)", - "minimum": 1, - "maximum": 10 - } - }, - "required": ["topic"], - "additionalProperties": false - } - } -} diff --git a/paddler_client_cli/src/chat_session.rs b/paddler_client_cli/src/chat_session.rs deleted file mode 100644 index 5742499a..00000000 --- a/paddler_client_cli/src/chat_session.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::io; - -use anyhow::Result; -use anyhow::anyhow; -use crossterm::event::Event as CrosstermEvent; -use crossterm::event::EventStream; -use crossterm::event::KeyCode; -use crossterm::event::KeyEvent; -use crossterm::event::KeyModifiers; -use crossterm::event::MouseButton; -use crossterm::event::MouseEvent; -use crossterm::event::MouseEventKind; -use futures_util::StreamExt; -use paddler_client::InferenceMessageStream; -use ratatui::Terminal; -use ratatui::backend::CrosstermBackend; -use ratatui::layout::Rect; -use tokio_util::sync::CancellationToken; - -use crate::chat_session_event::ChatSessionEvent; -use crate::streaming_response::StreamingResponse; -use crate::view_chat_panels::view_chat_panels; -use crate::view_panel_layout::ViewPanelLayout; -use crate::view_panel_navigation::ViewPanelNavigation; -use crate::view_terminal_guard::ViewTerminalGuard; - -const MOUSE_WHEEL_LINES: u16 = 3; -const ARROW_KEY_LINES: u16 = 1; - -pub struct ChatSession { - inference_stream: InferenceMessageStream, - state: StreamingResponse, - navigation: ViewPanelNavigation, - shutdown: CancellationToken, -} - -impl ChatSession { - pub fn new(inference_stream: InferenceMessageStream, shutdown: CancellationToken) -> Self { - Self { - inference_stream, - state: StreamingResponse::default(), - navigation: ViewPanelNavigation::default(), - shutdown, - } - } - - pub async fn run(mut self) -> Result<()> { - let _terminal_guard = ViewTerminalGuard::enter()?; - let mut terminal = Terminal::new(CrosstermBackend::new(io::stdout()))?; - let mut events = EventStream::new(); - - let mut layout = compute_layout(&terminal)?; - terminal.draw(|frame| { - view_chat_panels(&self.state, &mut self.navigation, &layout, frame); - })?; - - loop { - match self.next_event(&mut events).await { - ChatSessionEvent::InferenceMessage(message) => { - self.state.apply_message(message); - } - ChatSessionEvent::InferenceStreamEnded => { - if !self.state.is_finished() { - self.state.record_wire_error(&anyhow!( - "inference stream ended before sending Done" - )); - } - } - ChatSessionEvent::InferenceStreamError(error) => { - self.state.record_wire_error(&error); - } - ChatSessionEvent::Key(key_event) => { - if is_quit(key_event) { - return Ok(()); - } - self.handle_navigation_key(key_event, &layout); - } - ChatSessionEvent::Mouse(mouse_event) => { - self.handle_mouse(mouse_event, &layout); - } - ChatSessionEvent::Repaint => {} - ChatSessionEvent::Shutdown => return Ok(()), - } - layout = compute_layout(&terminal)?; - terminal.draw(|frame| { - view_chat_panels(&self.state, &mut self.navigation, &layout, frame); - })?; - } - } - - async fn next_event(&mut self, events: &mut EventStream) -> ChatSessionEvent { - let inference_active = !self.state.is_finished(); - loop { - tokio::select! { - biased; - () = self.shutdown.cancelled() => return ChatSessionEvent::Shutdown, - maybe_event = events.next() => match maybe_event { - Some(Ok(CrosstermEvent::Key(key))) => return ChatSessionEvent::Key(key), - Some(Ok(CrosstermEvent::Mouse(mouse))) => return ChatSessionEvent::Mouse(mouse), - Some(Ok(CrosstermEvent::Resize(_, _))) => return ChatSessionEvent::Repaint, - Some(Ok(_)) => {} - Some(Err(read_error)) => { - log::error!("terminal event read error: {read_error}"); - return ChatSessionEvent::Shutdown; - } - None => return ChatSessionEvent::Shutdown, - }, - maybe_message = self.inference_stream.next(), if inference_active => match maybe_message { - Some(Ok(message)) => return ChatSessionEvent::InferenceMessage(message), - Some(Err(stream_error)) => return ChatSessionEvent::InferenceStreamError(stream_error.into()), - None => return ChatSessionEvent::InferenceStreamEnded, - }, - } - } - } - - fn handle_navigation_key(&mut self, key_event: KeyEvent, layout: &ViewPanelLayout) { - let focused = self.navigation.focused(); - let viewport_rows = layout.viewport_rows(focused); - let page_lines = viewport_rows.saturating_sub(1).max(1); - match key_event.code { - KeyCode::Up => self.navigation.scroll_up(focused, ARROW_KEY_LINES), - KeyCode::Down => self.navigation.scroll_down(focused, ARROW_KEY_LINES), - KeyCode::PageUp => self.navigation.scroll_up(focused, page_lines), - KeyCode::PageDown => self.navigation.scroll_down(focused, page_lines), - KeyCode::Home => self.navigation.jump_to_top(focused), - KeyCode::End => self.navigation.jump_to_bottom(focused), - KeyCode::Tab => self.navigation.cycle_focus_forward(), - KeyCode::BackTab => self.navigation.cycle_focus_backward(), - _ => {} - } - } - - fn handle_mouse(&mut self, mouse_event: MouseEvent, layout: &ViewPanelLayout) { - let Some(panel) = layout.panel_at(mouse_event.column, mouse_event.row) else { - return; - }; - match mouse_event.kind { - MouseEventKind::ScrollUp => { - self.navigation.focus(panel); - self.navigation.scroll_up(panel, MOUSE_WHEEL_LINES); - } - MouseEventKind::ScrollDown => { - self.navigation.focus(panel); - self.navigation.scroll_down(panel, MOUSE_WHEEL_LINES); - } - MouseEventKind::Down(MouseButton::Left) => { - self.navigation.focus(panel); - } - _ => {} - } - } -} - -fn compute_layout(terminal: &Terminal>) -> Result { - let size = terminal.size()?; - Ok(ViewPanelLayout::compute(Rect::new( - 0, - 0, - size.width, - size.height, - ))) -} - -const fn is_quit(key_event: KeyEvent) -> bool { - if key_event.modifiers.contains(KeyModifiers::CONTROL) - && matches!(key_event.code, KeyCode::Char('c')) - { - return true; - } - matches!(key_event.code, KeyCode::Char('q' | 'Q') | KeyCode::Esc) -} diff --git a/paddler_client_cli/src/chat_session_event.rs b/paddler_client_cli/src/chat_session_event.rs deleted file mode 100644 index 8c1e0edd..00000000 --- a/paddler_client_cli/src/chat_session_event.rs +++ /dev/null @@ -1,14 +0,0 @@ -use crossterm::event::KeyEvent; -use crossterm::event::MouseEvent; -use paddler_types::inference_client::Message; - -#[derive(Debug)] -pub enum ChatSessionEvent { - InferenceMessage(Message), - InferenceStreamEnded, - InferenceStreamError(anyhow::Error), - Key(KeyEvent), - Mouse(MouseEvent), - Repaint, - Shutdown, -} diff --git a/paddler_client_cli/src/cmd/handler.rs b/paddler_client_cli/src/cmd/handler.rs deleted file mode 100644 index 2840065c..00000000 --- a/paddler_client_cli/src/cmd/handler.rs +++ /dev/null @@ -1,8 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; - -#[async_trait] -pub trait Handler { - async fn handle(&self, shutdown: CancellationToken) -> Result<()>; -} diff --git a/paddler_client_cli/src/cmd/mod.rs b/paddler_client_cli/src/cmd/mod.rs deleted file mode 100644 index bc27b40c..00000000 --- a/paddler_client_cli/src/cmd/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod handler; -pub mod prompt; diff --git a/paddler_client_cli/src/cmd/prompt.rs b/paddler_client_cli/src/cmd/prompt.rs deleted file mode 100644 index 13efb6ed..00000000 --- a/paddler_client_cli/src/cmd/prompt.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::path::PathBuf; - -use anyhow::Result; -use async_trait::async_trait; -use clap::Parser; -use paddler_client::ClientInference; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; -use tokio_util::sync::CancellationToken; -use url::Url; - -use super::handler::Handler; -use crate::chat_session::ChatSession; -use crate::prompt_load_tool::prompt_load_tool; -use crate::prompt_parse_inference_url::prompt_parse_inference_url; -use crate::prompt_thinking_mode::PromptThinkingMode; - -#[derive(Parser)] -pub struct Prompt { - #[arg(long, value_parser = prompt_parse_inference_url)] - /// Address of the inference server (e.g. 127.0.0.1:8061) - inference_addr: Url, - - #[arg(long)] - /// Maximum number of tokens to generate - max_tokens: i32, - - #[arg(long, value_enum)] - /// Whether chain-of-thought thinking is on or off - thinking: PromptThinkingMode, - - #[arg(long, action = clap::ArgAction::Append)] - /// Path to a JSON file describing one tool (repeatable) - tool: Vec, - - /// Prompt to send to the model - message: String, -} - -#[async_trait] -impl Handler for Prompt { - async fn handle(&self, shutdown: CancellationToken) -> Result<()> { - let tools = self - .tool - .iter() - .map(|path| prompt_load_tool(path)) - .collect::>>()?; - - let request = ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new(vec![ConversationMessage { - content: ConversationMessageContent::Text(self.message.clone()), - role: "user".to_owned(), - }]), - enable_thinking: self.thinking.is_enabled(), - grammar: None, - max_tokens: self.max_tokens, - parse_tool_calls: !tools.is_empty(), - tools, - }; - - let http_client = Client::new(); - let inference = ClientInference::new(&self.inference_addr, &http_client, 1); - let stream = inference - .post_continue_from_conversation_history(&request) - .await?; - - ChatSession::new(stream, shutdown).run().await - } -} diff --git a/paddler_client_cli/src/main.rs b/paddler_client_cli/src/main.rs deleted file mode 100644 index fb47974c..00000000 --- a/paddler_client_cli/src/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -mod chat_session; -mod chat_session_event; -mod cmd; -mod prompt_load_tool; -mod prompt_parse_inference_url; -mod prompt_thinking_mode; -mod stop_reason; -mod streaming_response; -mod view_chat_panels; -mod view_panel_kind; -mod view_panel_layout; -mod view_panel_navigation; -mod view_terminal_guard; - -use anyhow::Result; -use clap::Parser; -use clap::Subcommand; -use cmd::handler::Handler as _; -use cmd::prompt::Prompt; -use paddler_bootstrap::shutdown_signal::register_shutdown_signals; -use tokio_util::sync::CancellationToken; - -#[derive(Parser)] -#[command(arg_required_else_help(true), version, about, long_about = None)] -struct Cli { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand)] -enum Commands { - Prompt(Prompt), -} - -#[tokio::main] -async fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - - let shutdown_signals = register_shutdown_signals()?; - let shutdown = CancellationToken::new(); - let signal_shutdown = shutdown.clone(); - - tokio::spawn(async move { - if let Err(error) = shutdown_signals.wait().await { - log::error!("shutdown signal listener failed: {error}"); - return; - } - signal_shutdown.cancel(); - }); - - match Cli::parse().command { - Commands::Prompt(handler) => handler.handle(shutdown).await, - } -} diff --git a/paddler_client_cli/src/prompt_load_tool.rs b/paddler_client_cli/src/prompt_load_tool.rs deleted file mode 100644 index 54a20786..00000000 --- a/paddler_client_cli/src/prompt_load_tool.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::fs::File; -use std::path::Path; - -use anyhow::Context; -use anyhow::Result; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use paddler_types::validates::Validates; - -pub fn prompt_load_tool(path: &Path) -> Result> { - let file = File::open(path).with_context(|| format!("opening tool file {}", path.display()))?; - let raw: Tool = serde_json::from_reader(file) - .with_context(|| format!("parsing tool file {}", path.display()))?; - raw.validate() - .with_context(|| format!("validating tool from {}", path.display())) -} diff --git a/paddler_client_cli/src/prompt_parse_inference_url.rs b/paddler_client_cli/src/prompt_parse_inference_url.rs deleted file mode 100644 index 2678ed7c..00000000 --- a/paddler_client_cli/src/prompt_parse_inference_url.rs +++ /dev/null @@ -1,6 +0,0 @@ -use url::Url; - -pub fn prompt_parse_inference_url(input_addr: &str) -> Result { - Url::parse(&format!("http://{input_addr}")) - .map_err(|err| format!("invalid address '{input_addr}': {err}")) -} diff --git a/paddler_client_cli/src/prompt_thinking_mode.rs b/paddler_client_cli/src/prompt_thinking_mode.rs deleted file mode 100644 index 573a2a64..00000000 --- a/paddler_client_cli/src/prompt_thinking_mode.rs +++ /dev/null @@ -1,14 +0,0 @@ -use clap::ValueEnum; - -#[derive(Clone, Copy, ValueEnum)] -pub enum PromptThinkingMode { - On, - Off, -} - -impl PromptThinkingMode { - #[must_use] - pub const fn is_enabled(self) -> bool { - matches!(self, Self::On) - } -} diff --git a/paddler_client_cli/src/stop_reason.rs b/paddler_client_cli/src/stop_reason.rs deleted file mode 100644 index 391ca9f4..00000000 --- a/paddler_client_cli/src/stop_reason.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::fmt; - -use paddler_types::oversized_image_details::OversizedImageDetails; - -#[derive(Debug)] -pub enum StopReason { - Completed, - ChatTemplateError(String), - GrammarIncompatibleWithThinking(String), - GrammarInitializationFailed(String), - GrammarRejectedModelOutput(String), - GrammarSyntaxError(String), - ImageDecodingFailed(String), - ImageExceedsBatchSize(OversizedImageDetails), - InferenceError { code: i32, description: String }, - MultimodalNotSupported(String), - SamplerError(String), - Timeout, - TooManyBufferedRequests, - ToolCallParseFailed(String), - ToolCallValidationFailed(Vec), - ToolSchemaInvalid(String), - WireStreamError(String), -} - -impl fmt::Display for StopReason { - fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Completed => formatter.write_str("completed"), - Self::ChatTemplateError(detail) => { - write!(formatter, "chat template error: {detail}") - } - Self::GrammarIncompatibleWithThinking(detail) => { - write!(formatter, "grammar incompatible with thinking: {detail}") - } - Self::GrammarInitializationFailed(detail) => { - write!(formatter, "grammar initialization failed: {detail}") - } - Self::GrammarRejectedModelOutput(detail) => { - write!(formatter, "grammar rejected model output: {detail}") - } - Self::GrammarSyntaxError(detail) => { - write!(formatter, "grammar syntax error: {detail}") - } - Self::ImageDecodingFailed(detail) => { - write!(formatter, "image decoding failed: {detail}") - } - Self::ImageExceedsBatchSize(details) => { - write!( - formatter, - "image required {} tokens but agent n_batch is {}", - details.image_tokens, details.n_batch, - ) - } - Self::InferenceError { code, description } => { - write!(formatter, "inference error {code}: {description}") - } - Self::MultimodalNotSupported(detail) => { - write!(formatter, "multimodal input not supported: {detail}") - } - Self::SamplerError(detail) => write!(formatter, "sampler error: {detail}"), - Self::Timeout => formatter.write_str("balancer timed out the request"), - Self::TooManyBufferedRequests => { - formatter.write_str("balancer rejected the request: queue is full") - } - Self::ToolCallParseFailed(detail) => { - write!(formatter, "tool-call parse failed: {detail}") - } - Self::ToolCallValidationFailed(field_errors) => { - write!( - formatter, - "tool-call validation failed: {}", - field_errors.join("; ") - ) - } - Self::ToolSchemaInvalid(detail) => { - write!(formatter, "tool schema invalid: {detail}") - } - Self::WireStreamError(detail) => { - write!(formatter, "wire stream error: {detail}") - } - } - } -} diff --git a/paddler_client_cli/src/streaming_response.rs b/paddler_client_cli/src/streaming_response.rs deleted file mode 100644 index eb321e53..00000000 --- a/paddler_client_cli/src/streaming_response.rs +++ /dev/null @@ -1,229 +0,0 @@ -use llama_cpp_bindings_types::ParsedToolCall; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::generation_summary::GenerationSummary; -use paddler_types::inference_client::Message; -use paddler_types::inference_client::Response; -use paddler_types::raw_tool_call_tokens::RawToolCallTokens; - -use crate::stop_reason::StopReason; - -#[derive(Debug, Default)] -pub struct StreamingResponse { - pub thinking: Vec, - pub response: Vec, - pub tool_call_tokens: Vec, - pub tool_calls: Vec, - pub undetermined: Vec, - pub unrecognized_tool_call_format: Vec, - pub summary: Option, - pub stop_reason: Option, -} - -impl StreamingResponse { - pub fn apply_message(&mut self, message: Message) { - match message { - Message::Error(envelope) => { - self.stop_reason = Some(StopReason::InferenceError { - code: envelope.error.code, - description: envelope.error.description, - }); - } - Message::Response(envelope) => self.apply_response(envelope.response), - } - } - - pub fn record_wire_error(&mut self, error: &anyhow::Error) { - self.stop_reason = Some(StopReason::WireStreamError(error.to_string())); - } - - pub const fn is_finished(&self) -> bool { - self.stop_reason.is_some() - } - - fn apply_response(&mut self, response: Response) { - match response { - Response::GeneratedToken(token_result) => self.apply_token_result(token_result), - Response::Timeout => { - self.stop_reason = Some(StopReason::Timeout); - } - Response::TooManyBufferedRequests => { - self.stop_reason = Some(StopReason::TooManyBufferedRequests); - } - Response::Embedding(_) => { - unreachable!("server sent an embedding response on a token-generation stream") - } - } - } - - fn apply_token_result(&mut self, token_result: GeneratedTokenResult) { - match token_result { - GeneratedTokenResult::ContentToken(piece) => self.response.push(piece), - GeneratedTokenResult::ReasoningToken(piece) => self.thinking.push(piece), - GeneratedTokenResult::UndeterminableToken(piece) => self.undetermined.push(piece), - GeneratedTokenResult::ToolCallToken(piece) => self.tool_call_tokens.push(piece), - GeneratedTokenResult::ToolCallParsed(calls) => { - self.tool_calls.extend(calls); - } - GeneratedTokenResult::Done(summary) => { - self.summary = Some(summary); - self.stop_reason = Some(StopReason::Completed); - } - GeneratedTokenResult::ChatTemplateError(detail) => { - self.stop_reason = Some(StopReason::ChatTemplateError(detail)); - } - GeneratedTokenResult::GrammarIncompatibleWithThinking(detail) => { - self.stop_reason = Some(StopReason::GrammarIncompatibleWithThinking(detail)); - } - GeneratedTokenResult::GrammarInitializationFailed(detail) => { - self.stop_reason = Some(StopReason::GrammarInitializationFailed(detail)); - } - GeneratedTokenResult::GrammarRejectedModelOutput(detail) => { - self.stop_reason = Some(StopReason::GrammarRejectedModelOutput(detail)); - } - GeneratedTokenResult::GrammarSyntaxError(detail) => { - self.stop_reason = Some(StopReason::GrammarSyntaxError(detail)); - } - GeneratedTokenResult::ImageDecodingFailed(detail) => { - self.stop_reason = Some(StopReason::ImageDecodingFailed(detail)); - } - GeneratedTokenResult::ImageExceedsBatchSize(details) => { - self.stop_reason = Some(StopReason::ImageExceedsBatchSize(details)); - } - GeneratedTokenResult::MultimodalNotSupported(detail) => { - self.stop_reason = Some(StopReason::MultimodalNotSupported(detail)); - } - GeneratedTokenResult::SamplerError(detail) => { - self.stop_reason = Some(StopReason::SamplerError(detail)); - } - GeneratedTokenResult::ToolCallParseFailed(detail) => { - self.stop_reason = Some(StopReason::ToolCallParseFailed(detail)); - } - GeneratedTokenResult::ToolCallValidationFailed(field_errors) => { - self.stop_reason = Some(StopReason::ToolCallValidationFailed(field_errors)); - } - GeneratedTokenResult::ToolSchemaInvalid(detail) => { - self.stop_reason = Some(StopReason::ToolSchemaInvalid(detail)); - } - GeneratedTokenResult::UnrecognizedToolCallFormat(raw) => { - self.unrecognized_tool_call_format.push(raw); - } - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::anyhow; - use paddler_types::jsonrpc::Error; - use paddler_types::jsonrpc::ErrorEnvelope; - use paddler_types::jsonrpc::ResponseEnvelope; - - use super::*; - - fn token_message(token_result: GeneratedTokenResult) -> Message { - Message::Response(ResponseEnvelope { - generated_by: None, - request_id: "test-request".to_owned(), - response: Response::GeneratedToken(token_result), - }) - } - - #[test] - fn content_token_appended_to_response_stream() { - let mut state = StreamingResponse::default(); - - state.apply_message(token_message(GeneratedTokenResult::ContentToken( - "hello ".to_owned(), - ))); - state.apply_message(token_message(GeneratedTokenResult::ContentToken( - "world".to_owned(), - ))); - - assert_eq!( - state.response, - vec!["hello ".to_owned(), "world".to_owned()] - ); - assert!(state.thinking.is_empty()); - assert!(state.undetermined.is_empty()); - assert!(!state.is_finished()); - } - - #[test] - fn raw_tool_call_token_appended_to_token_stream() { - let mut state = StreamingResponse::default(); - - state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( - "{\"name\":".to_owned(), - ))); - state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( - "\"calc\"}".to_owned(), - ))); - - assert_eq!( - state.tool_call_tokens, - vec!["{\"name\":".to_owned(), "\"calc\"}".to_owned()] - ); - assert!(state.tool_calls.is_empty()); - } - - #[test] - fn tool_call_parsed_extends_calls_without_dropping_token_stream() { - let mut state = StreamingResponse::default(); - state.apply_message(token_message(GeneratedTokenResult::ToolCallToken( - "{\"name\":\"calc\"}".to_owned(), - ))); - let parsed = vec![ParsedToolCall::default()]; - - state.apply_message(token_message(GeneratedTokenResult::ToolCallParsed( - parsed.clone(), - ))); - - assert_eq!(state.tool_calls, parsed); - assert_eq!( - state.tool_call_tokens, - vec!["{\"name\":\"calc\"}".to_owned()] - ); - } - - #[test] - fn done_records_summary_and_completed_stop_reason() { - let mut state = StreamingResponse::default(); - let summary = GenerationSummary::default(); - - state.apply_message(token_message(GeneratedTokenResult::Done(summary))); - - assert!(state.summary.is_some()); - assert!(matches!(state.stop_reason, Some(StopReason::Completed))); - assert!(state.is_finished()); - } - - #[test] - fn message_error_sets_inference_error_stop_reason() { - let mut state = StreamingResponse::default(); - - state.apply_message(Message::Error(ErrorEnvelope { - request_id: "test-request".to_owned(), - error: Error { - code: 503, - description: "agent unavailable".to_owned(), - }, - })); - - assert!(matches!( - state.stop_reason, - Some(StopReason::InferenceError { code: 503, .. }) - )); - } - - #[test] - fn wire_error_sets_wire_stream_error_stop_reason() { - let mut state = StreamingResponse::default(); - - state.record_wire_error(&anyhow!("connection reset")); - - assert!(matches!( - state.stop_reason, - Some(StopReason::WireStreamError(ref message)) if message.contains("connection reset") - )); - } -} diff --git a/paddler_client_cli/src/view_chat_panels.rs b/paddler_client_cli/src/view_chat_panels.rs deleted file mode 100644 index 80e99892..00000000 --- a/paddler_client_cli/src/view_chat_panels.rs +++ /dev/null @@ -1,382 +0,0 @@ -use llama_cpp_bindings_types::ParsedToolCall; -use llama_cpp_bindings_types::ToolCallArguments; -use paddler_types::generation_summary::GenerationSummary; -use ratatui::Frame; -use ratatui::layout::Margin; -use ratatui::layout::Rect; -use ratatui::style::Color; -use ratatui::style::Style; -use ratatui::text::Line; -use ratatui::text::Span; -use ratatui::text::Text; -use ratatui::widgets::Block; -use ratatui::widgets::Paragraph; -use ratatui::widgets::Scrollbar; -use ratatui::widgets::ScrollbarOrientation; -use ratatui::widgets::ScrollbarState; -use ratatui::widgets::Wrap; - -use crate::streaming_response::StreamingResponse; -use crate::view_panel_kind::ViewPanelKind; -use crate::view_panel_layout::ViewPanelLayout; -use crate::view_panel_navigation::ViewPanelNavigation; - -const TOKEN_PALETTE: [Color; 6] = [ - Color::LightCyan, - Color::LightYellow, - Color::LightMagenta, - Color::LightGreen, - Color::LightBlue, - Color::LightRed, -]; - -pub fn view_chat_panels( - state: &StreamingResponse, - navigation: &mut ViewPanelNavigation, - layout: &ViewPanelLayout, - frame: &mut Frame<'_>, -) { - render_panel_text( - frame, - layout.thinking, - ViewPanelKind::Thinking, - Text::from(build_colored_token_lines(&state.thinking)), - navigation, - ); - render_panel_text( - frame, - layout.response, - ViewPanelKind::Response, - Text::from(build_colored_token_lines(&state.response)), - navigation, - ); - render_panel_text( - frame, - layout.tool_calls, - ViewPanelKind::ToolCalls, - Text::from(build_tool_calls_lines( - &state.tool_call_tokens, - &state.tool_calls, - Block::bordered().inner(layout.tool_calls).width, - )), - navigation, - ); - render_panel_text( - frame, - layout.undetermined, - ViewPanelKind::Undetermined, - Text::from(build_colored_token_lines(&state.undetermined)), - navigation, - ); - render_status_bar(frame, layout.status_bar, state); -} - -fn render_panel_text( - frame: &mut Frame<'_>, - area: Rect, - panel: ViewPanelKind, - text: Text<'_>, - navigation: &mut ViewPanelNavigation, -) { - let title = if navigation.focused() == panel { - format!("[ {} ]", panel.label()) - } else { - format!(" {} ", panel.label()) - }; - let block = Block::bordered().title(title); - let inner = block.inner(area); - let visible_rows = count_text_rows(&text, inner.width); - navigation.settle(panel, visible_rows.into(), inner.height.into()); - let position = navigation.position(panel); - let scroll_offset = u16::try_from(position).unwrap_or(u16::MAX); - - let paragraph = Paragraph::new(text) - .wrap(Wrap { trim: false }) - .scroll((scroll_offset, 0)) - .block(block); - frame.render_widget(paragraph, area); - - if visible_rows > inner.height { - let mut scrollbar_state = ScrollbarState::new(visible_rows.into()) - .position(position) - .viewport_content_length(inner.height.into()); - let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight) - .thumb_symbol("┃") - .track_symbol(Some("│")) - .begin_symbol(None) - .end_symbol(None); - frame.render_stateful_widget( - scrollbar, - area.inner(Margin { - vertical: 1, - horizontal: 0, - }), - &mut scrollbar_state, - ); - } -} - -fn build_colored_token_lines(tokens: &[String]) -> Vec> { - let mut lines: Vec> = Vec::new(); - let mut current_line: Vec> = Vec::new(); - let mut had_token = false; - - for (token_index, token) in tokens.iter().enumerate() { - if token.is_empty() { - continue; - } - had_token = true; - let is_whitespace_only = token.chars().all(char::is_whitespace); - let style = if is_whitespace_only { - Style::default() - .bg(palette_color(token_index)) - .fg(Color::Black) - } else { - Style::default().fg(palette_color(token_index)) - }; - for (piece_index, piece) in token.split('\n').enumerate() { - if piece_index > 0 { - if is_whitespace_only { - current_line.push(Span::styled("↵", style)); - } - lines.push(Line::from(std::mem::take(&mut current_line))); - } - if !piece.is_empty() { - let rendered = if is_whitespace_only { - piece.replace('\t', "→") - } else { - piece.to_owned() - }; - if !rendered.is_empty() { - current_line.push(Span::styled(rendered, style)); - } - } - } - } - - if had_token { - lines.push(Line::from(current_line)); - } - - lines -} - -fn build_tool_calls_lines( - token_stream: &[String], - parsed_calls: &[ParsedToolCall], - inner_width: u16, -) -> Vec> { - let mut lines = build_colored_token_lines(token_stream); - let has_tokens = !token_stream.is_empty(); - let has_parsed = !parsed_calls.is_empty(); - if has_tokens && has_parsed { - lines.push(divider_line(inner_width)); - } - lines.extend(parsed_call_lines(parsed_calls)); - lines -} - -fn divider_line(width: u16) -> Line<'static> { - let label = " parsed "; - let label_width = u16::try_from(label.chars().count()).unwrap_or(u16::MAX); - let total = width.max(label_width.saturating_add(2)); - let dash_count = total.saturating_sub(label_width); - let left = dash_count / 2; - let right = dash_count - left; - let mut text = String::new(); - for _ in 0..left { - text.push('─'); - } - text.push_str(label); - for _ in 0..right { - text.push('─'); - } - Line::from(Span::styled(text, Style::default().fg(Color::Gray))) -} - -fn parsed_call_lines(calls: &[ParsedToolCall]) -> Vec> { - let mut lines = Vec::new(); - for call in calls { - lines.push(Line::raw(call.name.clone())); - match &call.arguments { - ToolCallArguments::ValidJson(value) => match serde_json::to_string_pretty(value) { - Ok(formatted) => { - for inner_line in formatted.lines() { - lines.push(Line::raw(format!(" {inner_line}"))); - } - } - Err(format_error) => { - log::error!( - "failed to pretty-print tool-call arguments for {name}: {format_error}", - name = call.name - ); - lines.push(Line::raw(format!(" {value}"))); - } - }, - ToolCallArguments::InvalidJson(raw) => { - lines.push(Line::raw(format!(" invalid JSON: {raw}"))); - } - } - } - lines -} - -const fn palette_color(token_index: usize) -> Color { - TOKEN_PALETTE[token_index % TOKEN_PALETTE.len()] -} - -fn count_text_rows(text: &Text<'_>, width: u16) -> u16 { - if width == 0 { - return 0; - } - let mut total: u16 = 0; - for line in &text.lines { - let chars_count: usize = line - .spans - .iter() - .map(|span| span.content.chars().count()) - .sum(); - let chars = u16::try_from(chars_count).unwrap_or(u16::MAX); - let rows = chars.div_ceil(width).max(1); - total = total.saturating_add(rows); - } - total -} - -fn render_status_bar(frame: &mut Frame<'_>, area: Rect, state: &StreamingResponse) { - let text = match (&state.stop_reason, &state.summary) { - (None, _) => { - "generating… · tab/shift-tab focus · ↑↓ pgup/pgdn home/end scroll · q quit".to_owned() - } - (Some(_), Some(summary)) => format_completion_status(summary), - (Some(reason), None) => format!("stopped — {reason} · press q to quit"), - }; - frame.render_widget(Paragraph::new(text), area); -} - -fn format_completion_status(summary: &GenerationSummary) -> String { - let usage = summary.usage; - format!( - "done · response {response} · thinking {thinking} · tools {tools} · undet {undet} · prompt {prompt} · total {total} · press q to quit", - response = usage.content_tokens, - thinking = usage.reasoning_tokens, - tools = usage.tool_call_tokens, - undet = usage.undeterminable_tokens, - prompt = usage.prompt_tokens, - total = usage.total_tokens(), - ) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use paddler_types::generation_summary::GenerationSummary; - use ratatui::Terminal; - use ratatui::backend::TestBackend; - use ratatui::buffer::Buffer; - - use super::*; - use crate::stop_reason::StopReason; - - fn render_to_string(state: &StreamingResponse, width: u16, height: u16) -> Result { - let mut navigation = ViewPanelNavigation::default(); - let mut terminal = Terminal::new(TestBackend::new(width, height))?; - terminal.draw(|frame| { - let layout = ViewPanelLayout::compute(frame.area()); - view_chat_panels(state, &mut navigation, &layout, frame); - })?; - Ok(buffer_text(terminal.backend().buffer())) - } - - fn buffer_text(buffer: &Buffer) -> String { - let area = buffer.area; - let mut output = String::with_capacity((area.width as usize + 1) * area.height as usize); - for y in 0..area.height { - for x in 0..area.width { - output.push_str(buffer[(x, y)].symbol()); - } - output.push('\n'); - } - output - } - - #[test] - fn empty_state_shows_all_four_panels_and_generating_status() -> Result<()> { - let state = StreamingResponse::default(); - - let rendered = render_to_string(&state, 100, 30)?; - - assert!(rendered.contains("Thinking")); - assert!(rendered.contains("Response")); - assert!(rendered.contains("Tool Calls")); - assert!(rendered.contains("Undetermined")); - assert!(rendered.contains("generating")); - assert!(!rendered.contains("done")); - Ok(()) - } - - #[test] - fn focused_panel_title_uses_brackets() -> Result<()> { - let state = StreamingResponse::default(); - - let rendered = render_to_string(&state, 100, 30)?; - - assert!(rendered.contains("[ Response ]")); - assert!(rendered.contains(" Thinking ")); - Ok(()) - } - - #[test] - fn response_buffer_text_is_visible() -> Result<()> { - let mut state = StreamingResponse::default(); - state.response.push("hello world".to_owned()); - - let rendered = render_to_string(&state, 80, 30)?; - - assert!(rendered.contains("hello world")); - Ok(()) - } - - #[test] - fn completed_state_shows_summary_and_quit_hint() -> Result<()> { - let state = StreamingResponse { - summary: Some(GenerationSummary::default()), - stop_reason: Some(StopReason::Completed), - ..StreamingResponse::default() - }; - - let rendered = render_to_string(&state, 140, 30)?; - - assert!(rendered.contains("done")); - assert!(rendered.contains("press q to quit")); - assert!(!rendered.contains("generating")); - Ok(()) - } - - #[test] - fn whitespace_only_newline_token_renders_return_marker() -> Result<()> { - let mut state = StreamingResponse::default(); - state.response.push("hello".to_owned()); - state.response.push("\n".to_owned()); - state.response.push("world".to_owned()); - - let rendered = render_to_string(&state, 80, 30)?; - - assert!(rendered.contains("↵")); - Ok(()) - } - - #[test] - fn tool_calls_panel_shows_divider_when_tokens_and_parsed_both_present() -> Result<()> { - let mut state = StreamingResponse::default(); - state - .tool_call_tokens - .push("{\"name\":\"calc\"}".to_owned()); - state.tool_calls.push(ParsedToolCall::default()); - - let rendered = render_to_string(&state, 120, 30)?; - - assert!(rendered.contains("parsed")); - Ok(()) - } -} diff --git a/paddler_client_cli/src/view_panel_kind.rs b/paddler_client_cli/src/view_panel_kind.rs deleted file mode 100644 index 0a8eac0e..00000000 --- a/paddler_client_cli/src/view_panel_kind.rs +++ /dev/null @@ -1,19 +0,0 @@ -#[repr(u8)] -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum ViewPanelKind { - Thinking = 0, - Response = 1, - ToolCalls = 2, - Undetermined = 3, -} - -impl ViewPanelKind { - pub const fn label(self) -> &'static str { - match self { - Self::Thinking => "Thinking", - Self::Response => "Response", - Self::ToolCalls => "Tool Calls", - Self::Undetermined => "Undetermined", - } - } -} diff --git a/paddler_client_cli/src/view_panel_layout.rs b/paddler_client_cli/src/view_panel_layout.rs deleted file mode 100644 index 0873dda1..00000000 --- a/paddler_client_cli/src/view_panel_layout.rs +++ /dev/null @@ -1,61 +0,0 @@ -use ratatui::layout::Constraint; -use ratatui::layout::Layout; -use ratatui::layout::Position; -use ratatui::layout::Rect; - -use crate::view_panel_kind::ViewPanelKind; - -const STATUS_BAR_HEIGHT: u16 = 1; - -pub struct ViewPanelLayout { - pub thinking: Rect, - pub response: Rect, - pub tool_calls: Rect, - pub undetermined: Rect, - pub status_bar: Rect, -} - -impl ViewPanelLayout { - pub fn compute(area: Rect) -> Self { - let outer = Layout::vertical([Constraint::Min(0), Constraint::Length(STATUS_BAR_HEIGHT)]) - .split(area); - let rows = Layout::vertical([Constraint::Percentage(50), Constraint::Percentage(50)]) - .split(outer[0]); - let top = Layout::horizontal([Constraint::Percentage(50), Constraint::Percentage(50)]) - .split(rows[0]); - let bottom = Layout::horizontal([Constraint::Percentage(50), Constraint::Percentage(50)]) - .split(rows[1]); - Self { - thinking: top[0], - response: top[1], - tool_calls: bottom[0], - undetermined: bottom[1], - status_bar: outer[1], - } - } - - pub const fn rect_for(&self, panel: ViewPanelKind) -> Rect { - match panel { - ViewPanelKind::Thinking => self.thinking, - ViewPanelKind::Response => self.response, - ViewPanelKind::ToolCalls => self.tool_calls, - ViewPanelKind::Undetermined => self.undetermined, - } - } - - pub const fn viewport_rows(&self, panel: ViewPanelKind) -> u16 { - self.rect_for(panel).height.saturating_sub(2) - } - - pub fn panel_at(&self, column: u16, row: u16) -> Option { - let position = Position { x: column, y: row }; - [ - ViewPanelKind::Thinking, - ViewPanelKind::Response, - ViewPanelKind::ToolCalls, - ViewPanelKind::Undetermined, - ] - .into_iter() - .find(|panel| self.rect_for(*panel).contains(position)) - } -} diff --git a/paddler_client_cli/src/view_panel_navigation.rs b/paddler_client_cli/src/view_panel_navigation.rs deleted file mode 100644 index 62416c7b..00000000 --- a/paddler_client_cli/src/view_panel_navigation.rs +++ /dev/null @@ -1,185 +0,0 @@ -use crate::view_panel_kind::ViewPanelKind; - -const PANEL_COUNT: usize = 4; - -pub struct ViewPanelNavigation { - focused: ViewPanelKind, - views: [PanelView; PANEL_COUNT], -} - -#[derive(Clone, Copy)] -struct PanelView { - position: usize, - follow_bottom: bool, -} - -impl Default for PanelView { - fn default() -> Self { - Self { - position: 0, - follow_bottom: true, - } - } -} - -impl Default for ViewPanelNavigation { - fn default() -> Self { - Self { - focused: ViewPanelKind::Response, - views: [PanelView::default(); PANEL_COUNT], - } - } -} - -impl ViewPanelNavigation { - pub const fn focused(&self) -> ViewPanelKind { - self.focused - } - - pub const fn focus(&mut self, panel: ViewPanelKind) { - self.focused = panel; - } - - pub const fn cycle_focus_forward(&mut self) { - self.focused = match self.focused { - ViewPanelKind::Thinking => ViewPanelKind::Response, - ViewPanelKind::Response => ViewPanelKind::ToolCalls, - ViewPanelKind::ToolCalls => ViewPanelKind::Undetermined, - ViewPanelKind::Undetermined => ViewPanelKind::Thinking, - }; - } - - pub const fn cycle_focus_backward(&mut self) { - self.focused = match self.focused { - ViewPanelKind::Thinking => ViewPanelKind::Undetermined, - ViewPanelKind::Response => ViewPanelKind::Thinking, - ViewPanelKind::ToolCalls => ViewPanelKind::Response, - ViewPanelKind::Undetermined => ViewPanelKind::ToolCalls, - }; - } - - pub fn scroll_up(&mut self, panel: ViewPanelKind, lines: u16) { - let view = &mut self.views[panel as usize]; - view.follow_bottom = false; - view.position = view.position.saturating_sub(lines.into()); - } - - pub fn scroll_down(&mut self, panel: ViewPanelKind, lines: u16) { - let view = &mut self.views[panel as usize]; - view.position = view.position.saturating_add(lines.into()); - } - - pub const fn jump_to_top(&mut self, panel: ViewPanelKind) { - let view = &mut self.views[panel as usize]; - view.follow_bottom = false; - view.position = 0; - } - - pub const fn jump_to_bottom(&mut self, panel: ViewPanelKind) { - self.views[panel as usize].follow_bottom = true; - } - - pub const fn settle( - &mut self, - panel: ViewPanelKind, - content_rows: usize, - viewport_rows: usize, - ) { - let view = &mut self.views[panel as usize]; - let max_position = content_rows.saturating_sub(viewport_rows); - if view.follow_bottom || view.position >= max_position { - view.position = max_position; - view.follow_bottom = true; - } - } - - pub const fn position(&self, panel: ViewPanelKind) -> usize { - self.views[panel as usize].position - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn defaults_focus_response_and_follow_bottom() { - let mut nav = ViewPanelNavigation::default(); - - assert_eq!(nav.focused(), ViewPanelKind::Response); - nav.settle(ViewPanelKind::Response, 100, 10); - assert_eq!(nav.position(ViewPanelKind::Response), 90); - } - - #[test] - fn scroll_up_disengages_follow_and_decrements_position() { - let mut nav = ViewPanelNavigation::default(); - - nav.scroll_up(ViewPanelKind::Response, 5); - - nav.settle(ViewPanelKind::Response, 100, 10); - assert_eq!(nav.position(ViewPanelKind::Response), 0); - } - - #[test] - fn scroll_down_after_scroll_up_advances_within_content() { - let mut nav = ViewPanelNavigation::default(); - nav.scroll_up(ViewPanelKind::Response, 50); - - nav.scroll_down(ViewPanelKind::Response, 10); - - nav.settle(ViewPanelKind::Response, 100, 10); - assert_eq!(nav.position(ViewPanelKind::Response), 10); - } - - #[test] - fn jump_to_bottom_re_engages_auto_follow() { - let mut nav = ViewPanelNavigation::default(); - nav.scroll_up(ViewPanelKind::Response, 50); - - nav.jump_to_bottom(ViewPanelKind::Response); - - nav.settle(ViewPanelKind::Response, 200, 10); - assert_eq!(nav.position(ViewPanelKind::Response), 190); - } - - #[test] - fn cycle_focus_forward_walks_panels_in_reading_order() { - let mut nav = ViewPanelNavigation::default(); - - nav.cycle_focus_forward(); - assert_eq!(nav.focused(), ViewPanelKind::ToolCalls); - nav.cycle_focus_forward(); - assert_eq!(nav.focused(), ViewPanelKind::Undetermined); - nav.cycle_focus_forward(); - assert_eq!(nav.focused(), ViewPanelKind::Thinking); - nav.cycle_focus_forward(); - assert_eq!(nav.focused(), ViewPanelKind::Response); - } - - #[test] - fn scrolling_back_to_bottom_re_engages_auto_follow_for_subsequent_growth() { - let mut nav = ViewPanelNavigation::default(); - nav.settle(ViewPanelKind::Response, 100, 10); - - nav.scroll_up(ViewPanelKind::Response, 5); - nav.settle(ViewPanelKind::Response, 100, 10); - - nav.scroll_down(ViewPanelKind::Response, 10); - nav.settle(ViewPanelKind::Response, 100, 10); - - nav.settle(ViewPanelKind::Response, 110, 10); - - assert_eq!(nav.position(ViewPanelKind::Response), 100); - } - - #[test] - fn position_is_clamped_when_content_shorter_than_stored_offset() { - let mut nav = ViewPanelNavigation::default(); - nav.scroll_up(ViewPanelKind::Response, 0); - nav.scroll_down(ViewPanelKind::Response, 80); - - nav.settle(ViewPanelKind::Response, 30, 10); - assert_eq!(nav.position(ViewPanelKind::Response), 20); - } -} diff --git a/paddler_client_cli/src/view_terminal_guard.rs b/paddler_client_cli/src/view_terminal_guard.rs deleted file mode 100644 index d23f6bba..00000000 --- a/paddler_client_cli/src/view_terminal_guard.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::io; - -use anyhow::Context; -use anyhow::Result; -use crossterm::ExecutableCommand; -use crossterm::event::DisableMouseCapture; -use crossterm::event::EnableMouseCapture; -use crossterm::terminal::EnterAlternateScreen; -use crossterm::terminal::LeaveAlternateScreen; -use crossterm::terminal::disable_raw_mode; -use crossterm::terminal::enable_raw_mode; - -pub struct ViewTerminalGuard; - -impl ViewTerminalGuard { - pub fn enter() -> Result { - enable_raw_mode().context("enabling raw mode")?; - if let Err(enter_alt_screen_error) = io::stdout().execute(EnterAlternateScreen) { - if let Err(rollback_error) = disable_raw_mode() { - log::error!( - "failed to disable raw mode while rolling back alt-screen entry: {rollback_error}" - ); - } - return Err( - anyhow::Error::from(enter_alt_screen_error).context("entering alternate screen") - ); - } - if let Err(enable_mouse_error) = io::stdout().execute(EnableMouseCapture) { - if let Err(leave_alt_screen_error) = io::stdout().execute(LeaveAlternateScreen) { - log::error!( - "failed to leave alt screen while rolling back mouse-capture: {leave_alt_screen_error}" - ); - } - if let Err(rollback_error) = disable_raw_mode() { - log::error!( - "failed to disable raw mode while rolling back mouse-capture: {rollback_error}" - ); - } - return Err(anyhow::Error::from(enable_mouse_error).context("enabling mouse capture")); - } - Ok(Self) - } -} - -impl Drop for ViewTerminalGuard { - fn drop(&mut self) { - if let Err(disable_mouse_error) = io::stdout().execute(DisableMouseCapture) { - log::error!("failed to disable mouse capture: {disable_mouse_error}"); - } - if let Err(leave_alt_screen_error) = io::stdout().execute(LeaveAlternateScreen) { - log::error!("failed to leave alternate screen: {leave_alt_screen_error}"); - } - if let Err(disable_raw_mode_error) = disable_raw_mode() { - log::error!("failed to disable raw mode: {disable_raw_mode_error}"); - } - } -} diff --git a/paddler_client_python/shell.nix b/paddler_client_python/shell.nix deleted file mode 100644 index 91d309ad..00000000 --- a/paddler_client_python/shell.nix +++ /dev/null @@ -1,9 +0,0 @@ -{ pkgs ? import {} }: - -pkgs.mkShell { - nativeBuildInputs = with pkgs; [ - poetry - python3 - ruff - ]; -} diff --git a/paddler_download_manager/src/download_manager.rs b/paddler_download_manager/src/download_manager.rs index e06d8125..cacfbbf5 100644 --- a/paddler_download_manager/src/download_manager.rs +++ b/paddler_download_manager/src/download_manager.rs @@ -135,7 +135,9 @@ impl DownloadManager { let response = match request.send().await { Ok(response) => response, Err(send_error) => { - return Err(DownloadAttemptError::Unreachable(anyhow::Error::new(send_error))); + return Err(DownloadAttemptError::Unreachable(anyhow::Error::new( + send_error, + ))); } }; @@ -165,7 +167,10 @@ impl DownloadManager { | ResponseClassification::StreamFromStart => {} } - if matches!(classification, ResponseClassification::StreamFromCurrentOffset) { + if matches!( + classification, + ResponseClassification::StreamFromCurrentOffset + ) { let server_start = response .headers() .typed_get::() @@ -206,5 +211,3 @@ impl DownloadManager { Ok(()) } } - - diff --git a/paddler_download_manager/src/partial_file.rs b/paddler_download_manager/src/partial_file.rs index 335ac57f..c6a261a7 100644 --- a/paddler_download_manager/src/partial_file.rs +++ b/paddler_download_manager/src/partial_file.rs @@ -68,10 +68,7 @@ impl PartialFile { } async fn ensure_partial_parent_exists(&self) -> Result<(), io::Error> { - let parent = self - .partial_path - .parent() - .unwrap_or_else(|| Path::new(".")); + let parent = self.partial_path.parent().unwrap_or_else(|| Path::new(".")); fs::create_dir_all(parent).await } @@ -79,6 +76,9 @@ impl PartialFile { #[cfg(test)] mod tests { + #[cfg(unix)] + use std::path::PathBuf; + use tempfile::TempDir; use tokio::io::AsyncWriteExt; @@ -229,6 +229,16 @@ mod tests { assert!(result.is_err()); } + #[cfg(unix)] + #[tokio::test] + async fn open_for_append_returns_io_error_when_path_has_no_parent() { + let partial = PartialFile::new(PathBuf::from("/")); + + let result = partial.open_for_append().await; + + assert!(result.is_err()); + } + #[cfg(unix)] #[tokio::test] async fn finalize_returns_io_error_when_final_is_a_non_empty_directory() { diff --git a/paddler_download_manager/tests/download.rs b/paddler_download_manager/tests/download.rs index 7012ea5f..1faba30a 100644 --- a/paddler_download_manager/tests/download.rs +++ b/paddler_download_manager/tests/download.rs @@ -450,7 +450,10 @@ async fn fixture_drops_connection_when_configured_to() -> Result<()> { let response = reqwest::get(fixture.url("/x")).await?; let body_result = response.bytes().await; - assert!(body_result.is_err(), "expected dropped connection during body read"); + assert!( + body_result.is_err(), + "expected dropped connection during body read" + ); Ok(()) } @@ -544,10 +547,7 @@ async fn send_error_returns_download_server_is_unreachable() -> Result<()> { let url = "http://127.0.0.1:1/never-listens".to_owned(); let result = DownloadManager::new()?.download(&url, &dest, sink).await; - let Err(DownloadError::DownloadServerIsUnreachable { - url: error_url, .. - }) = result - else { + let Err(DownloadError::DownloadServerIsUnreachable { url: error_url, .. }) = result else { bail!("expected DownloadServerIsUnreachable, got {result:?}"); }; assert_eq!(error_url, url); @@ -674,8 +674,8 @@ async fn truncate_error_during_ignore_range_returns_io() -> Result<()> { let partial_path = dest.with_extension("partial"); tokio::fs::create_dir(&partial_path).await?; - let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))) - .await?; + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; let sink: Arc = Arc::new(RecordingSink::new()); let result = DownloadManager::new()? diff --git a/paddler_gui/Cargo.toml b/paddler_gui/Cargo.toml index cbc4a758..61218f74 100644 --- a/paddler_gui/Cargo.toml +++ b/paddler_gui/Cargo.toml @@ -16,12 +16,14 @@ iced = { workspace = true } if-addrs = { workspace = true } log = { workspace = true } open = { workspace = true } -paddler = { workspace = true } +paddler_balancer = { workspace = true } paddler_bootstrap = { workspace = true } -paddler_types = { workspace = true } +paddler_messaging = { workspace = true } +parking_lot = { workspace = true } statum = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } +trzcina = { workspace = true } [dev-dependencies] dashmap = { workspace = true } @@ -31,10 +33,10 @@ workspace = true [features] default = [] -cuda = ["paddler/cuda"] -metal = ["paddler/metal"] +cuda = ["paddler_bootstrap/cuda"] +metal = ["paddler_bootstrap/metal"] web_admin_panel = [ "dep:esbuild-metafile", - "paddler/web_admin_panel", + "paddler_balancer/web_admin_panel", "paddler_bootstrap/web_admin_panel", ] diff --git a/paddler_gui/src/agent_running_data.rs b/paddler_gui/src/agent_running_data.rs index 91fce71a..dae4782a 100644 --- a/paddler_gui/src/agent_running_data.rs +++ b/paddler_gui/src/agent_running_data.rs @@ -1,5 +1,5 @@ -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; pub struct AgentRunningData { pub balancer_address: String, diff --git a/paddler_gui/src/agent_running_handler.rs b/paddler_gui/src/agent_running_handler.rs index 78bfb814..2c11af37 100644 --- a/paddler_gui/src/agent_running_handler.rs +++ b/paddler_gui/src/agent_running_handler.rs @@ -1,4 +1,4 @@ -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use crate::agent_running_data::AgentRunningData; diff --git a/paddler_gui/src/app.rs b/paddler_gui/src/app.rs index 7bfa8024..fc71e24f 100644 --- a/paddler_gui/src/app.rs +++ b/paddler_gui/src/app.rs @@ -19,25 +19,26 @@ use iced::widget::image::Handle as ImageHandle; use iced::widget::operation; use iced::widget::stack; use iced::window; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::state_database_type::StateDatabaseType; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; +use paddler_balancer::resolved_socket_addr::ResolvedSocketAddr; +use paddler_balancer::state_database_type::StateDatabaseType; #[cfg(feature = "web_admin_panel")] -use paddler::balancer::web_admin_panel_service::template_data::TemplateData; -use paddler::produces_snapshot::ProducesSnapshot; +use paddler_balancer::web_admin_panel_service::configuration::Configuration as WebAdminPanelServiceConfiguration; #[cfg(feature = "web_admin_panel")] -use paddler::resolved_socket_addr::ResolvedSocketAddr; -use paddler::subscribes_to_updates::SubscribesToUpdates as _; +use paddler_balancer::web_admin_panel_service::template_data::TemplateData; use paddler_bootstrap::agent_runner::AgentRunner; use paddler_bootstrap::agent_runner::AgentRunnerParams; use paddler_bootstrap::balancer_runner::BalancerRunner; use paddler_bootstrap::balancer_runner::BalancerRunnerParams; use paddler_bootstrap::shutdown_signal::register_shutdown_signals; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::produces_snapshot::ProducesSnapshot; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates as _; use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; +use trzcina::ServiceShutdownOptions; use crate::agent_running_handler; use crate::current_screen::CurrentScreen; @@ -503,6 +504,7 @@ impl App { max_buffered_requests, openai_service_configuration: None, cancellation_token: cancel, + shutdown_options: ServiceShutdownOptions::default(), state_database_type: StateDatabaseType::Memory(Box::new(desired_state.clone())), statsd_prefix: statsd_prefix.to_owned(), statsd_service_configuration: None, diff --git a/paddler_gui/src/model_preset.rs b/paddler_gui/src/model_preset.rs index 923d7ef0..549aac2e 100644 --- a/paddler_gui/src/model_preset.rs +++ b/paddler_gui/src/model_preset.rs @@ -1,9 +1,9 @@ use std::fmt; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::inference_parameters::InferenceParameters; #[derive(Clone, Debug, PartialEq)] pub struct ModelPreset { diff --git a/paddler_gui/src/running_balancer_snapshot.rs b/paddler_gui/src/running_balancer_snapshot.rs index c2ecab0b..91f6960e 100644 --- a/paddler_gui/src/running_balancer_snapshot.rs +++ b/paddler_gui/src/running_balancer_snapshot.rs @@ -1,11 +1,11 @@ use anyhow::Context; use anyhow::Result; -use paddler::balancer::agent_controller_pool::AgentControllerPool; -use paddler::balancer_applicable_state::BalancerApplicableState; -use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; -use paddler::produces_snapshot::ProducesSnapshot as _; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::balancer_applicable_state::BalancerApplicableState; +use paddler_balancer::balancer_applicable_state_holder::BalancerApplicableStateHolder; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::produces_snapshot::ProducesSnapshot as _; #[derive(Clone, Debug, Default)] pub struct RunningBalancerSnapshot { @@ -45,27 +45,27 @@ impl RunningBalancerSnapshot { #[cfg(test)] mod tests { + use parking_lot::RwLock; use std::collections::BTreeSet; use std::sync::Arc; - use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; use anyhow::Result; - use paddler::atomic_value::AtomicValue; - use paddler::balancer::agent_controller::AgentController; - use paddler::balancer::agent_controller_pool::AgentControllerPool; - use paddler::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; - use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; - use paddler::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; - use paddler::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; - use paddler::balancer_applicable_state::BalancerApplicableState; - use paddler::balancer_applicable_state_holder::BalancerApplicableStateHolder; - use paddler_types::agent_desired_model::AgentDesiredModel; - use paddler_types::agent_desired_state::AgentDesiredState; - use paddler_types::agent_state_application_status::AgentStateApplicationStatus; - use paddler_types::inference_parameters::InferenceParameters; + use paddler_balancer::agent_controller::AgentController; + use paddler_balancer::agent_controller_pool::AgentControllerPool; + use paddler_balancer::balancer_applicable_state::BalancerApplicableState; + use paddler_balancer::balancer_applicable_state_holder::BalancerApplicableStateHolder; + use paddler_balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; + use paddler_balancer::embedding_sender_collection::EmbeddingSenderCollection; + use paddler_balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; + use paddler_balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; + use paddler_messaging::agent_desired_model::AgentDesiredModel; + use paddler_messaging::agent_desired_state::AgentDesiredState; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::atomic_value::AtomicValue; + use paddler_messaging::inference_parameters::InferenceParameters; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; diff --git a/paddler_gui/src/screen.rs b/paddler_gui/src/screen.rs index a46fff74..0d180df6 100644 --- a/paddler_gui/src/screen.rs +++ b/paddler_gui/src/screen.rs @@ -1,7 +1,7 @@ use std::collections::BTreeSet; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use statum::machine; use statum::state; use statum::transition; diff --git a/paddler_gui/src/start_balancer_form_handler.rs b/paddler_gui/src/start_balancer_form_handler.rs index f3254518..28206754 100644 --- a/paddler_gui/src/start_balancer_form_handler.rs +++ b/paddler_gui/src/start_balancer_form_handler.rs @@ -2,7 +2,7 @@ use std::io; use std::net::SocketAddr; use std::net::TcpListener; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use crate::model_preset::ModelPreset; use crate::start_balancer_form_data::StartBalancerFormData; diff --git a/paddler_gui/src/ui/view_agent_card.rs b/paddler_gui/src/ui/view_agent_card.rs index c680f0ed..e50cdef3 100644 --- a/paddler_gui/src/ui/view_agent_card.rs +++ b/paddler_gui/src/ui/view_agent_card.rs @@ -5,8 +5,8 @@ use iced::widget::container; use iced::widget::progress_bar; use iced::widget::row; use iced::widget::text; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use super::font::BOLD; use super::font::REGULAR; diff --git a/paddler_gui/src/ui/view_running_balancer.rs b/paddler_gui/src/ui/view_running_balancer.rs index 1b7bf2f3..40d59712 100644 --- a/paddler_gui/src/ui/view_running_balancer.rs +++ b/paddler_gui/src/ui/view_running_balancer.rs @@ -8,7 +8,7 @@ use iced::widget::row; use iced::widget::svg; use iced::widget::svg::Handle as SvgHandle; use iced::widget::text; -use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_desired_model::AgentDesiredModel; use super::font::BOLD; use super::font::REGULAR; diff --git a/paddler_messaging/Cargo.toml b/paddler_messaging/Cargo.toml new file mode 100644 index 00000000..66f7c79b --- /dev/null +++ b/paddler_messaging/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "paddler_messaging" +authors.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +anyhow = { workspace = true } +base64 = { workspace = true } +encoding_rs = { workspace = true } +llama-cpp-bindings-types = { workspace = true } +log = { workspace = true } +nanoid = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +url = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_types/src/agent_controller_pool_snapshot.rs b/paddler_messaging/src/agent_controller_pool_snapshot.rs similarity index 100% rename from paddler_types/src/agent_controller_pool_snapshot.rs rename to paddler_messaging/src/agent_controller_pool_snapshot.rs diff --git a/paddler_types/src/agent_controller_snapshot.rs b/paddler_messaging/src/agent_controller_snapshot.rs similarity index 100% rename from paddler_types/src/agent_controller_snapshot.rs rename to paddler_messaging/src/agent_controller_snapshot.rs diff --git a/paddler_types/src/agent_desired_model.rs b/paddler_messaging/src/agent_desired_model.rs similarity index 100% rename from paddler_types/src/agent_desired_model.rs rename to paddler_messaging/src/agent_desired_model.rs diff --git a/paddler_types/src/agent_desired_state.rs b/paddler_messaging/src/agent_desired_state.rs similarity index 100% rename from paddler_types/src/agent_desired_state.rs rename to paddler_messaging/src/agent_desired_state.rs diff --git a/paddler_types/src/agent_issue.rs b/paddler_messaging/src/agent_issue.rs similarity index 76% rename from paddler_types/src/agent_issue.rs rename to paddler_messaging/src/agent_issue.rs index 60c9c47e..bd59193b 100644 --- a/paddler_types/src/agent_issue.rs +++ b/paddler_messaging/src/agent_issue.rs @@ -1,10 +1,10 @@ use serde::Deserialize; use serde::Serialize; -use crate::agent_issue_params::ChatTemplateDoesNotCompileParams; -use crate::agent_issue_params::HuggingFaceDownloadLock; -use crate::agent_issue_params::ModelPath; -use crate::agent_issue_params::SlotCannotStartParams; +use crate::agent_issue_params::chat_template_does_not_compile_params::ChatTemplateDoesNotCompileParams; +use crate::agent_issue_params::hugging_face_download_lock::HuggingFaceDownloadLock; +use crate::agent_issue_params::model_path::ModelPath; +use crate::agent_issue_params::slot_cannot_start_params::SlotCannotStartParams; #[derive(Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_types/src/agent_issue_params/chat_template_does_not_compile_params.rs b/paddler_messaging/src/agent_issue_params/chat_template_does_not_compile_params.rs similarity index 84% rename from paddler_types/src/agent_issue_params/chat_template_does_not_compile_params.rs rename to paddler_messaging/src/agent_issue_params/chat_template_does_not_compile_params.rs index d704fcd1..c36186cf 100644 --- a/paddler_types/src/agent_issue_params/chat_template_does_not_compile_params.rs +++ b/paddler_messaging/src/agent_issue_params/chat_template_does_not_compile_params.rs @@ -1,7 +1,7 @@ use serde::Deserialize; use serde::Serialize; -use crate::agent_issue_params::ModelPath; +use crate::agent_issue_params::model_path::ModelPath; #[derive(Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_types/src/agent_issue_params/hugging_face_download_lock.rs b/paddler_messaging/src/agent_issue_params/hugging_face_download_lock.rs similarity index 82% rename from paddler_types/src/agent_issue_params/hugging_face_download_lock.rs rename to paddler_messaging/src/agent_issue_params/hugging_face_download_lock.rs index 3e0b87c1..8aacc73b 100644 --- a/paddler_types/src/agent_issue_params/hugging_face_download_lock.rs +++ b/paddler_messaging/src/agent_issue_params/hugging_face_download_lock.rs @@ -1,7 +1,7 @@ use serde::Deserialize; use serde::Serialize; -use crate::agent_issue_params::ModelPath; +use crate::agent_issue_params::model_path::ModelPath; #[derive(Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_messaging/src/agent_issue_params/mod.rs b/paddler_messaging/src/agent_issue_params/mod.rs new file mode 100644 index 00000000..f507cd1e --- /dev/null +++ b/paddler_messaging/src/agent_issue_params/mod.rs @@ -0,0 +1,4 @@ +pub mod chat_template_does_not_compile_params; +pub mod hugging_face_download_lock; +pub mod model_path; +pub mod slot_cannot_start_params; diff --git a/paddler_types/src/agent_issue_params/model_path.rs b/paddler_messaging/src/agent_issue_params/model_path.rs similarity index 100% rename from paddler_types/src/agent_issue_params/model_path.rs rename to paddler_messaging/src/agent_issue_params/model_path.rs diff --git a/paddler_types/src/agent_issue_params/slot_cannot_start_params.rs b/paddler_messaging/src/agent_issue_params/slot_cannot_start_params.rs similarity index 100% rename from paddler_types/src/agent_issue_params/slot_cannot_start_params.rs rename to paddler_messaging/src/agent_issue_params/slot_cannot_start_params.rs diff --git a/paddler_types/src/agent_state_application_status.rs b/paddler_messaging/src/agent_state_application_status.rs similarity index 87% rename from paddler_types/src/agent_state_application_status.rs rename to paddler_messaging/src/agent_state_application_status.rs index a9512498..ec778eb7 100644 --- a/paddler_types/src/agent_state_application_status.rs +++ b/paddler_messaging/src/agent_state_application_status.rs @@ -72,29 +72,27 @@ mod tests { } #[test] - fn try_from_valid_values() -> Result<()> { + fn try_from_valid_values() { assert_eq!( - AgentStateApplicationStatus::try_from(0)?, + AgentStateApplicationStatus::try_from(0).unwrap(), AgentStateApplicationStatus::Applied ); assert_eq!( - AgentStateApplicationStatus::try_from(1)?, + AgentStateApplicationStatus::try_from(1).unwrap(), AgentStateApplicationStatus::AttemptedAndNotAppliable ); assert_eq!( - AgentStateApplicationStatus::try_from(2)?, + AgentStateApplicationStatus::try_from(2).unwrap(), AgentStateApplicationStatus::AttemptedAndRetrying ); assert_eq!( - AgentStateApplicationStatus::try_from(3)?, + AgentStateApplicationStatus::try_from(3).unwrap(), AgentStateApplicationStatus::Fresh ); assert_eq!( - AgentStateApplicationStatus::try_from(4)?, + AgentStateApplicationStatus::try_from(4).unwrap(), AgentStateApplicationStatus::Stuck ); - - Ok(()) } #[test] diff --git a/paddler/src/atomic_value.rs b/paddler_messaging/src/atomic_value.rs similarity index 65% rename from paddler/src/atomic_value.rs rename to paddler_messaging/src/atomic_value.rs index 6061fe1a..cf1fe2d8 100644 --- a/paddler/src/atomic_value.rs +++ b/paddler_messaging/src/atomic_value.rs @@ -141,3 +141,76 @@ impl AtomicValue { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bool_set_check_reports_and_applies_changes() { + let value = AtomicValue::::new(false); + + assert!(!value.get()); + assert!(value.set_check(true)); + assert!(value.get()); + assert!(!value.set_check(true)); + + value.set(false); + + assert!(!value.get()); + } + + #[test] + fn i32_arithmetic_and_compare_and_swap() { + let value = AtomicValue::::new(0); + + value.increment(); + value.increment(); + value.decrement(); + + assert_eq!(value.get(), 1); + assert!(value.compare_and_swap(1, 5)); + assert_eq!(value.get(), 5); + assert!(!value.compare_and_swap(1, 9)); + assert_eq!(value.get(), 5); + + value.set(7); + + assert!(value.set_check(8)); + assert!(!value.set_check(8)); + + value.reset(); + + assert_eq!(value.get(), 0); + } + + #[test] + fn u64_increment_by_and_set_check() { + let value = AtomicValue::::new(0); + + value.increment_by(10); + + assert_eq!(value.get(), 10); + assert!(value.set_check(20)); + assert!(!value.set_check(20)); + + value.set(0); + + assert_eq!(value.get(), 0); + } + + #[test] + fn usize_increment_by_and_set_check() { + let value = AtomicValue::::new(0); + + value.increment_by(3); + + assert_eq!(value.get(), 3); + assert!(value.set_check(4)); + assert!(!value.set_check(4)); + + value.set(0); + + assert_eq!(value.get(), 0); + } +} diff --git a/paddler_messaging/src/balancer_desired_state.rs b/paddler_messaging/src/balancer_desired_state.rs new file mode 100644 index 00000000..a6fb2ce1 --- /dev/null +++ b/paddler_messaging/src/balancer_desired_state.rs @@ -0,0 +1,16 @@ +use serde::Deserialize; +use serde::Serialize; + +use crate::agent_desired_model::AgentDesiredModel; +use crate::chat_template::ChatTemplate; +use crate::inference_parameters::InferenceParameters; + +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[serde(deny_unknown_fields)] +pub struct BalancerDesiredState { + pub chat_template_override: Option, + pub inference_parameters: InferenceParameters, + pub model: AgentDesiredModel, + pub multimodal_projection: AgentDesiredModel, + pub use_chat_template_override: bool, +} diff --git a/paddler_types/src/buffered_request_manager_snapshot.rs b/paddler_messaging/src/buffered_request_manager_snapshot.rs similarity index 100% rename from paddler_types/src/buffered_request_manager_snapshot.rs rename to paddler_messaging/src/buffered_request_manager_snapshot.rs diff --git a/paddler_types/src/chat_template.rs b/paddler_messaging/src/chat_template.rs similarity index 100% rename from paddler_types/src/chat_template.rs rename to paddler_messaging/src/chat_template.rs diff --git a/paddler_types/src/chat_template_message.rs b/paddler_messaging/src/chat_template_message.rs similarity index 100% rename from paddler_types/src/chat_template_message.rs rename to paddler_messaging/src/chat_template_message.rs diff --git a/paddler_types/src/chat_template_message_content.rs b/paddler_messaging/src/chat_template_message_content.rs similarity index 100% rename from paddler_types/src/chat_template_message_content.rs rename to paddler_messaging/src/chat_template_message_content.rs diff --git a/paddler_types/src/chat_template_message_content_part.rs b/paddler_messaging/src/chat_template_message_content_part.rs similarity index 100% rename from paddler_types/src/chat_template_message_content_part.rs rename to paddler_messaging/src/chat_template_message_content_part.rs diff --git a/paddler_types/src/chat_template_messages.rs b/paddler_messaging/src/chat_template_messages.rs similarity index 100% rename from paddler_types/src/chat_template_messages.rs rename to paddler_messaging/src/chat_template_messages.rs diff --git a/paddler_types/src/conversation_history.rs b/paddler_messaging/src/conversation_history.rs similarity index 89% rename from paddler_types/src/conversation_history.rs rename to paddler_messaging/src/conversation_history.rs index 512c45c3..7af48067 100644 --- a/paddler_types/src/conversation_history.rs +++ b/paddler_messaging/src/conversation_history.rs @@ -141,14 +141,23 @@ mod tests { let marker = MediaMarker::new("[IMAGE]".to_owned()); let result = history.replace_images_with_marker(&marker); - let ChatTemplateMessageContent::Parts(parts) = &result.messages[0].content else { - unreachable!("expected Parts variant"); - }; - - assert_eq!(parts.len(), 3); - assert_eq!(parts[0].text, "before"); - assert_eq!(parts[1].text, "[IMAGE]"); - assert_eq!(parts[2].text, "after"); + assert_eq!( + result.messages[0].content, + ChatTemplateMessageContent::Parts(vec![ + ChatTemplateMessageContentPart { + content_type: "text".to_owned(), + text: "before".to_owned(), + }, + ChatTemplateMessageContentPart { + content_type: "text".to_owned(), + text: "[IMAGE]".to_owned(), + }, + ChatTemplateMessageContentPart { + content_type: "text".to_owned(), + text: "after".to_owned(), + }, + ]) + ); } #[test] diff --git a/paddler_types/src/conversation_message.rs b/paddler_messaging/src/conversation_message.rs similarity index 100% rename from paddler_types/src/conversation_message.rs rename to paddler_messaging/src/conversation_message.rs diff --git a/paddler_types/src/conversation_message_content.rs b/paddler_messaging/src/conversation_message_content.rs similarity index 100% rename from paddler_types/src/conversation_message_content.rs rename to paddler_messaging/src/conversation_message_content.rs diff --git a/paddler_types/src/conversation_message_content_part.rs b/paddler_messaging/src/conversation_message_content_part.rs similarity index 100% rename from paddler_types/src/conversation_message_content_part.rs rename to paddler_messaging/src/conversation_message_content_part.rs diff --git a/paddler_messaging/src/embedding.rs b/paddler_messaging/src/embedding.rs new file mode 100644 index 00000000..ad2a3d6c --- /dev/null +++ b/paddler_messaging/src/embedding.rs @@ -0,0 +1,14 @@ +use serde::Deserialize; +use serde::Serialize; + +use crate::embedding_normalization_method::EmbeddingNormalizationMethod; +use crate::pooling_type::PoolingType; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Embedding { + pub embedding: Vec, + pub normalization_method: EmbeddingNormalizationMethod, + pub pooling_type: PoolingType, + pub source_document_id: String, +} diff --git a/paddler_types/src/embedding_input_document.rs b/paddler_messaging/src/embedding_input_document.rs similarity index 100% rename from paddler_types/src/embedding_input_document.rs rename to paddler_messaging/src/embedding_input_document.rs diff --git a/paddler_types/src/embedding_normalization_method.rs b/paddler_messaging/src/embedding_normalization_method.rs similarity index 100% rename from paddler_types/src/embedding_normalization_method.rs rename to paddler_messaging/src/embedding_normalization_method.rs diff --git a/paddler_types/src/embedding_result.rs b/paddler_messaging/src/embedding_result.rs similarity index 100% rename from paddler_types/src/embedding_result.rs rename to paddler_messaging/src/embedding_result.rs diff --git a/paddler_types/src/generated_token_result.rs b/paddler_messaging/src/generated_token_result.rs similarity index 100% rename from paddler_types/src/generated_token_result.rs rename to paddler_messaging/src/generated_token_result.rs diff --git a/paddler_types/src/generation_summary.rs b/paddler_messaging/src/generation_summary.rs similarity index 100% rename from paddler_types/src/generation_summary.rs rename to paddler_messaging/src/generation_summary.rs diff --git a/paddler_types/src/grammar_constraint.rs b/paddler_messaging/src/grammar_constraint.rs similarity index 100% rename from paddler_types/src/grammar_constraint.rs rename to paddler_messaging/src/grammar_constraint.rs diff --git a/paddler_types/src/huggingface_model_reference.rs b/paddler_messaging/src/huggingface_model_reference.rs similarity index 100% rename from paddler_types/src/huggingface_model_reference.rs rename to paddler_messaging/src/huggingface_model_reference.rs diff --git a/paddler_types/src/image_url.rs b/paddler_messaging/src/image_url.rs similarity index 100% rename from paddler_types/src/image_url.rs rename to paddler_messaging/src/image_url.rs diff --git a/paddler_types/src/inference_client/message.rs b/paddler_messaging/src/inference_client/message.rs similarity index 62% rename from paddler_types/src/inference_client/message.rs rename to paddler_messaging/src/inference_client/message.rs index 9dd5e99a..1c338222 100644 --- a/paddler_types/src/inference_client/message.rs +++ b/paddler_messaging/src/inference_client/message.rs @@ -1,10 +1,10 @@ use serde::Deserialize; use serde::Serialize; -use super::Response; -use crate::jsonrpc::Error; -use crate::jsonrpc::ErrorEnvelope; -use crate::jsonrpc::ResponseEnvelope; +use super::response::Response; +use crate::jsonrpc::error::Error; +use crate::jsonrpc::error_envelope::ErrorEnvelope; +use crate::jsonrpc::response_envelope::ResponseEnvelope; use crate::rpc_message::RpcMessage; #[derive(Debug, Deserialize, Serialize)] diff --git a/paddler_messaging/src/inference_client/mod.rs b/paddler_messaging/src/inference_client/mod.rs new file mode 100644 index 00000000..0c73fa22 --- /dev/null +++ b/paddler_messaging/src/inference_client/mod.rs @@ -0,0 +1,2 @@ +pub mod message; +pub mod response; diff --git a/paddler_types/src/inference_client/response.rs b/paddler_messaging/src/inference_client/response.rs similarity index 100% rename from paddler_types/src/inference_client/response.rs rename to paddler_messaging/src/inference_client/response.rs diff --git a/paddler_types/src/inference_parameters.rs b/paddler_messaging/src/inference_parameters.rs similarity index 78% rename from paddler_types/src/inference_parameters.rs rename to paddler_messaging/src/inference_parameters.rs index 63e7a962..9a535009 100644 --- a/paddler_types/src/inference_parameters.rs +++ b/paddler_messaging/src/inference_parameters.rs @@ -59,8 +59,8 @@ impl Default for InferenceParameters { embedding_batch_size: 256, enable_embeddings: false, image_resize_to_fit: 1024, - k_cache_dtype: KvCacheDtype::Q8_0, - v_cache_dtype: KvCacheDtype::Q8_0, + k_cache_dtype: KvCacheDtype::Q80, + v_cache_dtype: KvCacheDtype::Q80, min_p: 0.05, n_gpu_layers: 0, penalty_frequency: 0.0, @@ -75,6 +75,22 @@ impl Default for InferenceParameters { } } +impl InferenceParameters { + #[must_use] + pub fn deterministic() -> Self { + Self { + min_p: 0.0, + penalty_frequency: 0.0, + penalty_presence: 0.0, + penalty_repeat: 1.0, + temperature: 0.0, + top_k: 1, + top_p: 1.0, + ..Self::default() + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -112,4 +128,21 @@ mod tests { assert_eq!(params.embedding_batch_size, 256); } + + #[test] + fn deterministic_applies_greedy_sampling_over_defaults() { + assert_eq!( + InferenceParameters::deterministic(), + InferenceParameters { + min_p: 0.0, + penalty_frequency: 0.0, + penalty_presence: 0.0, + penalty_repeat: 1.0, + temperature: 0.0, + top_k: 1, + top_p: 1.0, + ..InferenceParameters::default() + } + ); + } } diff --git a/paddler_types/src/inference_server/message.rs b/paddler_messaging/src/inference_server/message.rs similarity index 68% rename from paddler_types/src/inference_server/message.rs rename to paddler_messaging/src/inference_server/message.rs index ea0ad184..141996ec 100644 --- a/paddler_types/src/inference_server/message.rs +++ b/paddler_messaging/src/inference_server/message.rs @@ -1,10 +1,10 @@ use serde::Deserialize; use serde::Serialize; -use super::Request; -use crate::jsonrpc::Error; -use crate::jsonrpc::ErrorEnvelope; -use crate::jsonrpc::RequestEnvelope; +use super::request::Request; +use crate::jsonrpc::error::Error; +use crate::jsonrpc::error_envelope::ErrorEnvelope; +use crate::jsonrpc::request_envelope::RequestEnvelope; use crate::rpc_message::RpcMessage; #[derive(Deserialize, Serialize)] diff --git a/paddler_messaging/src/inference_server/mod.rs b/paddler_messaging/src/inference_server/mod.rs new file mode 100644 index 00000000..54de1457 --- /dev/null +++ b/paddler_messaging/src/inference_server/mod.rs @@ -0,0 +1,2 @@ +pub mod message; +pub mod request; diff --git a/paddler_types/src/inference_server/request.rs b/paddler_messaging/src/inference_server/request.rs similarity index 82% rename from paddler_types/src/inference_server/request.rs rename to paddler_messaging/src/inference_server/request.rs index ba47597c..ddf413a4 100644 --- a/paddler_types/src/inference_server/request.rs +++ b/paddler_messaging/src/inference_server/request.rs @@ -1,8 +1,8 @@ use serde::Deserialize; use serde::Serialize; -use crate::request_params::ContinueFromRawPromptParams; use crate::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use crate::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; #[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_types/src/jsonrpc/error.rs b/paddler_messaging/src/jsonrpc/error.rs similarity index 57% rename from paddler_types/src/jsonrpc/error.rs rename to paddler_messaging/src/jsonrpc/error.rs index 0432e9eb..e4dc6590 100644 --- a/paddler_types/src/jsonrpc/error.rs +++ b/paddler_messaging/src/jsonrpc/error.rs @@ -17,3 +17,18 @@ impl Display for Error { write!(formatter, "jsonrpc_error(code={})", self.code) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn formats_only_code_ignoring_description() { + let error = Error { + code: -32_600, + description: "Invalid Request".to_owned(), + }; + + assert_eq!("jsonrpc_error(code=-32600)", error.to_string()); + } +} diff --git a/paddler_types/src/jsonrpc/error_envelope.rs b/paddler_messaging/src/jsonrpc/error_envelope.rs similarity index 100% rename from paddler_types/src/jsonrpc/error_envelope.rs rename to paddler_messaging/src/jsonrpc/error_envelope.rs diff --git a/paddler_messaging/src/jsonrpc/mod.rs b/paddler_messaging/src/jsonrpc/mod.rs new file mode 100644 index 00000000..d1bac175 --- /dev/null +++ b/paddler_messaging/src/jsonrpc/mod.rs @@ -0,0 +1,4 @@ +pub mod error; +pub mod error_envelope; +pub mod request_envelope; +pub mod response_envelope; diff --git a/paddler_types/src/jsonrpc/request_envelope.rs b/paddler_messaging/src/jsonrpc/request_envelope.rs similarity index 100% rename from paddler_types/src/jsonrpc/request_envelope.rs rename to paddler_messaging/src/jsonrpc/request_envelope.rs diff --git a/paddler_types/src/jsonrpc/response_envelope.rs b/paddler_messaging/src/jsonrpc/response_envelope.rs similarity index 100% rename from paddler_types/src/jsonrpc/response_envelope.rs rename to paddler_messaging/src/jsonrpc/response_envelope.rs diff --git a/paddler_messaging/src/kv_cache_dtype.rs b/paddler_messaging/src/kv_cache_dtype.rs new file mode 100644 index 00000000..db6945a9 --- /dev/null +++ b/paddler_messaging/src/kv_cache_dtype.rs @@ -0,0 +1,22 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] +pub enum KvCacheDtype { + F32, + F16, + #[serde(rename = "BF16")] + Bf16, + #[serde(rename = "Q8_0")] + Q80, + #[serde(rename = "Q4_0")] + Q40, + #[serde(rename = "Q4_1")] + Q41, + #[serde(rename = "IQ4_NL")] + Iq4Nl, + #[serde(rename = "Q5_0")] + Q50, + #[serde(rename = "Q5_1")] + Q51, +} diff --git a/paddler_types/src/lib.rs b/paddler_messaging/src/lib.rs similarity index 90% rename from paddler_types/src/lib.rs rename to paddler_messaging/src/lib.rs index b5d3bb8b..242cd6b5 100644 --- a/paddler_types/src/lib.rs +++ b/paddler_messaging/src/lib.rs @@ -5,6 +5,7 @@ pub mod agent_desired_state; pub mod agent_issue; pub mod agent_issue_params; pub mod agent_state_application_status; +pub mod atomic_value; pub mod balancer_desired_state; pub mod buffered_request_manager_snapshot; pub mod chat_template; @@ -30,16 +31,19 @@ pub mod inference_parameters; pub mod inference_server; pub mod jsonrpc; pub mod kv_cache_dtype; +pub mod management_socket; pub mod media_marker; pub mod model_metadata; -pub mod normalization; pub mod oversized_embedding_document_details; pub mod oversized_image_details; pub mod pooling_type; +pub mod produces_snapshot; pub mod raw_tool_call_tokens; pub mod request_params; pub mod rpc_message; pub mod slot_aggregated_status_snapshot; pub mod streamable_result; +pub mod subscribes_to_updates; +pub mod tool_call_validation_error; pub mod url_model_reference; pub mod validates; diff --git a/paddler/src/agent/jsonrpc/message.rs b/paddler_messaging/src/management_socket/agent/message.rs similarity index 52% rename from paddler/src/agent/jsonrpc/message.rs rename to paddler_messaging/src/management_socket/agent/message.rs index 57950fcb..180dfe99 100644 --- a/paddler/src/agent/jsonrpc/message.rs +++ b/paddler_messaging/src/management_socket/agent/message.rs @@ -1,12 +1,12 @@ -use paddler_types::jsonrpc::Error; -use paddler_types::jsonrpc::ErrorEnvelope; -use paddler_types::jsonrpc::RequestEnvelope; -use paddler_types::rpc_message::RpcMessage; +use crate::jsonrpc::error::Error; +use crate::jsonrpc::error_envelope::ErrorEnvelope; +use crate::jsonrpc::request_envelope::RequestEnvelope; +use crate::rpc_message::RpcMessage; use serde::Deserialize; use serde::Serialize; -use super::Notification; -use super::Request; +use super::notification::Notification; +use super::request::Request; #[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_messaging/src/management_socket/agent/mod.rs b/paddler_messaging/src/management_socket/agent/mod.rs new file mode 100644 index 00000000..e7309624 --- /dev/null +++ b/paddler_messaging/src/management_socket/agent/mod.rs @@ -0,0 +1,5 @@ +pub mod message; +pub mod notification; +pub mod notification_params; +pub mod request; +pub mod response; diff --git a/paddler/src/agent/jsonrpc/notification.rs b/paddler_messaging/src/management_socket/agent/notification.rs similarity index 64% rename from paddler/src/agent/jsonrpc/notification.rs rename to paddler_messaging/src/management_socket/agent/notification.rs index eb05cc7c..6c2695ce 100644 --- a/paddler/src/agent/jsonrpc/notification.rs +++ b/paddler_messaging/src/management_socket/agent/notification.rs @@ -1,8 +1,8 @@ use serde::Deserialize; use serde::Serialize; -use super::notification_params::SetStateParams; -use super::notification_params::VersionParams; +use super::notification_params::set_state_params::SetStateParams; +use super::notification_params::version_params::VersionParams; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_messaging/src/management_socket/agent/notification_params/mod.rs b/paddler_messaging/src/management_socket/agent/notification_params/mod.rs new file mode 100644 index 00000000..4dcfe358 --- /dev/null +++ b/paddler_messaging/src/management_socket/agent/notification_params/mod.rs @@ -0,0 +1,2 @@ +pub mod set_state_params; +pub mod version_params; diff --git a/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs b/paddler_messaging/src/management_socket/agent/notification_params/set_state_params.rs similarity index 76% rename from paddler/src/agent/jsonrpc/notification_params/set_state_params.rs rename to paddler_messaging/src/management_socket/agent/notification_params/set_state_params.rs index 99876259..21f23aa2 100644 --- a/paddler/src/agent/jsonrpc/notification_params/set_state_params.rs +++ b/paddler_messaging/src/management_socket/agent/notification_params/set_state_params.rs @@ -1,4 +1,4 @@ -use paddler_types::agent_desired_state::AgentDesiredState; +use crate::agent_desired_state::AgentDesiredState; use serde::Deserialize; use serde::Serialize; diff --git a/paddler/src/agent/jsonrpc/notification_params/version_params.rs b/paddler_messaging/src/management_socket/agent/notification_params/version_params.rs similarity index 100% rename from paddler/src/agent/jsonrpc/notification_params/version_params.rs rename to paddler_messaging/src/management_socket/agent/notification_params/version_params.rs diff --git a/paddler/src/agent/jsonrpc/request.rs b/paddler_messaging/src/management_socket/agent/request.rs similarity index 68% rename from paddler/src/agent/jsonrpc/request.rs rename to paddler_messaging/src/management_socket/agent/request.rs index 88d95652..e5829d32 100644 --- a/paddler/src/agent/jsonrpc/request.rs +++ b/paddler_messaging/src/management_socket/agent/request.rs @@ -1,10 +1,10 @@ use serde::Deserialize; use serde::Serialize; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use crate::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use crate::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use crate::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; #[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_messaging/src/management_socket/agent/response.rs b/paddler_messaging/src/management_socket/agent/response.rs new file mode 100644 index 00000000..3893b04d --- /dev/null +++ b/paddler_messaging/src/management_socket/agent/response.rs @@ -0,0 +1,145 @@ +use crate::chat_template::ChatTemplate; +use crate::embedding_result::EmbeddingResult; +use crate::generated_token_result::GeneratedTokenResult; +use crate::model_metadata::ModelMetadata; +use serde::Deserialize; +use serde::Serialize; + +#[derive(Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub enum Response { + ChatTemplateOverride(Option), + Embedding(EmbeddingResult), + GeneratedToken(GeneratedTokenResult), + ModelMetadata(Option), +} + +impl From> for Response { + fn from(chat_template: Option) -> Self { + Self::ChatTemplateOverride(chat_template) + } +} + +impl From for Response { + fn from(embedding_result: EmbeddingResult) -> Self { + Self::Embedding(embedding_result) + } +} + +impl From for Response { + fn from(generated_token_result: GeneratedTokenResult) -> Self { + Self::GeneratedToken(generated_token_result) + } +} + +impl From> for Response { + fn from(model_metadata: Option) -> Self { + Self::ModelMetadata(model_metadata) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::mem::discriminant; + + use super::ChatTemplate; + use super::GeneratedTokenResult; + use super::ModelMetadata; + use super::Response; + + fn chat_template_payload(response: &Response) -> Option<&ChatTemplate> { + match response { + Response::ChatTemplateOverride(chat_template) => chat_template.as_ref(), + Response::Embedding(_) | Response::GeneratedToken(_) | Response::ModelMetadata(_) => { + None + } + } + } + + fn model_metadata_payload(response: &Response) -> Option<&ModelMetadata> { + match response { + Response::ModelMetadata(model_metadata) => model_metadata.as_ref(), + Response::ChatTemplateOverride(_) + | Response::Embedding(_) + | Response::GeneratedToken(_) => None, + } + } + + #[test] + fn converts_some_chat_template_into_chat_template_override_variant() { + let chat_template = ChatTemplate { + content: "{{ messages }}".to_owned(), + }; + let response = Response::from(Some(chat_template.clone())); + + assert_eq!( + discriminant(&response), + discriminant(&Response::ChatTemplateOverride(None)) + ); + assert_eq!(chat_template_payload(&response), Some(&chat_template)); + } + + #[test] + fn converts_none_chat_template_into_chat_template_override_variant() { + let response = Response::from(Option::::None); + + assert_eq!( + discriminant(&response), + discriminant(&Response::ChatTemplateOverride(None)) + ); + assert_eq!(chat_template_payload(&response), None); + } + + #[test] + fn generated_token_response_is_not_chat_template_override_variant() { + let response = Response::from(GeneratedTokenResult::ContentToken("hello".to_owned())); + + assert_ne!( + discriminant(&response), + discriminant(&Response::ChatTemplateOverride(None)) + ); + assert_eq!(chat_template_payload(&response), None); + } + + #[test] + fn converts_some_model_metadata_into_model_metadata_variant() { + let mut metadata = BTreeMap::new(); + metadata.insert("architecture".to_owned(), "llama".to_owned()); + let response = Response::from(Some(ModelMetadata { + metadata: metadata.clone(), + })); + + assert_eq!( + discriminant(&response), + discriminant(&Response::ModelMetadata(None)) + ); + + let extracted_metadata = model_metadata_payload(&response) + .expect("invariant: Some(ModelMetadata) carries a value"); + + assert_eq!(extracted_metadata.metadata, metadata); + } + + #[test] + fn converts_none_model_metadata_into_model_metadata_variant() { + let response = Response::from(Option::::None); + + assert_eq!( + discriminant(&response), + discriminant(&Response::ModelMetadata(None)) + ); + assert!(model_metadata_payload(&response).is_none()); + } + + #[test] + fn generated_token_response_is_not_model_metadata_variant() { + let response = Response::from(GeneratedTokenResult::ContentToken("hello".to_owned())); + + assert_ne!( + discriminant(&response), + discriminant(&Response::ModelMetadata(None)) + ); + assert!(model_metadata_payload(&response).is_none()); + } +} diff --git a/paddler_messaging/src/management_socket/balancer/message.rs b/paddler_messaging/src/management_socket/balancer/message.rs new file mode 100644 index 00000000..7f98a11e --- /dev/null +++ b/paddler_messaging/src/management_socket/balancer/message.rs @@ -0,0 +1,19 @@ +use crate::jsonrpc::error::Error; +use crate::jsonrpc::error_envelope::ErrorEnvelope; +use crate::jsonrpc::response_envelope::ResponseEnvelope; +use crate::rpc_message::RpcMessage; +use serde::Deserialize; +use serde::Serialize; + +use super::notification::Notification; +use crate::management_socket::agent::response::Response; + +#[derive(Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub enum Message { + Error(ErrorEnvelope), + Notification(Notification), + Response(ResponseEnvelope), +} + +impl RpcMessage for Message {} diff --git a/paddler_messaging/src/management_socket/balancer/mod.rs b/paddler_messaging/src/management_socket/balancer/mod.rs new file mode 100644 index 00000000..bbaa192a --- /dev/null +++ b/paddler_messaging/src/management_socket/balancer/mod.rs @@ -0,0 +1,3 @@ +pub mod message; +pub mod notification; +pub mod notification_params; diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification.rs b/paddler_messaging/src/management_socket/balancer/notification.rs similarity index 60% rename from paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification.rs rename to paddler_messaging/src/management_socket/balancer/notification.rs index 24561089..1821a54e 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification.rs +++ b/paddler_messaging/src/management_socket/balancer/notification.rs @@ -1,8 +1,8 @@ use serde::Deserialize; use serde::Serialize; -use super::notification_params::RegisterAgentParams; -use super::notification_params::UpdateAgentStatusParams; +use super::notification_params::register_agent_params::RegisterAgentParams; +use super::notification_params::update_agent_status_params::UpdateAgentStatusParams; #[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] diff --git a/paddler_messaging/src/management_socket/balancer/notification_params/mod.rs b/paddler_messaging/src/management_socket/balancer/notification_params/mod.rs new file mode 100644 index 00000000..646e2c7e --- /dev/null +++ b/paddler_messaging/src/management_socket/balancer/notification_params/mod.rs @@ -0,0 +1,2 @@ +pub mod register_agent_params; +pub mod update_agent_status_params; diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/register_agent_params.rs b/paddler_messaging/src/management_socket/balancer/notification_params/register_agent_params.rs similarity index 77% rename from paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/register_agent_params.rs rename to paddler_messaging/src/management_socket/balancer/notification_params/register_agent_params.rs index 00dae0ad..e6153ebc 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/register_agent_params.rs +++ b/paddler_messaging/src/management_socket/balancer/notification_params/register_agent_params.rs @@ -1,4 +1,4 @@ -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use crate::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use serde::Deserialize; use serde::Serialize; diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/update_agent_status_params.rs b/paddler_messaging/src/management_socket/balancer/notification_params/update_agent_status_params.rs similarity index 75% rename from paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/update_agent_status_params.rs rename to paddler_messaging/src/management_socket/balancer/notification_params/update_agent_status_params.rs index 705942ff..f6e17c39 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/jsonrpc/notification_params/update_agent_status_params.rs +++ b/paddler_messaging/src/management_socket/balancer/notification_params/update_agent_status_params.rs @@ -1,4 +1,4 @@ -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; +use crate::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use serde::Deserialize; use serde::Serialize; diff --git a/paddler_messaging/src/management_socket/mod.rs b/paddler_messaging/src/management_socket/mod.rs new file mode 100644 index 00000000..592aaa0d --- /dev/null +++ b/paddler_messaging/src/management_socket/mod.rs @@ -0,0 +1,2 @@ +pub mod agent; +pub mod balancer; diff --git a/paddler_types/src/media_marker.rs b/paddler_messaging/src/media_marker.rs similarity index 100% rename from paddler_types/src/media_marker.rs rename to paddler_messaging/src/media_marker.rs diff --git a/paddler_types/src/model_metadata.rs b/paddler_messaging/src/model_metadata.rs similarity index 100% rename from paddler_types/src/model_metadata.rs rename to paddler_messaging/src/model_metadata.rs diff --git a/paddler_types/src/oversized_embedding_document_details.rs b/paddler_messaging/src/oversized_embedding_document_details.rs similarity index 100% rename from paddler_types/src/oversized_embedding_document_details.rs rename to paddler_messaging/src/oversized_embedding_document_details.rs diff --git a/paddler_types/src/oversized_image_details.rs b/paddler_messaging/src/oversized_image_details.rs similarity index 100% rename from paddler_types/src/oversized_image_details.rs rename to paddler_messaging/src/oversized_image_details.rs diff --git a/paddler_types/src/pooling_type.rs b/paddler_messaging/src/pooling_type.rs similarity index 100% rename from paddler_types/src/pooling_type.rs rename to paddler_messaging/src/pooling_type.rs diff --git a/paddler/src/produces_snapshot.rs b/paddler_messaging/src/produces_snapshot.rs similarity index 100% rename from paddler/src/produces_snapshot.rs rename to paddler_messaging/src/produces_snapshot.rs diff --git a/paddler_messaging/src/raw_tool_call_tokens.rs b/paddler_messaging/src/raw_tool_call_tokens.rs new file mode 100644 index 00000000..e797c2aa --- /dev/null +++ b/paddler_messaging/src/raw_tool_call_tokens.rs @@ -0,0 +1,9 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct RawToolCallTokens { + pub text: String, + pub ffi_error_message: String, +} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/mod.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/mod.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/mod.rs diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/mod.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/mod.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/mod.rs diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/function.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/function.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/function.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/function.rs diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/mod.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/mod.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/mod.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/mod.rs diff --git a/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs new file mode 100644 index 00000000..32d71448 --- /dev/null +++ b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs @@ -0,0 +1,96 @@ +use anyhow::Result; +use serde::Deserialize; +use serde::Serialize; + +use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; +use crate::validates::Validates; +use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; + +#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] +#[serde(untagged)] +pub enum Parameters { + #[default] + Empty, + Schema(TParametersSchema), +} + +impl Parameters { + pub const fn is_empty(&self) -> bool { + matches!(self, Self::Empty) + } +} + +impl Validates> for Parameters { + fn validate(self) -> Result> { + match self { + Self::Empty => Ok(Parameters::Empty), + Self::Schema(schema) => Ok(Parameters::Schema(schema.validate()?)), + } + } +} + +#[cfg(test)] +mod tests { + use serde_json::Map; + use serde_json::Value; + + use super::*; + + fn properties_with_name() -> Map { + let mut properties = Map::new(); + properties.insert("name".to_owned(), Value::String("string".to_owned())); + + properties + } + + #[test] + fn is_empty_returns_true_for_empty_variant() { + let parameters: Parameters = Parameters::Empty; + + assert!(parameters.is_empty()); + } + + #[test] + fn is_empty_returns_false_for_schema_variant() { + let parameters = Parameters::Schema(RawParametersSchema { + schema_type: "object".to_owned(), + properties: Some(Map::new()), + required: None, + additional_properties: None, + }); + + assert!(!parameters.is_empty()); + } + + #[test] + fn validate_keeps_empty_variant_empty() { + let parameters: Parameters = Parameters::Empty; + + let validated = parameters.validate().unwrap(); + + assert!(validated.is_empty()); + } + + #[test] + fn validate_carries_schema_into_validated_variant() { + let parameters = Parameters::Schema(RawParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties_with_name()), + required: Some(vec!["name".to_owned()]), + additional_properties: None, + }); + + let validated = parameters.validate().unwrap(); + + assert!(!validated.is_empty()); + + let expected = Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties_with_name()), + required: Some(vec!["name".to_owned()]), + additional_properties: None, + }); + + assert_eq!(validated, expected); + } +} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/mod.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/mod.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/mod.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/mod.rs diff --git a/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs new file mode 100644 index 00000000..bf6e9572 --- /dev/null +++ b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs @@ -0,0 +1,106 @@ +use anyhow::Result; +use anyhow::anyhow; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Map; +use serde_json::Value; + +use super::validated_parameters_schema::ValidatedParametersSchema; +use crate::validates::Validates; + +#[derive(Default, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct RawParametersSchema { + #[serde(rename = "type")] + pub schema_type: String, + pub properties: Option>, + pub required: Option>, + #[serde(rename = "additionalProperties")] + pub additional_properties: Option, +} + +impl Validates for RawParametersSchema { + fn validate(self) -> Result { + if let (Some(required), Some(properties)) = (&self.required, &self.properties) { + for field in required { + if !properties.contains_key(field) { + return Err(anyhow!("Required field '{field}' not found in properties")); + } + } + } + + Ok(ValidatedParametersSchema { + schema_type: self.schema_type, + properties: self.properties, + required: self.required, + additional_properties: self.additional_properties, + }) + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn validate_passes_when_every_required_field_is_present() { + let mut properties = Map::new(); + properties.insert("name".to_owned(), json!({"type": "string"})); + properties.insert("age".to_owned(), json!({"type": "integer"})); + + let raw_schema = RawParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties), + required: Some(vec!["name".to_owned()]), + additional_properties: Some(json!(false)), + }; + + let schema = raw_schema.validate().unwrap(); + + assert_eq!(schema.schema_type, "object"); + assert_eq!(schema.properties.as_ref().unwrap().len(), 2); + assert_eq!(schema.required, Some(vec!["name".to_owned()])); + assert_eq!(schema.additional_properties, Some(json!(false))); + } + + #[test] + fn validate_passes_when_required_is_absent() { + let mut properties = Map::new(); + properties.insert("name".to_owned(), json!({"type": "string"})); + + let raw_schema = RawParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties), + required: None, + additional_properties: None, + }; + + let schema = raw_schema.validate().unwrap(); + + assert_eq!(schema.schema_type, "object"); + assert_eq!(schema.required, None); + assert_eq!(schema.additional_properties, None); + } + + #[test] + fn validate_fails_when_required_field_is_missing_from_properties() { + let mut properties = Map::new(); + properties.insert("name".to_owned(), json!({"type": "string"})); + + let raw_schema = RawParametersSchema { + schema_type: "object".to_owned(), + properties: Some(properties), + required: Some(vec!["name".to_owned(), "missing_field".to_owned()]), + additional_properties: None, + }; + + let error = raw_schema.validate().unwrap_err(); + + assert_eq!( + error.to_string(), + "Required field 'missing_field' not found in properties" + ); + } +} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/validated_parameters_schema.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/validated_parameters_schema.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/validated_parameters_schema.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/validated_parameters_schema.rs diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs b/paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs rename to paddler_messaging/src/request_params/continue_from_conversation_history_params/tool/tool_params/mod.rs diff --git a/paddler_types/src/request_params/continue_from_raw_prompt_params.rs b/paddler_messaging/src/request_params/continue_from_raw_prompt_params.rs similarity index 100% rename from paddler_types/src/request_params/continue_from_raw_prompt_params.rs rename to paddler_messaging/src/request_params/continue_from_raw_prompt_params.rs diff --git a/paddler_types/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs b/paddler_messaging/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs similarity index 100% rename from paddler_types/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs rename to paddler_messaging/src/request_params/generate_embedding_batch_params/chunk_evenly_with_cap_error.rs diff --git a/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs b/paddler_messaging/src/request_params/generate_embedding_batch_params/mod.rs similarity index 72% rename from paddler_types/src/request_params/generate_embedding_batch_params/mod.rs rename to paddler_messaging/src/request_params/generate_embedding_batch_params/mod.rs index b9a0fc42..a5682c6c 100644 --- a/paddler_types/src/request_params/generate_embedding_batch_params/mod.rs +++ b/paddler_messaging/src/request_params/generate_embedding_batch_params/mod.rs @@ -1,9 +1,9 @@ -mod chunk_evenly_with_cap_error; +pub mod chunk_evenly_with_cap_error; use serde::Deserialize; use serde::Serialize; -pub use self::chunk_evenly_with_cap_error::ChunkEvenlyWithCapError; +use self::chunk_evenly_with_cap_error::ChunkEvenlyWithCapError; use crate::embedding_input_document::EmbeddingInputDocument; use crate::embedding_normalization_method::EmbeddingNormalizationMethod; @@ -65,8 +65,6 @@ impl GenerateEmbeddingBatchParams { #[cfg(test)] mod tests { - use anyhow::Result; - use super::*; fn make_doc(id: &str, content: &str) -> EmbeddingInputDocument { @@ -90,209 +88,188 @@ mod tests { } #[test] - fn chunk_evenly_with_cap_empty_input() -> Result<()> { + fn chunk_evenly_with_cap_empty_input() { let params = make_params(vec![]); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert!(sub_batches.is_empty()); - - Ok(()) } #[test] - fn chunk_evenly_with_cap_single_doc_single_agent() -> Result<()> { + fn chunk_evenly_with_cap_single_doc_single_agent() { let params = make_params(vec![make_doc("only", "content")]); - let sub_batches = params.chunk_evenly_with_cap(1, 256)?; + let sub_batches = params.chunk_evenly_with_cap(1, 256).unwrap(); assert_eq!(sub_batches.len(), 1); assert_eq!(sub_batches[0].input_batch.len(), 1); assert_eq!(sub_batches[0].input_batch[0].id, "only"); - - Ok(()) } #[test] - fn chunk_evenly_with_cap_single_doc_many_agents() -> Result<()> { + fn chunk_evenly_with_cap_single_doc_many_agents() { let params = make_params(vec![make_doc("only", "content")]); - let sub_batches = params.chunk_evenly_with_cap(5, 256)?; + let sub_batches = params.chunk_evenly_with_cap(5, 256).unwrap(); assert_eq!(sub_batches.len(), 1); assert_eq!(sub_batches[0].input_batch.len(), 1); assert_eq!(sub_batches[0].input_batch[0].id, "only"); - - Ok(()) } #[test] - fn chunk_evenly_with_cap_more_agents_than_docs_uses_n_chunks() -> Result<()> { + fn chunk_evenly_with_cap_more_agents_than_docs_uses_n_chunks() { let params = make_params(make_docs(3)); - let sub_batches = params.chunk_evenly_with_cap(5, 256)?; + let sub_batches = params.chunk_evenly_with_cap(5, 256).unwrap(); assert_eq!(sub_batches.len(), 3); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 1); } - - Ok(()) } #[test] fn chunk_evenly_with_cap_rejects_zero_agent_count() { let params = make_params(make_docs(5)); - let result = params.chunk_evenly_with_cap(0, 256); + let is_zero_agent_count = + |result: Result, ChunkEvenlyWithCapError>| { + matches!(result, Err(ChunkEvenlyWithCapError::ZeroAgentCount)) + }; - assert!(matches!( - result, - Err(ChunkEvenlyWithCapError::ZeroAgentCount) - )); + assert!(is_zero_agent_count(params.chunk_evenly_with_cap(0, 256))); + assert!(!is_zero_agent_count(params.chunk_evenly_with_cap(2, 0))); } #[test] fn chunk_evenly_with_cap_rejects_zero_max_documents_per_chunk() { let params = make_params(make_docs(4)); - let result = params.chunk_evenly_with_cap(2, 0); + let is_zero_max_documents = + |result: Result, ChunkEvenlyWithCapError>| { + matches!( + result, + Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) + ) + }; - assert!(matches!( - result, - Err(ChunkEvenlyWithCapError::ZeroMaxDocumentsPerChunk) - )); + assert!(is_zero_max_documents(params.chunk_evenly_with_cap(2, 0))); + assert!(!is_zero_max_documents(params.chunk_evenly_with_cap(0, 0))); } #[test] - fn chunk_evenly_with_cap_below_cap_splits_per_agent() -> Result<()> { + fn chunk_evenly_with_cap_below_cap_splits_per_agent() { let params = make_params(make_docs(4)); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert_eq!(sub_batches.len(), 4); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 1); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_below_cap_uneven_split() -> Result<()> { + fn chunk_evenly_with_cap_below_cap_uneven_split() { let params = make_params(make_docs(11)); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert_eq!(sub_batches.len(), 4); assert_eq!(sub_batches[0].input_batch.len(), 3); assert_eq!(sub_batches[1].input_batch.len(), 3); assert_eq!(sub_batches[2].input_batch.len(), 3); assert_eq!(sub_batches[3].input_batch.len(), 2); - - Ok(()) } #[test] - fn chunk_evenly_with_cap_user_example_80_docs_4_agents_cap_100() -> Result<()> { + fn chunk_evenly_with_cap_user_example_80_docs_4_agents_cap_100() { let params = make_params(make_docs(80)); - let sub_batches = params.chunk_evenly_with_cap(4, 100)?; + let sub_batches = params.chunk_evenly_with_cap(4, 100).unwrap(); assert_eq!(sub_batches.len(), 4); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 20); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_user_example_1000_docs_4_agents_cap_100() -> Result<()> { + fn chunk_evenly_with_cap_user_example_1000_docs_4_agents_cap_100() { let params = make_params(make_docs(1000)); - let sub_batches = params.chunk_evenly_with_cap(4, 100)?; + let sub_batches = params.chunk_evenly_with_cap(4, 100).unwrap(); assert_eq!(sub_batches.len(), 10); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 100); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_at_cap_boundary_uses_agent_count() -> Result<()> { + fn chunk_evenly_with_cap_at_cap_boundary_uses_agent_count() { let params = make_params(make_docs(1024)); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert_eq!(sub_batches.len(), 4); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 256); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_above_cap_boundary_creates_extra_chunks() -> Result<()> { + fn chunk_evenly_with_cap_above_cap_boundary_creates_extra_chunks() { let params = make_params(make_docs(2000)); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert_eq!(sub_batches.len(), 8); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 250); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_far_above_cap_distributes_evenly() -> Result<()> { + fn chunk_evenly_with_cap_far_above_cap_distributes_evenly() { let params = make_params(make_docs(1100)); - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); assert_eq!(sub_batches.len(), 5); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 220); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_extreme_large_n_small_cap() -> Result<()> { + fn chunk_evenly_with_cap_extreme_large_n_small_cap() { let params = make_params(make_docs(10_000)); - let sub_batches = params.chunk_evenly_with_cap(4, 1)?; + let sub_batches = params.chunk_evenly_with_cap(4, 1).unwrap(); assert_eq!(sub_batches.len(), 10_000); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 1); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_extreme_one_doc_per_chunk() -> Result<()> { + fn chunk_evenly_with_cap_extreme_one_doc_per_chunk() { let params = make_params(make_docs(100)); - let sub_batches = params.chunk_evenly_with_cap(100, 256)?; + let sub_batches = params.chunk_evenly_with_cap(100, 256).unwrap(); assert_eq!(sub_batches.len(), 100); for sub_batch in &sub_batches { assert_eq!(sub_batch.input_batch.len(), 1); } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_no_sub_batch_exceeds_cap_sweep() -> Result<()> { + fn chunk_evenly_with_cap_no_sub_batch_exceeds_cap_sweep() { let document_counts: Vec = (0..=50).chain([256, 257, 1000, 2001]).collect(); let agent_counts: Vec = (1..=8).collect(); let caps: Vec = vec![1, 2, 4, 100, 256]; @@ -302,7 +279,7 @@ mod tests { for &cap in &caps { let params = make_params(make_docs(document_count)); - let sub_batches = params.chunk_evenly_with_cap(agent_count, cap)?; + let sub_batches = params.chunk_evenly_with_cap(agent_count, cap).unwrap(); let total_documents: usize = sub_batches.iter().map(|sub| sub.input_batch.len()).sum(); @@ -311,14 +288,15 @@ mod tests { "total documents must equal N (N={document_count}, agents={agent_count}, cap={cap})", ); - for sub_batch in &sub_batches { - assert!( - sub_batch.input_batch.len() <= cap, - "sub-batch size {} exceeds cap {} (N={document_count}, agents={agent_count}, cap={cap})", - sub_batch.input_batch.len(), - cap, - ); - } + let largest_sub_batch_size = sub_batches + .iter() + .map(|sub| sub.input_batch.len()) + .max() + .unwrap_or_default(); + assert!( + largest_sub_batch_size <= cap, + "largest sub-batch size {largest_sub_batch_size} exceeds cap {cap} (N={document_count}, agents={agent_count})", + ); let collected_ids: Vec = sub_batches .iter() @@ -344,40 +322,40 @@ mod tests { ); } } else { - assert!(sub_batches.is_empty(), "empty input must produce empty Vec",); + assert!(sub_batches.is_empty(), "empty input must produce empty Vec"); } } } } - - Ok(()) } #[test] - fn chunk_evenly_with_cap_preserves_normalization_method() -> Result<()> { + fn chunk_evenly_with_cap_preserves_normalization_method() { let params = GenerateEmbeddingBatchParams { input_batch: make_docs(8), normalization_method: EmbeddingNormalizationMethod::L2, }; - let sub_batches = params.chunk_evenly_with_cap(4, 256)?; + let sub_batches = params.chunk_evenly_with_cap(4, 256).unwrap(); - assert_eq!(sub_batches.len(), 4); - for sub_batch in &sub_batches { - assert!(matches!( - sub_batch.normalization_method, - EmbeddingNormalizationMethod::L2 - )); - } + let is_l2 = |normalization_method: &EmbeddingNormalizationMethod| { + matches!(normalization_method, EmbeddingNormalizationMethod::L2) + }; - Ok(()) + assert_eq!(sub_batches.len(), 4); + assert!( + sub_batches + .iter() + .all(|sub_batch| is_l2(&sub_batch.normalization_method)) + ); + assert!(!is_l2(&EmbeddingNormalizationMethod::None)); } #[test] - fn chunk_evenly_with_cap_preserves_document_ids_and_order() -> Result<()> { + fn chunk_evenly_with_cap_preserves_document_ids_and_order() { let params = make_params(make_docs(12)); - let sub_batches = params.chunk_evenly_with_cap(5, 256)?; + let sub_batches = params.chunk_evenly_with_cap(5, 256).unwrap(); let collected_ids: Vec = sub_batches .iter() @@ -386,7 +364,5 @@ mod tests { let expected_ids: Vec = (0..12).map(|index| format!("doc{index:05}")).collect(); assert_eq!(collected_ids, expected_ids); - - Ok(()) } } diff --git a/paddler_messaging/src/request_params/mod.rs b/paddler_messaging/src/request_params/mod.rs new file mode 100644 index 00000000..036cc852 --- /dev/null +++ b/paddler_messaging/src/request_params/mod.rs @@ -0,0 +1,3 @@ +pub mod continue_from_conversation_history_params; +pub mod continue_from_raw_prompt_params; +pub mod generate_embedding_batch_params; diff --git a/paddler_types/src/rpc_message.rs b/paddler_messaging/src/rpc_message.rs similarity index 100% rename from paddler_types/src/rpc_message.rs rename to paddler_messaging/src/rpc_message.rs diff --git a/paddler_types/src/slot_aggregated_status_snapshot.rs b/paddler_messaging/src/slot_aggregated_status_snapshot.rs similarity index 100% rename from paddler_types/src/slot_aggregated_status_snapshot.rs rename to paddler_messaging/src/slot_aggregated_status_snapshot.rs diff --git a/paddler_types/src/streamable_result.rs b/paddler_messaging/src/streamable_result.rs similarity index 100% rename from paddler_types/src/streamable_result.rs rename to paddler_messaging/src/streamable_result.rs diff --git a/paddler/src/subscribes_to_updates.rs b/paddler_messaging/src/subscribes_to_updates.rs similarity index 100% rename from paddler/src/subscribes_to_updates.rs rename to paddler_messaging/src/subscribes_to_updates.rs diff --git a/paddler/src/tool_call_validation_error.rs b/paddler_messaging/src/tool_call_validation_error.rs similarity index 100% rename from paddler/src/tool_call_validation_error.rs rename to paddler_messaging/src/tool_call_validation_error.rs diff --git a/paddler_types/src/url_model_reference.rs b/paddler_messaging/src/url_model_reference.rs similarity index 100% rename from paddler_types/src/url_model_reference.rs rename to paddler_messaging/src/url_model_reference.rs diff --git a/paddler_types/src/validates.rs b/paddler_messaging/src/validates.rs similarity index 100% rename from paddler_types/src/validates.rs rename to paddler_messaging/src/validates.rs diff --git a/paddler_openai_client_python_test/.gitignore b/paddler_openai_client_python_test/.gitignore new file mode 100644 index 00000000..0d778d0f --- /dev/null +++ b/paddler_openai_client_python_test/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.venv/ diff --git a/paddler_openai_client_python_test/README.md b/paddler_openai_client_python_test/README.md new file mode 100644 index 00000000..ad4522f9 --- /dev/null +++ b/paddler_openai_client_python_test/README.md @@ -0,0 +1,20 @@ +# paddler_openai_client_python_test + +Verifies that the **official OpenAI Python client** works against Paddler's OpenAI-compatible +endpoints (`/v1/chat/completions` and `/v1/responses`). It depends on the official `openai` package +only — never on Paddler's own client — and exercises nothing but the OpenAI endpoints, so a passing +run is objective evidence that a real OpenAI client is compatible with the server. + +It does not start or configure a cluster. Point it at an already-running, model-configured Paddler +server via `PADDLER_OPENAI_BASE_URL`; the suite fails if that variable is not set. + +## Running + +```sh +poetry install +PADDLER_OPENAI_BASE_URL=http://127.0.0.1:8062/v1 poetry run pytest +``` + +- `PADDLER_OPENAI_BASE_URL` (required): base URL of the running endpoint, ending in `/v1`. +- `PADDLER_OPENAI_MODEL` (optional, default `qwen3`): the model name to send. Paddler ignores it, + but the OpenAI client requires one. diff --git a/paddler_openai_client_python_test/poetry.lock b/paddler_openai_client_python_test/poetry.lock new file mode 100644 index 00000000..5fd96baf --- /dev/null +++ b/paddler_openai_client_python_test/poetry.lock @@ -0,0 +1,850 @@ +# This file is automatically @generated by Poetry 2.4.1 and should not be changed by hand. + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[[package]] +name = "anyio" +version = "4.13.0" +description = "High-level concurrency and networking framework on top of asyncio or Trio" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708"}, + {file = "anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc"}, +] + +[package.dependencies] +idna = ">=2.8" +typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} + +[package.extras] +trio = ["trio (>=0.32.0)"] + +[[package]] +name = "certifi" +version = "2026.5.20" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "certifi-2026.5.20-py3-none-any.whl", hash = "sha256:3c52e209ba0a4ad7aebe60436a4ab349c39e1e602e8c134221e546902ad25897"}, + {file = "certifi-2026.5.20.tar.gz", hash = "sha256:69dea482ab64caa7b9f6aba1c6bf48bb6a5448d1c0f1b17ab42ad8c763a5344d"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] +markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} + +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + +[[package]] +name = "h11" +version = "0.16.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"}, + {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"}, + {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.16" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.28.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" + +[package.extras] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + +[[package]] +name = "idna" +version = "3.18" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "idna-3.18-py3-none-any.whl", hash = "sha256:7f952cbe720b688055e3f87de14f5c3e5fdaa8bc3928985c4077ca689de849a2"}, + {file = "idna-3.18.tar.gz", hash = "sha256:ffb385a7e039654cef1ab9ef32c6fafe283c0c0467bba1d9029738ce4a14a848"}, +] + +[package.extras] +all = ["mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] + +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + +[[package]] +name = "jiter" +version = "0.15.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "jiter-0.15.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:edebcf7d1f601199084bb6e844d7dc67e03e04f6ac786b0332d616635c4ff7a4"}, + {file = "jiter-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f924585cdacf631cd382b657966847bb537bf9ed0a6f9b991da5f05a631480f"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abbf258599526ad0326fe51e252e24f2bd6f24f1852681b4b78feda3808f1d18"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c468136b8bd6bb18c8786e4236a1fa27362f24cb23450ba0cb204ab379b8e6f"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05906b93d72f03339e6bb7cf8dc10ebda64a0266126eed6beba79e20abcf5fd4"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30ce785d2adb8e32c3f7741442370a74834ec4c01f3c48f0750227a0b4ef27d6"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd73e3da91a0a722d67165e849ce2cdc10de0e0d48738c142be8c6c5f310f4c"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:ceb8fc27d38793f9c97149be8302720c5b22e5c195a37bf2c45dc36c4600a512"}, + {file = "jiter-0.15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d726e3ceeb337191324b49de298142f27c3ad10886341555d1d5315b5f252c6a"}, + {file = "jiter-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:2c8aea7781d2a372227871de4e1a1332aa96f5a89fd76c5e835dafdbad102887"}, + {file = "jiter-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cf4bd113a69c0a740e27cb962ce10630c36d2b8f59d759a651b955ee9d18a823"}, + {file = "jiter-0.15.0-cp310-cp310-win32.whl", hash = "sha256:d92a5cd21fdb083931d546c207aa29633787c5dc5b02daab2d32b843f88a2c53"}, + {file = "jiter-0.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:e58585a58209d72691ce2d62a9147445f5a87beb0bde97fde284c96ae392a3d1"}, + {file = "jiter-0.15.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:0f862193b8696249d22ec433e85fd2ab0ad9596bc3e45e6c0bc55e8aeba97be2"}, + {file = "jiter-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1303d4d68a9b051ea90502402063ecf3807da00ad2affa19ca1ae3b90b3c5f67"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:392b8ab019e5502d08aff85c6272209c24bc2cbe706ea82a56368f524236614a"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:773b6eb282ce11ee19f05f6b2d4404fa308e5bbd353b0b80a0262caad6db2cd7"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8d2c0c44d569ce0f2850f5c926f8caeb5f245fbc84475aeb36efccc2103e6dbd"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:032396229564bca02440396bd327710719f724f5e7b7e9f7a8eb3faa4a2c2281"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3d37768fce7f88dd2a8c6091f2325dea27d30d30d5c6e7a1c0f0af77723b708"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:2c9cb907439d20bd0c7d7565ca01ee52234203208433749bae5b516907526928"}, + {file = "jiter-0.15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9100ddbec09741cc66feb0fc6773f8bdbd0e3c345689368f260082ff85dcc0cd"}, + {file = "jiter-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ae1b0d82ac2d987f9ea512b1c9adfcc71a28de3dea3a6039b54d76cffda9901e"}, + {file = "jiter-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8020c99ec13a7db2b6f96cbe82ef4721c88b426a4892f27478044af0284615ef"}, + {file = "jiter-0.15.0-cp311-cp311-win32.whl", hash = "sha256:42bfb257930800cf43e7c62c832402c704ab60797c992faf88d20e903eac8f32"}, + {file = "jiter-0.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:860a74063284a2ae9bfedd694f299cc2c68e2696c5f3d440cc9d18bb81b9dd04"}, + {file = "jiter-0.15.0-cp311-cp311-win_arm64.whl", hash = "sha256:37a10c377ce3a4a85f4a67f28b7afe093154cde77eaf248a72e856aa08b4d865"}, + {file = "jiter-0.15.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0e90a1c315a0226ec822d973817967f9223b7701546c8c2a7913e7ab0926294d"}, + {file = "jiter-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c9004af7c8d67cce7f1aae1026fb55607f4aa600710d08ede3a3ce4aeefe7e0"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c210f8b35dc6f30aafd4b4365ca89b9d1189f21ab49b8e68fa6322a847aef138"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f30bae8bc1c2d613e28e5af3e8cceb09b742f1c8a8a5f839fb67afaffc03b61"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c60e71b6d10cfc284c9bf36bd885e8d44c46f688ce50aa91b5edd90181dea687"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ab068bce62a45aa3e7367eceaffb5dde60b7eb853be8dece45132e3d0ff4879"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa248c9eb220197d363f688818dac2fd4b2f0cd7d843ca7105d652034823427d"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2a77aadd57cac1682e4401a72724d2796d89a4ba129b1a5812aa94ee480826eb"}, + {file = "jiter-0.15.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2ae901f3a55bfafdde31d289590fa25e3245735a2b1e8c7cc15871710a002871"}, + {file = "jiter-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f0b271b462769543716f92d3a4f90527df6ef5ed05ee95ec4137f513e21e1b77"}, + {file = "jiter-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2fb6a5d26af81fc0f00f9360a891e05cf755e149bba391c4d563adc54812973d"}, + {file = "jiter-0.15.0-cp312-cp312-win32.whl", hash = "sha256:c2f6bb8b5216ab9e7873bc08b5d7bef2b8abbb578a3069bf1cd14a45d71d771d"}, + {file = "jiter-0.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:40b2c7e92c44a84d748d21706c68dc6ff8161d80b59c99d774721a0d2317d7c7"}, + {file = "jiter-0.15.0-cp312-cp312-win_arm64.whl", hash = "sha256:cc0bc345cf2df9d1c00ac443f50d543c1ccfa8b0422cb85b1ab70d681c0b255b"}, + {file = "jiter-0.15.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1c11465f97e2abf45a014b83b730222f8f1c5335e802c7055a67d50de6f1f4e3"}, + {file = "jiter-0.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d1e7b1776f0797956c509e123d0952d10d293a9492dea9f288ab9570ec01d1a5"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:351a341c2105aa430b7047e30f1bf7975f6313b00165d3fc07be2edaf741f279"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4ab395feec8d249ec4044e228e98a7033f043426a265df439dc3698823f0a4e4"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2a438005b6f22d0273413484d6094d7c2c5d10ec1b3a3bf128e0d1d3ba53258"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f18f85e4218d1b40f000f42a92239a7a61a902cd42c65e6c360dbd17dcb20894"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1aa62e277fc1cbd80e6deacae6f4d983b41b3d7728e0645c5d741a6149bba45"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:6550fa135c7deb8ead6af49ed7ff648532ea8334a1447fe34a36315ef79c5c29"}, + {file = "jiter-0.15.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:066f8f33f18b2419cd8213b2436fa7fbc9c499f315971cfa3ce1f9820c001b1b"}, + {file = "jiter-0.15.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:75e8a04e91432dde9f1838373cf93d23726c79d3e908d319acf0e796f85592e7"}, + {file = "jiter-0.15.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:a97261f1fccb8e50ecd2890a96e46efdc3f57c80a197324c6777827231eca712"}, + {file = "jiter-0.15.0-cp313-cp313-win32.whl", hash = "sha256:c77496cb10bd7549690fbbab3e5ec05857b83e49276f4a9423a766ddd2afcd4c"}, + {file = "jiter-0.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b15741f501469009ae0ae90b7147958a664a7dede40aa7ff174a8a4645f546d0"}, + {file = "jiter-0.15.0-cp313-cp313-win_arm64.whl", hash = "sha256:5d6a60072b44c3c2b797a7ddcbcbbf2b34ea3cfd4721580fbfd2a09d9d9b84ba"}, + {file = "jiter-0.15.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ef1fd24d9413f6209e00d3d5a453e67acfe004a25cc6c8e8484faed4311ab9e8"}, + {file = "jiter-0.15.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:144f8e72cb53dab146347b91cceac01f5481237f2b93b4a339a1ee8f8878b67c"}, + {file = "jiter-0.15.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:553fcac2ef2cb990877f9fc0833b8b629a3e6a5670b6b5fd58219b41a653ddc4"}, + {file = "jiter-0.15.0-cp313-cp313t-win_amd64.whl", hash = "sha256:774f93f65031856bf14ad9f59bdcab8b8cad501e5ceabd51ba3525f76937a25b"}, + {file = "jiter-0.15.0-cp313-cp313t-win_arm64.whl", hash = "sha256:f1e1754960f38ec40613a07e5e372df67acb3b890fb383b6fb3de3e49ddbf3c7"}, + {file = "jiter-0.15.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:ac0d9ddea4350974be7a221fc25895f251a8fee748c889bdced2141c0fec1a49"}, + {file = "jiter-0.15.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:01a8222cf05ab1128e239421156c207949808acaaea2bdfd33130ae666786e86"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:182226cbc930c9fab81bc2e41a4da672f89539906dadb05e75670ac07b94f71f"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:71683c38c825452999b5717fcae07ea708e8c93003e808be4319c1b02e3d176e"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30f2218e6a9e5c18bc10fe6d41ac189c442c88eacf11bad9f28ef95a9bef00e6"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5157de9f76eb4bc5ea74a1219366a25f945ad305641d74e04f59c54087091aa9"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c5db5527c221249a876160663ab891ace358c17f7b9c93ec1478b7f0550e5c"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:3e4540b8e74e4268811ac05db226a6a128ff572e7e0ce3f1163b693cadb184cd"}, + {file = "jiter-0.15.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:62ebd14e47e9aed9df4472afcb2663668ce4d74891cd54f86bf6e44029d6dc89"}, + {file = "jiter-0.15.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0be6f5ad41a809f303f416d17cec92a7a725902fb9b4f3de3d19362ac0ef8554"}, + {file = "jiter-0.15.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:813dfbb17d65328bf86e5f0905dd277ba2265d3ca20556e86c0c7035b7182e5a"}, + {file = "jiter-0.15.0-cp314-cp314-win32.whl", hash = "sha256:50e51156192722a9c58db112837d3f8ef96fb3c5ecc14e95f409134b08b158ec"}, + {file = "jiter-0.15.0-cp314-cp314-win_amd64.whl", hash = "sha256:30ce1a5d16b5641dc935d50ef775af6a0871e3d14ab05d6fc54dff371b78e558"}, + {file = "jiter-0.15.0-cp314-cp314-win_arm64.whl", hash = "sha256:510c8b3c17a0ed9ac69850c0438dada3c9b82d9c4d589fcb62002a5a9cf3a866"}, + {file = "jiter-0.15.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7553333dd0930c104a5a0db8df72bf7219fe663d731383b576bb6ed6351c984d"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2143ab06181d2b029eedcb6af3cebe95f11bbac62441781860f98ee9330a6a6"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6eac374c5c975709b69c10f09afd199df74150172156ad10c8d4fd785b7da995"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b3b3b775e33d3bfaec9899edc526ae97b0da0bf9d071a46124ba419149a414f8"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eda3071db3346334beae1360b46da4606da57bf3528c167b3c38533afaf9f2c5"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6694a173ecabc12eb60efbc0b474464ead1951ff65cd8b1e72100715c64512b"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:a254e10b593624d230c365b6d616b22ca0ad65e63a16e6631c2b3466022e6ba8"}, + {file = "jiter-0.15.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d8d2955167274e15d79a7a020afdd9b39c990eb80b2d89fca695d92dcfdd38ec"}, + {file = "jiter-0.15.0-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:acf4ee4d1fc55917239fe72972fb292dd773055d05eb040d36f4326e02cc2c0e"}, + {file = "jiter-0.15.0-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:e7196e56f1cd69af1dbb07dff02dcfb260a50b45a82d409d92a06fedb32473b5"}, + {file = "jiter-0.15.0-cp314-cp314t-win32.whl", hash = "sha256:7f6163c0f10b055245f814dcc59f4818da60dfe72f3e72ab89fc24b6bd5e9c52"}, + {file = "jiter-0.15.0-cp314-cp314t-win_amd64.whl", hash = "sha256:980c256edb05b78a111b99c4de3b1d32e31634b867fd1fc2cf726e7b7bba9854"}, + {file = "jiter-0.15.0-cp314-cp314t-win_arm64.whl", hash = "sha256:66b1880df2d01e206e8339769d1c7c1753bcb653efd6289e203f6f24ebada0c0"}, + {file = "jiter-0.15.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:04b400bbf8c9efb03d9bdd976475c919c1d85593b04b9fff7ae234065daf87ae"}, + {file = "jiter-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25ffbe229aa8cd98c28879d8aa1a6e34ae77992ab984a65fba800859dab16269"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5607e6013ed7e6b0ec9661e467b7ffde0aa7ab36833a04850f26fcf88ed4845b"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:50164d7610c00e7cd913a873fce30b6beeebf4b37e53983e33f22de4c900f6b8"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab596fa3837e91e7e6a31b5f639988bfc6a35d1f915ac3932d946062219d588f"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72d8af5c1013656a8870c866660627d1a75bc185814ee022c8533caa1de88ae"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c84c1b7be454b0c16f8499b4ebfbfd82ea5cca6527cceefcbbc06a7557b5ed2e"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_31_riscv64.whl", hash = "sha256:d636d5095155afd364247f65070fab7beda13498d7ff4de331046e704ab9657f"}, + {file = "jiter-0.15.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7d3d6683288c11cbab50e865f2e2f13950179aa45410e30b2cfbd3fb7b0177bf"}, + {file = "jiter-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7ce8902f939970048b233087082e7bb829db29375811c7ad50687b8624c6fd08"}, + {file = "jiter-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4363818355dbc70ae1a8e9eaba9de350d93ede4ff6992b8f8eb8cbb6e5122d42"}, + {file = "jiter-0.15.0-cp39-cp39-win32.whl", hash = "sha256:8f7e9bc0f1135039b22ee6eab588d42df1ce55842b30740a352885eb267bd941"}, + {file = "jiter-0.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:1c15024a3d892223b18f597c86d59387249dc396590844ce6b9f6131d1093bae"}, + {file = "jiter-0.15.0-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:411fa4dfa5a7ae3d11491027ffb9beadec3996010a986862db70d91abba1c750"}, + {file = "jiter-0.15.0-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:2b0074e2f56eb2dacca1689760fd2852a068f85a0547a157b82cb4cafeb6768b"}, + {file = "jiter-0.15.0-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:913d02d29c9606643418d9ccfc3b72492ab25a6bf7889934e09a3490f8d3438b"}, + {file = "jiter-0.15.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b15d3ec9b0449c40e85319bdb4caa8b77ab526e74f5532ed94bec15e2f66822c"}, + {file = "jiter-0.15.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:631f13a3d04e97d4e083993b10f4b99530e3a10d953e2eb5e196b7dc7f812ce0"}, + {file = "jiter-0.15.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:b6c0ffae686c39bf3737be60793783267628783ea42545632c10b291105aee45"}, + {file = "jiter-0.15.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d54fb5b31dea401a41af3f8a7d2512e9b6a6a005491e6166c7e4ffab9639a9c"}, + {file = "jiter-0.15.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54d5d6090cdc1b7c9e780dfb04949a990adb1e301a2fc0bbcee7de4638d33f9a"}, + {file = "jiter-0.15.0.tar.gz", hash = "sha256:4251acc80e2b7c9b7b8823456ea0fceeb0734dac2df7636d3c711b38476b5a76"}, +] + +[[package]] +name = "librt" +version = "0.11.0" +description = "Mypyc runtime library" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "platform_python_implementation != \"PyPy\"" +files = [ + {file = "librt-0.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6e94ebfcfa2d5e9926d6c3b9aa4617ffc42a845b4321fb84021b872358c82a0f"}, + {file = "librt-0.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ae627397a2f351560440d872d6f7c8dbb4072e57868e7b2fc5b8b430fe489d45"}, + {file = "librt-0.11.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc329359321b67d24efdf4bc69012b0597001649544db662c001db5a0184794c"}, + {file = "librt-0.11.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:7e82e642ab0f7608ce2fe53d76ca2280a9ee33a1b06556142c7c6fe80a86fc33"}, + {file = "librt-0.11.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:88145c15c67731d54283d135b03244028c750cc9edc334a96a4f5950ebdb2884"}, + {file = "librt-0.11.0-cp310-cp310-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9d36a51b3d93320b686588e27123f4995804dbf1bce81df78c02fc3c6eea9280"}, + {file = "librt-0.11.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3ac06a2a8b246327f11e186a53a100a4d5c7ed52346367e5ec751d51586c"}, + {file = "librt-0.11.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:461bbceede621f1ffb8839755f8663e886087ee7af16294cab7fb4d782c62eeb"}, + {file = "librt-0.11.0-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0cad8a4d6a8ff03c9b76f9414caccd78e7cfbc8a2e12fa334d8e1d9932753783"}, + {file = "librt-0.11.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f37aa505b3cf60701562eddb32df74b12a9e380c207fd8b06dd157a943ac7ea0"}, + {file = "librt-0.11.0-cp310-cp310-win32.whl", hash = "sha256:94663a21534637f0e787ec2a2a756022df6e5b7b2335a5cdd7d8e33d68a2af89"}, + {file = "librt-0.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:dec7db73758c2b54953fd8b7fe348c45188fe26b39ee18446196edd08453a5d4"}, + {file = "librt-0.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:93d95bd45b7d58343d8b90d904450a545144eec19a002511163426f8ab1fae29"}, + {file = "librt-0.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4ee278c769a713638cdacd4c0436d72156e75df3ebc0166ab2b9dc43acc386c9"}, + {file = "librt-0.11.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f230cb1cbc9faaa616f9a678f530ebcf186e414b6bcbd88b960e4ba1b92428d5"}, + {file = "librt-0.11.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:5d63c855d86938d9de93e265c9bd8c705b51ec494de5738340ee93767a686e4b"}, + {file = "librt-0.11.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:993f028be9e96a08d31df3479ac80d99be374d17f3b78e4796b3fd3c913d4e89"}, + {file = "librt-0.11.0-cp311-cp311-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:258d73a0aa66a055e65b2e4d1b8cdb23b9d132c5bb915d9547d804fcaed116cc"}, + {file = "librt-0.11.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0827efe7854718f04aaddf6496e96960a956e676fe1d0f04eb41511fd8ad06d5"}, + {file = "librt-0.11.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7753e57d6e12d019c0d8786f1c09c709f4c3fcc57c3887b24e36e6c06ec938b7"}, + {file = "librt-0.11.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:11bd19822431cc21af9f27374e7ae2e58103c7d98bda823536a6c47f6bb2bb3d"}, + {file = "librt-0.11.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:22bdf239b219d3993761a148ffa134b19e52e9989c84f845d5d7b71d70a17412"}, + {file = "librt-0.11.0-cp311-cp311-win32.whl", hash = "sha256:46c60b61e308eb535fbd6fa622b1ee1bb2815691c1ad9c98bf7b84952ec3bc8d"}, + {file = "librt-0.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:902e546ff044f579ff1c953ff5fce97b636fe9e3943996b2177710c6ef076f73"}, + {file = "librt-0.11.0-cp311-cp311-win_arm64.whl", hash = "sha256:65ac3bc20f78aa0ee5ae84baa68917f89fef4af63e941084dd019a0d0e749f0c"}, + {file = "librt-0.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b87504f1690a23b9a2cca841191a04f83895d4fc2dd04df91d82b1a04ca2ad46"}, + {file = "librt-0.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40071fc5fe0ce8daa6de616702314a01e1250711682b0523d6ab8d4525910cb3"}, + {file = "librt-0.11.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:137e79445c896a0ea7b265f52d23954e05b64222ee1af69e2cb34219067cbb67"}, + {file = "librt-0.11.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:cca6644054e78746d8d4ef238681f9c34ff8b584fe6b988ecebb8db3b15e622a"}, + {file = "librt-0.11.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5b0eea49f5562861ee8d757a32ef7d559c1d35be2aaaa1ec28941d74c9ffc8a"}, + {file = "librt-0.11.0-cp312-cp312-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0d1029d7e1ae1a7e647ed6fb5df8c4ce2dffefb7a9f5fd1376a4554d96dac09f"}, + {file = "librt-0.11.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bc3ce6b33c5828d9e80592011a5c584cb2ce86edbc4088405f70da47dc1d1b3b"}, + {file = "librt-0.11.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:936c5995f3514a42111f20099397d8177c79b4d7e70961e396c6f5a0a3566766"}, + {file = "librt-0.11.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:9bc0ca6ad9381cbe8e4aa6e5726e4c80c78115a6e9723c599ed1d73e092bc49d"}, + {file = "librt-0.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:070aa8c26c0a74774317a72df8851facc7f0f012a5b406557ac56992d92e1ec8"}, + {file = "librt-0.11.0-cp312-cp312-win32.whl", hash = "sha256:6bf14feb84b05ae945277395451998c89c54d0def4070eb5c08de544930b245a"}, + {file = "librt-0.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:75672f0bc524ede266287d532d7923dbce94c7514ad07627bac3d0c6d92cc4d9"}, + {file = "librt-0.11.0-cp312-cp312-win_arm64.whl", hash = "sha256:2f10cf143e4a9bb0f4f5af568a00df94a2d69ef41c2579584454bb0fe5cc642c"}, + {file = "librt-0.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:78dc31f7fdfe9c9d0eb0e8f42d139db230e826415bbcabd9f0e9faaaee909894"}, + {file = "librt-0.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fa475675db22290c3158e1d42326d0f5a65f04f44a0e68c3630a25b53560fb9c"}, + {file = "librt-0.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:621db29691044bdeda22e789e482e1b0f3a985d90e3426c9c6d17606416205ea"}, + {file = "librt-0.11.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:a9010e2ed5b3a9e158c5fd966b3ab7e834bb3d3aacc8f66c91dd4b57a3799230"}, + {file = "librt-0.11.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c39513d8b7477a2e1ed8c43fc21c524e8d5a0f8d4e8b7b074dbdbe7820a08e2"}, + {file = "librt-0.11.0-cp313-cp313-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7aef3cf1d5af86e770ab04bfd993dfc4ae8b8c17f66fb77dd4a7d50de7bbb1a3"}, + {file = "librt-0.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:557183ddc36babe46b27dd60facbd5adb4492181a5be887587d57cda6e092f21"}, + {file = "librt-0.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:83d3e1f72bd42f6c5c0b7daec530c3f829bd02db42c70b8ddf0c2d90a2459930"}, + {file = "librt-0.11.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:4ce1f21fbe589bc1afd7872dece84fb0e1144f794a288e58a10d2c54a55c43be"}, + {file = "librt-0.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b09f7044ea2b64c9da42fd3d335666518cfd1c6e8a182c95da73d0214b41e"}, + {file = "librt-0.11.0-cp313-cp313-win32.whl", hash = "sha256:78fddc31cd4d3caa897ad5d31f856b1faadc9474021ad6cb182b9018793e254e"}, + {file = "librt-0.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ca8aa88751a775870b764e93bad5135385f563cb8dcee399abf034ea4d3cb47"}, + {file = "librt-0.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:96f044bb325fd9cf1a723015638c219e9143f0dfbc0ca54c565df2b7fc748b44"}, + {file = "librt-0.11.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:4a017a95e5837dc15a8c5661d60e05daa96b90908b1aa6b7acdf443cd25c8ebd"}, + {file = "librt-0.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b1ecbd9819deccc39b7542bf4d2a740d8a620694d39989e58661d3763458f8d4"}, + {file = "librt-0.11.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da327dacd7be8f8ec36547373550744a3cc0e536d54665cd83f8bcd961200e8"}, + {file = "librt-0.11.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:0dc56b1f8d06e60db362cc3fdae206681817f86ce4725d34511473487f12a34b"}, + {file = "librt-0.11.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:05fb8fb2ab90e21c8d12ea240d744ad514da9baf381ebfa70d91d20d21713175"}, + {file = "librt-0.11.0-cp314-cp314-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cae74872be221df4374d10fec61f93ed1513b9546ea84f2c0bf73ab3e9bd0b03"}, + {file = "librt-0.11.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:32bcc918c0148eb7e3d57385125bac7e5f9e4359d05f07448b09f6f778c2f31c"}, + {file = "librt-0.11.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f9743fc99135d5f78d2454435615f6dec0473ca507c26ce9d92b10b562a280d3"}, + {file = "librt-0.11.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5ba067f4aadae8fda802d91d2124c90c42195ff32d9161d3549e6d05cfe26f96"}, + {file = "librt-0.11.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:de3bf945454d032f9e390b85c4072e0a0570bf825421c8be0e71209fa65e1abe"}, + {file = "librt-0.11.0-cp314-cp314-win32.whl", hash = "sha256:d2277a05f6dcb9fd13db9566aac4fabd68c3ea1ea46ee5567d4eef8efa495a2f"}, + {file = "librt-0.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:ab73e8db5e3f564d812c1f5c3a175930a5f9bc96ccb5e3b22a34d7858b401cf7"}, + {file = "librt-0.11.0-cp314-cp314-win_arm64.whl", hash = "sha256:aea3caa317752e3a466fa8af45d91ee0ea8c7fdd96e42b0a8dd9b76a7931eba1"}, + {file = "librt-0.11.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:d1b36540d7aaf9b9101b3a6f376c8d8e9f7a9aec93ed05918f2c69d493ffef72"}, + {file = "librt-0.11.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:efbb343ab2ce3540f4ecbe6315d677ed70f37cd9a72b1e58066c918ca83acbaa"}, + {file = "librt-0.11.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa0dd688aab3f7914d3e6e5e3554978e0383312fb8e771d84be008a35b9ee548"}, + {file = "librt-0.11.0-cp314-cp314t-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:f5fb36b8c6c63fdcbb1d526d94c0d1331610d43f4118cc1beb4efef4f3faacb2"}, + {file = "librt-0.11.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4a9a237d13addb93715b6fee74023d5ee3469b53fce527626c0e088aa585805f"}, + {file = "librt-0.11.0-cp314-cp314t-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5ddd17bd87b2c56ddd60e546a7984a2e64c4e8eab92fb4cf3830a48ad5469d51"}, + {file = "librt-0.11.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bd43992b4473d42f12ff9e68326079f0696d9d4e6000e8f39a0238d482ba6ee2"}, + {file = "librt-0.11.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:f8e3e8056dd674e279741485e2e512d6e9a751c7455809d0114e6ebf8d781085"}, + {file = "librt-0.11.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:c1f708d8ae9c56cf38a903c44297243d2ec83fd82b396b977e0144a3e76217e3"}, + {file = "librt-0.11.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0add982e0e7b9fc14cf4b33789d5f13f66581889b88c2f58099f6ce8f92617bd"}, + {file = "librt-0.11.0-cp314-cp314t-win32.whl", hash = "sha256:2b481d846ac894c4e8403c5fd0e87c5d11d6499e404b474602508a224ff531c8"}, + {file = "librt-0.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:28edb433edde181112a908c78907af28f964eabc15f4dd16c9d66c834302677c"}, + {file = "librt-0.11.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dee008f20b542e3cd162ba338a7f9ec0f6d23d395f66fe8aeeec3c9d067ea253"}, + {file = "librt-0.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6bd72d903911d995ab666dbd1871f8b1e80925a699af8063fbf50053329fb05f"}, + {file = "librt-0.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ef69ac715f3cd8e5cd252cb2aebfa72c015492aacc339d5d7bf8fef3c62c677"}, + {file = "librt-0.11.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:624a40c4a4ad7773315c287276cd024509b2c66ff5904f504bfc08d2c70293ab"}, + {file = "librt-0.11.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:41dc19fe150b69716c8ece4f76773a9e8813fe3e35e032a58b4d46423fb8d7c0"}, + {file = "librt-0.11.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4e8bd98ea9c47ae90b319a087ab28dac493f1ffbc1ecd1f28fcdbf3b7e1108d1"}, + {file = "librt-0.11.0-cp39-cp39-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84308fc49423ce6475d1c5d1985cd69a8ca9f0325fc7d5f81bb690a3f3625d4e"}, + {file = "librt-0.11.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ff0fbaf5f44a21beeb0110f2ab64f45135a9536a834b79c0d1ef018f2786bbfa"}, + {file = "librt-0.11.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:9c028a9442a18e266955d364ce42259136e79a7ba14d773e0d778d5f70cd56f1"}, + {file = "librt-0.11.0-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:9f1692105a02bcf853f355032a5fdc5494358ef83d8fd22d16de375c85cec3f5"}, + {file = "librt-0.11.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7a80a71e1fda83cc752a9141e87aae7fef279538597564d670e9ce513f286192"}, + {file = "librt-0.11.0-cp39-cp39-win32.whl", hash = "sha256:140695816ddf3c86eb972981a26f35efd871c44b0c3aed44c8cd01749386617f"}, + {file = "librt-0.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:92f7ff819c197fc30473190a12c2856f325ac90aabfccbeb2072d28cc2e234e3"}, + {file = "librt-0.11.0.tar.gz", hash = "sha256:075dc3ef4458a278e0195cbf6ac9d38808d9b906c5a6c7f7f79c3888276a3fb1"}, +] + +[[package]] +name = "mypy" +version = "1.20.2" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "mypy-1.20.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cf5a4db6dca263010e2c7bff081c89383c72d187ba2cf4c44759aac970e2f0c4"}, + {file = "mypy-1.20.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7b0e817b518bff7facd7f85ea05b643ad8bdcce684cf29784987b0a7c8e1f997"}, + {file = "mypy-1.20.2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97d7b9a485b40f8ca425460e89bf1da2814625b2da627c0dcc6aa46c92631d14"}, + {file = "mypy-1.20.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1e1c12f6d2db3d78b909b5f77513c11eb7f2dd2782b96a3ab6dffc7d44575c99"}, + {file = "mypy-1.20.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:89dce27e142d25ffbc154c1819383b69f2e9234dc4ed4766f42e0e8cb264ab5c"}, + {file = "mypy-1.20.2-cp310-cp310-win_amd64.whl", hash = "sha256:f376e37f9bf2a946872fc5fd1199c99310748e3c26c7a26683f13f8bdb756cbd"}, + {file = "mypy-1.20.2-cp310-cp310-win_arm64.whl", hash = "sha256:6e2b469efd811707bc530fd1effef0f5d6eebcb7fe376affae69025da4b979a2"}, + {file = "mypy-1.20.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4077797a273e56e8843d001e9dfe4ba10e33323d6ade647ff260e5cd97d9758c"}, + {file = "mypy-1.20.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cdecf62abcc4292500d7858aeae87a1f8f1150f4c4dd08fb0b336ee79b2a6df3"}, + {file = "mypy-1.20.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c566c3a88b6ece59b3d70f65bedef17304f48eb52ff040a6a18214e1917b3254"}, + {file = "mypy-1.20.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0deb80d062b2479f2c87ae568f89845afc71d11bc41b04179e58165fd9f31e98"}, + {file = "mypy-1.20.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bba9ad231e92a3e424b3e56b65aa17704993425bba97e302c832f9466bb85bac"}, + {file = "mypy-1.20.2-cp311-cp311-win_amd64.whl", hash = "sha256:baf593f2765fa3a6b1ef95807dbaa3d25b594f6a52adcc506a6b9cb115e1be67"}, + {file = "mypy-1.20.2-cp311-cp311-win_arm64.whl", hash = "sha256:20175a1c0f49863946ec20b7f63255768058ac4f07d2b9ded6a6b46cfb5a9100"}, + {file = "mypy-1.20.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4dbfcf869f6b0517f70cf0030ba6ea1d6645e132337a7d5204a18d8d5636c02b"}, + {file = "mypy-1.20.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b6481b228d072315b053210b01ac320e1be243dc17f9e5887ef167f23f5fae4"}, + {file = "mypy-1.20.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:34397cdced6b90b836e38182076049fdb41424322e0b0728c946b0939ebdf9f6"}, + {file = "mypy-1.20.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a5da6976f20cae27059ea8d0c86e7cef3de720e04c4bb9ee18e3690fdb792066"}, + {file = "mypy-1.20.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:56908d7e08318d39f85b1f0c6cfd47b0cac1a130da677630dac0de3e0623e102"}, + {file = "mypy-1.20.2-cp312-cp312-win_amd64.whl", hash = "sha256:d52ad8d78522da1d308789df651ee5379088e77c76cb1994858d40a426b343b9"}, + {file = "mypy-1.20.2-cp312-cp312-win_arm64.whl", hash = "sha256:785b08db19c9f214dc37d65f7c165d19a30fcecb48abfa30f31b01b5acaabb58"}, + {file = "mypy-1.20.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:edfbfca868cdd6bd8d974a60f8a3682f5565d3f5c99b327640cedd24c4264026"}, + {file = "mypy-1.20.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e2877a02380adfcdbc69071a0f74d6e9dbbf593c0dc9d174e1f223ffd5281943"}, + {file = "mypy-1.20.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7488448de6007cd5177c6cea0517ac33b4c0f5ee9b5e9f2be51ce75511a85517"}, + {file = "mypy-1.20.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bb9c2fa06887e21d6a3a868762acb82aec34e2c6fd0174064f27c93ede68ad15"}, + {file = "mypy-1.20.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9d56a78b646f2e3daa865bc70cd5ec5a46c50045801ca8ff17a0c43abc97e3ee"}, + {file = "mypy-1.20.2-cp313-cp313-win_amd64.whl", hash = "sha256:2a4102b03bb7481d9a91a6da8d174740c9c8c4401024684b9ca3b7cc5e49852f"}, + {file = "mypy-1.20.2-cp313-cp313-win_arm64.whl", hash = "sha256:a95a9248b0c6fd933a442c03c3b113c3b61320086b88e2c444676d3fd1ca3330"}, + {file = "mypy-1.20.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:419413398fe250aae057fd2fe50166b61077083c9b82754c341cf4fd73038f30"}, + {file = "mypy-1.20.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e73c07f23009962885c197ccb9b41356a30cc0e5a1d0c2ea8fd8fb1362d7f924"}, + {file = "mypy-1.20.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c64e5973df366b747646fc98da921f9d6eba9716d57d1db94a83c026a08e0fb"}, + {file = "mypy-1.20.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a65aa591af023864fd08a97da9974e919452cfe19cb146c8a5dc692626445dc"}, + {file = "mypy-1.20.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4fef51b01e638974a6e69885687e9bd40c8d1e09a6cd291cca0619625cf1f558"}, + {file = "mypy-1.20.2-cp314-cp314-win_amd64.whl", hash = "sha256:913485a03f1bcf5d279409a9d2b9ed565c151f61c09f29991e5faa14033da4c8"}, + {file = "mypy-1.20.2-cp314-cp314-win_arm64.whl", hash = "sha256:c3bae4f855d965b5453784300c12ffc63a548304ac7f99e55d4dc7c898673aa3"}, + {file = "mypy-1.20.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:2de3dcea53babc1c3237a19002bc3d228ce1833278f093b8d619e06e7cc79609"}, + {file = "mypy-1.20.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:52b176444e2e5054dfcbcb8c75b0b719865c96247b37407184bbfca5c353f2c2"}, + {file = "mypy-1.20.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:688c3312e5dadb573a2c69c82af3a298d43ecf9e6d264e0f95df960b5f6ac19c"}, + {file = "mypy-1.20.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29752dbbf8cc53f89f6ac096d363314333045c257c9c75cbd189ca2de0455744"}, + {file = "mypy-1.20.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:803203d2b6ea644982c644895c2f78b28d0e208bba7b27d9b921e0ec5eb207c6"}, + {file = "mypy-1.20.2-cp314-cp314t-win_amd64.whl", hash = "sha256:9bcb8aa397ff0093c824182fd76a935a9ba7ad097fcbef80ae89bf6c1731d8ec"}, + {file = "mypy-1.20.2-cp314-cp314t-win_arm64.whl", hash = "sha256:e061b58443f1736f8a37c48978d7ab581636d6ab03e3d4f99e3fa90463bb9382"}, + {file = "mypy-1.20.2-py3-none-any.whl", hash = "sha256:a94c5a76ab46c5e6257c7972b6c8cff0574201ca7dc05647e33e795d78680563"}, + {file = "mypy-1.20.2.tar.gz", hash = "sha256:e8222c26daaafd9e8626dec58ae36029f82585890589576f769a650dd20fd665"}, +] + +[package.dependencies] +librt = {version = ">=0.8.0", markers = "platform_python_implementation != \"PyPy\""} +mypy_extensions = ">=1.0.0" +pathspec = ">=1.0.0" +typing_extensions = [ + {version = ">=4.6.0", markers = "python_version < \"3.15\""}, + {version = ">=4.14.0", markers = "python_version >= \"3.15\""}, +] + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +native-parser = ["ast-serialize (>=0.1.1,<1.0.0)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.1.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"}, + {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, +] + +[[package]] +name = "nodeenv" +version = "1.10.0" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +files = [ + {file = "nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827"}, + {file = "nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb"}, +] + +[[package]] +name = "openai" +version = "2.41.0" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "openai-2.41.0-py3-none-any.whl", hash = "sha256:20cc7952e8501c7e5773dd2ef7be437bae9cb549044902e1041a83a54516e375"}, + {file = "openai-2.41.0.tar.gz", hash = "sha256:db5c362acd6604b84f076abbefa66826ea4b46ecba2954ed866e6a149a1352c0"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.10.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.14,<5" + +[package.extras] +aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.9)"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +realtime = ["websockets (>=13,<16)"] +voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"] + +[[package]] +name = "packaging" +version = "26.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "packaging-26.2-py3-none-any.whl", hash = "sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e"}, + {file = "packaging-26.2.tar.gz", hash = "sha256:ff452ff5a3e828ce110190feff1178bb1f2ea2281fa2075aadb987c2fb221661"}, +] + +[[package]] +name = "pathspec" +version = "1.1.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pathspec-1.1.1-py3-none-any.whl", hash = "sha256:a00ce642f577bf7f473932318056212bc4f8bfdf53128c78bbd5af0b9b20b189"}, + {file = "pathspec-1.1.1.tar.gz", hash = "sha256:17db5ecd524104a120e173814c90367a96a98d07c45b2e10c2f3919fff91bf5a"}, +] + +[package.extras] +hyperscan = ["hyperscan (>=0.7)"] +optional = ["typing-extensions (>=4)"] +re2 = ["google-re2 (>=1.1)"] + +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + +[[package]] +name = "pydantic" +version = "2.13.4" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pydantic-2.13.4-py3-none-any.whl", hash = "sha256:45a282cde31d808236fd7ea9d919b128653c8b38b393d1c4ab335c62924d9aba"}, + {file = "pydantic-2.13.4.tar.gz", hash = "sha256:c40756b57adaa8b1efeeced5c196f3f3b7c435f90e84ea7f443901bec8099ef6"}, +] + +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.46.4" +typing-extensions = ">=4.14.1" +typing-inspection = ">=0.4.2" + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] + +[[package]] +name = "pydantic-core" +version = "2.46.4" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pydantic_core-2.46.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a396dcc17e5a0b164dbe026896245a4fa9ff402edca1dff0be3d53a517f74de4"}, + {file = "pydantic_core-2.46.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da4b951fe36dc7c3a1ccb4e3cd1747c3542b8c9ceede8fc86cae054e764485f5"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb63e0198ca18aad131c089b9204c23079c3afa95487e561f4c522d519e55aba"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f47286a97f0bc9b8859519809077b91b2cefe4ae47fcbf5e466a009c1c5d742b"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a0ed8ea6f2d61c1738835f99b699348d7857379083e5fc497fa0c967a407c"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea793e075b70290d89d8142074262885d3f7da19634845135751bd6344f73b50"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395aebd9183f9d112f569aeb5b2214d1a10a33bec8456447f7fbdfa51d38d4cd"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:b078afbc25f3a1436c7a1d2cd3e322497ee99615ba97c563566fdf46aff1ee01"}, + {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f747929cf940cddb5b3668a390056ddd5ba2e5010615ea2dcf4f9c4f3ab8791d"}, + {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:daa27d92c36f24388fe3ad306b174781c747627f134452e4f128ea00ce1fe8c4"}, + {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:19e51f073cd3df251856a8a4189fbdf1de4012c3ebacfb1884f94f1eb406079f"}, + {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1747f85cee84c26985853c6f3d9bd3e75da5212912443fa111c113b9c246f39"}, + {file = "pydantic_core-2.46.4-cp310-cp310-win32.whl", hash = "sha256:2f84c03c8607173d16b5a854ec68a2f9079ae03237a54fb506d13af47e1d018d"}, + {file = "pydantic_core-2.46.4-cp310-cp310-win_amd64.whl", hash = "sha256:8358a950c8909158e3df31538a7e4edc2d7265a7c54b47f0864d9e5bae9dcebf"}, + {file = "pydantic_core-2.46.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:0e96592440881c74a213e5ad528e2b24d3d4f940de2766bed9010ab1d9e51594"}, + {file = "pydantic_core-2.46.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0d65b8c354be7fb5f720c3caa8bc940bc2d20ce749c8e06135f07f8ed95dd7c"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bfb192b3f4b9e8a89b6277b6ce787564f62cfd272055f6e685726b111dc7826"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9037063db01f09b09e237c282b6792bd4da634b5402c4e7f0c61effed7701a04"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc010ab034c8c7452522748bf937df58020d256ccae0874463d1f4d01758af8e"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c5dac79fa1614d1e06ca695109c6105923bd9c7d1d6c918d4e637b7e6b32fd3"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fa868638bf362d3d138ea55829cefb3d5f4b0d7f142234382a15e2485dbec4"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:17299feefe090f2caa5b8e37222bb5f663e4935a8bfa6931d4102e5df1a9f398"}, + {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4c63ebc82684aa89d9a3bcbd13d515b3be44250dc68dd3bd81526c1cb31286c3"}, + {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:aaa2a54443eff1950ba5ddc6b6ccda0d9c84a364276a62f969bdf2a390650848"}, + {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:18e5ceec2ab67e6d5f1a9085e5a24c9c4e2ac4545730bfe668680bca05e555f3"}, + {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a0f62d0a58f4e7da165457e995725421e0064f2255d8eccebc49f41bbc23b109"}, + {file = "pydantic_core-2.46.4-cp311-cp311-win32.whl", hash = "sha256:041bde0a48fd37cf71cab1c9d56d3e8625a3793fef1f7dd232b3ff37e978ecda"}, + {file = "pydantic_core-2.46.4-cp311-cp311-win_amd64.whl", hash = "sha256:6f2eeda33a839975441c86a4119e1383c50b47faf0cbb5176985565c6bb02c33"}, + {file = "pydantic_core-2.46.4-cp311-cp311-win_arm64.whl", hash = "sha256:14f4c5d6db102bd796a627bbb3a17b4cf4574b9ae861d8b7c9a9661c6dd3362d"}, + {file = "pydantic_core-2.46.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3245406455a5d98187ec35530fd772b1d799b26667980872c8d4614991e2c4a2"}, + {file = "pydantic_core-2.46.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:962ccbab7b642487b1d8b7df90ef677e03134cf1fd8880bf698649b22a69371f"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8233f2947cf85404441fd7e0085f53b10c93e0ee78611099b5c7237e36aacbf7"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a233125ac121aa3ffba9a2b59edfc4a985a76092dc8279586ab4b71390875e7"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b712b53160b79a5850310b912a5ef8e57e56947c8ad690c227f5c9d7e561712"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9401557acd873c3a7f3eb9383edef8ac4968f9510e340f4808d427e75667e7b4"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:926c9541b14b12b1681dca8a0b75feb510b06c6341b70a8e500c2fdcff837cce"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:56cb4851bcaf3d117eddcef4fe66afd750a50274b0da8e22be256d10e5611987"}, + {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c68fcd102d71ea85c5b2dfac3f4f8476eff42a9e078fd5faefff6d145063536b"}, + {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b2f69dec1725e79a012d920df1707de5caf7ed5e08f3be4435e25803efc47458"}, + {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:8d0820e8192167f80d88d64038e609c31452eeca865b4e1d9950a27a4609b00b"}, + {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fbdb89b3e1c94a30cc5edfce477c6e6a5dc4d8f84665b455c27582f211a1c72c"}, + {file = "pydantic_core-2.46.4-cp312-cp312-win32.whl", hash = "sha256:9aa768456404a8bf48a4406685ac2bec8e72b62c69313734fa3b73cf33b3a894"}, + {file = "pydantic_core-2.46.4-cp312-cp312-win_amd64.whl", hash = "sha256:e9c26f834c65f5752f3f06cb08cb86a913ceb7274d0db6e267808a708b46bc89"}, + {file = "pydantic_core-2.46.4-cp312-cp312-win_arm64.whl", hash = "sha256:4fc73cb559bdb54b1134a706a2802a4cddd27a0633f5abb7e53056268751ac6a"}, + {file = "pydantic_core-2.46.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:5d5902252db0d3cedf8d4a1bc68f70eeb430f7e4c7104c8c476753519b423008"}, + {file = "pydantic_core-2.46.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94f0688e7b8d0a67abf40e57a7eaaecd17cc9586706a31b76c031f63df052b4"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f027324c56cd5406ca49c124b0db10e56c69064fec039acc571c29020cc87c76"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e739fee756ba1010f8bcccb534252e85a35fe45ae92c295a06059ce58b74ccd3"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d56801be94b86a9da183e5f3766e6310752b99ff647e38b09a9500d88e46e76"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2412e734dcb48da14d4e4006b82b46b74f2518b8a26ee7e58c6844a6cd6d03c4"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9551187363ffc0de2a00b2e47c25aeaeb1020b69b668762966df15fc5659dd5a"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0186750b482eefa11d7f435892b09c5c606193ef3375bcf94aa00ae6bfb66262"}, + {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5855698a4856556d86e8e6cd8434bc3ac0314ee8e12089ae0e143f64c6256e4e"}, + {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:cbaf13819775b7f769bf4a1f066cb6df7a28d4480081a589828ef190226881cd"}, + {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:633147d34cf4550417f12e2b1a0383973bdf5cdfde212cb09e9a581cf10820be"}, + {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:82cf5301172168103724d49a1444d3378cb20cdee30b116a1bd6031236298a5d"}, + {file = "pydantic_core-2.46.4-cp313-cp313-win32.whl", hash = "sha256:9fa8ae11da9e2b3126c6426f147e0fba88d96d65921799bb30c6abd1cb2c97fb"}, + {file = "pydantic_core-2.46.4-cp313-cp313-win_amd64.whl", hash = "sha256:6b3ace8194b0e5204818c92802dcdca7fc6d88aabbb799d7c795540d9cd6d292"}, + {file = "pydantic_core-2.46.4-cp313-cp313-win_arm64.whl", hash = "sha256:184c081504d17f1c1066e430e117142b2c77d9448a97f7b65c6ac9fd9aee238d"}, + {file = "pydantic_core-2.46.4-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:428e04521a40150c85216fc8b85e8d39fece235a9cf5e383761238c7fa9b96fb"}, + {file = "pydantic_core-2.46.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23ace664830ee0bfe014a0c7bc248b1f7f25ed7ad103852c317624a1083af462"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce5c1d2a8b27468f433ca974829c44060b8097eedc39933e3c206a90ee49c4a9"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7283d57845ecf5a163403eb0702dfc220cc4fbdd18919cb5ccea4f95ee1cdab4"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8daafc69c93ee8a0204506a3b6b30f586ef54028f52aeeeb5c4cfc5184fd5914"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2213145bcc2ba85884d0ac63d222fece9209678f77b9b4d76f054c561adb28"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a5f930472650a82629163023e630d160863fce524c616f4e5186e5de9d9a49b"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:c1b3f518abeca3aa13c712fd202306e145abf59a18b094a6bafb2d2bbf59192c"}, + {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a7dd0b3ee80d90150e3495a3a13ac34dbcbfd4f012996a6a1d8900e91b5c0fb"}, + {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:3fb702cd90b0446a3a1c5e470bfa0dd23c0233b676a9099ddcc964fa6ca13898"}, + {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:b8458003118a712e66286df6a707db01c52c0f52f7db8e4a38f0da1d3b94fc4e"}, + {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:372429a130e469c9cd698925ce5fc50940b7a1336b0d82038e63d5bbc4edc519"}, + {file = "pydantic_core-2.46.4-cp314-cp314-win32.whl", hash = "sha256:85bb3611ff1802f3ee7fdd7dbff26b56f343fb432d57a4728fdd49b6ef35e2f4"}, + {file = "pydantic_core-2.46.4-cp314-cp314-win_amd64.whl", hash = "sha256:811ff8e9c313ab425368bcbb36e5c4ebd7108c2bbf4e4089cfbb0b01eff63fac"}, + {file = "pydantic_core-2.46.4-cp314-cp314-win_arm64.whl", hash = "sha256:bfec22eab3c8cc2ceec0248aec886624116dc079afa027ecc8ad4a7e62010f8a"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:af8244b2bef6aaad6d92cda81372de7f8c8d36c9f0c3ea36e827c60e7d9467a0"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a4330cdbc57162e4b3aa303f588ba752257694c9c9be3e7ebb11b4aca659b5d"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c61fc04a3d840155ff08e475a04809278972fe6aef51e2720554e96367e34b"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c50f2528cf200c5eed56faf3f4e22fcd5f38c157a8b78576e6ba3168ec35f000"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0cbe8b01f948de4286c74cdd6c667aceb38f5c1e26f0693b3983d9d74887c65e"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:617d7e2ca7dcb8c5cf6bcb8c59b8832c94b36196bbf1cbd1bfb56ed341905edd"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7027560ee92211647d0d34e3f7cd6f50da56399d26a9c8ad0da286d3869a53f3"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:f99626688942fb746e545232e7726926f3be91b5975f8b55327665fafda991c7"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fc3e9034a63de20e15e8ade85358bc6efc614008cab72898b4b4952bea0509ff"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:97e7cf2be5c77b7d1a9713a05605d49460d02c6078d38d8bef3cbe323c548424"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:3bf92c5d0e00fefaab325a4d27828fe6b6e2a21848686b5b60d2d9eeb09d76c6"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:3ecbc122d18468d06ca279dc26a8c2e2d5acb10943bb35e36ae92096dc3b5565"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-win32.whl", hash = "sha256:e846ae7835bf0703ae43f534ab79a867146dadd59dc9ca5c8b53d5c8f7c9ef02"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-win_amd64.whl", hash = "sha256:2108ba5c1c1eca18030634489dc544844144ee36357f2f9f780b93e7ddbb44b5"}, + {file = "pydantic_core-2.46.4-cp314-cp314t-win_arm64.whl", hash = "sha256:4fcbe087dbc2068af7eda3aa87634eba216dbda64d1ae73c8684b621d33f6596"}, + {file = "pydantic_core-2.46.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:fd8b3d9fd264be37976686c7f65cd52a83f5e84f4bfd2adf9c1d469676bbb6ae"}, + {file = "pydantic_core-2.46.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9f444c499b3eefd3a92e348059471ea0c3a6e303d9c1cec09fa748fd9f895201"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3447661d99f75a3683a4cf5c87da72f2161964611864dbbeac7fbb118bb4bfc0"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b9bab013d1c7a79d3501ff86d0bc9c31bf587db4551677b96bec07df78c6b15"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d995260fdf4e1db774581b4900e0f832abe3c7c84996726bbc161b19c8f29e76"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f13a646d65d09fbf1bc6b3a9635d30095c8e7e5cc419ff35ecc563c5fd04cd49"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432c179df7874eeb73307aad2df0755e1ae0efa61ff0ea89b93e194411ae3928"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_31_riscv64.whl", hash = "sha256:e68b7a074f65a2fd746c52a7ce6142ab7006074ac269ace0c25cd8ba171f8066"}, + {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4a05d69cba51d852c5c3e92758653245a50c0b646ced0cf05bd793ed592839d6"}, + {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:228ee9bae8bef5b1e97ec58302f80357c37199e0d0a99174e138d28e6957b9d9"}, + {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:10e17cbb10a330363733efc4d7c4d0dd827ac0909b8f6a6542298fed1ea62f29"}, + {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:91a06d2e259ecfbd8c901d70c3c507900458498142b3026a296b7de4d1322cc9"}, + {file = "pydantic_core-2.46.4-cp39-cp39-win32.whl", hash = "sha256:d80ee3d731373b24cebbc10d689ca4ee1875caf0d5703a245db18efd4dd37fc1"}, + {file = "pydantic_core-2.46.4-cp39-cp39-win_amd64.whl", hash = "sha256:3be77f45df024d789a672ae34f8b06fb346c4f9f46ea714956660ea4862e89ac"}, + {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:14d4edf427bdcf950a8a02d7cb44a08614388dd6e1bdcbf4f67504fa7887da9c"}, + {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:0ce40cd7b21210e99342afafbd4d0f76d784eb5b1d60f3bdc566be4983c6c73b"}, + {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90884113d8b48f760e9587002789ddd741e76ab9f89518cd1e43b1f1a52ec44b"}, + {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66ce7632c22d837c95301830e111ad0128a32b8207533b60896a96c4915192ea"}, + {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:1d8ba486450b14f3b1d63bc521d410ec7565e52f887b9fb671791886436a42f7"}, + {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3009f12e4e90b7f88b4f9adb1b0c4a3d58fe7820f3238c190047209d148026df"}, + {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad785e92e6dc634c21555edc8bd6b64957ab844541bcb96a1366c202951ae526"}, + {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00c603d540afdd6b80eb39f078f33ebd46211f02f33e34a32d9f053bba711de0"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:0c563b08bca408dc7f65f700633d8442fffb2421fc47b8101377e9fd65051ff0"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:db06ffe51636ffe9ca531fe9023dd64bdd794be8754cb5df57c5498ae5b518a7"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:133878133d271ade3d41d1bfb2a45ec38dbdbda40bc065921c6b04e4630127e2"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9bc519fbf2b7578398853d815009ae5e4d4603d12f4e3f91da8c06852d3da3e9"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c7a7bd4e39e8e4c12c39cd480356842b6a8a06e41b23a55a5e3e191718838ddf"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:d396ec2b979760aaf3218e76c24e65bd0aca24983298653b3a9d7a45f9e47b30"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:86e1a4418c6cd97d60c95c71164158eaf7324fae7b0923264016baa993eba6fc"}, + {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:d51026d73fcfd93610abc7b27789c26b313920fcfb20e27462d74a7f8b06e983"}, + {file = "pydantic_core-2.46.4.tar.gz", hash = "sha256:62f875393d7f270851f20523dd2e29f082bcc82292d66db2b64ea71f64b6e1c1"}, +] + +[package.dependencies] +typing-extensions = ">=4.14.1" + +[[package]] +name = "pygments" +version = "2.20.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176"}, + {file = "pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + +[[package]] +name = "pyright" +version = "1.1.410" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "pyright-1.1.410-py3-none-any.whl", hash = "sha256:5e961bed37cacf96b3f7cd7b1da39b350a9239aa2e69138d0e88f728cfaf296c"}, + {file = "pyright-1.1.410.tar.gz", hash = "sha256:07a073b8ba6749826773c1269773efa11b93440d9a6aa60419d9a3172d6dc488"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + +[[package]] +name = "pytest" +version = "9.0.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9"}, + {file = "pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "ruff" +version = "0.15.12" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c"}, + {file = "ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c"}, + {file = "ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0"}, + {file = "ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b"}, + {file = "ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e"}, + {file = "ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20"}, + {file = "ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d"}, + {file = "ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f"}, + {file = "ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6"}, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + +[[package]] +name = "tqdm" +version = "4.67.3" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf"}, + {file = "tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"] +discord = ["requests"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +description = "Backported and Experimental Type Hints for Python 3.9+" +optional = false +python-versions = ">=3.9" +groups = ["main", "dev"] +files = [ + {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, + {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +description = "Runtime typing introspection tools" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, + {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, +] + +[package.dependencies] +typing-extensions = ">=4.12.0" + +[metadata] +lock-version = "2.1" +python-versions = "^3.11" +content-hash = "b9117148cfa5b3817589147e41c62c69a1f903e85bf326cfb6a167d0f3a20be3" diff --git a/paddler_openai_client_python_test/pyproject.toml b/paddler_openai_client_python_test/pyproject.toml new file mode 100644 index 00000000..e0b8d38e --- /dev/null +++ b/paddler_openai_client_python_test/pyproject.toml @@ -0,0 +1,52 @@ +[tool.poetry] +name = "paddler-openai-client-python-test" +version = "4.0.0" +description = "Drives Paddler's OpenAI-compatible endpoints with the official OpenAI client to verify real-client compatibility" +authors = ["Intentee"] +license = "Apache-2.0" +package-mode = false + +[tool.poetry.dependencies] +python = "^3.11" +openai = "==2.41.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^9" +pyright = "^1" +ruff = "0.15.12" +mypy = "^1" + +[tool.mypy] +python_version = "3.11" +strict = true +extra_checks = true + +[tool.pyright] +pythonVersion = "3.11" +pythonPlatform = "All" +typeCheckingMode = "strict" + +[tool.ruff] +line-length = 88 +target-version = "py311" + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "ANN401", # Any is acceptable for loosely-typed JSON probing + "COM812", # trailing comma - conflicts with formatter + "D", # docstrings - project prefers self-documenting code + "ISC001", # implicit string concatenation - conflicts with formatter + "TRY003", # long messages on the env-missing guard are intentional +] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "INP001", # the tests directory is a pytest suite, not an importable package + "PLR2004", # magic values are fine in test assertions + "S101", # assert statements are the point of a test suite +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-v" diff --git a/paddler_openai_client_python_test/tests/conftest.py b/paddler_openai_client_python_test/tests/conftest.py new file mode 100644 index 00000000..47b6da9f --- /dev/null +++ b/paddler_openai_client_python_test/tests/conftest.py @@ -0,0 +1,32 @@ +import os + +import pytest +from openai import OpenAI + +BASE_URL_ENV = "PADDLER_OPENAI_BASE_URL" +MODEL_ENV = "PADDLER_OPENAI_MODEL" +DEFAULT_MODEL = "qwen3" + + +@pytest.fixture(scope="session") +def base_url() -> str: + base_url = os.environ.get(BASE_URL_ENV) + + if not base_url: + raise RuntimeError( + f"{BASE_URL_ENV} must point at a running Paddler OpenAI-compatible " + "endpoint, e.g. http://127.0.0.1:8062/v1 — this suite's sole purpose " + "is to drive that endpoint with the official OpenAI client." + ) + + return base_url + + +@pytest.fixture(scope="session") +def model() -> str: + return os.environ.get(MODEL_ENV, DEFAULT_MODEL) + + +@pytest.fixture +def openai_client(base_url: str) -> OpenAI: + return OpenAI(base_url=base_url, api_key="paddler") diff --git a/paddler_openai_client_python_test/tests/test_chat_completions.py b/paddler_openai_client_python_test/tests/test_chat_completions.py new file mode 100644 index 00000000..12cebd96 --- /dev/null +++ b/paddler_openai_client_python_test/tests/test_chat_completions.py @@ -0,0 +1,52 @@ +from openai import OpenAI + + +def test_non_streaming_returns_message_content_and_usage( + openai_client: OpenAI, + model: str, +) -> None: + completion = openai_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Say hi briefly."}], + max_completion_tokens=600, + ) + + assert completion.object == "chat.completion" + assert completion.choices + assert completion.choices[0].message.content + assert completion.usage is not None + assert completion.usage.total_tokens > 0 + + +def test_streaming_accumulates_content_and_reports_usage( + openai_client: OpenAI, + model: str, +) -> None: + stream = openai_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Say hi briefly."}], + max_completion_tokens=600, + stream=True, + stream_options={"include_usage": True}, + ) + + content = "" + finish_reason: str | None = None + total_tokens = 0 + + for chunk in stream: + assert chunk.object == "chat.completion.chunk" + + if chunk.choices: + choice = chunk.choices[0] + if choice.delta.content: + content += choice.delta.content + if choice.finish_reason: + finish_reason = choice.finish_reason + + if chunk.usage is not None: + total_tokens = chunk.usage.total_tokens + + assert content + assert finish_reason == "stop" + assert total_tokens > 0 diff --git a/paddler_openai_client_python_test/tests/test_responses.py b/paddler_openai_client_python_test/tests/test_responses.py new file mode 100644 index 00000000..ba93f9a1 --- /dev/null +++ b/paddler_openai_client_python_test/tests/test_responses.py @@ -0,0 +1,43 @@ +from openai import OpenAI + + +def test_non_streaming_returns_output_text_and_usage( + openai_client: OpenAI, + model: str, +) -> None: + response = openai_client.responses.create( + model=model, + input="Say hi briefly.", + max_output_tokens=600, + ) + + assert response.object == "response" + assert response.status == "completed" + assert response.output_text + assert response.usage is not None + assert response.usage.total_tokens > 0 + + +def test_streaming_reaches_completed_and_accumulates_output_text( + openai_client: OpenAI, + model: str, +) -> None: + stream = openai_client.responses.create( + model=model, + input="Say hi briefly.", + max_output_tokens=600, + stream=True, + ) + + event_types: list[str] = [] + output_text = "" + + for event in stream: + event_types.append(event.type) + + if event.type == "response.output_text.delta": + output_text += event.delta + + assert event_types[0] == "response.created" + assert event_types[-1] == "response.completed" + assert output_text diff --git a/paddler_types/Cargo.toml b/paddler_openai_response_format_validator/Cargo.toml similarity index 78% rename from paddler_types/Cargo.toml rename to paddler_openai_response_format_validator/Cargo.toml index e6432612..c45e5037 100644 --- a/paddler_types/Cargo.toml +++ b/paddler_openai_response_format_validator/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "paddler_types" +name = "paddler_openai_response_format_validator" authors.workspace = true description.workspace = true edition.workspace = true @@ -11,10 +11,9 @@ version.workspace = true [dependencies] anyhow = { workspace = true } jsonschema = { workspace = true } -llama-cpp-bindings-types = { workspace = true } -serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } +yaml-rust2 = { workspace = true } [lints] workspace = true diff --git a/paddler_openai_response_format_validator/src/lib.rs b/paddler_openai_response_format_validator/src/lib.rs new file mode 100644 index 00000000..8dc5ba9e --- /dev/null +++ b/paddler_openai_response_format_validator/src/lib.rs @@ -0,0 +1,5 @@ +pub mod openai_spec; +pub mod openai_validator; +pub mod openai_validator_error; +pub mod strict_chat_completion_schema; +pub mod yaml_to_json_value; diff --git a/paddler_openai_response_format_validator/src/openai_spec.rs b/paddler_openai_response_format_validator/src/openai_spec.rs new file mode 100644 index 00000000..de777539 --- /dev/null +++ b/paddler_openai_response_format_validator/src/openai_spec.rs @@ -0,0 +1,82 @@ +use anyhow::Context as _; +use anyhow::Result; +use anyhow::bail; +use serde_json::Value; +use yaml_rust2::YamlLoader; + +use crate::yaml_to_json_value::yaml_to_json_value; + +pub const OPENAPI_YAML: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../vendor/openai/openai-openapi/openapi.yaml" +)); + +pub fn parse_components(openapi_yaml: &str) -> Result { + let documents = YamlLoader::load_from_str(openapi_yaml) + .context("the OpenAI OpenAPI document is not valid YAML")?; + + let document = documents + .into_iter() + .next() + .context("the OpenAI OpenAPI document is empty")?; + + let specification = yaml_to_json_value(&document)?; + + match specification.pointer("/components/schemas") { + Some(components) => Ok(components.clone()), + None => bail!("the OpenAI OpenAPI document has no components.schemas object"), + } +} + +#[cfg(test)] +mod tests { + use super::OPENAPI_YAML; + use super::parse_components; + + #[test] + fn parses_the_embedded_spec_components() { + let components = parse_components(OPENAPI_YAML).unwrap(); + + assert!(components.get("CreateChatCompletionRequest").is_some()); + assert!(components.get("CreateChatCompletionResponse").is_some()); + assert!( + components + .get("CreateChatCompletionStreamResponse") + .is_some() + ); + } + + #[test] + fn the_embedded_spec_is_the_modern_3_1_spec() { + assert!(OPENAPI_YAML.contains("reasoning_tokens")); + assert!(OPENAPI_YAML.contains("service_tier")); + } + + #[test] + fn rejects_invalid_yaml() { + let error = parse_components("key: \"unterminated").unwrap_err(); + + assert!(error.to_string().contains("not valid YAML")); + } + + #[test] + fn rejects_empty_document() { + let error = parse_components("").unwrap_err(); + + assert!(error.to_string().contains("empty")); + } + + #[test] + fn rejects_document_without_components() { + let error = parse_components("openapi: 3.1.0").unwrap_err(); + + assert!(error.to_string().contains("no components.schemas")); + } + + #[test] + fn propagates_yaml_conversion_failures() { + let error = parse_components("1: value").unwrap_err(); + + assert!(error.to_string().contains("mapping keys must be strings")); + } +} diff --git a/paddler_openai_response_format_validator/src/openai_validator.rs b/paddler_openai_response_format_validator/src/openai_validator.rs new file mode 100644 index 00000000..a7e19656 --- /dev/null +++ b/paddler_openai_response_format_validator/src/openai_validator.rs @@ -0,0 +1,558 @@ +use anyhow::Result; +use anyhow::anyhow; +use jsonschema::Validator; +use serde_json::Value; + +use crate::openai_spec::OPENAPI_YAML; +use crate::openai_spec::parse_components; +use crate::openai_validator_error::OpenAIValidatorError; +use crate::strict_chat_completion_schema::strict_chat_completion_schema; + +const REQUEST_ROOT: &str = "CreateChatCompletionRequest"; +const RESPONSE_ROOT: &str = "CreateChatCompletionResponse"; +const STREAM_ROOT: &str = "CreateChatCompletionStreamResponse"; + +const REQUEST_STRICT_POINTERS: &[&str] = &["/$defs/CreateChatCompletionRequest"]; +const RESPONSE_STRICT_POINTERS: &[&str] = &[ + "/$defs/CreateChatCompletionResponse", + "/$defs/ChatCompletionResponseMessage", + "/$defs/CompletionUsage", + "/$defs/CompletionUsage/properties/prompt_tokens_details", + "/$defs/CompletionUsage/properties/completion_tokens_details", +]; +const STREAM_STRICT_POINTERS: &[&str] = &[ + "/$defs/CreateChatCompletionStreamResponse", + "/$defs/ChatCompletionStreamResponseDelta", + "/$defs/CompletionUsage", + "/$defs/CompletionUsage/properties/prompt_tokens_details", + "/$defs/CompletionUsage/properties/completion_tokens_details", +]; + +const RESPONSES_REQUEST_ROOT: &str = "CreateResponse"; +const RESPONSES_RESPONSE_ROOT: &str = "Response"; +const RESPONSES_STREAM_EVENT_ROOT: &str = "ResponseStreamEvent"; + +const RESPONSES_REQUEST_STRICT_POINTERS: &[&str] = &[]; + +const RESPONSES_SHARED_OUTPUT_STRICT_POINTERS: &[&str] = &[ + "/$defs/Response", + "/$defs/OutputMessage", + "/$defs/OutputTextContent", + "/$defs/ReasoningItem", + "/$defs/FunctionToolCall", + "/$defs/ResponseUsage", + "/$defs/ResponseUsage/properties/input_tokens_details", + "/$defs/ResponseUsage/properties/output_tokens_details", +]; + +const RESPONSES_EMITTED_EVENT_STRICT_POINTERS: &[&str] = &[ + "/$defs/ResponseCreatedEvent", + "/$defs/ResponseInProgressEvent", + "/$defs/ResponseCompletedEvent", + "/$defs/ResponseFailedEvent", + "/$defs/ResponseOutputItemAddedEvent", + "/$defs/ResponseOutputItemDoneEvent", + "/$defs/ResponseContentPartAddedEvent", + "/$defs/ResponseContentPartDoneEvent", + "/$defs/ResponseTextDeltaEvent", + "/$defs/ResponseTextDoneEvent", + "/$defs/ResponseReasoningTextDeltaEvent", + "/$defs/ResponseReasoningTextDoneEvent", + "/$defs/ResponseFunctionCallArgumentsDeltaEvent", + "/$defs/ResponseFunctionCallArgumentsDoneEvent", +]; + +fn responses_stream_event_strict_pointers() -> Vec<&'static str> { + let mut pointers = RESPONSES_EMITTED_EVENT_STRICT_POINTERS.to_vec(); + + pointers.extend_from_slice(RESPONSES_SHARED_OUTPUT_STRICT_POINTERS); + + pointers +} + +fn compile_strict_schema( + components: &Value, + root_name: &str, + strict_pointers: &[&str], +) -> Result { + let schema = strict_chat_completion_schema(components, root_name, strict_pointers)?; + + jsonschema::validator_for(&schema) + .map_err(|error| anyhow!("compiling the strict {root_name:?} schema: {error}")) +} + +fn schema_violations(validator: &Validator, instance: &Value) -> Vec { + validator + .iter_errors(instance) + .map(|error| error.to_string()) + .collect() +} + +pub struct OpenAIValidator { + request: Validator, + response: Validator, + stream_chunk: Validator, + responses_request: Validator, + responses_response: Validator, + responses_stream_event: Validator, +} + +impl OpenAIValidator { + pub fn new() -> Result { + Self::from_openapi_yaml(OPENAPI_YAML) + } + + fn from_openapi_yaml(openapi_yaml: &str) -> Result { + Self::from_components(&parse_components(openapi_yaml)?) + } + + fn from_components(components: &Value) -> Result { + Ok(Self { + request: compile_strict_schema(components, REQUEST_ROOT, REQUEST_STRICT_POINTERS)?, + response: compile_strict_schema(components, RESPONSE_ROOT, RESPONSE_STRICT_POINTERS)?, + stream_chunk: compile_strict_schema(components, STREAM_ROOT, STREAM_STRICT_POINTERS)?, + responses_request: compile_strict_schema( + components, + RESPONSES_REQUEST_ROOT, + RESPONSES_REQUEST_STRICT_POINTERS, + )?, + responses_response: compile_strict_schema( + components, + RESPONSES_RESPONSE_ROOT, + RESPONSES_SHARED_OUTPUT_STRICT_POINTERS, + )?, + responses_stream_event: compile_strict_schema( + components, + RESPONSES_STREAM_EVENT_ROOT, + &responses_stream_event_strict_pointers(), + )?, + }) + } + + pub fn validate_chat_completion_request( + &self, + instance: &Value, + ) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.request, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::RequestDoesNotConform { violations }) + } + } + + pub fn validate_chat_completion_response( + &self, + instance: &Value, + ) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.response, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::ResponseDoesNotConform { violations }) + } + } + + pub fn validate_chat_completion_stream_chunk( + &self, + instance: &Value, + ) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.stream_chunk, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::StreamChunkDoesNotConform { violations }) + } + } + + pub fn validate_responses_request(&self, instance: &Value) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.responses_request, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::ResponsesRequestDoesNotConform { violations }) + } + } + + pub fn validate_responses_response( + &self, + instance: &Value, + ) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.responses_response, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::ResponsesResponseDoesNotConform { violations }) + } + } + + pub fn validate_responses_stream_event( + &self, + instance: &Value, + ) -> Result<(), OpenAIValidatorError> { + let violations = schema_violations(&self.responses_stream_event, instance); + + if violations.is_empty() { + Ok(()) + } else { + Err(OpenAIValidatorError::ResponsesStreamEventDoesNotConform { violations }) + } + } +} + +#[cfg(test)] +mod tests { + use serde_json::Value; + use serde_json::json; + + use super::OpenAIValidator; + use super::compile_strict_schema; + use crate::openai_spec::OPENAPI_YAML; + use crate::openai_spec::parse_components; + + fn validator() -> OpenAIValidator { + OpenAIValidator::new().unwrap() + } + + fn official_request() -> Value { + json!({ + "model": "test", + "messages": [{ "role": "user", "content": "Say hello" }] + }) + } + + fn official_response() -> Value { + json!({ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "test", + "choices": [{ + "index": 0, + "message": { "role": "assistant", "content": "hello", "refusal": null }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + "prompt_tokens_details": { "cached_tokens": 0, "audio_tokens": 0 }, + "completion_tokens_details": { "reasoning_tokens": 0 } + } + }) + } + + fn official_stream_chunk() -> Value { + json!({ + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 0, + "model": "test", + "choices": [{ + "index": 0, + "delta": { "role": "assistant", "content": "hello" }, + "finish_reason": null + }] + }) + } + + #[test] + fn accepts_an_official_request() { + validator() + .validate_chat_completion_request(&official_request()) + .unwrap(); + } + + #[test] + fn rejects_request_with_chat_template_kwargs() { + let mut request = official_request(); + request["chat_template_kwargs"] = json!({ "enable_thinking": false }); + + let error = validator() + .validate_chat_completion_request(&request) + .err() + .unwrap(); + + assert!(error.to_string().contains("request does not conform")); + } + + #[test] + fn accepts_an_official_response() { + validator() + .validate_chat_completion_response(&official_response()) + .unwrap(); + } + + #[test] + fn rejects_response_with_reasoning_content() { + let mut response = official_response(); + response["choices"][0]["message"]["reasoning_content"] = json!("thinking"); + + let error = validator() + .validate_chat_completion_response(&response) + .err() + .unwrap(); + + assert!(error.to_string().contains("response does not conform")); + } + + #[test] + fn rejects_response_with_image_tokens() { + let mut response = official_response(); + response["usage"]["prompt_tokens_details"]["image_tokens"] = json!(3); + + let error = validator() + .validate_chat_completion_response(&response) + .err() + .unwrap(); + + assert!(error.to_string().contains("response does not conform")); + } + + #[test] + fn accepts_an_official_stream_chunk() { + validator() + .validate_chat_completion_stream_chunk(&official_stream_chunk()) + .unwrap(); + } + + #[test] + fn rejects_stream_chunk_with_reasoning_content() { + let mut chunk = official_stream_chunk(); + chunk["choices"][0]["delta"]["reasoning_content"] = json!("thinking"); + + let error = validator() + .validate_chat_completion_stream_chunk(&chunk) + .err() + .unwrap(); + + assert!(error.to_string().contains("stream chunk does not conform")); + } + + #[test] + fn rejects_invalid_openapi_yaml() { + let error = OpenAIValidator::from_openapi_yaml("key: \"unterminated") + .err() + .unwrap(); + + assert!(error.to_string().contains("not valid YAML")); + } + + #[test] + fn fails_when_request_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components + .as_object_mut() + .unwrap() + .remove("CreateChatCompletionRequest"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!(error.to_string().contains("CreateChatCompletionRequest")); + } + + #[test] + fn fails_when_response_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components + .as_object_mut() + .unwrap() + .remove("CreateChatCompletionResponse"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!(error.to_string().contains("CreateChatCompletionResponse")); + } + + #[test] + fn fails_when_stream_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components + .as_object_mut() + .unwrap() + .remove("CreateChatCompletionStreamResponse"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!( + error + .to_string() + .contains("CreateChatCompletionStreamResponse") + ); + } + + #[test] + fn fails_to_compile_a_structurally_broken_schema() { + let components = json!({ "Broken": { "$ref": "#/$defs/Missing" } }); + + let error = compile_strict_schema(&components, "Broken", &["/$defs/Broken"]) + .err() + .unwrap(); + + assert!(error.to_string().contains("Broken")); + } + + fn official_responses_request() -> Value { + json!({ "model": "test", "input": "Say hello" }) + } + + fn official_responses_response() -> Value { + json!({ + "id": "resp_test", + "object": "response", + "created_at": 0, + "error": null, + "incomplete_details": null, + "instructions": null, + "model": "test", + "tools": [], + "output": [{ + "id": "msg_0", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{ + "type": "output_text", + "text": "hello", + "annotations": [], + "logprobs": [] + }] + }], + "parallel_tool_calls": true, + "metadata": {}, + "tool_choice": "auto", + "temperature": 1, + "top_p": 1, + "usage": { + "input_tokens": 1, + "input_tokens_details": { "cached_tokens": 0 }, + "output_tokens": 1, + "output_tokens_details": { "reasoning_tokens": 0 }, + "total_tokens": 2 + } + }) + } + + fn official_responses_stream_event() -> Value { + json!({ + "type": "response.output_text.delta", + "item_id": "msg_0", + "output_index": 0, + "content_index": 0, + "delta": "hello", + "sequence_number": 1, + "logprobs": [] + }) + } + + #[test] + fn accepts_an_official_responses_request() { + validator() + .validate_responses_request(&official_responses_request()) + .unwrap(); + } + + #[test] + fn accepts_an_official_responses_response() { + validator() + .validate_responses_response(&official_responses_response()) + .unwrap(); + } + + #[test] + fn rejects_responses_response_with_an_extra_top_level_key() { + let mut response = official_responses_response(); + response["paddler_extension"] = json!("nope"); + + let error = validator() + .validate_responses_response(&response) + .err() + .unwrap(); + + assert!( + error + .to_string() + .contains("responses response does not conform") + ); + } + + #[test] + fn rejects_responses_response_with_an_extra_output_text_field() { + let mut response = official_responses_response(); + response["output"][0]["content"][0]["reasoning_content"] = json!("nope"); + + let error = validator() + .validate_responses_response(&response) + .err() + .unwrap(); + + assert!( + error + .to_string() + .contains("responses response does not conform") + ); + } + + #[test] + fn accepts_an_official_responses_stream_event() { + validator() + .validate_responses_stream_event(&official_responses_stream_event()) + .unwrap(); + } + + #[test] + fn rejects_responses_stream_event_with_an_extra_key() { + let mut event = official_responses_stream_event(); + event["paddler_extension"] = json!("nope"); + + let error = validator() + .validate_responses_stream_event(&event) + .err() + .unwrap(); + + assert!( + error + .to_string() + .contains("responses stream event does not conform") + ); + } + + #[test] + fn fails_when_create_response_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components.as_object_mut().unwrap().remove("CreateResponse"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!(error.to_string().contains("CreateResponse")); + } + + #[test] + fn fails_when_responses_response_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components.as_object_mut().unwrap().remove("Response"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!(error.to_string().contains("Response")); + } + + #[test] + fn fails_when_response_stream_event_schema_is_absent() { + let mut components = parse_components(OPENAPI_YAML).unwrap(); + components + .as_object_mut() + .unwrap() + .remove("ResponseStreamEvent"); + + let error = OpenAIValidator::from_components(&components).err().unwrap(); + + assert!(error.to_string().contains("ResponseStreamEvent")); + } +} diff --git a/paddler_openai_response_format_validator/src/openai_validator_error.rs b/paddler_openai_response_format_validator/src/openai_validator_error.rs new file mode 100644 index 00000000..703d9a36 --- /dev/null +++ b/paddler_openai_response_format_validator/src/openai_validator_error.rs @@ -0,0 +1,23 @@ +#[derive(Debug, thiserror::Error)] +pub enum OpenAIValidatorError { + #[error( + "chat completion request does not conform to the official OpenAI schema: {violations:?}" + )] + RequestDoesNotConform { violations: Vec }, + #[error( + "chat completion response does not conform to the official OpenAI schema: {violations:?}" + )] + ResponseDoesNotConform { violations: Vec }, + #[error( + "chat completion stream chunk does not conform to the official OpenAI schema: {violations:?}" + )] + StreamChunkDoesNotConform { violations: Vec }, + #[error("responses request does not conform to the official OpenAI schema: {violations:?}")] + ResponsesRequestDoesNotConform { violations: Vec }, + #[error("responses response does not conform to the official OpenAI schema: {violations:?}")] + ResponsesResponseDoesNotConform { violations: Vec }, + #[error( + "responses stream event does not conform to the official OpenAI schema: {violations:?}" + )] + ResponsesStreamEventDoesNotConform { violations: Vec }, +} diff --git a/paddler_openai_response_format_validator/src/strict_chat_completion_schema.rs b/paddler_openai_response_format_validator/src/strict_chat_completion_schema.rs new file mode 100644 index 00000000..4d7350a9 --- /dev/null +++ b/paddler_openai_response_format_validator/src/strict_chat_completion_schema.rs @@ -0,0 +1,358 @@ +use std::collections::BTreeMap; +use std::collections::BTreeSet; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::bail; +use serde_json::Map; +use serde_json::Value; +use serde_json::json; + +const COMPONENT_REF_PREFIX: &str = "#/components/schemas/"; +const DIALECT: &str = "https://json-schema.org/draft/2020-12/schema"; + +fn rewrite_ref(reference: &Value) -> Value { + match reference { + Value::String(reference) => reference.strip_prefix(COMPONENT_REF_PREFIX).map_or_else( + || Value::String(reference.clone()), + |name| Value::String(format!("#/$defs/{name}")), + ), + other => transform_node(other), + } +} + +fn unique_strings(values: &[Value]) -> Vec { + let mut seen = BTreeSet::new(); + let mut unique = Vec::new(); + + for value in values { + let key = value.to_string(); + + if seen.insert(key) { + unique.push(value.clone()); + } + } + + unique +} + +fn transform_object(object: &Map) -> Value { + let mut transformed = Map::new(); + let mut nullable = false; + + for (key, value) in object { + match key.as_str() { + "nullable" => nullable = matches!(value, Value::Bool(true)), + // Draft 2019-09 recursion keywords the OpenAI document still carries; Draft 2020-12 + // replaced them with `$dynamicAnchor`/`$dynamicRef`. Drop them so the assembled schema + // passes 2020-12 meta-validation. The schemas that use them (recursive filters) are not + // part of any Paddler-emitted payload, so removing the recursion is inconsequential. + "$recursiveAnchor" | "$recursiveRef" => {} + // OpenAPI 3.0 expressed exclusive bounds as booleans; Draft 2020-12 expects the bound to + // be the number itself. A boolean form is meaningless under 2020-12, so drop it. + "exclusiveMinimum" | "exclusiveMaximum" if value.is_boolean() => {} + "required" => { + if let Value::Array(entries) = value { + transformed.insert(key.clone(), Value::Array(unique_strings(entries))); + } else { + transformed.insert(key.clone(), transform_node(value)); + } + } + "$ref" => { + transformed.insert("$ref".to_owned(), rewrite_ref(value)); + } + _ => { + transformed.insert(key.clone(), transform_node(value)); + } + } + } + + if nullable { + let mut nullable_wrapper = Map::new(); + nullable_wrapper.insert( + "anyOf".to_owned(), + Value::Array(vec![Value::Object(transformed), json!({ "type": "null" })]), + ); + + Value::Object(nullable_wrapper) + } else { + Value::Object(transformed) + } +} + +fn transform_node(node: &Value) -> Value { + match node { + Value::Array(items) => Value::Array(items.iter().map(transform_node).collect()), + Value::Object(object) => transform_object(object), + other => other.clone(), + } +} + +fn collect_component_refs(node: &Value, found: &mut BTreeSet) { + match node { + Value::Array(items) => { + for item in items { + collect_component_refs(item, found); + } + } + Value::Object(object) => { + for (key, value) in object { + if key == "$ref" + && let Value::String(reference) = value + && let Some(name) = reference.strip_prefix(COMPONENT_REF_PREFIX) + { + found.insert(name.to_owned()); + } else { + collect_component_refs(value, found); + } + } + } + _ => {} + } +} + +fn transitive_closure<'spec>( + components: &'spec Value, + root_name: &str, +) -> Result> { + let mut reachable: BTreeMap = BTreeMap::new(); + let mut pending = vec![root_name.to_owned()]; + + while let Some(name) = pending.pop() { + if reachable.contains_key(&name) { + continue; + } + + let component = components + .get(name.as_str()) + .with_context(|| format!("schema references unknown component {name:?}"))?; + + reachable.insert(name.clone(), component); + + let mut direct_refs = BTreeSet::new(); + collect_component_refs(component, &mut direct_refs); + + for reference in direct_refs { + pending.push(reference); + } + } + + Ok(reachable) +} + +pub fn strict_chat_completion_schema( + components: &Value, + root_name: &str, + strict_pointers: &[&str], +) -> Result { + let closure = transitive_closure(components, root_name)?; + + let mut definitions = Map::new(); + + for (name, component) in closure { + definitions.insert(name, transform_node(component)); + } + + let mut schema = json!({ + "$schema": DIALECT, + "$ref": format!("#/$defs/{root_name}"), + "$defs": Value::Object(definitions), + }); + + for pointer in strict_pointers { + match schema.pointer_mut(pointer) { + Some(Value::Object(target)) => { + target.insert("unevaluatedProperties".to_owned(), Value::Bool(false)); + } + Some(other) => bail!("strict target {pointer:?} is not an object: {other}"), + None => bail!("strict target {pointer:?} was not found in the assembled schema"), + } + } + + Ok(schema) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::strict_chat_completion_schema; + use super::transform_node; + + #[test] + fn rewrites_component_refs_to_defs() { + let rewritten = transform_node(&json!({ "$ref": "#/components/schemas/Foo" })); + + assert_eq!(rewritten, json!({ "$ref": "#/$defs/Foo" })); + } + + #[test] + fn leaves_non_component_refs_untouched() { + let rewritten = transform_node(&json!({ "$ref": "#/$defs/Foo" })); + + assert_eq!(rewritten, json!({ "$ref": "#/$defs/Foo" })); + } + + #[test] + fn leaves_non_string_refs_untouched() { + let rewritten = transform_node(&json!({ "$ref": 42 })); + + assert_eq!(rewritten, json!({ "$ref": 42 })); + } + + #[test] + fn wraps_nullable_into_anyof_null() { + let wrapped = transform_node(&json!({ "type": "string", "nullable": true })); + + assert_eq!( + wrapped, + json!({ "anyOf": [{ "type": "string" }, { "type": "null" }] }) + ); + } + + #[test] + fn drops_nullable_false() { + let transformed = transform_node(&json!({ "type": "string", "nullable": false })); + + assert_eq!(transformed, json!({ "type": "string" })); + } + + #[test] + fn drops_draft_2019_recursive_keywords() { + let transformed = transform_node(&json!({ + "$recursiveAnchor": true, + "$recursiveRef": "#", + "type": "object" + })); + + assert_eq!(transformed, json!({ "type": "object" })); + } + + #[test] + fn drops_boolean_exclusive_bounds() { + let transformed = transform_node(&json!({ + "type": "number", + "minimum": 0, + "exclusiveMinimum": true + })); + + assert_eq!(transformed, json!({ "type": "number", "minimum": 0 })); + } + + #[test] + fn keeps_numeric_exclusive_bounds() { + let transformed = transform_node(&json!({ "type": "number", "exclusiveMinimum": 0 })); + + assert_eq!( + transformed, + json!({ "type": "number", "exclusiveMinimum": 0 }) + ); + } + + #[test] + fn deduplicates_required_entries() { + let transformed = transform_node(&json!({ + "type": "object", + "required": ["id", "name", "id"] + })); + + assert_eq!( + transformed, + json!({ "type": "object", "required": ["id", "name"] }) + ); + } + + #[test] + fn passes_scalars_through() { + assert_eq!(transform_node(&json!(7)), json!(7)); + } + + #[test] + fn transforms_refs_nested_in_arrays() { + let transformed = + transform_node(&json!({ "allOf": [{ "$ref": "#/components/schemas/Sub" }] })); + + assert_eq!(transformed, json!({ "allOf": [{ "$ref": "#/$defs/Sub" }] })); + } + + #[test] + fn builds_self_contained_schema_with_strict_targets() { + let components = json!({ + "Root": { + "type": "object", + "properties": { "child": { "$ref": "#/components/schemas/Child" } } + }, + "Child": { "type": "object" } + }); + + let schema = + strict_chat_completion_schema(&components, "Root", &["/$defs/Root", "/$defs/Child"]) + .unwrap(); + + assert_eq!(schema["$ref"], json!("#/$defs/Root")); + assert_eq!( + schema["$defs"]["Root"]["unevaluatedProperties"], + json!(false) + ); + assert_eq!( + schema["$defs"]["Child"]["unevaluatedProperties"], + json!(false) + ); + } + + #[test] + fn collects_a_shared_component_reached_through_two_paths_only_once() { + let components = json!({ + "Root": { + "type": "object", + "properties": { + "left": { "$ref": "#/components/schemas/Left" }, + "right": { "$ref": "#/components/schemas/Right" } + } + }, + "Left": { "properties": { "shared": { "$ref": "#/components/schemas/Shared" } } }, + "Right": { "properties": { "shared": { "$ref": "#/components/schemas/Shared" } } }, + "Shared": { "type": "object" } + }); + + let schema = strict_chat_completion_schema(&components, "Root", &[]).unwrap(); + + assert!(schema["$defs"]["Shared"].is_object()); + assert_eq!(schema["$defs"].as_object().unwrap().len(), 4); + } + + #[test] + fn rejects_unknown_root_schema() { + let error = strict_chat_completion_schema(&json!({}), "Root", &[]).unwrap_err(); + + assert!(error.to_string().contains("unknown component \"Root\"")); + } + + #[test] + fn rejects_dangling_component_ref() { + let components = json!({ "Root": { "$ref": "#/components/schemas/Missing" } }); + + let error = strict_chat_completion_schema(&components, "Root", &[]).unwrap_err(); + + assert!(error.to_string().contains("unknown component \"Missing\"")); + } + + #[test] + fn rejects_missing_strict_target() { + let components = json!({ "Root": { "type": "object" } }); + + let error = + strict_chat_completion_schema(&components, "Root", &["/$defs/Nope"]).unwrap_err(); + + assert!(error.to_string().contains("was not found")); + } + + #[test] + fn rejects_non_object_strict_target() { + let components = json!({ "Root": { "type": "object" } }); + + let error = strict_chat_completion_schema(&components, "Root", &["/$ref"]).unwrap_err(); + + assert!(error.to_string().contains("is not an object")); + } +} diff --git a/paddler_openai_response_format_validator/src/yaml_to_json_value.rs b/paddler_openai_response_format_validator/src/yaml_to_json_value.rs new file mode 100644 index 00000000..0a6359e8 --- /dev/null +++ b/paddler_openai_response_format_validator/src/yaml_to_json_value.rs @@ -0,0 +1,172 @@ +use anyhow::Result; +use anyhow::anyhow; +use anyhow::bail; +use serde_json::Map; +use serde_json::Number; +use serde_json::Value; +use yaml_rust2::yaml::Hash; +use yaml_rust2::yaml::Yaml; + +fn real_to_value(real: &str) -> Result { + let parsed: f64 = real + .parse() + .map_err(|error| anyhow!("could not parse YAML real {real:?}: {error}"))?; + + let number = + Number::from_f64(parsed).ok_or_else(|| anyhow!("YAML real {real:?} is not finite"))?; + + Ok(Value::Number(number)) +} + +fn hash_to_value(hash: &Hash) -> Result { + let mut object = Map::new(); + + for (key, value) in hash { + let Yaml::String(key) = key else { + bail!("YAML mapping keys must be strings, found {key:?}"); + }; + + object.insert(key.clone(), yaml_to_json_value(value)?); + } + + Ok(Value::Object(object)) +} + +pub fn yaml_to_json_value(yaml: &Yaml) -> Result { + match yaml { + Yaml::Null => Ok(Value::Null), + Yaml::Boolean(boolean) => Ok(Value::Bool(*boolean)), + Yaml::Integer(integer) => Ok(Value::Number(Number::from(*integer))), + Yaml::Real(real) => real_to_value(real), + Yaml::String(string) => Ok(Value::String(string.clone())), + Yaml::Array(array) => array + .iter() + .map(yaml_to_json_value) + .collect::>>() + .map(Value::Array), + Yaml::Hash(hash) => hash_to_value(hash), + Yaml::Alias(index) => bail!("YAML aliases are not supported (alias #{index})"), + Yaml::BadValue => bail!("encountered an invalid YAML node"), + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use yaml_rust2::yaml::Hash; + use yaml_rust2::yaml::Yaml; + + use super::yaml_to_json_value; + + #[test] + fn converts_null() { + assert_eq!(yaml_to_json_value(&Yaml::Null).unwrap(), json!(null)); + } + + #[test] + fn converts_boolean() { + assert_eq!( + yaml_to_json_value(&Yaml::Boolean(true)).unwrap(), + json!(true) + ); + } + + #[test] + fn converts_integer() { + assert_eq!(yaml_to_json_value(&Yaml::Integer(42)).unwrap(), json!(42)); + } + + #[test] + fn converts_string() { + assert_eq!( + yaml_to_json_value(&Yaml::String("hello".to_owned())).unwrap(), + json!("hello") + ); + } + + #[test] + fn converts_real() { + assert_eq!( + yaml_to_json_value(&Yaml::Real("1.5".to_owned())).unwrap(), + json!(1.5) + ); + } + + #[test] + fn rejects_unparseable_real() { + let error = yaml_to_json_value(&Yaml::Real("not-a-number".to_owned())).unwrap_err(); + + assert!(error.to_string().contains("could not parse YAML real")); + } + + #[test] + fn rejects_non_finite_real() { + let error = yaml_to_json_value(&Yaml::Real("inf".to_owned())).unwrap_err(); + + assert!(error.to_string().contains("not finite")); + } + + #[test] + fn converts_array() { + let array = Yaml::Array(vec![Yaml::Integer(1), Yaml::String("two".to_owned())]); + + assert_eq!(yaml_to_json_value(&array).unwrap(), json!([1, "two"])); + } + + #[test] + fn converts_hash_with_string_keys() { + let mut hash = Hash::new(); + hash.insert( + Yaml::String("name".to_owned()), + Yaml::String("paddler".to_owned()), + ); + + assert_eq!( + yaml_to_json_value(&Yaml::Hash(hash)).unwrap(), + json!({"name": "paddler"}) + ); + } + + #[test] + fn rejects_non_string_hash_keys() { + let mut hash = Hash::new(); + hash.insert(Yaml::Integer(1), Yaml::Null); + + let error = yaml_to_json_value(&Yaml::Hash(hash)).unwrap_err(); + + assert!(error.to_string().contains("mapping keys must be strings")); + } + + #[test] + fn propagates_errors_from_hash_values() { + let mut hash = Hash::new(); + hash.insert(Yaml::String("broken".to_owned()), Yaml::BadValue); + + let error = yaml_to_json_value(&Yaml::Hash(hash)).unwrap_err(); + + assert!(error.to_string().contains("invalid YAML node")); + } + + #[test] + fn propagates_errors_from_array_elements() { + let array = Yaml::Array(vec![Yaml::BadValue]); + + let error = yaml_to_json_value(&array).unwrap_err(); + + assert!(error.to_string().contains("invalid YAML node")); + } + + #[test] + fn rejects_alias() { + let error = yaml_to_json_value(&Yaml::Alias(7)).unwrap_err(); + + assert!(error.to_string().contains("aliases are not supported")); + } + + #[test] + fn rejects_bad_value() { + let error = yaml_to_json_value(&Yaml::BadValue).unwrap_err(); + + assert!(error.to_string().contains("invalid YAML node")); + } +} diff --git a/paddler_state_conversion/Cargo.toml b/paddler_state_conversion/Cargo.toml new file mode 100644 index 00000000..1456be5e --- /dev/null +++ b/paddler_state_conversion/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "paddler_state_conversion" +authors.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_state_conversion/src/converts_to_applicable_state.rs b/paddler_state_conversion/src/converts_to_applicable_state.rs new file mode 100644 index 00000000..1206b171 --- /dev/null +++ b/paddler_state_conversion/src/converts_to_applicable_state.rs @@ -0,0 +1,13 @@ +use anyhow::Result; +use async_trait::async_trait; + +#[async_trait] +pub trait ConvertsToApplicableState { + type ApplicableState; + type DesiredState; + + async fn to_applicable_state( + &self, + desired_state: Self::DesiredState, + ) -> Result; +} diff --git a/paddler_state_conversion/src/converts_to_desired_state.rs b/paddler_state_conversion/src/converts_to_desired_state.rs new file mode 100644 index 00000000..9cbea494 --- /dev/null +++ b/paddler_state_conversion/src/converts_to_desired_state.rs @@ -0,0 +1,6 @@ +pub trait ConvertsToDesiredState { + type DesiredState; + type Source; + + fn to_desired_state(&self, source: Self::Source) -> Self::DesiredState; +} diff --git a/paddler_state_conversion/src/lib.rs b/paddler_state_conversion/src/lib.rs new file mode 100644 index 00000000..a6a9157c --- /dev/null +++ b/paddler_state_conversion/src/lib.rs @@ -0,0 +1,2 @@ +pub mod converts_to_applicable_state; +pub mod converts_to_desired_state; diff --git a/paddler_test_cluster_harness/Cargo.toml b/paddler_test_cluster_harness/Cargo.toml new file mode 100644 index 00000000..e97271e6 --- /dev/null +++ b/paddler_test_cluster_harness/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "paddler_test_cluster_harness" +authors.workspace = true +description = "Shared balancer-API-facing test harness for Paddler integration tests" +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +anyhow = { workspace = true } +async-openai = { workspace = true } +async-stream = { workspace = true } +async-trait = { workspace = true } +base64 = { workspace = true } +futures-util = { workspace = true } +paddler_client = { workspace = true } +paddler_messaging = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tempfile = { workspace = true } +tokio = { workspace = true } +url = { workspace = true } + +[dev-dependencies] +http = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_tests/src/agent_config.rs b/paddler_test_cluster_harness/src/agent_config.rs similarity index 100% rename from paddler_tests/src/agent_config.rs rename to paddler_test_cluster_harness/src/agent_config.rs diff --git a/paddler_test_cluster_harness/src/agent_spawner.rs b/paddler_test_cluster_harness/src/agent_spawner.rs new file mode 100644 index 00000000..b06638a9 --- /dev/null +++ b/paddler_test_cluster_harness/src/agent_spawner.rs @@ -0,0 +1,8 @@ +use anyhow::Result; + +use crate::agent_config::AgentConfig; +use crate::managed_process::ManagedProcess; + +pub trait AgentSpawner: Send + Sync { + fn spawn(&self, config: &AgentConfig) -> Result>; +} diff --git a/paddler_tests/src/agents_status/assert_agent_count.rs b/paddler_test_cluster_harness/src/agents_status/assert_agent_count.rs similarity index 65% rename from paddler_tests/src/agents_status/assert_agent_count.rs rename to paddler_test_cluster_harness/src/agents_status/assert_agent_count.rs index dcb0df70..aca71ded 100644 --- a/paddler_tests/src/agents_status/assert_agent_count.rs +++ b/paddler_test_cluster_harness/src/agents_status/assert_agent_count.rs @@ -1,4 +1,4 @@ -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; pub fn assert_agent_count(expected_count: usize) -> impl Fn(&AgentControllerPoolSnapshot) -> bool { move |snapshot| snapshot.agents.len() == expected_count diff --git a/paddler_tests/src/agents_status/assert_slots_processing.rs b/paddler_test_cluster_harness/src/agents_status/assert_slots_processing.rs similarity index 80% rename from paddler_tests/src/agents_status/assert_slots_processing.rs rename to paddler_test_cluster_harness/src/agents_status/assert_slots_processing.rs index 797f817a..b91c485f 100644 --- a/paddler_tests/src/agents_status/assert_slots_processing.rs +++ b/paddler_test_cluster_harness/src/agents_status/assert_slots_processing.rs @@ -1,4 +1,4 @@ -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; pub fn assert_slots_processing( agent_id: &str, diff --git a/paddler_test_cluster_harness/src/agents_status/assert_slots_total_at_least.rs b/paddler_test_cluster_harness/src/agents_status/assert_slots_total_at_least.rs new file mode 100644 index 00000000..5a37df6a --- /dev/null +++ b/paddler_test_cluster_harness/src/agents_status/assert_slots_total_at_least.rs @@ -0,0 +1,68 @@ +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; + +pub fn assert_slots_total_at_least( + agent_id: &str, + expected_slots_total: i32, +) -> impl Fn(&AgentControllerPoolSnapshot) -> bool { + let agent_id = agent_id.to_owned(); + + move |snapshot| { + snapshot + .agents + .iter() + .any(|agent| agent.id == agent_id && agent.slots_total >= expected_slots_total) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; + use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; + + use super::assert_slots_total_at_least; + + fn snapshot_with(id: &str, slots_total: i32) -> AgentControllerPoolSnapshot { + AgentControllerPoolSnapshot { + agents: vec![AgentControllerSnapshot { + desired_slots_total: slots_total, + download_current: 0, + download_filename: None, + download_indeterminate: false, + download_total: 0, + id: id.to_owned(), + issues: BTreeSet::new(), + model_path: None, + name: None, + slots_processing: 0, + slots_total, + state_application_status: AgentStateApplicationStatus::Applied, + uses_chat_template_override: false, + }], + } + } + + #[test] + fn matches_when_the_named_agent_has_enough_slots() { + let predicate = assert_slots_total_at_least("agent-a", 4); + + assert!(predicate(&snapshot_with("agent-a", 4))); + assert!(predicate(&snapshot_with("agent-a", 6))); + } + + #[test] + fn rejects_when_slots_are_below_the_threshold() { + let predicate = assert_slots_total_at_least("agent-a", 4); + + assert!(!predicate(&snapshot_with("agent-a", 3))); + } + + #[test] + fn rejects_when_no_agent_matches_the_id() { + let predicate = assert_slots_total_at_least("agent-b", 1); + + assert!(!predicate(&snapshot_with("agent-a", 8))); + } +} diff --git a/paddler_tests/src/agents_status/mod.rs b/paddler_test_cluster_harness/src/agents_status/mod.rs similarity index 100% rename from paddler_tests/src/agents_status/mod.rs rename to paddler_test_cluster_harness/src/agents_status/mod.rs diff --git a/paddler_tests/src/agents_stream_watcher.rs b/paddler_test_cluster_harness/src/agents_stream_watcher.rs similarity index 88% rename from paddler_tests/src/agents_stream_watcher.rs rename to paddler_test_cluster_harness/src/agents_stream_watcher.rs index c77f5d08..19d0448a 100644 --- a/paddler_tests/src/agents_stream_watcher.rs +++ b/paddler_test_cluster_harness/src/agents_stream_watcher.rs @@ -6,8 +6,8 @@ use anyhow::anyhow; use anyhow::bail; use futures_util::Stream; use futures_util::StreamExt as _; -use paddler_client::ClientManagement; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_client::client_management::ClientManagement; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; pub struct AgentsStreamWatcher { stream: Pin> + Send>>, @@ -171,10 +171,10 @@ mod tests { use std::collections::BTreeSet; use futures_util::stream; - use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; - use paddler_types::agent_issue::AgentIssue; - use paddler_types::agent_issue_params::ModelPath; - use paddler_types::agent_state_application_status::AgentStateApplicationStatus; + use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; + use paddler_messaging::agent_issue::AgentIssue; + use paddler_messaging::agent_issue_params::model_path::ModelPath; + use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; use super::*; @@ -395,4 +395,39 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn wait_for_slots_ready_completes_once_observed_counts_match_expected() { + let snapshots = vec![ + AgentControllerPoolSnapshot { + agents: vec![snapshot_with_agent_and_slots("a", BTreeSet::new(), 1)], + }, + AgentControllerPoolSnapshot { + agents: vec![ + snapshot_with_agent_and_slots("a", BTreeSet::new(), 1), + snapshot_with_agent_and_slots("b", BTreeSet::new(), 2), + ], + }, + ]; + + let mut watcher = make_watcher(snapshots); + + watcher.wait_for_slots_ready(&[2, 1]).await.unwrap(); + } + + #[tokio::test] + async fn wait_for_slots_ready_errors_when_an_agent_reports_issues() { + let snapshots = vec![AgentControllerPoolSnapshot { + agents: vec![ + snapshot_with_agent_and_slots("a", unable_to_find_chat_template_issue(), 1), + snapshot_with_agent_and_slots("b", BTreeSet::new(), 2), + ], + }]; + + let mut watcher = make_watcher(snapshots); + + let error = watcher.wait_for_slots_ready(&[1, 2]).await.err().unwrap(); + + assert!(format!("{error:#}").contains("issues")); + } } diff --git a/paddler_tests/src/balancer_addresses.rs b/paddler_test_cluster_harness/src/balancer_addresses.rs similarity index 62% rename from paddler_tests/src/balancer_addresses.rs rename to paddler_test_cluster_harness/src/balancer_addresses.rs index 9831c56f..7c582060 100644 --- a/paddler_tests/src/balancer_addresses.rs +++ b/paddler_test_cluster_harness/src/balancer_addresses.rs @@ -60,3 +60,41 @@ impl BalancerAddresses { .with_context(|| format!("failed to build base URL for {address}")) } } + +#[cfg(test)] +mod tests { + use super::BalancerAddresses; + + #[test] + fn pick_reserves_three_distinct_loopback_ports() { + let addresses = BalancerAddresses::pick().unwrap(); + + for address in [ + addresses.inference, + addresses.management, + addresses.compat_openai, + ] { + assert!(address.ip().is_loopback()); + assert_ne!(address.port(), 0); + } + + assert_ne!(addresses.inference.port(), addresses.management.port()); + assert_ne!(addresses.inference.port(), addresses.compat_openai.port()); + assert_ne!(addresses.management.port(), addresses.compat_openai.port()); + } + + #[test] + fn builds_base_urls_for_each_service() { + let addresses = BalancerAddresses::pick().unwrap(); + + assert_eq!(addresses.inference_base_url().unwrap().scheme(), "http"); + assert_eq!( + addresses.management_base_url().unwrap().port(), + Some(addresses.management.port()) + ); + assert_eq!( + addresses.compat_openai_base_url().unwrap().port(), + Some(addresses.compat_openai.port()) + ); + } +} diff --git a/paddler_tests/src/buffered_requests_status/assert_count.rs b/paddler_test_cluster_harness/src/buffered_requests_status/assert_count.rs similarity index 65% rename from paddler_tests/src/buffered_requests_status/assert_count.rs rename to paddler_test_cluster_harness/src/buffered_requests_status/assert_count.rs index 5c5b6222..ef3c574f 100644 --- a/paddler_tests/src/buffered_requests_status/assert_count.rs +++ b/paddler_test_cluster_harness/src/buffered_requests_status/assert_count.rs @@ -1,4 +1,4 @@ -use paddler_types::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; pub fn assert_count(expected_count: i32) -> impl Fn(&BufferedRequestManagerSnapshot) -> bool { move |snapshot| snapshot.buffered_requests_current == expected_count diff --git a/paddler_tests/src/buffered_requests_status/mod.rs b/paddler_test_cluster_harness/src/buffered_requests_status/mod.rs similarity index 100% rename from paddler_tests/src/buffered_requests_status/mod.rs rename to paddler_test_cluster_harness/src/buffered_requests_status/mod.rs diff --git a/paddler_test_cluster_harness/src/buffered_requests_stream_watcher.rs b/paddler_test_cluster_harness/src/buffered_requests_stream_watcher.rs new file mode 100644 index 00000000..8759c536 --- /dev/null +++ b/paddler_test_cluster_harness/src/buffered_requests_stream_watcher.rs @@ -0,0 +1,105 @@ +use std::pin::Pin; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use futures_util::Stream; +use futures_util::StreamExt as _; +use paddler_client::client_management::ClientManagement; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; + +pub struct BufferedRequestsStreamWatcher { + stream: Pin> + Send>>, +} + +impl BufferedRequestsStreamWatcher { + pub async fn connect(management: &ClientManagement<'_>) -> Result { + let raw_stream = management + .get_buffered_requests_stream() + .await + .map_err(anyhow::Error::new) + .context("failed to open /api/v1/buffered_requests/stream")?; + + let stream = raw_stream.map(|item| item.map_err(anyhow::Error::new)); + + Ok(Self { + stream: Box::pin(stream), + }) + } + + #[must_use] + pub fn from_stream( + stream: Pin> + Send>>, + ) -> Self { + Self { stream } + } + + pub async fn until( + &mut self, + mut predicate: TPredicate, + ) -> Result + where + TPredicate: FnMut(&BufferedRequestManagerSnapshot) -> bool, + { + while let Some(item) = self.stream.next().await { + let snapshot = item.context("buffered requests stream yielded an error")?; + + if predicate(&snapshot) { + return Ok(snapshot); + } + } + + Err(anyhow!( + "buffered requests stream closed before predicate was satisfied" + )) + } +} + +#[cfg(test)] +mod tests { + use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; + + use super::BufferedRequestsStreamWatcher; + + fn watcher( + items: Vec>, + ) -> BufferedRequestsStreamWatcher { + BufferedRequestsStreamWatcher::from_stream(Box::pin(futures_util::stream::iter(items))) + } + + fn snapshot(buffered_requests_current: i32) -> BufferedRequestManagerSnapshot { + BufferedRequestManagerSnapshot { + buffered_requests_current, + } + } + + #[tokio::test] + async fn returns_the_first_snapshot_satisfying_the_predicate() { + let mut watcher = watcher(vec![Ok(snapshot(5)), Ok(snapshot(0))]); + + let matched = watcher + .until(|snapshot| snapshot.buffered_requests_current == 0) + .await + .unwrap(); + + assert_eq!(matched.buffered_requests_current, 0); + } + + #[tokio::test] + async fn errors_when_the_stream_closes_before_the_predicate_is_satisfied() { + let mut watcher = watcher(vec![Ok(snapshot(5))]); + + let error = watcher.until(|_| false).await.err().unwrap(); + + assert!(error.to_string().contains("closed before predicate")); + } + + #[tokio::test] + async fn propagates_a_stream_error() { + let mut watcher = watcher(vec![Err(anyhow::anyhow!("socket closed"))]); + + let error = watcher.until(|_| true).await.err().unwrap(); + + assert!(error.to_string().contains("yielded an error")); + } +} diff --git a/paddler_test_cluster_harness/src/cluster.rs b/paddler_test_cluster_harness/src/cluster.rs new file mode 100644 index 00000000..32901c5a --- /dev/null +++ b/paddler_test_cluster_harness/src/cluster.rs @@ -0,0 +1,298 @@ +use std::future::Future; + +use anyhow::Context as _; +use anyhow::Result; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_client::PaddlerClient; +use reqwest::Client; +use serde_json::Value; + +use crate::agent_config::AgentConfig; +use crate::agent_spawner::AgentSpawner; +use crate::agents_status::assert_agent_count::assert_agent_count; +use crate::agents_status::assert_slots_processing::assert_slots_processing; +use crate::agents_status::assert_slots_total_at_least::assert_slots_total_at_least; +use crate::agents_stream_watcher::AgentsStreamWatcher; +use crate::buffered_requests_status::assert_count::assert_count; +use crate::buffered_requests_stream_watcher::BufferedRequestsStreamWatcher; +use crate::collect_embedding_results::collect_embedding_results; +use crate::collect_generated_tokens::collect_generated_tokens; +use crate::collected_embedding_results::CollectedEmbeddingResults; +use crate::collected_generated_tokens::CollectedGeneratedTokens; +use crate::inference_http_client::InferenceHttpClient; +use crate::inference_message_stream::InferenceMessageStream; +use crate::openai_chat_completions_client::OpenAIChatCompletionsClient; +use crate::openai_responses_client::OpenAIResponsesClient; +use crate::running_agent::RunningAgent; +use crate::running_balancer::RunningBalancer; +use crate::wait_until_healthy::wait_until_healthy; + +pub struct Cluster { + pub agent_ids: Vec, + pub agents: Vec, + pub agents_watcher: AgentsStreamWatcher, + pub balancer: RunningBalancer, + pub buffered_requests_watcher: BufferedRequestsStreamWatcher, + pub paddler_client: PaddlerClient, + agent_spawner: Box, + inference_client: InferenceHttpClient, + openai_client: OpenAIChatCompletionsClient, + openai_responses_client: OpenAIResponsesClient, +} + +impl Cluster { + pub async fn connect( + balancer: RunningBalancer, + agent_spawner: Box, + desired_state: Option<&BalancerDesiredState>, + ) -> Result { + let management_base_url = balancer.addresses.management_base_url()?; + let inference_base_url = balancer.addresses.inference_base_url()?; + let openai_base_url = balancer.addresses.compat_openai_base_url()?; + + wait_until_healthy(&management_base_url, "health") + .await + .context("balancer did not become healthy")?; + + let paddler_client = PaddlerClient::new(inference_base_url.clone(), management_base_url, 1); + + if let Some(desired_state) = desired_state { + paddler_client + .management() + .put_balancer_desired_state(desired_state) + .await + .map_err(anyhow::Error::new) + .context("failed to PUT balancer desired state")?; + } + + let agents_watcher = AgentsStreamWatcher::connect(&paddler_client.management()).await?; + let buffered_requests_watcher = + BufferedRequestsStreamWatcher::connect(&paddler_client.management()).await?; + + let inference_client = InferenceHttpClient::new(Client::new(), inference_base_url); + let openai_client = OpenAIChatCompletionsClient::new(&openai_base_url)?; + let openai_responses_client = OpenAIResponsesClient::new(&openai_base_url)?; + + Ok(Self { + agent_ids: Vec::new(), + agents: Vec::new(), + agents_watcher, + balancer, + buffered_requests_watcher, + paddler_client, + agent_spawner, + inference_client, + openai_client, + openai_responses_client, + }) + } + + pub fn continue_from_raw_prompt( + &self, + params: &ContinueFromRawPromptParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + collect_generated_tokens( + inference_client + .post_continue_from_raw_prompt(¶ms) + .await?, + ) + .await + } + } + + pub fn continue_from_raw_prompt_stream( + &self, + params: &ContinueFromRawPromptParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + inference_client + .post_continue_from_raw_prompt(¶ms) + .await + } + } + + pub fn continue_from_conversation_history( + &self, + params: &ContinueFromConversationHistoryParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + collect_generated_tokens( + inference_client + .post_continue_from_conversation_history(¶ms) + .await?, + ) + .await + } + } + + pub fn continue_from_conversation_history_stream( + &self, + params: &ContinueFromConversationHistoryParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + inference_client + .post_continue_from_conversation_history(¶ms) + .await + } + } + + pub fn generate_embedding_batch( + &self, + params: &GenerateEmbeddingBatchParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + collect_embedding_results( + inference_client + .post_generate_embedding_batch(¶ms) + .await?, + ) + .await + } + } + + pub fn generate_embedding_batch_stream( + &self, + params: &GenerateEmbeddingBatchParams, + ) -> impl Future> + Send + use<> { + let inference_client = self.inference_client.clone(); + let params = params.clone(); + + async move { + inference_client + .post_generate_embedding_batch(¶ms) + .await + } + } + + pub fn openai_chat_completion_streaming( + &self, + body: &Value, + ) -> impl Future>> + Send + use<> { + let openai_client = self.openai_client.clone(); + let body = body.clone(); + + async move { openai_client.post_streaming(&body).await } + } + + pub fn openai_chat_completion_non_streaming( + &self, + body: &Value, + ) -> impl Future> + Send + use<> { + let openai_client = self.openai_client.clone(); + let body = body.clone(); + + async move { openai_client.post_non_streaming(&body).await } + } + + pub fn openai_responses_streaming( + &self, + body: &Value, + ) -> impl Future>> + Send + use<> { + let openai_responses_client = self.openai_responses_client.clone(); + let body = body.clone(); + + async move { openai_responses_client.post_streaming(&body).await } + } + + pub fn openai_responses_non_streaming( + &self, + body: &Value, + ) -> impl Future> + Send + use<> { + let openai_responses_client = self.openai_responses_client.clone(); + let body = body.clone(); + + async move { openai_responses_client.post_non_streaming(&body).await } + } + + pub async fn wait_for_agent_count( + &mut self, + expected_count: usize, + ) -> Result { + self.agents_watcher + .until(assert_agent_count(expected_count)) + .await + } + + pub async fn wait_for_agent_ready( + &mut self, + agent_name: &str, + expected_slot_count: i32, + ) -> Result { + self.agents_watcher + .wait_for_agent_ready(agent_name, expected_slot_count) + .await + } + + pub async fn wait_for_agents_ready(&mut self, expected_slot_counts: &[i32]) -> Result<()> { + self.agents_watcher + .wait_for_slots_ready(expected_slot_counts) + .await + } + + pub async fn wait_for_slots_processing( + &mut self, + agent_id: &str, + expected_slots_processing: i32, + ) -> Result { + self.agents_watcher + .until(assert_slots_processing(agent_id, expected_slots_processing)) + .await + } + + pub async fn wait_for_slots_total_at_least( + &mut self, + agent_id: &str, + expected_slots_total: i32, + ) -> Result { + self.agents_watcher + .until(assert_slots_total_at_least(agent_id, expected_slots_total)) + .await + } + + pub async fn wait_for_buffered_request_count( + &mut self, + expected_count: i32, + ) -> Result { + self.buffered_requests_watcher + .until(assert_count(expected_count)) + .await + } + + pub fn spawn_additional_agent(&mut self, config: &AgentConfig) -> Result<()> { + let process = self.agent_spawner.spawn(config)?; + + self.agents.push(RunningAgent::new(config.clone(), process)); + + Ok(()) + } + + pub async fn shutdown(self) -> Result<()> { + for agent in self.agents { + agent.shutdown().await?; + } + + self.balancer.shutdown().await + } +} diff --git a/paddler_tests/src/subprocess_cluster_params.rs b/paddler_test_cluster_harness/src/cluster_params.rs similarity index 81% rename from paddler_tests/src/subprocess_cluster_params.rs rename to paddler_test_cluster_harness/src/cluster_params.rs index 03fa0739..72483a0f 100644 --- a/paddler_tests/src/subprocess_cluster_params.rs +++ b/paddler_test_cluster_harness/src/cluster_params.rs @@ -1,10 +1,10 @@ use std::time::Duration; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; use crate::agent_config::AgentConfig; -pub struct SubprocessClusterParams { +pub struct ClusterParams { pub agents: Vec, pub buffered_request_timeout: Duration, pub desired_state: Option, @@ -16,14 +16,14 @@ pub struct SubprocessClusterParams { pub wait_for_slots_ready: bool, } -impl Default for SubprocessClusterParams { +impl Default for ClusterParams { fn default() -> Self { Self { agents: AgentConfig::uniform(1, 4), buffered_request_timeout: Duration::from_secs(10), desired_state: Some(BalancerDesiredState::default()), inference_cors_allowed_hosts: Vec::new(), - inference_item_timeout: Duration::from_secs(60), + inference_item_timeout: Duration::from_secs(30), management_cors_allowed_hosts: Vec::new(), max_buffered_requests: 10, state_database_url: "memory://".to_owned(), diff --git a/paddler_test_cluster_harness/src/collect_embedding_results.rs b/paddler_test_cluster_harness/src/collect_embedding_results.rs new file mode 100644 index 00000000..91e2038a --- /dev/null +++ b/paddler_test_cluster_harness/src/collect_embedding_results.rs @@ -0,0 +1,283 @@ +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use futures_util::StreamExt as _; +use paddler_messaging::embedding_result::EmbeddingResult; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_client::response::Response as InferenceResponse; + +use crate::collected_embedding_results::CollectedEmbeddingResults; +use crate::embedding_with_producer::EmbeddingWithProducer; +use crate::inference_message_stream::InferenceMessageStream; + +pub async fn collect_embedding_results( + mut stream: InferenceMessageStream, +) -> Result { + let mut embeddings: Vec = Vec::new(); + let mut embeddings_disabled = false; + let mut errors: Vec = Vec::new(); + let mut embedding_rejected_due_to_active_token_generation_count: usize = 0; + let mut no_embeddings_produced_count: usize = 0; + let mut oversized_documents = Vec::new(); + let mut saw_done = false; + let mut wire_errors = Vec::new(); + + while let Some(item) = stream.next().await { + let message = item.context("embedding stream yielded an error")?; + + match message { + InferenceMessage::Response(envelope) => { + let generated_by = envelope.generated_by.clone(); + + match envelope.response { + InferenceResponse::Embedding(EmbeddingResult::Done) => { + saw_done = true; + + break; + } + InferenceResponse::Embedding(EmbeddingResult::Embedding(embedding)) => { + embeddings.push(EmbeddingWithProducer { + embedding, + generated_by, + }); + } + InferenceResponse::Embedding(EmbeddingResult::DocumentExceedsBatchSize( + details, + )) => { + oversized_documents.push(details); + } + InferenceResponse::Embedding(EmbeddingResult::EmbeddingsDisabled) => { + embeddings_disabled = true; + } + InferenceResponse::Embedding(EmbeddingResult::Error(message)) => { + errors.push(message); + } + InferenceResponse::Embedding( + EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration, + ) => { + embedding_rejected_due_to_active_token_generation_count += 1; + } + InferenceResponse::Embedding(EmbeddingResult::NoEmbeddingsProduced) => { + no_embeddings_produced_count += 1; + } + InferenceResponse::GeneratedToken(_) => { + return Err(anyhow!( + "unexpected generated-token response on an embedding stream" + )); + } + InferenceResponse::Timeout => { + return Err(anyhow!("embedding request timed out on balancer")); + } + InferenceResponse::TooManyBufferedRequests => { + return Err(anyhow!( + "balancer rejected embedding request: too many buffered" + )); + } + } + } + InferenceMessage::Error(error_envelope) => { + wire_errors.push(error_envelope.error); + } + } + } + + Ok(CollectedEmbeddingResults { + embeddings, + embeddings_disabled, + errors, + embedding_rejected_due_to_active_token_generation_count, + no_embeddings_produced_count, + oversized_documents, + saw_done, + wire_errors, + }) +} + +#[cfg(test)] +mod tests { + use paddler_messaging::embedding::Embedding; + use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::inference_client::message::Message as InferenceMessage; + use paddler_messaging::inference_client::response::Response as InferenceResponse; + use paddler_messaging::jsonrpc::error::Error; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + use paddler_messaging::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; + use paddler_messaging::pooling_type::PoolingType; + + use super::EmbeddingResult; + use super::collect_embedding_results; + use crate::inference_message_stream::InferenceMessageStream; + + fn stream(items: Vec>) -> InferenceMessageStream { + Box::pin(futures_util::stream::iter(items)) + } + + fn embedding_message(result: EmbeddingResult) -> InferenceMessage { + InferenceMessage::Response(ResponseEnvelope { + generated_by: Some("agent-1".to_owned()), + request_id: "req".to_owned(), + response: InferenceResponse::Embedding(result), + }) + } + + fn sample_embedding() -> Embedding { + Embedding { + embedding: vec![0.1, 0.2], + normalization_method: EmbeddingNormalizationMethod::None, + pooling_type: PoolingType::Last, + source_document_id: "doc".to_owned(), + } + } + + #[tokio::test] + async fn collects_embeddings_until_done() { + let collected = collect_embedding_results(stream(vec![ + Ok(embedding_message(EmbeddingResult::Embedding( + sample_embedding(), + ))), + Ok(embedding_message(EmbeddingResult::Embedding( + sample_embedding(), + ))), + Ok(embedding_message(EmbeddingResult::Done)), + ])) + .await + .unwrap(); + + assert_eq!(collected.embeddings.len(), 2); + assert!(collected.saw_done); + assert_eq!( + collected.embeddings[0].generated_by.as_deref(), + Some("agent-1") + ); + } + + #[tokio::test] + async fn records_oversized_documents() { + let collected = collect_embedding_results(stream(vec![Ok(embedding_message( + EmbeddingResult::DocumentExceedsBatchSize(OversizedEmbeddingDocumentDetails { + document_tokens: 5000, + n_batch: 512, + source_document_id: "big".to_owned(), + }), + ))])) + .await + .unwrap(); + + assert_eq!(collected.oversized_documents.len(), 1); + assert_eq!(collected.oversized_documents[0].document_tokens, 5000); + } + + #[tokio::test] + async fn records_embeddings_disabled() { + let collected = collect_embedding_results(stream(vec![Ok(embedding_message( + EmbeddingResult::EmbeddingsDisabled, + ))])) + .await + .unwrap(); + + assert!(collected.embeddings_disabled); + } + + #[tokio::test] + async fn records_errors_and_rejections() { + let collected = collect_embedding_results(stream(vec![ + Ok(embedding_message(EmbeddingResult::Error("boom".to_owned()))), + Ok(embedding_message( + EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration, + )), + Ok(embedding_message(EmbeddingResult::NoEmbeddingsProduced)), + ])) + .await + .unwrap(); + + assert_eq!(collected.errors, vec!["boom".to_owned()]); + assert_eq!( + collected.embedding_rejected_due_to_active_token_generation_count, + 1 + ); + assert_eq!(collected.no_embeddings_produced_count, 1); + } + + #[tokio::test] + async fn rejects_a_generated_token_response() { + let error = collect_embedding_results(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::GeneratedToken(GeneratedTokenResult::ContentToken( + "x".to_owned(), + )), + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("unexpected generated-token")); + } + + #[tokio::test] + async fn rejects_a_timeout() { + let error = collect_embedding_results(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::Timeout, + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("timed out")); + } + + #[tokio::test] + async fn rejects_too_many_buffered_requests() { + let error = collect_embedding_results(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::TooManyBufferedRequests, + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("too many buffered")); + } + + #[tokio::test] + async fn records_wire_errors() { + let collected = + collect_embedding_results(stream(vec![Ok(InferenceMessage::Error(ErrorEnvelope { + request_id: "req".to_owned(), + error: Error { + code: -32000, + description: "wire failure".to_owned(), + }, + }))])) + .await + .unwrap(); + + assert_eq!(collected.wire_errors.len(), 1); + assert_eq!(collected.wire_errors[0].description, "wire failure"); + } + + #[tokio::test] + async fn propagates_a_stream_error() { + let error = collect_embedding_results(stream(vec![Err(anyhow::anyhow!("socket closed"))])) + .await + .err() + .unwrap(); + + assert!( + error + .to_string() + .contains("embedding stream yielded an error") + ); + } +} diff --git a/paddler_test_cluster_harness/src/collect_generated_tokens.rs b/paddler_test_cluster_harness/src/collect_generated_tokens.rs new file mode 100644 index 00000000..c921c092 --- /dev/null +++ b/paddler_test_cluster_harness/src/collect_generated_tokens.rs @@ -0,0 +1,207 @@ +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use futures_util::StreamExt as _; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_client::response::Response as InferenceResponse; +use paddler_messaging::streamable_result::StreamableResult as _; + +use crate::collected_generated_tokens::CollectedGeneratedTokens; +use crate::inference_message_stream::InferenceMessageStream; +use crate::token_result_with_producer::TokenResultWithProducer; + +pub async fn collect_generated_tokens( + mut stream: InferenceMessageStream, +) -> Result { + let mut text = String::new(); + let mut token_results: Vec = Vec::new(); + + while let Some(item) = stream.next().await { + let message = item.context("inference stream yielded an error")?; + + match message { + InferenceMessage::Response(envelope) => { + let generated_by = envelope.generated_by.clone(); + + match envelope.response { + InferenceResponse::GeneratedToken(token_result) => { + if let Some(token_text) = token_result.token_text() { + text.push_str(token_text); + } + + let is_done = token_result.is_done(); + + token_results.push(TokenResultWithProducer { + token_result, + generated_by, + }); + + if is_done { + break; + } + } + InferenceResponse::Embedding(_) => { + return Err(anyhow!( + "unexpected embedding response on a token-generation stream" + )); + } + InferenceResponse::Timeout => { + return Err(anyhow!("inference request timed out on balancer")); + } + InferenceResponse::TooManyBufferedRequests => { + return Err(anyhow!("balancer rejected request: too many buffered")); + } + } + } + InferenceMessage::Error(error_envelope) => { + return Err(anyhow!( + "inference stream returned JSON-RPC error code {} ({})", + error_envelope.error.code, + error_envelope.error.description + )); + } + } + } + + Ok(CollectedGeneratedTokens { + text, + token_results, + }) +} + +#[cfg(test)] +mod tests { + use paddler_messaging::embedding_result::EmbeddingResult; + use paddler_messaging::generated_token_result::GeneratedTokenResult; + use paddler_messaging::inference_client::message::Message as InferenceMessage; + use paddler_messaging::inference_client::response::Response as InferenceResponse; + use paddler_messaging::jsonrpc::error::Error; + use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; + use paddler_messaging::jsonrpc::response_envelope::ResponseEnvelope; + + use super::collect_generated_tokens; + use crate::inference_message_stream::InferenceMessageStream; + + fn stream(items: Vec>) -> InferenceMessageStream { + Box::pin(futures_util::stream::iter(items)) + } + + fn token(result: GeneratedTokenResult) -> InferenceMessage { + InferenceMessage::Response(ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::GeneratedToken(result), + }) + } + + #[tokio::test] + async fn accumulates_content_token_text() { + let collected = collect_generated_tokens(stream(vec![ + Ok(token(GeneratedTokenResult::ContentToken("hel".to_owned()))), + Ok(token(GeneratedTokenResult::ContentToken("lo".to_owned()))), + ])) + .await + .unwrap(); + + assert_eq!(collected.text, "hello"); + assert_eq!(collected.token_results.len(), 2); + } + + #[tokio::test] + async fn stops_after_a_terminal_token() { + let collected = collect_generated_tokens(stream(vec![ + Ok(token(GeneratedTokenResult::ContentToken("hi".to_owned()))), + Ok(token(GeneratedTokenResult::ImageDecodingFailed( + "dead".to_owned(), + ))), + Ok(token(GeneratedTokenResult::ContentToken( + "IGNORED".to_owned(), + ))), + ])) + .await + .unwrap(); + + assert_eq!(collected.token_results.len(), 2); + assert!(collected.text.starts_with("hi")); + assert!(!collected.text.contains("IGNORED")); + } + + #[tokio::test] + async fn rejects_an_embedding_response() { + let error = collect_generated_tokens(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::Embedding(EmbeddingResult::Done), + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("unexpected embedding response")); + } + + #[tokio::test] + async fn rejects_a_timeout() { + let error = collect_generated_tokens(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::Timeout, + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("timed out")); + } + + #[tokio::test] + async fn rejects_too_many_buffered_requests() { + let error = collect_generated_tokens(stream(vec![Ok(InferenceMessage::Response( + ResponseEnvelope { + generated_by: None, + request_id: "req".to_owned(), + response: InferenceResponse::TooManyBufferedRequests, + }, + ))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("too many buffered")); + } + + #[tokio::test] + async fn propagates_a_wire_error() { + let error = + collect_generated_tokens(stream(vec![Ok(InferenceMessage::Error(ErrorEnvelope { + request_id: "req".to_owned(), + error: Error { + code: -32001, + description: "rpc failure".to_owned(), + }, + }))])) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("JSON-RPC error code -32001")); + } + + #[tokio::test] + async fn propagates_a_stream_error() { + let error = collect_generated_tokens(stream(vec![Err(anyhow::anyhow!("socket closed"))])) + .await + .err() + .unwrap(); + + assert!( + error + .to_string() + .contains("inference stream yielded an error") + ); + } +} diff --git a/paddler_tests/src/collected_embedding_results.rs b/paddler_test_cluster_harness/src/collected_embedding_results.rs similarity index 74% rename from paddler_tests/src/collected_embedding_results.rs rename to paddler_test_cluster_harness/src/collected_embedding_results.rs index 4d785d94..10e2490d 100644 --- a/paddler_tests/src/collected_embedding_results.rs +++ b/paddler_test_cluster_harness/src/collected_embedding_results.rs @@ -1,5 +1,5 @@ -use paddler_types::jsonrpc::Error as JsonRpcError; -use paddler_types::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; +use paddler_messaging::jsonrpc::error::Error as JsonRpcError; +use paddler_messaging::oversized_embedding_document_details::OversizedEmbeddingDocumentDetails; use crate::embedding_with_producer::EmbeddingWithProducer; diff --git a/paddler_tests/src/collected_generated_tokens.rs b/paddler_test_cluster_harness/src/collected_generated_tokens.rs similarity index 100% rename from paddler_tests/src/collected_generated_tokens.rs rename to paddler_test_cluster_harness/src/collected_generated_tokens.rs diff --git a/paddler_tests/src/embedding_with_producer.rs b/paddler_test_cluster_harness/src/embedding_with_producer.rs similarity index 73% rename from paddler_tests/src/embedding_with_producer.rs rename to paddler_test_cluster_harness/src/embedding_with_producer.rs index cc3bf3d6..b9824481 100644 --- a/paddler_tests/src/embedding_with_producer.rs +++ b/paddler_test_cluster_harness/src/embedding_with_producer.rs @@ -1,4 +1,4 @@ -use paddler_types::embedding::Embedding; +use paddler_messaging::embedding::Embedding; #[derive(Debug)] pub struct EmbeddingWithProducer { diff --git a/paddler_test_cluster_harness/src/inference_http_client.rs b/paddler_test_cluster_harness/src/inference_http_client.rs new file mode 100644 index 00000000..09e5866c --- /dev/null +++ b/paddler_test_cluster_harness/src/inference_http_client.rs @@ -0,0 +1,117 @@ +use anyhow::Context as _; +use anyhow::Result; +use futures_util::StreamExt as _; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use reqwest::Client; +use url::Url; + +use crate::inference_message_stream::InferenceMessageStream; +use crate::ndjson_lines_from_response::ndjson_lines_from_response; + +#[derive(Clone)] +pub struct InferenceHttpClient { + http_client: Client, + inference_base_url: Url, +} + +impl InferenceHttpClient { + #[must_use] + pub const fn new(http_client: Client, inference_base_url: Url) -> Self { + Self { + http_client, + inference_base_url, + } + } + + pub async fn post_continue_from_raw_prompt( + &self, + params: &ContinueFromRawPromptParams, + ) -> Result { + self.post_streaming("api/v1/continue_from_raw_prompt", params) + .await + } + + pub async fn post_continue_from_conversation_history( + &self, + params: &ContinueFromConversationHistoryParams, + ) -> Result { + self.post_streaming("api/v1/continue_from_conversation_history", params) + .await + } + + pub async fn post_generate_embedding_batch( + &self, + params: &GenerateEmbeddingBatchParams, + ) -> Result { + self.post_streaming("api/v1/generate_embedding_batch", params) + .await + } + + async fn post_streaming( + &self, + relative_path: &str, + body: &TBody, + ) -> Result + where + TBody: serde::Serialize + Sync + ?Sized, + { + let request_url = self + .inference_base_url + .join(relative_path) + .with_context(|| format!("failed to build URL for {relative_path}"))?; + + let response = self + .http_client + .post(request_url) + .json(body) + .send() + .await + .with_context(|| format!("failed to POST {relative_path}"))? + .error_for_status() + .with_context(|| format!("non-success status on {relative_path}"))?; + + Ok(Box::pin(ndjson_lines_from_response(response).map( + |line_result| { + let line = line_result?; + + serde_json::from_str::(&line) + .with_context(|| format!("failed to parse NDJSON line: {line}")) + }, + ))) + } +} + +#[cfg(test)] +mod tests { + use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; + use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; + use reqwest::Client; + use url::Url; + + use super::InferenceHttpClient; + + fn empty_embedding_batch() -> GenerateEmbeddingBatchParams { + GenerateEmbeddingBatchParams { + input_batch: Vec::new(), + normalization_method: EmbeddingNormalizationMethod::None, + } + } + + #[tokio::test] + async fn errors_when_the_request_url_cannot_be_built() { + let base_url = Url::parse("data:text/plain,paddler").unwrap(); + let client = InferenceHttpClient::new(Client::new(), base_url); + + let error = client + .post_generate_embedding_batch(&empty_embedding_batch()) + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("failed to build URL")); + } +} diff --git a/paddler_tests/src/inference_message_stream.rs b/paddler_test_cluster_harness/src/inference_message_stream.rs similarity index 67% rename from paddler_tests/src/inference_message_stream.rs rename to paddler_test_cluster_harness/src/inference_message_stream.rs index ecae880d..3906f535 100644 --- a/paddler_tests/src/inference_message_stream.rs +++ b/paddler_test_cluster_harness/src/inference_message_stream.rs @@ -2,6 +2,6 @@ use std::pin::Pin; use anyhow::Result; use futures_util::Stream; -use paddler_types::inference_client::Message as InferenceMessage; +use paddler_messaging::inference_client::message::Message as InferenceMessage; pub type InferenceMessageStream = Pin> + Send>>; diff --git a/paddler_test_cluster_harness/src/lib.rs b/paddler_test_cluster_harness/src/lib.rs new file mode 100644 index 00000000..1b5ccac0 --- /dev/null +++ b/paddler_test_cluster_harness/src/lib.rs @@ -0,0 +1,32 @@ +pub mod agent_config; +pub mod agent_spawner; +pub mod agents_stream_watcher; +pub mod balancer_addresses; +pub mod buffered_requests_stream_watcher; +pub mod cluster; +pub mod cluster_params; +pub mod collect_embedding_results; +pub mod collect_generated_tokens; +pub mod collected_embedding_results; +pub mod collected_generated_tokens; +pub mod embedding_with_producer; +pub mod inference_message_stream; +pub mod load_test_image_data_uri; +pub mod managed_process; +#[cfg(any(target_os = "macos", target_os = "linux"))] +pub mod resource_snapshot; +#[cfg(any(target_os = "macos", target_os = "linux"))] +pub mod resource_snapshot_diff; +pub mod running_agent; +pub mod running_balancer; +pub mod state_database_file; +pub mod token_result_with_producer; + +mod agents_status; +mod buffered_requests_status; +mod inference_http_client; +mod ndjson_lines_from_response; +mod openai_chat_completions_client; +mod openai_config_from_base_url; +mod openai_responses_client; +mod wait_until_healthy; diff --git a/paddler_tests/src/load_test_image_data_uri.rs b/paddler_test_cluster_harness/src/load_test_image_data_uri.rs similarity index 61% rename from paddler_tests/src/load_test_image_data_uri.rs rename to paddler_test_cluster_harness/src/load_test_image_data_uri.rs index d2aec7cd..22e94a1c 100644 --- a/paddler_tests/src/load_test_image_data_uri.rs +++ b/paddler_test_cluster_harness/src/load_test_image_data_uri.rs @@ -14,3 +14,16 @@ pub fn load_test_image_data_uri() -> Result { Ok(format!("data:image/jpeg;base64,{encoded}")) } + +#[cfg(test)] +mod tests { + use super::load_test_image_data_uri; + + #[test] + fn encodes_the_fixture_as_a_jpeg_data_uri() { + let data_uri = load_test_image_data_uri().unwrap(); + + assert!(data_uri.starts_with("data:image/jpeg;base64,")); + assert!(data_uri.len() > "data:image/jpeg;base64,".len()); + } +} diff --git a/paddler_test_cluster_harness/src/managed_process.rs b/paddler_test_cluster_harness/src/managed_process.rs new file mode 100644 index 00000000..0d844d66 --- /dev/null +++ b/paddler_test_cluster_harness/src/managed_process.rs @@ -0,0 +1,7 @@ +use anyhow::Result; +use async_trait::async_trait; + +#[async_trait] +pub trait ManagedProcess: Send { + async fn shutdown(&mut self) -> Result<()>; +} diff --git a/paddler_test_cluster_harness/src/ndjson_lines_from_response.rs b/paddler_test_cluster_harness/src/ndjson_lines_from_response.rs new file mode 100644 index 00000000..3017bbb7 --- /dev/null +++ b/paddler_test_cluster_harness/src/ndjson_lines_from_response.rs @@ -0,0 +1,130 @@ +use anyhow::Context as _; +use anyhow::Result; +use async_stream::try_stream; +use futures_util::Stream; +use futures_util::StreamExt as _; + +pub fn ndjson_lines_from_response( + response: reqwest::Response, +) -> impl Stream> + Send { + try_stream! { + let mut bytes_stream = response.bytes_stream(); + let mut buffer: Vec = Vec::new(); + + while let Some(chunk_result) = bytes_stream.next().await { + let chunk = chunk_result.context("failed to read response chunk")?; + + buffer.extend_from_slice(&chunk); + + while let Some(newline_position) = buffer.iter().position(|byte| *byte == b'\n') { + let line_bytes: Vec = buffer.drain(..=newline_position).collect(); + let line_text = std::str::from_utf8(&line_bytes[..newline_position]) + .context("response stream produced non-UTF8 bytes")? + .trim(); + + if line_text.is_empty() { + continue; + } + + yield line_text.to_owned(); + } + } + + let trailing_text = std::str::from_utf8(&buffer) + .context("response stream produced trailing non-UTF8 bytes")? + .trim(); + + if !trailing_text.is_empty() { + yield trailing_text.to_owned(); + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Error as IoError; + + use futures_util::StreamExt as _; + + use super::ndjson_lines_from_response; + + fn response_from_chunks(chunks: Vec, IoError>>) -> reqwest::Response { + let byte_stream = futures_util::stream::iter(chunks); + let body = reqwest::Body::wrap_stream(byte_stream); + + reqwest::Response::from(http::Response::new(body)) + } + + async fn collect_lines(response: reqwest::Response) -> Vec> { + Box::pin(ndjson_lines_from_response(response)) + .collect() + .await + } + + #[tokio::test] + async fn splits_multiple_lines_on_newlines() { + let lines = collect_lines(response_from_chunks(vec![Ok(b"alpha\nbeta\n".to_vec())])).await; + + let collected: Vec = lines.into_iter().map(anyhow::Result::unwrap).collect(); + + assert_eq!(collected, vec!["alpha".to_owned(), "beta".to_owned()]); + } + + #[tokio::test] + async fn yields_a_trailing_line_without_a_terminating_newline() { + let lines = collect_lines(response_from_chunks(vec![Ok(b"alpha\nbeta".to_vec())])).await; + + let collected: Vec = lines.into_iter().map(anyhow::Result::unwrap).collect(); + + assert_eq!(collected, vec!["alpha".to_owned(), "beta".to_owned()]); + } + + #[tokio::test] + async fn skips_empty_and_whitespace_only_lines() { + let lines = collect_lines(response_from_chunks(vec![Ok(b"\n \nalpha\n".to_vec())])).await; + + let collected: Vec = lines.into_iter().map(anyhow::Result::unwrap).collect(); + + assert_eq!(collected, vec!["alpha".to_owned()]); + } + + #[tokio::test] + async fn buffers_a_line_split_across_chunks() { + let lines = collect_lines(response_from_chunks(vec![ + Ok(b"al".to_vec()), + Ok(b"pha\n".to_vec()), + ])) + .await; + + let collected: Vec = lines.into_iter().map(anyhow::Result::unwrap).collect(); + + assert_eq!(collected, vec!["alpha".to_owned()]); + } + + #[tokio::test] + async fn errors_on_a_non_utf8_line() { + let lines = collect_lines(response_from_chunks(vec![Ok(vec![0xff, 0xfe, b'\n'])])).await; + + let error = lines.into_iter().next().unwrap().err().unwrap(); + + assert!(error.to_string().contains("non-UTF8")); + } + + #[tokio::test] + async fn errors_on_non_utf8_trailing_bytes() { + let lines = collect_lines(response_from_chunks(vec![Ok(vec![0xff, 0xfe])])).await; + + let error = lines.into_iter().next().unwrap().err().unwrap(); + + assert!(error.to_string().contains("trailing non-UTF8")); + } + + #[tokio::test] + async fn errors_when_a_chunk_fails_to_read() { + let lines = collect_lines(response_from_chunks(vec![Err(IoError::other("boom"))])).await; + + let error = lines.into_iter().next().unwrap().err().unwrap(); + + assert!(error.to_string().contains("failed to read response chunk")); + } +} diff --git a/paddler_test_cluster_harness/src/openai_chat_completions_client.rs b/paddler_test_cluster_harness/src/openai_chat_completions_client.rs new file mode 100644 index 00000000..cccf769e --- /dev/null +++ b/paddler_test_cluster_harness/src/openai_chat_completions_client.rs @@ -0,0 +1,53 @@ +use anyhow::Context as _; +use anyhow::Result; +use async_openai::Client; +use async_openai::config::OpenAIConfig; +use futures_util::StreamExt as _; +use serde_json::Value; +use url::Url; + +use crate::openai_config_from_base_url::openai_config_from_base_url; + +#[derive(Clone)] +pub struct OpenAIChatCompletionsClient { + client: Client, +} + +impl OpenAIChatCompletionsClient { + pub fn new(openai_base_url: &Url) -> Result { + Ok(Self { + client: Client::with_config(openai_config_from_base_url(openai_base_url)?), + }) + } + + pub async fn post_streaming(&self, body: &Value) -> Result> { + let mut streaming_body = body.clone(); + + if let Some(object) = streaming_body.as_object_mut() { + object.insert("stream".to_owned(), Value::Bool(true)); + } + + let mut stream = self + .client + .chat() + .create_stream_byot::(streaming_body) + .await + .context("failed to start OpenAI streaming chat completion")?; + + let mut chunks: Vec = Vec::new(); + + while let Some(chunk) = stream.next().await { + chunks.push(chunk.context("OpenAI streaming chat completion chunk failed")?); + } + + Ok(chunks) + } + + pub async fn post_non_streaming(&self, body: &Value) -> Result { + self.client + .chat() + .create_byot::<&Value, Value>(body) + .await + .context("OpenAI non-streaming chat completion failed") + } +} diff --git a/paddler_test_cluster_harness/src/openai_config_from_base_url.rs b/paddler_test_cluster_harness/src/openai_config_from_base_url.rs new file mode 100644 index 00000000..f1986271 --- /dev/null +++ b/paddler_test_cluster_harness/src/openai_config_from_base_url.rs @@ -0,0 +1,41 @@ +use anyhow::Context as _; +use anyhow::Result; +use async_openai::config::OpenAIConfig; +use url::Url; + +pub fn openai_config_from_base_url(openai_base_url: &Url) -> Result { + let api_base = openai_base_url + .join("v1") + .context("failed to build the OpenAI /v1 base URL")?; + + Ok(OpenAIConfig::default() + .with_api_base(api_base.as_str().trim_end_matches('/')) + .with_api_key("paddler")) +} + +#[cfg(test)] +mod tests { + use url::Url; + + use super::openai_config_from_base_url; + + #[test] + fn builds_a_v1_api_base_from_the_root_url() { + let config = + openai_config_from_base_url(&Url::parse("http://127.0.0.1:8062/").unwrap()).unwrap(); + + assert_eq!( + async_openai::config::Config::api_base(&config), + "http://127.0.0.1:8062/v1" + ); + } + + #[test] + fn errors_for_an_unbuildable_base_url() { + let error = openai_config_from_base_url(&Url::parse("data:text/plain,paddler").unwrap()) + .err() + .unwrap(); + + assert!(error.to_string().contains("/v1 base URL")); + } +} diff --git a/paddler_test_cluster_harness/src/openai_responses_client.rs b/paddler_test_cluster_harness/src/openai_responses_client.rs new file mode 100644 index 00000000..a96aa8a7 --- /dev/null +++ b/paddler_test_cluster_harness/src/openai_responses_client.rs @@ -0,0 +1,53 @@ +use anyhow::Context as _; +use anyhow::Result; +use async_openai::Client; +use async_openai::config::OpenAIConfig; +use futures_util::StreamExt as _; +use serde_json::Value; +use url::Url; + +use crate::openai_config_from_base_url::openai_config_from_base_url; + +#[derive(Clone)] +pub struct OpenAIResponsesClient { + client: Client, +} + +impl OpenAIResponsesClient { + pub fn new(openai_base_url: &Url) -> Result { + Ok(Self { + client: Client::with_config(openai_config_from_base_url(openai_base_url)?), + }) + } + + pub async fn post_streaming(&self, body: &Value) -> Result> { + let mut streaming_body = body.clone(); + + if let Some(object) = streaming_body.as_object_mut() { + object.insert("stream".to_owned(), Value::Bool(true)); + } + + let mut stream = self + .client + .responses() + .create_stream_byot::(streaming_body) + .await + .context("failed to start OpenAI streaming response")?; + + let mut events: Vec = Vec::new(); + + while let Some(event) = stream.next().await { + events.push(event.context("OpenAI streaming response event failed")?); + } + + Ok(events) + } + + pub async fn post_non_streaming(&self, body: &Value) -> Result { + self.client + .responses() + .create_byot::<&Value, Value>(body) + .await + .context("OpenAI non-streaming response failed") + } +} diff --git a/paddler_tests/src/resource_snapshot.rs b/paddler_test_cluster_harness/src/resource_snapshot.rs similarity index 60% rename from paddler_tests/src/resource_snapshot.rs rename to paddler_test_cluster_harness/src/resource_snapshot.rs index d9e71732..43bdea65 100644 --- a/paddler_tests/src/resource_snapshot.rs +++ b/paddler_test_cluster_harness/src/resource_snapshot.rs @@ -52,3 +52,39 @@ const fn open_descriptors_directory_path() -> &'static str { const fn open_descriptors_directory_path() -> &'static str { "/proc/self/fd" } + +#[cfg(test)] +mod tests { + use super::ResourceSnapshot; + + #[test] + fn try_from_self_counts_the_processes_open_descriptors() { + let snapshot = ResourceSnapshot::try_from_self().unwrap(); + + assert!(snapshot.open_file_descriptor_count > 0); + } + + #[test] + fn diff_reports_growth() { + let later = ResourceSnapshot { + open_file_descriptor_count: 10, + }; + let earlier = ResourceSnapshot { + open_file_descriptor_count: 3, + }; + + assert_eq!(later.diff(&earlier).open_file_descriptors_grew_by, 7); + } + + #[test] + fn diff_saturates_when_descriptors_shrink() { + let later = ResourceSnapshot { + open_file_descriptor_count: 3, + }; + let earlier = ResourceSnapshot { + open_file_descriptor_count: 10, + }; + + assert_eq!(later.diff(&earlier).open_file_descriptors_grew_by, 0); + } +} diff --git a/paddler_tests/src/resource_snapshot_diff.rs b/paddler_test_cluster_harness/src/resource_snapshot_diff.rs similarity index 100% rename from paddler_tests/src/resource_snapshot_diff.rs rename to paddler_test_cluster_harness/src/resource_snapshot_diff.rs diff --git a/paddler_test_cluster_harness/src/running_agent.rs b/paddler_test_cluster_harness/src/running_agent.rs new file mode 100644 index 00000000..55b469d3 --- /dev/null +++ b/paddler_test_cluster_harness/src/running_agent.rs @@ -0,0 +1,20 @@ +use anyhow::Result; + +use crate::agent_config::AgentConfig; +use crate::managed_process::ManagedProcess; + +pub struct RunningAgent { + pub config: AgentConfig, + process: Box, +} + +impl RunningAgent { + #[must_use] + pub const fn new(config: AgentConfig, process: Box) -> Self { + Self { config, process } + } + + pub async fn shutdown(mut self) -> Result<()> { + self.process.shutdown().await + } +} diff --git a/paddler_test_cluster_harness/src/running_balancer.rs b/paddler_test_cluster_harness/src/running_balancer.rs new file mode 100644 index 00000000..1a25f1de --- /dev/null +++ b/paddler_test_cluster_harness/src/running_balancer.rs @@ -0,0 +1,20 @@ +use anyhow::Result; + +use crate::balancer_addresses::BalancerAddresses; +use crate::managed_process::ManagedProcess; + +pub struct RunningBalancer { + pub addresses: BalancerAddresses, + process: Box, +} + +impl RunningBalancer { + #[must_use] + pub const fn new(addresses: BalancerAddresses, process: Box) -> Self { + Self { addresses, process } + } + + pub async fn shutdown(mut self) -> Result<()> { + self.process.shutdown().await + } +} diff --git a/paddler_tests/src/state_database_file.rs b/paddler_test_cluster_harness/src/state_database_file.rs similarity index 63% rename from paddler_tests/src/state_database_file.rs rename to paddler_test_cluster_harness/src/state_database_file.rs index 0f76eb18..78b418bf 100644 --- a/paddler_tests/src/state_database_file.rs +++ b/paddler_test_cluster_harness/src/state_database_file.rs @@ -19,3 +19,17 @@ impl StateDatabaseFile { Ok(Self { _file: file, url }) } } + +#[cfg(test)] +mod tests { + use super::StateDatabaseFile; + + #[test] + fn new_builds_a_file_url_for_a_real_temp_file() { + let database = StateDatabaseFile::new().unwrap(); + + let path = database.url.strip_prefix("file://").unwrap(); + + assert!(std::path::Path::new(path).exists()); + } +} diff --git a/paddler_tests/src/token_result_with_producer.rs b/paddler_test_cluster_harness/src/token_result_with_producer.rs similarity index 66% rename from paddler_tests/src/token_result_with_producer.rs rename to paddler_test_cluster_harness/src/token_result_with_producer.rs index 6687eb4a..662f086b 100644 --- a/paddler_tests/src/token_result_with_producer.rs +++ b/paddler_test_cluster_harness/src/token_result_with_producer.rs @@ -1,4 +1,4 @@ -use paddler_types::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generated_token_result::GeneratedTokenResult; #[derive(Debug)] pub struct TokenResultWithProducer { diff --git a/paddler_tests/src/wait_until_healthy.rs b/paddler_test_cluster_harness/src/wait_until_healthy.rs similarity index 75% rename from paddler_tests/src/wait_until_healthy.rs rename to paddler_test_cluster_harness/src/wait_until_healthy.rs index 9595cf2f..cb94a67a 100644 --- a/paddler_tests/src/wait_until_healthy.rs +++ b/paddler_test_cluster_harness/src/wait_until_healthy.rs @@ -38,3 +38,22 @@ pub async fn wait_until_healthy(base_url: &Url, endpoint: &str) -> Result<()> { } } } + +#[cfg(test)] +mod tests { + use url::Url; + + use super::wait_until_healthy; + + #[tokio::test] + async fn fails_to_construct_the_probe_url_for_a_malformed_endpoint() { + let base_url = Url::parse("http://127.0.0.1:8080/").unwrap(); + + let error = wait_until_healthy(&base_url, "http://") + .await + .err() + .unwrap(); + + assert!(error.to_string().contains("failed to construct")); + } +} diff --git a/paddler_tests/Cargo.toml b/paddler_tests/Cargo.toml index f1006554..94eb61e5 100644 --- a/paddler_tests/Cargo.toml +++ b/paddler_tests/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "paddler_tests" authors.workspace = true -description = "Stable test harness and integration/model tests for Paddler" +description = "In-process integration tests for Paddler" edition.workspace = true homepage.workspace = true license.workspace = true @@ -10,26 +10,28 @@ version.workspace = true [features] default = [] -cuda = ["paddler/cuda"] -metal = ["paddler/metal"] -tests_that_use_compiled_paddler = [] -tests_that_use_in_process_cluster = [] +cuda = ["paddler_agent/cuda"] +metal = ["paddler_agent/metal"] tests_that_use_llms = [] -web_admin_panel = ["paddler/web_admin_panel", "paddler_bootstrap/web_admin_panel"] +web_admin_panel = ["paddler_balancer/web_admin_panel", "paddler_bootstrap/web_admin_panel"] [dependencies] anyhow = { workspace = true } async-stream = { workspace = true } +async-trait = { workspace = true } base64 = { workspace = true } futures-util = { workspace = true } hf-hub = { workspace = true } llama-cpp-bindings = { workspace = true } log = { workspace = true } nix = { workspace = true } -paddler = { workspace = true } +paddler_agent = { workspace = true } +paddler_balancer = { workspace = true } paddler_bootstrap = { workspace = true } paddler_client = { workspace = true } -paddler_types = { workspace = true } +paddler_messaging = { workspace = true } +paddler_test_cluster_harness = { workspace = true } +parking_lot = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -37,9 +39,11 @@ tempfile = { workspace = true } tokio = { workspace = true } tokio-tungstenite = { workspace = true } tokio-util = { workspace = true } +trzcina = { workspace = true } url = { workspace = true } [dev-dependencies] +paddler_openai_response_format_validator = { workspace = true } serial_test = { workspace = true } [lints] diff --git a/paddler_tests/src/agents_status/assert_slots_total_at_least.rs b/paddler_tests/src/agents_status/assert_slots_total_at_least.rs deleted file mode 100644 index fba157e5..00000000 --- a/paddler_tests/src/agents_status/assert_slots_total_at_least.rs +++ /dev/null @@ -1,15 +0,0 @@ -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; - -pub fn assert_slots_total_at_least( - agent_id: &str, - expected_slots_total: i32, -) -> impl Fn(&AgentControllerPoolSnapshot) -> bool { - let agent_id = agent_id.to_owned(); - - move |snapshot| { - snapshot - .agents - .iter() - .any(|agent| agent.id == agent_id && agent.slots_total >= expected_slots_total) - } -} diff --git a/paddler_tests/src/buffered_requests_stream_watcher.rs b/paddler_tests/src/buffered_requests_stream_watcher.rs deleted file mode 100644 index c782f3fb..00000000 --- a/paddler_tests/src/buffered_requests_stream_watcher.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::pin::Pin; - -use anyhow::Context as _; -use anyhow::Result; -use anyhow::anyhow; -use futures_util::Stream; -use futures_util::StreamExt as _; -use paddler_client::ClientManagement; -use paddler_types::buffered_request_manager_snapshot::BufferedRequestManagerSnapshot; - -pub struct BufferedRequestsStreamWatcher { - stream: Pin> + Send>>, -} - -impl BufferedRequestsStreamWatcher { - pub async fn connect(management: &ClientManagement<'_>) -> Result { - let raw_stream = management - .get_buffered_requests_stream() - .await - .map_err(anyhow::Error::new) - .context("failed to open /api/v1/buffered_requests/stream")?; - - let stream = raw_stream.map(|item| item.map_err(anyhow::Error::new)); - - Ok(Self { - stream: Box::pin(stream), - }) - } - - #[must_use] - pub fn from_stream( - stream: Pin> + Send>>, - ) -> Self { - Self { stream } - } - - pub async fn until( - &mut self, - mut predicate: TPredicate, - ) -> Result - where - TPredicate: FnMut(&BufferedRequestManagerSnapshot) -> bool, - { - while let Some(item) = self.stream.next().await { - let snapshot = item.context("buffered requests stream yielded an error")?; - - if predicate(&snapshot) { - return Ok(snapshot); - } - } - - Err(anyhow!( - "buffered requests stream closed before predicate was satisfied" - )) - } -} diff --git a/paddler_tests/src/cluster_completion.rs b/paddler_tests/src/cluster_completion.rs deleted file mode 100644 index f46c42e7..00000000 --- a/paddler_tests/src/cluster_completion.rs +++ /dev/null @@ -1,14 +0,0 @@ -use paddler_bootstrap::agent_runner::AgentRunner; -use paddler_bootstrap::balancer_runner::BalancerRunner; -use tokio::process::Child; - -pub enum ClusterCompletion { - InProcess { - agents: Vec, - balancer: BalancerRunner, - }, - Subprocess { - agents: Vec, - balancer: Child, - }, -} diff --git a/paddler_tests/src/cluster_handle.rs b/paddler_tests/src/cluster_handle.rs deleted file mode 100644 index 7374fadb..00000000 --- a/paddler_tests/src/cluster_handle.rs +++ /dev/null @@ -1,92 +0,0 @@ -use anyhow::Result; -use log::warn; -use paddler_client::PaddlerClient; -use tokio_util::sync::CancellationToken; - -use crate::agents_stream_watcher::AgentsStreamWatcher; -use crate::balancer_addresses::BalancerAddresses; -use crate::buffered_requests_stream_watcher::BufferedRequestsStreamWatcher; -use crate::cluster_completion::ClusterCompletion; -use crate::cluster_handle_params::ClusterHandleParams; -use crate::terminate_child::terminate_child; - -pub struct ClusterHandle { - pub addresses: BalancerAddresses, - pub agent_ids: Vec, - pub agents: AgentsStreamWatcher, - pub buffered_requests: BufferedRequestsStreamWatcher, - pub paddler_client: PaddlerClient, - pub cancel_token: CancellationToken, - completion: ClusterCompletion, -} - -impl ClusterHandle { - #[must_use] - pub fn new( - ClusterHandleParams { - addresses, - agent_ids, - agents, - buffered_requests, - cancel_token, - completion, - paddler_client, - }: ClusterHandleParams, - ) -> Self { - Self { - addresses, - agent_ids, - agents, - buffered_requests, - paddler_client, - cancel_token, - completion, - } - } - - pub async fn shutdown(mut self) -> Result<()> { - self.cancel_token.cancel(); - - match &mut self.completion { - ClusterCompletion::InProcess { agents, balancer } => { - for agent_runner in agents.iter_mut() { - agent_runner.wait_for_completion().await?; - } - - balancer.wait_for_completion().await?; - } - ClusterCompletion::Subprocess { agents, balancer } => { - for child in agents.iter_mut() { - terminate_child(child)?; - } - - terminate_child(balancer)?; - - for agent in agents.iter_mut() { - agent.wait().await?; - } - - balancer.wait().await?; - } - } - - Ok(()) - } -} - -impl Drop for ClusterHandle { - fn drop(&mut self) { - self.cancel_token.cancel(); - - if let ClusterCompletion::Subprocess { agents, balancer } = &mut self.completion { - for child in agents.iter_mut() { - if let Err(error) = terminate_child(child) { - warn!("ClusterHandle drop: failed to terminate agent subprocess: {error:#}"); - } - } - if let Err(error) = terminate_child(balancer) { - warn!("ClusterHandle drop: failed to terminate balancer subprocess: {error:#}"); - } - } - } -} diff --git a/paddler_tests/src/cluster_handle_params.rs b/paddler_tests/src/cluster_handle_params.rs deleted file mode 100644 index d0092cab..00000000 --- a/paddler_tests/src/cluster_handle_params.rs +++ /dev/null @@ -1,17 +0,0 @@ -use paddler_client::PaddlerClient; -use tokio_util::sync::CancellationToken; - -use crate::agents_stream_watcher::AgentsStreamWatcher; -use crate::balancer_addresses::BalancerAddresses; -use crate::buffered_requests_stream_watcher::BufferedRequestsStreamWatcher; -use crate::cluster_completion::ClusterCompletion; - -pub struct ClusterHandleParams { - pub addresses: BalancerAddresses, - pub agent_ids: Vec, - pub agents: AgentsStreamWatcher, - pub buffered_requests: BufferedRequestsStreamWatcher, - pub cancel_token: CancellationToken, - pub completion: ClusterCompletion, - pub paddler_client: PaddlerClient, -} diff --git a/paddler_tests/src/collect_embedding_results.rs b/paddler_tests/src/collect_embedding_results.rs deleted file mode 100644 index 0b846d05..00000000 --- a/paddler_tests/src/collect_embedding_results.rs +++ /dev/null @@ -1,94 +0,0 @@ -use anyhow::Context as _; -use anyhow::Result; -use anyhow::anyhow; -use futures_util::StreamExt as _; -use paddler_types::embedding_result::EmbeddingResult; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::inference_client::Response as InferenceResponse; - -use crate::collected_embedding_results::CollectedEmbeddingResults; -use crate::embedding_with_producer::EmbeddingWithProducer; -use crate::inference_message_stream::InferenceMessageStream; - -pub async fn collect_embedding_results( - mut stream: InferenceMessageStream, -) -> Result { - let mut embeddings: Vec = Vec::new(); - let mut embeddings_disabled = false; - let mut errors: Vec = Vec::new(); - let mut embedding_rejected_due_to_active_token_generation_count: usize = 0; - let mut no_embeddings_produced_count: usize = 0; - let mut oversized_documents = Vec::new(); - let mut saw_done = false; - let mut wire_errors = Vec::new(); - - while let Some(item) = stream.next().await { - let message = item.context("embedding stream yielded an error")?; - - match message { - InferenceMessage::Response(envelope) => { - let generated_by = envelope.generated_by.clone(); - - match envelope.response { - InferenceResponse::Embedding(EmbeddingResult::Done) => { - saw_done = true; - - break; - } - InferenceResponse::Embedding(EmbeddingResult::Embedding(embedding)) => { - embeddings.push(EmbeddingWithProducer { - embedding, - generated_by, - }); - } - InferenceResponse::Embedding(EmbeddingResult::DocumentExceedsBatchSize( - details, - )) => { - oversized_documents.push(details); - } - InferenceResponse::Embedding(EmbeddingResult::EmbeddingsDisabled) => { - embeddings_disabled = true; - } - InferenceResponse::Embedding(EmbeddingResult::Error(message)) => { - errors.push(message); - } - InferenceResponse::Embedding( - EmbeddingResult::EmbeddingRejectedDueToActiveTokenGeneration, - ) => { - embedding_rejected_due_to_active_token_generation_count += 1; - } - InferenceResponse::Embedding(EmbeddingResult::NoEmbeddingsProduced) => { - no_embeddings_produced_count += 1; - } - InferenceResponse::GeneratedToken(_) => { - return Err(anyhow!( - "unexpected generated-token response on an embedding stream" - )); - } - InferenceResponse::Timeout => { - return Err(anyhow!("embedding request timed out on balancer")); - } - InferenceResponse::TooManyBufferedRequests => { - return Err(anyhow!( - "balancer rejected embedding request: too many buffered" - )); - } - } - } - InferenceMessage::Error(error_envelope) => { - wire_errors.push(error_envelope.error); - } - } - } - - Ok(CollectedEmbeddingResults { - embeddings, - embeddings_disabled, - errors, - embedding_rejected_due_to_active_token_generation_count, - no_embeddings_produced_count, - oversized_documents, - saw_done, - wire_errors, - }) -} diff --git a/paddler_tests/src/collect_generated_tokens.rs b/paddler_tests/src/collect_generated_tokens.rs deleted file mode 100644 index baddceea..00000000 --- a/paddler_tests/src/collect_generated_tokens.rs +++ /dev/null @@ -1,70 +0,0 @@ -use anyhow::Context as _; -use anyhow::Result; -use anyhow::anyhow; -use futures_util::StreamExt as _; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::inference_client::Response as InferenceResponse; -use paddler_types::streamable_result::StreamableResult as _; - -use crate::collected_generated_tokens::CollectedGeneratedTokens; -use crate::inference_message_stream::InferenceMessageStream; -use crate::token_result_with_producer::TokenResultWithProducer; - -pub async fn collect_generated_tokens( - mut stream: InferenceMessageStream, -) -> Result { - let mut text = String::new(); - let mut token_results: Vec = Vec::new(); - - while let Some(item) = stream.next().await { - let message = item.context("inference stream yielded an error")?; - - match message { - InferenceMessage::Response(envelope) => { - let generated_by = envelope.generated_by.clone(); - - match envelope.response { - InferenceResponse::GeneratedToken(token_result) => { - if let Some(token_text) = token_result.token_text() { - text.push_str(token_text); - } - - let is_done = token_result.is_done(); - - token_results.push(TokenResultWithProducer { - token_result, - generated_by, - }); - - if is_done { - break; - } - } - InferenceResponse::Embedding(_) => { - return Err(anyhow!( - "unexpected embedding response on a token-generation stream" - )); - } - InferenceResponse::Timeout => { - return Err(anyhow!("inference request timed out on balancer")); - } - InferenceResponse::TooManyBufferedRequests => { - return Err(anyhow!("balancer rejected request: too many buffered")); - } - } - } - InferenceMessage::Error(error_envelope) => { - return Err(anyhow!( - "inference stream returned JSON-RPC error code {} ({})", - error_envelope.error.code, - error_envelope.error.description - )); - } - } - } - - Ok(CollectedGeneratedTokens { - text, - token_results, - }) -} diff --git a/paddler_tests/src/current_test_device.rs b/paddler_tests/src/current_test_device.rs deleted file mode 100644 index 837a8a30..00000000 --- a/paddler_tests/src/current_test_device.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::env; - -use anyhow::Result; -use anyhow::bail; - -use crate::parse_test_device_value::parse_test_device_value; -use crate::test_device::TestDevice; - -const PADDLER_TEST_DEVICE: &str = "PADDLER_TEST_DEVICE"; - -pub fn current_test_device() -> Result { - match env::var(PADDLER_TEST_DEVICE) { - Ok(value) => parse_test_device_value(Some(&value)), - Err(env::VarError::NotPresent) => parse_test_device_value(None), - Err(env::VarError::NotUnicode(value)) => bail!( - "{PADDLER_TEST_DEVICE} is set but is not valid UTF-8: {}", - value.to_string_lossy() - ), - } -} diff --git a/paddler_tests/src/in_process_agent.rs b/paddler_tests/src/in_process_agent.rs new file mode 100644 index 00000000..fd1fcf80 --- /dev/null +++ b/paddler_tests/src/in_process_agent.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use async_trait::async_trait; +use paddler_bootstrap::agent_runner::AgentRunner; + +use paddler_test_cluster_harness::managed_process::ManagedProcess; + +pub struct InProcessAgent { + runner: AgentRunner, +} + +impl InProcessAgent { + #[must_use] + pub const fn new(runner: AgentRunner) -> Self { + Self { runner } + } +} + +#[async_trait] +impl ManagedProcess for InProcessAgent { + async fn shutdown(&mut self) -> Result<()> { + self.runner.cancel(); + self.runner.wait_for_completion().await + } +} diff --git a/paddler_tests/src/in_process_agent_spawner.rs b/paddler_tests/src/in_process_agent_spawner.rs new file mode 100644 index 00000000..26d7ddb1 --- /dev/null +++ b/paddler_tests/src/in_process_agent_spawner.rs @@ -0,0 +1,33 @@ +use anyhow::Result; +use paddler_bootstrap::agent_runner::AgentRunner; +use paddler_bootstrap::agent_runner::AgentRunnerParams; +use tokio_util::sync::CancellationToken; + +use crate::in_process_agent::InProcessAgent; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::agent_spawner::AgentSpawner; +use paddler_test_cluster_harness::managed_process::ManagedProcess; + +pub struct InProcessAgentSpawner { + management_address: String, +} + +impl InProcessAgentSpawner { + #[must_use] + pub const fn new(management_address: String) -> Self { + Self { management_address } + } +} + +impl AgentSpawner for InProcessAgentSpawner { + fn spawn(&self, config: &AgentConfig) -> Result> { + let runner = AgentRunner::start(AgentRunnerParams { + agent_name: Some(config.name.clone()), + cancellation_token: CancellationToken::new(), + management_address: self.management_address.clone(), + slots: config.slot_count, + }); + + Ok(Box::new(InProcessAgent::new(runner))) + } +} diff --git a/paddler_tests/src/in_process_balancer.rs b/paddler_tests/src/in_process_balancer.rs new file mode 100644 index 00000000..e82a4b61 --- /dev/null +++ b/paddler_tests/src/in_process_balancer.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use async_trait::async_trait; +use paddler_bootstrap::balancer_runner::BalancerRunner; + +use paddler_test_cluster_harness::managed_process::ManagedProcess; + +pub struct InProcessBalancer { + runner: BalancerRunner, +} + +impl InProcessBalancer { + #[must_use] + pub const fn new(runner: BalancerRunner) -> Self { + Self { runner } + } +} + +#[async_trait] +impl ManagedProcess for InProcessBalancer { + async fn shutdown(&mut self) -> Result<()> { + self.runner.cancel(); + self.runner.wait_for_completion().await + } +} diff --git a/paddler_tests/src/in_process_cluster_params.rs b/paddler_tests/src/in_process_cluster_params.rs deleted file mode 100644 index 2585fe62..00000000 --- a/paddler_tests/src/in_process_cluster_params.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::time::Duration; - -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; - -pub struct InProcessClusterParams { - pub agent: Option, - pub buffered_request_timeout: Duration, - pub desired_state: BalancerDesiredState, - pub inference_cors_allowed_hosts: Vec, - pub inference_item_timeout: Duration, - pub management_cors_allowed_hosts: Vec, - pub max_buffered_requests: i32, - pub wait_for_slots_ready: bool, -} - -impl Default for InProcessClusterParams { - fn default() -> Self { - Self { - agent: Some(AgentConfig { - name: "test-agent".to_owned(), - slot_count: 4, - }), - buffered_request_timeout: Duration::from_secs(10), - desired_state: BalancerDesiredState::default(), - inference_cors_allowed_hosts: Vec::new(), - inference_item_timeout: Duration::from_secs(30), - management_cors_allowed_hosts: Vec::new(), - max_buffered_requests: 10, - wait_for_slots_ready: true, - } - } -} diff --git a/paddler_tests/src/inference_http_client.rs b/paddler_tests/src/inference_http_client.rs deleted file mode 100644 index a17436c6..00000000 --- a/paddler_tests/src/inference_http_client.rs +++ /dev/null @@ -1,123 +0,0 @@ -use anyhow::Context as _; -use anyhow::Result; -use async_stream::try_stream; -use futures_util::Stream; -use futures_util::StreamExt as _; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; -use url::Url; - -use crate::inference_message_stream::InferenceMessageStream; - -#[derive(Clone)] -pub struct InferenceHttpClient { - http_client: Client, - inference_base_url: Url, -} - -impl InferenceHttpClient { - #[must_use] - pub const fn new(http_client: Client, inference_base_url: Url) -> Self { - Self { - http_client, - inference_base_url, - } - } - - pub async fn post_continue_from_raw_prompt( - &self, - params: &ContinueFromRawPromptParams, - ) -> Result { - self.post_streaming("api/v1/continue_from_raw_prompt", params) - .await - } - - pub async fn post_continue_from_conversation_history( - &self, - params: &ContinueFromConversationHistoryParams, - ) -> Result { - self.post_streaming("api/v1/continue_from_conversation_history", params) - .await - } - - pub async fn post_generate_embedding_batch( - &self, - params: &GenerateEmbeddingBatchParams, - ) -> Result { - self.post_streaming("api/v1/generate_embedding_batch", params) - .await - } - - async fn post_streaming( - &self, - relative_path: &str, - body: &TBody, - ) -> Result - where - TBody: serde::Serialize + Sync + ?Sized, - { - let request_url = self - .inference_base_url - .join(relative_path) - .with_context(|| format!("failed to build URL for {relative_path}"))?; - - let response = self - .http_client - .post(request_url) - .json(body) - .send() - .await - .with_context(|| format!("failed to POST {relative_path}"))? - .error_for_status() - .with_context(|| format!("non-success status on {relative_path}"))?; - - Ok(Box::pin(inference_messages_from_response(response))) - } -} - -fn inference_messages_from_response( - response: reqwest::Response, -) -> impl Stream> + Send { - try_stream! { - let mut bytes_stream = response.bytes_stream(); - let mut buffer: Vec = Vec::new(); - - while let Some(chunk_result) = bytes_stream.next().await { - let chunk = chunk_result.context("failed to read inference response chunk")?; - - buffer.extend_from_slice(&chunk); - - while let Some(newline_position) = buffer.iter().position(|byte| *byte == b'\n') { - let line_bytes: Vec = buffer.drain(..=newline_position).collect(); - let line_without_newline = &line_bytes[..newline_position]; - let line_text = std::str::from_utf8(line_without_newline) - .context("inference stream produced non-UTF8 bytes")? - .trim(); - - if line_text.is_empty() { - continue; - } - - let message: InferenceMessage = serde_json::from_str(line_text) - .with_context(|| format!("failed to parse NDJSON line: {line_text}"))?; - - yield message; - } - } - - let trailing_text = std::str::from_utf8(&buffer) - .context("inference stream produced trailing non-UTF8 bytes")? - .trim(); - - if !trailing_text.is_empty() { - let message: InferenceMessage = serde_json::from_str(trailing_text) - .with_context(|| format!("failed to parse trailing NDJSON line: {trailing_text}"))?; - - yield message; - } - } -} diff --git a/paddler_tests/src/lib.rs b/paddler_tests/src/lib.rs index 6ef3f36f..7169c5d8 100644 --- a/paddler_tests/src/lib.rs +++ b/paddler_tests/src/lib.rs @@ -1,61 +1,23 @@ -pub mod agent_config; -pub mod agents_status; -pub mod agents_stream_watcher; -pub mod balancer_addresses; -pub mod buffered_requests_status; -pub mod buffered_requests_stream_watcher; -pub mod cluster_completion; -pub mod cluster_handle; -pub mod cluster_handle_params; -pub mod collect_embedding_results; -pub mod collect_generated_tokens; -pub mod collected_embedding_results; -pub mod collected_generated_tokens; -pub mod current_test_device; -pub mod embedding_with_producer; -pub mod in_process_cluster_params; -pub mod inference_http_client; -pub mod inference_message_stream; -pub mod load_test_image_data_uri; +pub mod in_process_agent; +pub mod in_process_agent_spawner; +pub mod in_process_balancer; pub mod local_http_fixture; pub mod make_agent_controller_without_remote_agent; -pub mod make_inference_parameters_deterministic; -pub mod ministral_3_in_process_cluster_params; +pub mod ministral_3_cluster_params; pub mod model_card; -pub mod openai_chat_completions_client; -pub mod paddler_command; -pub mod parse_test_device_value; pub mod qwen3_embedding_cluster_params; -#[cfg(any(target_os = "macos", target_os = "linux"))] -pub mod resource_snapshot; -#[cfg(any(target_os = "macos", target_os = "linux"))] -pub mod resource_snapshot_diff; -pub mod spawn_agent_subprocess; -pub mod spawn_agent_subprocess_params; -pub mod start_in_process_cluster; -pub mod start_in_process_cluster_with_deepseek_r1_distill_llama_8b; -pub mod start_in_process_cluster_with_gemma_4; -pub mod start_in_process_cluster_with_gemma_4_and_mmproj; -pub mod start_in_process_cluster_with_glm_4_7_flash; -pub mod start_in_process_cluster_with_ministral_3; -pub mod start_in_process_cluster_with_ministral_3_and_mmproj; -pub mod start_in_process_cluster_with_qwen2_5_vl; -pub mod start_in_process_cluster_with_qwen3; -pub mod start_in_process_cluster_with_qwen3_5; -pub mod start_in_process_cluster_with_qwen3_6; -pub mod start_in_process_cluster_with_qwen3_6_and_mmproj; -pub mod start_in_process_cluster_with_smolvlm2; -pub mod start_in_process_cluster_with_smolvlm2_and_n_batch; -pub mod start_in_process_embedding_cluster; -pub mod start_subprocess_cluster; -pub mod start_subprocess_cluster_with_qwen2_5_vl; -pub mod start_subprocess_cluster_with_qwen3; -pub mod start_subprocess_cluster_with_qwen3_embedding; -pub mod start_subprocess_cluster_with_smolvlm2; -pub mod state_database_file; -pub mod subprocess_cluster_lifecycle_in_dedicated_runtime; -pub mod subprocess_cluster_params; -pub mod terminate_child; -pub mod test_device; -pub mod token_result_with_producer; -pub mod wait_until_healthy; +pub mod start_cluster; +pub mod start_cluster_with_deepseek_r1_distill_llama_8b; +pub mod start_cluster_with_gemma_4; +pub mod start_cluster_with_gemma_4_and_mmproj; +pub mod start_cluster_with_glm_4_7_flash; +pub mod start_cluster_with_ministral_3; +pub mod start_cluster_with_ministral_3_and_mmproj; +pub mod start_cluster_with_qwen2_5_vl; +pub mod start_cluster_with_qwen3; +pub mod start_cluster_with_qwen3_5; +pub mod start_cluster_with_qwen3_6; +pub mod start_cluster_with_qwen3_6_and_mmproj; +pub mod start_cluster_with_smolvlm2; +pub mod start_cluster_with_smolvlm2_and_n_batch; +pub mod start_embedding_cluster; diff --git a/paddler_tests/src/make_agent_controller_without_remote_agent.rs b/paddler_tests/src/make_agent_controller_without_remote_agent.rs index 263f2ad6..908d590a 100644 --- a/paddler_tests/src/make_agent_controller_without_remote_agent.rs +++ b/paddler_tests/src/make_agent_controller_without_remote_agent.rs @@ -1,17 +1,17 @@ +use parking_lot::RwLock; use std::collections::BTreeSet; use std::sync::Arc; -use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; use std::sync::atomic::AtomicU64; -use paddler::atomic_value::AtomicValue; -use paddler::balancer::agent_controller::AgentController; -use paddler::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; -use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use paddler::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; -use paddler::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; +use paddler_balancer::agent_controller::AgentController; +use paddler_balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection; +use paddler_balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler_balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection; +use paddler_balancer::model_metadata_sender_collection::ModelMetadataSenderCollection; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::atomic_value::AtomicValue; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; diff --git a/paddler_tests/src/make_inference_parameters_deterministic.rs b/paddler_tests/src/make_inference_parameters_deterministic.rs deleted file mode 100644 index c9eaa2bd..00000000 --- a/paddler_tests/src/make_inference_parameters_deterministic.rs +++ /dev/null @@ -1,17 +0,0 @@ -use paddler_types::inference_parameters::InferenceParameters; - -#[must_use] -pub const fn make_inference_parameters_deterministic( - base: InferenceParameters, -) -> InferenceParameters { - InferenceParameters { - temperature: 0.0, - top_k: 1, - top_p: 1.0, - min_p: 0.0, - penalty_repeat: 1.0, - penalty_presence: 0.0, - penalty_frequency: 0.0, - ..base - } -} diff --git a/paddler_tests/src/ministral_3_cluster_params.rs b/paddler_tests/src/ministral_3_cluster_params.rs new file mode 100644 index 00000000..95ae7dd7 --- /dev/null +++ b/paddler_tests/src/ministral_3_cluster_params.rs @@ -0,0 +1,15 @@ +use paddler_test_cluster_harness::agent_config::AgentConfig; + +pub struct Ministral3ClusterParams { + pub agents: Vec, + pub deterministic_sampling: bool, +} + +impl Default for Ministral3ClusterParams { + fn default() -> Self { + Self { + agents: AgentConfig::uniform(1, 1), + deterministic_sampling: false, + } + } +} diff --git a/paddler_tests/src/ministral_3_in_process_cluster_params.rs b/paddler_tests/src/ministral_3_in_process_cluster_params.rs deleted file mode 100644 index 10aa8424..00000000 --- a/paddler_tests/src/ministral_3_in_process_cluster_params.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::agent_config::AgentConfig; - -pub struct Ministral3InProcessClusterParams { - pub agent: AgentConfig, - pub deterministic_sampling: bool, -} - -impl Default for Ministral3InProcessClusterParams { - fn default() -> Self { - Self { - agent: AgentConfig::single(1), - deterministic_sampling: false, - } - } -} diff --git a/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs b/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs index 993f116b..19498fa4 100644 --- a/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs +++ b/paddler_tests/src/model_card/deepseek_r1_distill_llama_8b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/gemma_4_e4b_it.rs b/paddler_tests/src/model_card/gemma_4_e4b_it.rs index 2959b2cc..f0bac2ba 100644 --- a/paddler_tests/src/model_card/gemma_4_e4b_it.rs +++ b/paddler_tests/src/model_card/gemma_4_e4b_it.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs b/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs index 083db911..1173111a 100644 --- a/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs +++ b/paddler_tests/src/model_card/gemma_4_e4b_it_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/glm_4_7_flash.rs b/paddler_tests/src/model_card/glm_4_7_flash.rs index 5d5bba3e..e1f9b281 100644 --- a/paddler_tests/src/model_card/glm_4_7_flash.rs +++ b/paddler_tests/src/model_card/glm_4_7_flash.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs b/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs index 0718b7fa..01ba9bd6 100644 --- a/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs +++ b/paddler_tests/src/model_card/ministral_3_14b_reasoning.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs b/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs index be0c5b76..235a6294 100644 --- a/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs +++ b/paddler_tests/src/model_card/ministral_3_14b_reasoning_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/mod.rs b/paddler_tests/src/model_card/mod.rs index dedd03a3..4f576435 100644 --- a/paddler_tests/src/model_card/mod.rs +++ b/paddler_tests/src/model_card/mod.rs @@ -16,7 +16,7 @@ pub mod qwen3_embedding_0_6b; pub mod smolvlm2_256m; pub mod smolvlm2_256m_mmproj; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; pub struct ModelCard { pub gpu_layer_count: u32, diff --git a/paddler_tests/src/model_card/nomic_embed_text_v1_5.rs b/paddler_tests/src/model_card/nomic_embed_text_v1_5.rs index 3ce3a876..6c0cc336 100644 --- a/paddler_tests/src/model_card/nomic_embed_text_v1_5.rs +++ b/paddler_tests/src/model_card/nomic_embed_text_v1_5.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen2_5_vl_3b.rs b/paddler_tests/src/model_card/qwen2_5_vl_3b.rs index d2225305..9b743892 100644 --- a/paddler_tests/src/model_card/qwen2_5_vl_3b.rs +++ b/paddler_tests/src/model_card/qwen2_5_vl_3b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen2_5_vl_3b_mmproj.rs b/paddler_tests/src/model_card/qwen2_5_vl_3b_mmproj.rs index f21b7848..cbc64713 100644 --- a/paddler_tests/src/model_card/qwen2_5_vl_3b_mmproj.rs +++ b/paddler_tests/src/model_card/qwen2_5_vl_3b_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_0_6b.rs b/paddler_tests/src/model_card/qwen3_0_6b.rs index 5a2e2263..4a421502 100644 --- a/paddler_tests/src/model_card/qwen3_0_6b.rs +++ b/paddler_tests/src/model_card/qwen3_0_6b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_5_0_8b.rs b/paddler_tests/src/model_card/qwen3_5_0_8b.rs index 4c8b2b66..f15925c8 100644 --- a/paddler_tests/src/model_card/qwen3_5_0_8b.rs +++ b/paddler_tests/src/model_card/qwen3_5_0_8b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_5_0_8b_mmproj.rs b/paddler_tests/src/model_card/qwen3_5_0_8b_mmproj.rs index d0d0c7f1..8e0e65e9 100644 --- a/paddler_tests/src/model_card/qwen3_5_0_8b_mmproj.rs +++ b/paddler_tests/src/model_card/qwen3_5_0_8b_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs b/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs index c75a5f8a..5ac2f150 100644 --- a/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs +++ b/paddler_tests/src/model_card/qwen3_6_35b_a3b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs b/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs index 5d6a5b55..28783a18 100644 --- a/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs +++ b/paddler_tests/src/model_card/qwen3_6_35b_a3b_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/qwen3_embedding_0_6b.rs b/paddler_tests/src/model_card/qwen3_embedding_0_6b.rs index 17786862..fcdf63b9 100644 --- a/paddler_tests/src/model_card/qwen3_embedding_0_6b.rs +++ b/paddler_tests/src/model_card/qwen3_embedding_0_6b.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/smolvlm2_256m.rs b/paddler_tests/src/model_card/smolvlm2_256m.rs index da0536ef..ce906a22 100644 --- a/paddler_tests/src/model_card/smolvlm2_256m.rs +++ b/paddler_tests/src/model_card/smolvlm2_256m.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/model_card/smolvlm2_256m_mmproj.rs b/paddler_tests/src/model_card/smolvlm2_256m_mmproj.rs index 52acd0e9..e858873e 100644 --- a/paddler_tests/src/model_card/smolvlm2_256m_mmproj.rs +++ b/paddler_tests/src/model_card/smolvlm2_256m_mmproj.rs @@ -1,4 +1,4 @@ -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; use crate::model_card::ModelCard; diff --git a/paddler_tests/src/openai_chat_completions_client.rs b/paddler_tests/src/openai_chat_completions_client.rs deleted file mode 100644 index c0571661..00000000 --- a/paddler_tests/src/openai_chat_completions_client.rs +++ /dev/null @@ -1,86 +0,0 @@ -use anyhow::Context as _; -use anyhow::Result; -use futures_util::StreamExt as _; -use reqwest::Client; -use serde_json::Value; -use url::Url; - -pub struct OpenAIChatCompletionsClient { - http_client: Client, - completions_url: Url, -} - -impl OpenAIChatCompletionsClient { - pub fn new(http_client: Client, openai_base_url: &Url) -> Result { - Ok(Self { - http_client, - completions_url: openai_base_url - .join("v1/chat/completions") - .context("failed to build /v1/chat/completions URL")?, - }) - } - - pub async fn post_streaming(&self, body: &Value) -> Result> { - let response = self - .http_client - .post(self.completions_url.clone()) - .json(body) - .send() - .await - .context("failed to POST OpenAI streaming chat completion")? - .error_for_status() - .context("non-success status from OpenAI streaming endpoint")?; - - let mut bytes_stream = response.bytes_stream(); - let mut buffer: Vec = Vec::new(); - let mut chunks: Vec = Vec::new(); - - while let Some(chunk_result) = bytes_stream.next().await { - let chunk = chunk_result.context("failed to read OpenAI streaming chunk")?; - - buffer.extend_from_slice(&chunk); - - while let Some(newline_position) = buffer.iter().position(|byte| *byte == b'\n') { - let line_bytes: Vec = buffer.drain(..=newline_position).collect(); - let line_text = std::str::from_utf8(&line_bytes[..newline_position]) - .context("OpenAI stream produced non-UTF8 bytes")? - .trim(); - - if line_text.is_empty() { - continue; - } - - chunks.push(serde_json::from_str(line_text).with_context(|| { - format!("failed to parse OpenAI streaming chunk: {line_text}") - })?); - } - } - - let trailing_text = std::str::from_utf8(&buffer) - .context("OpenAI stream produced trailing non-UTF8 bytes")? - .trim(); - - if !trailing_text.is_empty() { - chunks.push( - serde_json::from_str(trailing_text) - .with_context(|| format!("failed to parse trailing chunk: {trailing_text}"))?, - ); - } - - Ok(chunks) - } - - pub async fn post_non_streaming(&self, body: &Value) -> Result { - self.http_client - .post(self.completions_url.clone()) - .json(body) - .send() - .await - .context("failed to POST OpenAI non-streaming chat completion")? - .error_for_status() - .context("non-success status from OpenAI non-streaming endpoint")? - .json::() - .await - .context("failed to parse OpenAI non-streaming JSON response") - } -} diff --git a/paddler_tests/src/paddler_command.rs b/paddler_tests/src/paddler_command.rs deleted file mode 100644 index 24fad157..00000000 --- a/paddler_tests/src/paddler_command.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::env; -use std::sync::LazyLock; - -use tokio::process::Command; - -static PADDLER_BINARY_PATH: LazyLock = LazyLock::new(|| { - env::var("PADDLER_BINARY_PATH").unwrap_or_else(|_| "../target/debug/paddler".to_owned()) -}); - -pub fn paddler_command() -> Command { - let mut command = Command::new(PADDLER_BINARY_PATH.as_str()); - - if let Ok(profile_file) = env::var("LLVM_PROFILE_FILE") { - command.env("LLVM_PROFILE_FILE", profile_file); - } - - command -} diff --git a/paddler_tests/src/parse_test_device_value.rs b/paddler_tests/src/parse_test_device_value.rs deleted file mode 100644 index 2cff5825..00000000 --- a/paddler_tests/src/parse_test_device_value.rs +++ /dev/null @@ -1,25 +0,0 @@ -use anyhow::Result; -use anyhow::anyhow; - -use crate::test_device::TestDevice; - -pub fn parse_test_device_value(value: Option<&str>) -> Result { - match value { - None | Some("cpu") => Ok(TestDevice::Cpu), - #[cfg(feature = "cuda")] - Some("cuda") => Ok(TestDevice::Cuda), - #[cfg(not(feature = "cuda"))] - Some("cuda") => Err(anyhow!( - "PADDLER_TEST_DEVICE=cuda requires building with --features cuda; the cuda backend is not linked into this binary" - )), - #[cfg(feature = "metal")] - Some("metal") => Ok(TestDevice::Metal), - #[cfg(not(feature = "metal"))] - Some("metal") => Err(anyhow!( - "PADDLER_TEST_DEVICE=metal requires building with --features metal; the metal backend is not linked into this binary" - )), - Some(other) => Err(anyhow!( - "unrecognised PADDLER_TEST_DEVICE value {other:?}; expected one of cpu | cuda | metal" - )), - } -} diff --git a/paddler_tests/src/qwen3_embedding_cluster_params.rs b/paddler_tests/src/qwen3_embedding_cluster_params.rs index 64307c56..49ee2bc1 100644 --- a/paddler_tests/src/qwen3_embedding_cluster_params.rs +++ b/paddler_tests/src/qwen3_embedding_cluster_params.rs @@ -1,8 +1,8 @@ use std::time::Duration; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::inference_parameters::InferenceParameters; -use crate::agent_config::AgentConfig; +use paddler_test_cluster_harness::agent_config::AgentConfig; pub struct Qwen3EmbeddingClusterParams { pub agents: Vec, diff --git a/paddler_tests/src/start_cluster.rs b/paddler_tests/src/start_cluster.rs new file mode 100644 index 00000000..d3a2fb81 --- /dev/null +++ b/paddler_tests/src/start_cluster.rs @@ -0,0 +1,110 @@ +use std::str::FromStr as _; + +use anyhow::Context as _; +use anyhow::Result; +use paddler_balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; +use paddler_balancer::state_database_type::StateDatabaseType; +use paddler_bootstrap::balancer_runner::BalancerRunner; +use paddler_bootstrap::balancer_runner::BalancerRunnerParams; +use tokio_util::sync::CancellationToken; +use trzcina::ServiceShutdownOptions; + +use crate::in_process_agent_spawner::InProcessAgentSpawner; +use crate::in_process_balancer::InProcessBalancer; +use paddler_test_cluster_harness::balancer_addresses::BalancerAddresses; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::running_balancer::RunningBalancer; + +pub async fn start_cluster( + ClusterParams { + agents, + buffered_request_timeout, + desired_state, + inference_cors_allowed_hosts, + inference_item_timeout, + management_cors_allowed_hosts, + max_buffered_requests, + state_database_url, + wait_for_slots_ready, + }: ClusterParams, +) -> Result { + // Make every `log!` macro evaluate its argument expressions during integration + // tests; no logger is installed, so this exercises the logging instrumentation + // without emitting output. + log::set_max_level(log::LevelFilter::Trace); + + let addresses = BalancerAddresses::pick()?; + let management_address = addresses.management.to_string(); + let state_database_type = StateDatabaseType::from_str(&state_database_url) + .context("failed to parse state_database_url")?; + + let balancer_runner = BalancerRunner::start(BalancerRunnerParams { + buffered_request_timeout, + inference_service_configuration: InferenceServiceConfiguration { + addr: addresses.inference, + cors_allowed_hosts: inference_cors_allowed_hosts, + inference_item_timeout, + }, + management_service_configuration: ManagementServiceConfiguration { + addr: addresses.management, + cors_allowed_hosts: management_cors_allowed_hosts, + }, + max_buffered_requests, + openai_service_configuration: Some(OpenAIServiceConfiguration { + addr: addresses.compat_openai, + }), + cancellation_token: CancellationToken::new(), + shutdown_options: ServiceShutdownOptions::default(), + state_database_type, + statsd_prefix: "paddler_tests_".to_owned(), + statsd_service_configuration: None, + #[cfg(feature = "web_admin_panel")] + web_admin_panel_service_configuration: None, + }) + .await + .context("failed to start in-process BalancerRunner")?; + + let running_balancer = + RunningBalancer::new(addresses, Box::new(InProcessBalancer::new(balancer_runner))); + + let mut cluster = Cluster::connect( + running_balancer, + Box::new(InProcessAgentSpawner::new(management_address)), + desired_state.as_ref(), + ) + .await?; + + let expected_agent_count = agents.len(); + let mut last_ready_snapshot = None; + + for agent in &agents { + cluster.spawn_additional_agent(agent)?; + + if wait_for_slots_ready { + last_ready_snapshot = Some( + cluster + .wait_for_agent_ready(&agent.name, agent.slot_count) + .await?, + ); + } + } + + let registered_snapshot = match last_ready_snapshot { + Some(snapshot) => snapshot, + None => cluster + .wait_for_agent_count(expected_agent_count) + .await + .context("not all in-process agents registered")?, + }; + + cluster.agent_ids = registered_snapshot + .agents + .iter() + .map(|registered_agent| registered_agent.id.clone()) + .collect(); + + Ok(cluster) +} diff --git a/paddler_tests/src/start_cluster_with_deepseek_r1_distill_llama_8b.rs b/paddler_tests/src/start_cluster_with_deepseek_r1_distill_llama_8b.rs new file mode 100644 index 00000000..85dc86a5 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_deepseek_r1_distill_llama_8b.rs @@ -0,0 +1,37 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::deepseek_r1_distill_llama_8b::deepseek_r1_distill_llama_8b; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_deepseek_r1_distill_llama_8b( + agents: Vec, +) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = deepseek_r1_distill_llama_8b(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_gemma_4.rs b/paddler_tests/src/start_cluster_with_gemma_4.rs new file mode 100644 index 00000000..4bd362d7 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_gemma_4.rs @@ -0,0 +1,35 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_gemma_4(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = gemma_4_e4b_it(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_gemma_4_and_mmproj.rs b/paddler_tests/src/start_cluster_with_gemma_4_and_mmproj.rs new file mode 100644 index 00000000..eae7a115 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_gemma_4_and_mmproj.rs @@ -0,0 +1,40 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; +use crate::model_card::gemma_4_e4b_it_mmproj::gemma_4_e4b_it_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_gemma_4_and_mmproj(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = gemma_4_e4b_it(); + let ModelCard { + reference: mmproj_reference, + .. + } = gemma_4_e4b_it_mmproj(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_glm_4_7_flash.rs b/paddler_tests/src/start_cluster_with_glm_4_7_flash.rs new file mode 100644 index 00000000..d27dde1b --- /dev/null +++ b/paddler_tests/src/start_cluster_with_glm_4_7_flash.rs @@ -0,0 +1,35 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::glm_4_7_flash::glm_4_7_flash; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_glm_4_7_flash(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = glm_4_7_flash(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_ministral_3.rs b/paddler_tests/src/start_cluster_with_ministral_3.rs new file mode 100644 index 00000000..dd80108e --- /dev/null +++ b/paddler_tests/src/start_cluster_with_ministral_3.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::ministral_3_cluster_params::Ministral3ClusterParams; +use crate::model_card::ModelCard; +use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_ministral_3( + Ministral3ClusterParams { + agents, + deterministic_sampling, + }: Ministral3ClusterParams, +) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = ministral_3_14b_reasoning(); + + let inference_parameters = if deterministic_sampling { + InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + } + } else { + InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + } + }; + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_ministral_3_and_mmproj.rs b/paddler_tests/src/start_cluster_with_ministral_3_and_mmproj.rs new file mode 100644 index 00000000..bf78cac7 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_ministral_3_and_mmproj.rs @@ -0,0 +1,42 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; +use crate::model_card::ministral_3_14b_reasoning_mmproj::ministral_3_14b_reasoning_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_ministral_3_and_mmproj( + agents: Vec, +) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = ministral_3_14b_reasoning(); + let ModelCard { + reference: mmproj_reference, + .. + } = ministral_3_14b_reasoning_mmproj(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_qwen2_5_vl.rs b/paddler_tests/src/start_cluster_with_qwen2_5_vl.rs new file mode 100644 index 00000000..77dc7ee0 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_qwen2_5_vl.rs @@ -0,0 +1,40 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::qwen2_5_vl_3b::qwen2_5_vl_3b; +use crate::model_card::qwen2_5_vl_3b_mmproj::qwen2_5_vl_3b_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_qwen2_5_vl(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = qwen2_5_vl_3b(); + let ModelCard { + reference: mmproj_reference, + .. + } = qwen2_5_vl_3b_mmproj(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_qwen3.rs b/paddler_tests/src/start_cluster_with_qwen3.rs new file mode 100644 index 00000000..d8abb9c9 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_qwen3.rs @@ -0,0 +1,35 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_0_6b::qwen3_0_6b; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_qwen3(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_qwen3_5.rs b/paddler_tests/src/start_cluster_with_qwen3_5.rs new file mode 100644 index 00000000..ebdafe22 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_qwen3_5.rs @@ -0,0 +1,50 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_5_0_8b::qwen3_5_0_8b; +use crate::model_card::qwen3_5_0_8b_mmproj::qwen3_5_0_8b_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_qwen3_5( + agents: Vec, + with_mmproj: bool, +) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = qwen3_5_0_8b(); + + let multimodal_projection = if with_mmproj { + let ModelCard { + reference: mmproj_reference, + .. + } = qwen3_5_0_8b_mmproj(); + + AgentDesiredModel::HuggingFace(mmproj_reference) + } else { + AgentDesiredModel::None + }; + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_qwen3_6.rs b/paddler_tests/src/start_cluster_with_qwen3_6.rs new file mode 100644 index 00000000..b5142618 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_qwen3_6.rs @@ -0,0 +1,35 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_qwen3_6(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_6_35b_a3b(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_qwen3_6_and_mmproj.rs b/paddler_tests/src/start_cluster_with_qwen3_6_and_mmproj.rs new file mode 100644 index 00000000..43782779 --- /dev/null +++ b/paddler_tests/src/start_cluster_with_qwen3_6_and_mmproj.rs @@ -0,0 +1,40 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; +use crate::model_card::qwen3_6_35b_a3b_mmproj::qwen3_6_35b_a3b_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_qwen3_6_and_mmproj(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = qwen3_6_35b_a3b(); + let ModelCard { + reference: mmproj_reference, + .. + } = qwen3_6_35b_a3b_mmproj(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_cluster_with_smolvlm2.rs b/paddler_tests/src/start_cluster_with_smolvlm2.rs new file mode 100644 index 00000000..095061da --- /dev/null +++ b/paddler_tests/src/start_cluster_with_smolvlm2.rs @@ -0,0 +1,40 @@ +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; + +use crate::model_card::ModelCard; +use crate::model_card::smolvlm2_256m::smolvlm2_256m; +use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; + +pub async fn start_cluster_with_smolvlm2(agents: Vec) -> Result { + let ModelCard { + gpu_layer_count, + reference: primary_reference, + } = smolvlm2_256m(); + let ModelCard { + reference: mmproj_reference, + .. + } = smolvlm2_256m_mmproj(); + + start_cluster(ClusterParams { + agents, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::deterministic() + }, + model: AgentDesiredModel::HuggingFace(primary_reference), + multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await +} diff --git a/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs b/paddler_tests/src/start_cluster_with_smolvlm2_and_n_batch.rs similarity index 50% rename from paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs rename to paddler_tests/src/start_cluster_with_smolvlm2_and_n_batch.rs index 88d324ef..7a0aeb90 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_smolvlm2.rs +++ b/paddler_tests/src/start_cluster_with_smolvlm2_and_n_batch.rs @@ -1,23 +1,20 @@ use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; use crate::model_card::smolvlm2_256m::smolvlm2_256m; use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; -use crate::start_subprocess_cluster::start_subprocess_cluster; -use crate::subprocess_cluster_params::SubprocessClusterParams; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; -pub async fn start_subprocess_cluster_with_smolvlm2( +pub async fn start_cluster_with_smolvlm2_and_n_batch( agents: Vec, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - + n_batch: usize, +) -> Result { let ModelCard { gpu_layer_count, reference: primary_reference, @@ -27,17 +24,23 @@ pub async fn start_subprocess_cluster_with_smolvlm2( .. } = smolvlm2_256m_mmproj(); - start_subprocess_cluster(SubprocessClusterParams { + let inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + n_batch, + ..InferenceParameters::deterministic() + }; + + start_cluster(ClusterParams { agents, desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters, model: AgentDesiredModel::HuggingFace(primary_reference), multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), use_chat_template_override: false, }), wait_for_slots_ready: true, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await } diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs b/paddler_tests/src/start_embedding_cluster.rs similarity index 55% rename from paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs rename to paddler_tests/src/start_embedding_cluster.rs index ab5e7069..a628df9d 100644 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen3_embedding.rs +++ b/paddler_tests/src/start_embedding_cluster.rs @@ -1,40 +1,34 @@ use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; use crate::model_card::ModelCard; use crate::model_card::qwen3_embedding_0_6b::qwen3_embedding_0_6b; use crate::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use crate::start_subprocess_cluster::start_subprocess_cluster; -use crate::subprocess_cluster_params::SubprocessClusterParams; +use crate::start_cluster::start_cluster; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; -pub async fn start_subprocess_cluster_with_qwen3_embedding( +pub async fn start_embedding_cluster( Qwen3EmbeddingClusterParams { agents, buffered_request_timeout, inference_parameters, max_buffered_requests, }: Qwen3EmbeddingClusterParams, -) -> Result { +) -> Result { let ModelCard { gpu_layer_count, reference, } = qwen3_embedding_0_6b(); - let test_device = current_test_device()?; - test_device.require_available()?; - let device_offload_parameters = - test_device.inference_parameters_for_full_offload(gpu_layer_count); - let inference_parameters_with_offload = InferenceParameters { - n_gpu_layers: device_offload_parameters.n_gpu_layers, + n_gpu_layers: gpu_layer_count, ..inference_parameters }; - start_subprocess_cluster(SubprocessClusterParams { + start_cluster(ClusterParams { agents, buffered_request_timeout, desired_state: Some(BalancerDesiredState { @@ -46,7 +40,7 @@ pub async fn start_subprocess_cluster_with_qwen3_embedding( }), max_buffered_requests, wait_for_slots_ready: true, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await } diff --git a/paddler_tests/src/start_in_process_cluster.rs b/paddler_tests/src/start_in_process_cluster.rs deleted file mode 100644 index 6e679c10..00000000 --- a/paddler_tests/src/start_in_process_cluster.rs +++ /dev/null @@ -1,126 +0,0 @@ -use anyhow::Context as _; -use anyhow::Result; -use paddler::balancer::compatibility::openai_service::configuration::Configuration as OpenAIServiceConfiguration; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::management_service::configuration::Configuration as ManagementServiceConfiguration; -use paddler::balancer::state_database_type::StateDatabaseType; -use paddler_bootstrap::agent_runner::AgentRunner; -use paddler_bootstrap::agent_runner::AgentRunnerParams; -use paddler_bootstrap::balancer_runner::BalancerRunner; -use paddler_bootstrap::balancer_runner::BalancerRunnerParams; -use paddler_client::PaddlerClient; -use tokio_util::sync::CancellationToken; - -use crate::agents_stream_watcher::AgentsStreamWatcher; -use crate::balancer_addresses::BalancerAddresses; -use crate::buffered_requests_stream_watcher::BufferedRequestsStreamWatcher; -use crate::cluster_completion::ClusterCompletion; -use crate::cluster_handle::ClusterHandle; -use crate::cluster_handle_params::ClusterHandleParams; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::wait_until_healthy::wait_until_healthy; - -pub async fn start_in_process_cluster( - InProcessClusterParams { - agent, - buffered_request_timeout, - desired_state, - inference_cors_allowed_hosts, - inference_item_timeout, - management_cors_allowed_hosts, - max_buffered_requests, - wait_for_slots_ready, - }: InProcessClusterParams, -) -> Result { - let addresses = BalancerAddresses::pick()?; - let cancel_token = CancellationToken::new(); - - let balancer = BalancerRunner::start(BalancerRunnerParams { - buffered_request_timeout, - inference_service_configuration: InferenceServiceConfiguration { - addr: addresses.inference, - cors_allowed_hosts: inference_cors_allowed_hosts, - inference_item_timeout, - }, - management_service_configuration: ManagementServiceConfiguration { - addr: addresses.management, - cors_allowed_hosts: management_cors_allowed_hosts, - }, - max_buffered_requests, - openai_service_configuration: Some(OpenAIServiceConfiguration { - addr: addresses.compat_openai, - }), - cancellation_token: cancel_token.clone(), - state_database_type: StateDatabaseType::Memory(Box::new(desired_state.clone())), - statsd_prefix: "paddler_tests_".to_owned(), - statsd_service_configuration: None, - #[cfg(feature = "web_admin_panel")] - web_admin_panel_service_configuration: None, - }) - .await - .context("failed to start in-process BalancerRunner")?; - - let management_base_url = addresses.management_base_url()?; - let inference_base_url = addresses.inference_base_url()?; - - wait_until_healthy(&management_base_url, "health") - .await - .context("balancer did not become healthy")?; - - let paddler_client = PaddlerClient::new(inference_base_url, management_base_url, 1); - - paddler_client - .management() - .put_balancer_desired_state(&desired_state) - .await - .map_err(anyhow::Error::new) - .context("failed to PUT balancer desired state")?; - - let mut agents_watcher = AgentsStreamWatcher::connect(&paddler_client.management()).await?; - let buffered_requests_watcher = - BufferedRequestsStreamWatcher::connect(&paddler_client.management()).await?; - - let expected_agent_count: usize = usize::from(agent.is_some()); - let mut agent_runners: Vec = Vec::with_capacity(expected_agent_count); - - if let Some(agent_config) = agent.as_ref() { - let agent_runner = AgentRunner::start(AgentRunnerParams { - agent_name: Some(agent_config.name.clone()), - management_address: addresses.management.to_string(), - cancellation_token: cancel_token.clone(), - slots: agent_config.slot_count, - }); - - agent_runners.push(agent_runner); - } - - let registered_snapshot = agents_watcher - .until(move |snapshot| snapshot.agents.len() >= expected_agent_count) - .await - .context("agent did not register")?; - - let agent_ids: Vec = registered_snapshot - .agents - .iter() - .map(|registered_agent| registered_agent.id.clone()) - .collect(); - - if wait_for_slots_ready && let Some(agent_config) = agent.as_ref() { - agents_watcher - .wait_for_slots_ready(&[agent_config.slot_count]) - .await?; - } - - Ok(ClusterHandle::new(ClusterHandleParams { - addresses, - agent_ids, - agents: agents_watcher, - buffered_requests: buffered_requests_watcher, - cancel_token, - completion: ClusterCompletion::InProcess { - agents: agent_runners, - balancer, - }, - paddler_client, - })) -} diff --git a/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs b/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs deleted file mode 100644 index f783b62b..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_deepseek_r1_distill_llama_8b.rs +++ /dev/null @@ -1,38 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::deepseek_r1_distill_llama_8b::deepseek_r1_distill_llama_8b; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_deepseek_r1_distill_llama_8b( - agent: AgentConfig, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = deepseek_r1_distill_llama_8b(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs b/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs deleted file mode 100644 index bc2762d5..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_gemma_4.rs +++ /dev/null @@ -1,36 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_gemma_4(agent: AgentConfig) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = gemma_4_e4b_it(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs deleted file mode 100644 index f4297d49..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_gemma_4_and_mmproj.rs +++ /dev/null @@ -1,43 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::gemma_4_e4b_it::gemma_4_e4b_it; -use crate::model_card::gemma_4_e4b_it_mmproj::gemma_4_e4b_it_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_gemma_4_and_mmproj( - agent: AgentConfig, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = gemma_4_e4b_it(); - let ModelCard { - reference: mmproj_reference, - .. - } = gemma_4_e4b_it_mmproj(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs b/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs deleted file mode 100644 index 7d055561..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_glm_4_7_flash.rs +++ /dev/null @@ -1,38 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::glm_4_7_flash::glm_4_7_flash; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_glm_4_7_flash( - agent: AgentConfig, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = glm_4_7_flash(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs b/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs deleted file mode 100644 index 179d72fe..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_ministral_3.rs +++ /dev/null @@ -1,49 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::make_inference_parameters_deterministic::make_inference_parameters_deterministic; -use crate::ministral_3_in_process_cluster_params::Ministral3InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_ministral_3( - Ministral3InProcessClusterParams { - agent, - deterministic_sampling, - }: Ministral3InProcessClusterParams, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = ministral_3_14b_reasoning(); - - let base_inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); - let inference_parameters = if deterministic_sampling { - make_inference_parameters_deterministic(base_inference_parameters) - } else { - base_inference_parameters - }; - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters, - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs deleted file mode 100644 index 515ac248..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_ministral_3_and_mmproj.rs +++ /dev/null @@ -1,43 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::ministral_3_14b_reasoning::ministral_3_14b_reasoning; -use crate::model_card::ministral_3_14b_reasoning_mmproj::ministral_3_14b_reasoning_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_ministral_3_and_mmproj( - agent: AgentConfig, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = ministral_3_14b_reasoning(); - let ModelCard { - reference: mmproj_reference, - .. - } = ministral_3_14b_reasoning_mmproj(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs b/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs deleted file mode 100644 index cb5dc0a9..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_qwen2_5_vl.rs +++ /dev/null @@ -1,41 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen2_5_vl_3b::qwen2_5_vl_3b; -use crate::model_card::qwen2_5_vl_3b_mmproj::qwen2_5_vl_3b_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_qwen2_5_vl(agent: AgentConfig) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = qwen2_5_vl_3b(); - let ModelCard { - reference: mmproj_reference, - .. - } = qwen2_5_vl_3b_mmproj(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3.rs deleted file mode 100644 index befceb8e..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3.rs +++ /dev/null @@ -1,36 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_0_6b::qwen3_0_6b; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_qwen3(agent: AgentConfig) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs deleted file mode 100644 index 3d9189bd..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3_5.rs +++ /dev/null @@ -1,51 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_5_0_8b::qwen3_5_0_8b; -use crate::model_card::qwen3_5_0_8b_mmproj::qwen3_5_0_8b_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_qwen3_5( - agent: AgentConfig, - with_mmproj: bool, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = qwen3_5_0_8b(); - - let multimodal_projection = if with_mmproj { - let ModelCard { - reference: mmproj_reference, - .. - } = qwen3_5_0_8b_mmproj(); - - AgentDesiredModel::HuggingFace(mmproj_reference) - } else { - AgentDesiredModel::None - }; - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs deleted file mode 100644 index 7f6e765b..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3_6.rs +++ /dev/null @@ -1,36 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_qwen3_6(agent: AgentConfig) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_6_35b_a3b(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs b/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs deleted file mode 100644 index 2d5c5dad..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_qwen3_6_and_mmproj.rs +++ /dev/null @@ -1,43 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_6_35b_a3b::qwen3_6_35b_a3b; -use crate::model_card::qwen3_6_35b_a3b_mmproj::qwen3_6_35b_a3b_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_qwen3_6_and_mmproj( - agent: AgentConfig, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = qwen3_6_35b_a3b(); - let ModelCard { - reference: mmproj_reference, - .. - } = qwen3_6_35b_a3b_mmproj(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs b/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs deleted file mode 100644 index b92fa31c..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_smolvlm2.rs +++ /dev/null @@ -1,41 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::smolvlm2_256m::smolvlm2_256m; -use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_smolvlm2(agent: AgentConfig) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = smolvlm2_256m(); - let ModelCard { - reference: mmproj_reference, - .. - } = smolvlm2_256m_mmproj(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs b/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs deleted file mode 100644 index b6bbd67a..00000000 --- a/paddler_tests/src/start_in_process_cluster_with_smolvlm2_and_n_batch.rs +++ /dev/null @@ -1,47 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::smolvlm2_256m::smolvlm2_256m; -use crate::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_cluster_with_smolvlm2_and_n_batch( - agent: AgentConfig, - n_batch: usize, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = smolvlm2_256m(); - let ModelCard { - reference: mmproj_reference, - .. - } = smolvlm2_256m_mmproj(); - - let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); - inference_parameters.n_batch = n_batch; - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters, - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }, - wait_for_slots_ready: true, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_in_process_embedding_cluster.rs b/paddler_tests/src/start_in_process_embedding_cluster.rs deleted file mode 100644 index 827dbf8c..00000000 --- a/paddler_tests/src/start_in_process_embedding_cluster.rs +++ /dev/null @@ -1,31 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::in_process_cluster_params::InProcessClusterParams; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_embedding_0_6b::qwen3_embedding_0_6b; -use crate::start_in_process_cluster::start_in_process_cluster; - -pub async fn start_in_process_embedding_cluster( - inference_parameters: InferenceParameters, - agent: AgentConfig, -) -> Result { - let ModelCard { reference, .. } = qwen3_embedding_0_6b(); - - start_in_process_cluster(InProcessClusterParams { - agent: Some(agent), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters, - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - ..InProcessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_subprocess_cluster.rs b/paddler_tests/src/start_subprocess_cluster.rs deleted file mode 100644 index b676e8c7..00000000 --- a/paddler_tests/src/start_subprocess_cluster.rs +++ /dev/null @@ -1,149 +0,0 @@ -use std::process::Stdio; - -use anyhow::Context as _; -use anyhow::Result; -use paddler_client::PaddlerClient; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; -use tokio::process::Child; -use tokio_util::sync::CancellationToken; - -use crate::agents_stream_watcher::AgentsStreamWatcher; -use crate::balancer_addresses::BalancerAddresses; -use crate::buffered_requests_stream_watcher::BufferedRequestsStreamWatcher; -use crate::cluster_completion::ClusterCompletion; -use crate::cluster_handle::ClusterHandle; -use crate::cluster_handle_params::ClusterHandleParams; -use crate::paddler_command::paddler_command; -use crate::subprocess_cluster_params::SubprocessClusterParams; -use crate::wait_until_healthy::wait_until_healthy; - -pub async fn start_subprocess_cluster( - SubprocessClusterParams { - agents, - buffered_request_timeout, - desired_state, - inference_cors_allowed_hosts, - inference_item_timeout, - management_cors_allowed_hosts, - max_buffered_requests, - state_database_url, - wait_for_slots_ready, - }: SubprocessClusterParams, -) -> Result { - let addresses = BalancerAddresses::pick()?; - - let mut balancer_command = paddler_command(); - - balancer_command - .arg("balancer") - .arg("--inference-addr") - .arg(addresses.inference.to_string()) - .arg("--management-addr") - .arg(addresses.management.to_string()) - .arg("--compat-openai-addr") - .arg(addresses.compat_openai.to_string()) - .arg("--state-database") - .arg(&state_database_url) - .arg("--max-buffered-requests") - .arg(max_buffered_requests.to_string()) - .arg("--buffered-request-timeout") - .arg(buffered_request_timeout.as_millis().to_string()) - .arg("--inference-item-timeout") - .arg(inference_item_timeout.as_millis().to_string()) - .stdout(Stdio::null()) - .stderr(Stdio::null()); - - for allowed_host in &inference_cors_allowed_hosts { - balancer_command - .arg("--inference-cors-allowed-host") - .arg(allowed_host); - } - - for allowed_host in &management_cors_allowed_hosts { - balancer_command - .arg("--management-cors-allowed-host") - .arg(allowed_host); - } - - let balancer = balancer_command - .spawn() - .context("failed to spawn paddler balancer subprocess")?; - - let management_base_url = addresses.management_base_url()?; - let inference_base_url = addresses.inference_base_url()?; - - wait_until_healthy(&management_base_url, "health") - .await - .context("subprocess balancer did not become healthy")?; - - let paddler_client = PaddlerClient::new(inference_base_url, management_base_url, 1); - - if let Some(desired_state) = desired_state.as_ref() { - paddler_client - .management() - .put_balancer_desired_state(desired_state) - .await - .map_err(anyhow::Error::new) - .context("failed to PUT desired state on subprocess balancer")?; - } - - let mut agents_watcher = AgentsStreamWatcher::connect(&paddler_client.management()).await?; - let buffered_requests_watcher = - BufferedRequestsStreamWatcher::connect(&paddler_client.management()).await?; - - let expected_agent_count = agents.len(); - let mut agent_children: Vec = Vec::with_capacity(expected_agent_count); - let mut last_ready_snapshot: Option = None; - - for agent in &agents { - let agent_child = paddler_command() - .arg("agent") - .arg("--management-addr") - .arg(addresses.management.to_string()) - .arg("--name") - .arg(&agent.name) - .arg("--slots") - .arg(agent.slot_count.to_string()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .context("failed to spawn paddler agent subprocess")?; - - agent_children.push(agent_child); - - if wait_for_slots_ready { - last_ready_snapshot = Some( - agents_watcher - .wait_for_agent_ready(&agent.name, agent.slot_count) - .await?, - ); - } - } - - let registered_snapshot = match last_ready_snapshot { - Some(snapshot) => snapshot, - None => agents_watcher - .until(move |snapshot| snapshot.agents.len() >= expected_agent_count) - .await - .context("not all subprocess agents registered")?, - }; - - let agent_ids: Vec = registered_snapshot - .agents - .iter() - .map(|registered_agent| registered_agent.id.clone()) - .collect(); - - Ok(ClusterHandle::new(ClusterHandleParams { - addresses, - agent_ids, - agents: agents_watcher, - buffered_requests: buffered_requests_watcher, - cancel_token: CancellationToken::new(), - completion: ClusterCompletion::Subprocess { - agents: agent_children, - balancer, - }, - paddler_client, - })) -} diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs b/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs deleted file mode 100644 index c898cc8e..00000000 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen2_5_vl.rs +++ /dev/null @@ -1,43 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::model_card::ModelCard; -use crate::model_card::qwen2_5_vl_3b::qwen2_5_vl_3b; -use crate::model_card::qwen2_5_vl_3b_mmproj::qwen2_5_vl_3b_mmproj; -use crate::start_subprocess_cluster::start_subprocess_cluster; -use crate::subprocess_cluster_params::SubprocessClusterParams; - -pub async fn start_subprocess_cluster_with_qwen2_5_vl( - agents: Vec, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = qwen2_5_vl_3b(); - let ModelCard { - reference: mmproj_reference, - .. - } = qwen2_5_vl_3b_mmproj(); - - start_subprocess_cluster(SubprocessClusterParams { - agents, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }), - wait_for_slots_ready: true, - ..SubprocessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs b/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs deleted file mode 100644 index 36db2eff..00000000 --- a/paddler_tests/src/start_subprocess_cluster_with_qwen3.rs +++ /dev/null @@ -1,38 +0,0 @@ -use anyhow::Result; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; - -use crate::agent_config::AgentConfig; -use crate::cluster_handle::ClusterHandle; -use crate::current_test_device::current_test_device; -use crate::model_card::ModelCard; -use crate::model_card::qwen3_0_6b::qwen3_0_6b; -use crate::start_subprocess_cluster::start_subprocess_cluster; -use crate::subprocess_cluster_params::SubprocessClusterParams; - -pub async fn start_subprocess_cluster_with_qwen3( - agents: Vec, -) -> Result { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - start_subprocess_cluster(SubprocessClusterParams { - agents, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }), - wait_for_slots_ready: true, - ..SubprocessClusterParams::default() - }) - .await -} diff --git a/paddler_tests/src/subprocess_cluster_lifecycle_in_dedicated_runtime.rs b/paddler_tests/src/subprocess_cluster_lifecycle_in_dedicated_runtime.rs deleted file mode 100644 index 1f65461c..00000000 --- a/paddler_tests/src/subprocess_cluster_lifecycle_in_dedicated_runtime.rs +++ /dev/null @@ -1,21 +0,0 @@ -use anyhow::Result; -use tokio::runtime::Builder; - -use crate::agent_config::AgentConfig; -use crate::start_subprocess_cluster::start_subprocess_cluster; -use crate::subprocess_cluster_params::SubprocessClusterParams; - -pub fn subprocess_cluster_lifecycle_in_dedicated_runtime() -> Result<()> { - let runtime = Builder::new_multi_thread().enable_all().build()?; - - runtime.block_on(async { - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: AgentConfig::uniform(1, 4), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - cluster.shutdown().await - }) -} diff --git a/paddler_tests/src/test_device.rs b/paddler_tests/src/test_device.rs deleted file mode 100644 index 4468ffae..00000000 --- a/paddler_tests/src/test_device.rs +++ /dev/null @@ -1,95 +0,0 @@ -use anyhow::Result; -#[cfg(any(feature = "cuda", feature = "metal"))] -use anyhow::bail; -#[cfg(any(feature = "cuda", feature = "metal"))] -use llama_cpp_bindings::llama_backend::LlamaBackend; -#[cfg(any(feature = "cuda", feature = "metal"))] -use llama_cpp_bindings::llama_backend_device::LlamaBackendDeviceType; -#[cfg(any(feature = "cuda", feature = "metal"))] -use llama_cpp_bindings::llama_backend_device::list_llama_ggml_backend_devices; -use paddler_types::inference_parameters::InferenceParameters; - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum TestDevice { - Cpu, - #[cfg(feature = "cuda")] - Cuda, - #[cfg(feature = "metal")] - Metal, -} - -impl TestDevice { - #[must_use] - pub const fn name(self) -> &'static str { - match self { - Self::Cpu => "cpu", - #[cfg(feature = "cuda")] - Self::Cuda => "cuda", - #[cfg(feature = "metal")] - Self::Metal => "metal", - } - } - - #[cfg_attr( - not(any(feature = "cuda", feature = "metal")), - expect( - clippy::missing_const_for_fn, - reason = "non-const branches appear under GPU feature flags" - ) - )] - pub fn require_available(self) -> Result<()> { - match self { - Self::Cpu => Ok(()), - #[cfg(feature = "cuda")] - Self::Cuda => require_backend_device("CUDA"), - #[cfg(feature = "metal")] - Self::Metal => require_backend_device("MTL"), - } - } - - #[must_use] - pub fn inference_parameters_for_full_offload( - self, - gpu_layer_count: u32, - ) -> InferenceParameters { - #[cfg(not(any(feature = "cuda", feature = "metal")))] - let _ = gpu_layer_count; - - match self { - Self::Cpu => InferenceParameters::default(), - #[cfg(feature = "cuda")] - Self::Cuda => InferenceParameters { - n_gpu_layers: gpu_layer_count, - ..InferenceParameters::default() - }, - #[cfg(feature = "metal")] - Self::Metal => InferenceParameters { - n_gpu_layers: gpu_layer_count, - ..InferenceParameters::default() - }, - } - } -} - -#[cfg(any(feature = "cuda", feature = "metal"))] -fn require_backend_device(backend_name: &str) -> Result<()> { - let backend = LlamaBackend::init()?; - - if !backend.supports_gpu_offload() { - bail!( - "binary built without GPU offload support; rebuild with --features cuda or --features metal" - ); - } - - drop(backend); - - let devices_found = list_llama_ggml_backend_devices().into_iter().any(|device| { - device.backend == backend_name && device.device_type == LlamaBackendDeviceType::Gpu - }); - - if devices_found { - Ok(()) - } else { - bail!("no {backend_name} GPU devices detected at runtime") - } -} diff --git a/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs b/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs index aed646c8..68351e60 100644 --- a/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs +++ b/paddler_tests/tests/agent_chunks_embedding_batch_larger_than_slot_count.rs @@ -3,31 +3,27 @@ use std::collections::BTreeSet; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_chunks_embedding_batch_larger_than_slot_count() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(4)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(4), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let input_batch: Vec = (0..12) .map(|index| EmbeddingInputDocument { content: format!("Document number {index}."), @@ -35,15 +31,13 @@ async fn agent_chunks_embedding_batch_larger_than_slot_count() -> Result<()> { }) .collect(); - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch, normalization_method: EmbeddingNormalizationMethod::None, }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 12); assert!(collected.saw_done); assert!(collected.errors.is_empty()); diff --git a/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs b/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs index abfddfe9..5f2ecdaa 100644 --- a/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs +++ b/paddler_tests/tests/agent_completes_generation_with_adequate_n_batch.rs @@ -1,22 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use std::fs; +use std::future::Future; use anyhow::Context as _; use anyhow::Result; use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_smolvlm2::start_in_process_cluster_with_smolvlm2; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_tests::start_cluster_with_smolvlm2::start_cluster_with_smolvlm2; fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result { let fixture_path = format!("{}/../fixtures/{fixture_name}", env!("CARGO_MANIFEST_DIR")); @@ -27,15 +26,16 @@ fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result Result<()> { +) -> Result> + Send + use<>> { let image_data_uri = load_fixture_as_data_uri(fixture_name, mime_type)?; + let fixture_name = fixture_name.to_owned(); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let generation = + cluster.continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -55,38 +55,37 @@ async fn drive_normal_image_fixture( max_tokens: 20, parse_tool_calls: false, tools: vec![], - }) - .await?; + }); - let collected = collect_generated_tokens(stream).await?; + Ok(async move { + let collected = generation.await?; - let saw_token = collected - .token_results - .iter() - .any(|result| result.token_result.is_token()); - - assert!( - saw_token, - "fixture {fixture_name} should produce at least one content/reasoning/tool-call token with adequate n_batch; got {:?}", - collected + let saw_token = collected .token_results .iter() - .map(|result| &result.token_result) - .collect::>(), - ); + .any(|result| result.token_result.is_token()); - Ok(()) + assert!( + saw_token, + "fixture {fixture_name} should produce at least one content/reasoning/tool-call token with adequate n_batch; got {:?}", + collected + .token_results + .iter() + .map(|result| &result.token_result) + .collect::>(), + ); + + Ok(()) + }) } #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_completes_generation_with_adequate_n_batch() -> Result<()> { - let cluster = start_in_process_cluster_with_smolvlm2(AgentConfig::single(1)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_smolvlm2(vec![AgentConfig::single(1)]).await?; - drive_normal_image_fixture(&inference_client, "sarnow.jpeg", "image/jpeg").await?; - drive_normal_image_fixture(&inference_client, "llamas.webp", "image/webp").await?; + drive_normal_image_fixture(&cluster, "sarnow.jpeg", "image/jpeg")?.await?; + drive_normal_image_fixture(&cluster, "llamas.webp", "image/webp")?.await?; cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_controller_applies_newer_status_snapshot_with_model_path_change.rs b/paddler_tests/tests/agent_controller_applies_newer_status_snapshot_with_model_path_change.rs index f1d83ef3..ef388423 100644 --- a/paddler_tests/tests/agent_controller_applies_newer_status_snapshot_with_model_path_change.rs +++ b/paddler_tests/tests/agent_controller_applies_newer_status_snapshot_with_model_path_change.rs @@ -1,6 +1,6 @@ -use paddler::balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; #[test] fn agent_controller_applies_newer_status_snapshot_with_model_path_change() { diff --git a/paddler_tests/tests/agent_controller_discards_status_snapshot_with_older_version.rs b/paddler_tests/tests/agent_controller_discards_status_snapshot_with_older_version.rs index 0a78fa16..1c7d86c5 100644 --- a/paddler_tests/tests/agent_controller_discards_status_snapshot_with_older_version.rs +++ b/paddler_tests/tests/agent_controller_discards_status_snapshot_with_older_version.rs @@ -1,6 +1,6 @@ -use paddler::balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; #[test] fn agent_controller_discards_status_snapshot_with_older_version() { diff --git a/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs b/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs index e9609871..cc387d25 100644 --- a/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs +++ b/paddler_tests/tests/agent_controller_pool_distributes_concurrent_dispatch_evenly_across_idle_agents.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::agent_controller_pool::AgentControllerPool; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; use tokio::sync::Barrier; diff --git a/paddler_tests/tests/agent_controller_pool_does_not_oversubscribe_under_concurrent_dispatch.rs b/paddler_tests/tests/agent_controller_pool_does_not_oversubscribe_under_concurrent_dispatch.rs index 9689c583..bcada9b0 100644 --- a/paddler_tests/tests/agent_controller_pool_does_not_oversubscribe_under_concurrent_dispatch.rs +++ b/paddler_tests/tests/agent_controller_pool_does_not_oversubscribe_under_concurrent_dispatch.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::agent_controller_pool::AgentControllerPool; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; use tokio::sync::Barrier; diff --git a/paddler_tests/tests/agent_controller_pool_notifies_subscribers_when_slot_guard_drops.rs b/paddler_tests/tests/agent_controller_pool_notifies_subscribers_when_slot_guard_drops.rs index b4204622..a15bd2d8 100644 --- a/paddler_tests/tests/agent_controller_pool_notifies_subscribers_when_slot_guard_drops.rs +++ b/paddler_tests/tests/agent_controller_pool_notifies_subscribers_when_slot_guard_drops.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; -use paddler::balancer::agent_controller_pool::AgentControllerPool; -use paddler::subscribes_to_updates::SubscribesToUpdates as _; +use paddler_balancer::agent_controller_pool::AgentControllerPool; +use paddler_messaging::subscribes_to_updates::SubscribesToUpdates as _; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; #[test] diff --git a/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs b/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs index 380a2805..a7932519 100644 --- a/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs +++ b/paddler_tests/tests/agent_controller_pool_re_selects_after_contended_claim.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use anyhow::Result; use anyhow::anyhow; -use paddler::balancer::agent_controller_pool::AgentControllerPool; +use paddler_balancer::agent_controller_pool::AgentControllerPool; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] diff --git a/paddler_tests/tests/agent_controller_slot_guard_decrements_slots_processing_on_drop.rs b/paddler_tests/tests/agent_controller_slot_guard_decrements_slots_processing_on_drop.rs index a2385d44..5cf15bc7 100644 --- a/paddler_tests/tests/agent_controller_slot_guard_decrements_slots_processing_on_drop.rs +++ b/paddler_tests/tests/agent_controller_slot_guard_decrements_slots_processing_on_drop.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use paddler::balancer::agent_controller_slot_guard::AgentControllerSlotGuard; +use paddler_balancer::agent_controller_slot_guard::AgentControllerSlotGuard; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; use tokio::sync::watch; diff --git a/paddler_tests/tests/agent_controller_status_snapshot_does_not_clobber_local_slots_processing.rs b/paddler_tests/tests/agent_controller_status_snapshot_does_not_clobber_local_slots_processing.rs index 31358f05..520fccdc 100644 --- a/paddler_tests/tests/agent_controller_status_snapshot_does_not_clobber_local_slots_processing.rs +++ b/paddler_tests/tests/agent_controller_status_snapshot_does_not_clobber_local_slots_processing.rs @@ -1,5 +1,5 @@ +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; #[test] fn agent_controller_status_snapshot_does_not_clobber_local_slots_processing() { diff --git a/paddler_tests/tests/agent_controller_status_snapshot_with_unchanged_values_reports_no_meaningful_changes.rs b/paddler_tests/tests/agent_controller_status_snapshot_with_unchanged_values_reports_no_meaningful_changes.rs index c0da6bbf..79eb4785 100644 --- a/paddler_tests/tests/agent_controller_status_snapshot_with_unchanged_values_reports_no_meaningful_changes.rs +++ b/paddler_tests/tests/agent_controller_status_snapshot_with_unchanged_values_reports_no_meaningful_changes.rs @@ -1,6 +1,6 @@ -use paddler::balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_balancer::agent_controller_update_result::AgentControllerUpdateResult; +use paddler_messaging::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot; #[test] fn agent_controller_status_snapshot_with_unchanged_values_reports_no_meaningful_changes() { diff --git a/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs b/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs index fe4ebffc..b6c814ce 100644 --- a/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs +++ b/paddler_tests/tests/agent_conversation_accepts_empty_tools_list.rs @@ -1,29 +1,20 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_accepts_empty_tools_list() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello".to_owned()), @@ -37,8 +28,6 @@ async fn agent_conversation_accepts_empty_tools_list() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs b/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs index a77054b9..4c6b180f 100644 --- a/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs +++ b/paddler_tests/tests/agent_conversation_history_respects_max_tokens.rs @@ -1,29 +1,20 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_history_respects_max_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello".to_owned()), @@ -37,8 +28,6 @@ async fn agent_conversation_history_respects_max_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs b/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs index f2cbf393..2a257a54 100644 --- a/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs +++ b/paddler_tests/tests/agent_conversation_with_function_tool_succeeds.rs @@ -1,33 +1,24 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_function_tool_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let mut location_properties = Map::new(); @@ -36,8 +27,8 @@ async fn agent_conversation_with_function_tool_succeeds() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello".to_owned()), @@ -62,8 +53,6 @@ async fn agent_conversation_with_function_tool_succeeds() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - assert!( !collected.token_results.is_empty(), "should receive a response when a function tool is provided" diff --git a/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs b/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs index d38d1f83..d9cc4287 100644 --- a/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs +++ b/paddler_tests/tests/agent_conversation_with_gbnf_grammar_constrains_output.rs @@ -1,30 +1,21 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_gbnf_grammar_constrains_output() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -43,8 +34,6 @@ async fn agent_conversation_with_gbnf_grammar_constrains_output() -> Result<()> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let lower = collected.text.to_lowercase(); assert!( diff --git a/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs b/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs index 727eabcc..9a268987 100644 --- a/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs +++ b/paddler_tests/tests/agent_conversation_with_json_schema_grammar_returns_valid_json.rs @@ -1,30 +1,21 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_with_json_schema_grammar_returns_valid_json() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is 2+2?".to_owned()), @@ -40,8 +31,6 @@ async fn agent_conversation_with_json_schema_grammar_returns_valid_json() -> Res }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed: serde_json::Value = serde_json::from_str(&collected.text)?; assert!( diff --git a/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs b/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs index 4bc10a38..05ffe7e8 100644 --- a/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs +++ b/paddler_tests/tests/agent_conversation_without_grammar_field_succeeds.rs @@ -1,20 +1,18 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_conversation_without_grammar_field_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_url = cluster + .balancer .addresses .inference_base_url()? .join("api/v1/continue_from_conversation_history")?; diff --git a/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs b/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs index eb1daf1a..c89ae6f0 100644 --- a/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs +++ b/paddler_tests/tests/agent_does_not_crash_on_oversized_image.rs @@ -1,23 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use std::fs; +use std::future::Future; use anyhow::Context as _; use anyhow::Result; use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_smolvlm2_and_n_batch::start_in_process_cluster_with_smolvlm2_and_n_batch; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_tests::start_cluster_with_smolvlm2_and_n_batch::start_cluster_with_smolvlm2_and_n_batch; fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result { let fixture_path = format!("{}/../fixtures/{fixture_name}", env!("CARGO_MANIFEST_DIR")); @@ -28,15 +27,16 @@ fn load_fixture_as_data_uri(fixture_name: &str, mime_type: &str) -> Result Result<()> { +) -> Result> + Send + use<>> { let image_data_uri = load_fixture_as_data_uri(fixture_name, mime_type)?; + let fixture_name = fixture_name.to_owned(); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let generation = + cluster.continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -56,41 +56,39 @@ async fn drive_oversized_image_fixture( max_tokens: 20, parse_tool_calls: false, tools: vec![], - }) - .await?; + }); - let collected = collect_generated_tokens(stream).await?; + Ok(async move { + let collected = generation.await?; - let saw_oversized = collected.token_results.iter().any(|result| { - matches!( - result.token_result, - GeneratedTokenResult::ImageExceedsBatchSize(_), - ) - }); + let saw_oversized = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageExceedsBatchSize(_), + ) + }); - assert!( - saw_oversized, - "fixture {fixture_name} must produce GeneratedTokenResult::ImageExceedsBatchSize when n_batch < image tokens; got {:?}", - collected - .token_results - .iter() - .map(|result| &result.token_result) - .collect::>(), - ); + assert!( + saw_oversized, + "fixture {fixture_name} must produce GeneratedTokenResult::ImageExceedsBatchSize when n_batch < image tokens; got {:?}", + collected + .token_results + .iter() + .map(|result| &result.token_result) + .collect::>(), + ); - Ok(()) + Ok(()) + }) } #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_does_not_crash_on_oversized_image() -> Result<()> { - let cluster = - start_in_process_cluster_with_smolvlm2_and_n_batch(AgentConfig::single(1), 32).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_smolvlm2_and_n_batch(vec![AgentConfig::single(1)], 32).await?; - drive_oversized_image_fixture(&inference_client, "sarnow.jpeg", "image/jpeg").await?; - drive_oversized_image_fixture(&inference_client, "llamas.webp", "image/webp").await?; + drive_oversized_image_fixture(&cluster, "sarnow.jpeg", "image/jpeg")?.await?; + drive_oversized_image_fixture(&cluster, "llamas.webp", "image/webp")?.await?; cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs b/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs index dbefb692..bf41dc96 100644 --- a/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs +++ b/paddler_tests/tests/agent_embedding_batch_distribution_independent_of_context_size.rs @@ -3,35 +3,31 @@ use std::collections::BTreeSet; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_embedding_batch_distribution_independent_of_context_size() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(4)], + inference_parameters: InferenceParameters { n_batch: 64, context_size: 512, enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(4), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: "This is the first document with enough content to contribute meaningfully to the batch size calculation".to_owned(), @@ -54,8 +50,6 @@ async fn agent_embedding_batch_distribution_independent_of_context_size() -> Res }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 4); assert!(collected.saw_done); assert!(collected.errors.is_empty()); diff --git a/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs b/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs index d09b1842..526e7484 100644 --- a/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs +++ b/paddler_tests/tests/agent_embedding_batch_returns_one_embedding_per_input_document.rs @@ -3,33 +3,29 @@ use std::collections::BTreeSet; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_embedding_batch_returns_one_embedding_per_input_document() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: "The quick brown fox jumps over the lazy dog".to_owned(), @@ -44,8 +40,6 @@ async fn agent_embedding_batch_returns_one_embedding_per_input_document() -> Res }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 2); assert!(collected.saw_done); assert!(collected.errors.is_empty()); diff --git a/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs b/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs index 96c04184..2ef3f0b9 100644 --- a/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs +++ b/paddler_tests/tests/agent_embedding_batch_with_all_oversized_documents_reports_error.rs @@ -1,39 +1,35 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; const N_BATCH: u32 = 64; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_embedding_batch_with_all_oversized_documents_reports_error() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { n_batch: N_BATCH as usize, context_size: 4096, enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let huge_content = "The quick brown fox jumps over the lazy dog. ".repeat(40); - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: huge_content.clone(), @@ -48,8 +44,6 @@ async fn agent_embedding_batch_with_all_oversized_documents_reports_error() -> R }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!( collected.embeddings.len(), 0, diff --git a/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs b/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs index 4c238f4a..ac34c7af 100644 --- a/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs +++ b/paddler_tests/tests/agent_embedding_document_exceeds_n_batch.rs @@ -1,39 +1,35 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; const N_BATCH: u32 = 64; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_embedding_document_exceeds_n_batch() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { n_batch: N_BATCH as usize, context_size: 4096, enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let huge_content = "The quick brown fox jumps over the lazy dog. ".repeat(40); - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: "ok".to_owned(), @@ -48,8 +44,6 @@ async fn agent_embedding_document_exceeds_n_batch() -> Result<()> { }) .await?; - let collected = collect_embedding_results(stream).await?; - assert!( collected.saw_done, "stream must terminate with Done even when one document is oversized", diff --git a/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs b/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs index d9eae9d6..c189d096 100644 --- a/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs +++ b/paddler_tests/tests/agent_embeddings_share_dimension_across_inputs_of_varying_length.rs @@ -1,33 +1,29 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_embeddings_share_dimension_across_inputs_of_varying_length() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: "Hello".to_owned(), @@ -46,8 +42,6 @@ async fn agent_embeddings_share_dimension_across_inputs_of_varying_length() -> R }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 3); assert!(collected.saw_done); diff --git a/paddler_tests/tests/agent_evicts_largest_sequence_under_kv_cache_pressure.rs b/paddler_tests/tests/agent_evicts_largest_sequence_under_kv_cache_pressure.rs new file mode 100644 index 00000000..072fb6ef --- /dev/null +++ b/paddler_tests/tests/agent_evicts_largest_sequence_under_kv_cache_pressure.rs @@ -0,0 +1,71 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::model_card::ModelCard; +use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; +use paddler_tests::start_cluster::start_cluster; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_evicts_largest_sequence_under_kv_cache_pressure() -> Result<()> { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + let inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + n_batch: 256, + context_size: 256, + temperature: 0.0, + ..InferenceParameters::default() + }; + + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }], + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await?; + + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 4096, + raw_prompt: "Write an exhaustive, never-ending encyclopedia entry that lists every fact about the natural world in extreme detail:".to_owned(), + }) + .await?; + + let evicted = collected.token_results.iter().any(|result| { + matches!( + &result.token_result, + GeneratedTokenResult::SamplerError(message) if message.contains("evicted") + ) + }); + + assert!( + evicted, + "the sole sequence must be evicted once its KV cache footprint exceeds the context size" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs b/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs deleted file mode 100644 index 5ad7e18c..00000000 --- a/paddler_tests/tests/agent_exits_cleanly_on_sigterm_during_multimodal_inference.rs +++ /dev/null @@ -1,137 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use anyhow::Context as _; -use anyhow::Result; -use futures_util::StreamExt as _; -use paddler_tests::agents_status::assert_agent_count::assert_agent_count; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen2_5_vl_3b::qwen2_5_vl_3b; -use paddler_tests::model_card::qwen2_5_vl_3b_mmproj::qwen2_5_vl_3b_mmproj; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_tests::terminate_child::terminate_child; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn agent_exits_cleanly_on_sigterm_during_multimodal_inference() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference: primary_reference, - } = qwen2_5_vl_3b(); - let ModelCard { - reference: mmproj_reference, - .. - } = qwen2_5_vl_3b_mmproj(); - - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: Vec::new(), - wait_for_slots_ready: false, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(primary_reference), - multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference), - use_chat_template_override: false, - }), - ..SubprocessClusterParams::default() - }) - .await?; - - let mut agent_child = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("multimodal-shutdown-agent".to_owned()), - slots: 2, - })?; - - let snapshot = cluster - .agents - .until(|snapshot| { - snapshot.agents.len() == 1 && snapshot.agents.iter().any(|agent| agent.slots_total >= 2) - }) - .await - .context("agent should register with slots ready")?; - - let agent_id = snapshot - .agents - .first() - .context("registered agent must be present in snapshot")? - .id - .clone(); - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let image_data_uri = load_test_image_data_uri()?; - - let mut stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new(vec![ConversationMessage { - content: ConversationMessageContent::Parts(vec![ - ConversationMessageContentPart::ImageUrl { - image_url: ImageUrl { - url: image_data_uri, - }, - }, - ConversationMessageContentPart::Text { - text: "Describe this image in detail".to_owned(), - }, - ]), - role: "user".to_owned(), - }]), - enable_thinking: false, - grammar: None, - max_tokens: 200, - parse_tool_calls: false, - tools: vec![], - }) - .await?; - - let _first_message = stream - .next() - .await - .ok_or_else(|| anyhow::anyhow!("multimodal stream must yield at least one message"))?; - - cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) - .await?; - - terminate_child(&mut agent_child)?; - let exit_status = agent_child.wait().await?; - - cluster.agents.until(assert_agent_count(0)).await?; - - drop(stream); - - cluster.shutdown().await?; - - assert!( - exit_status.code().is_some() || exit_status.success(), - "agent must exit cleanly during multimodal inference; got {exit_status:?}" - ); - - Ok(()) -} diff --git a/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs b/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs deleted file mode 100644 index 15a341da..00000000 --- a/paddler_tests/tests/agent_exits_cleanly_when_killed_during_generation.rs +++ /dev/null @@ -1,108 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use anyhow::Context as _; -use anyhow::Result; -use futures_util::StreamExt as _; -use paddler_tests::agents_status::assert_agent_count::assert_agent_count; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_tests::terminate_child::terminate_child; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn agent_exits_cleanly_when_killed_during_generation() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: Vec::new(), - wait_for_slots_ready: false, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }), - ..SubprocessClusterParams::default() - }) - .await?; - - let mut agent_child = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("graceful-shutdown-agent".to_owned()), - slots: 2, - })?; - - let snapshot = cluster - .agents - .until(|snapshot| { - snapshot.agents.len() == 1 && snapshot.agents.iter().any(|agent| agent.slots_total >= 2) - }) - .await - .context("agent must register with slots before generation starts")?; - - let agent_id = snapshot - .agents - .first() - .context("snapshot must contain registered agent")? - .id - .clone(); - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 1000, - raw_prompt: "Write a long story".to_owned(), - }) - .await?; - - let _first_message = stream - .next() - .await - .ok_or_else(|| anyhow::anyhow!("stream must yield at least one message"))?; - - cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) - .await?; - - terminate_child(&mut agent_child)?; - let exit_status = agent_child.wait().await?; - - cluster.agents.until(assert_agent_count(0)).await?; - - drop(stream); - - cluster.shutdown().await?; - - assert!( - exit_status.code().is_some() || exit_status.success(), - "agent must exit cleanly (no abnormal termination); got {exit_status:?}" - ); - - Ok(()) -} diff --git a/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs b/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs index 649c38f6..f8df30db 100644 --- a/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs +++ b/paddler_tests/tests/agent_grammar_with_thinking_returns_incompatible_error.rs @@ -1,31 +1,22 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_grammar_with_thinking_returns_incompatible_error() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is 2+2?".to_owned()), @@ -41,18 +32,14 @@ async fn agent_grammar_with_thinking_returns_incompatible_error() -> Result<()> }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - assert!( - collected.token_results.iter().any(|result| matches!( - result.token_result, - GeneratedTokenResult::GrammarIncompatibleWithThinking(_) - )), - "expected GrammarIncompatibleWithThinking error" - ); - } + if let Ok(collected) = outcome { + assert!( + collected.token_results.iter().any(|result| matches!( + result.token_result, + GeneratedTokenResult::GrammarIncompatibleWithThinking(_) + )), + "expected GrammarIncompatibleWithThinking error" + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs b/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs index fe6094d2..cebb40d8 100644 --- a/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs +++ b/paddler_tests/tests/agent_isolates_concurrent_embedding_requests_per_client.rs @@ -3,15 +3,13 @@ use std::collections::BTreeSet; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -19,41 +17,28 @@ async fn agent_isolates_concurrent_embedding_requests_per_client() -> Result<()> let client_count: usize = 4; let docs_per_client: usize = 3; - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(4)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(4), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_base_url = cluster.addresses.inference_base_url()?; - let client_tasks = (0..client_count).map(|client_index| { - let inference_base_url = inference_base_url.clone(); - - async move { - let inference_client = InferenceHttpClient::new(Client::new(), inference_base_url); - - let input_batch: Vec = (0..docs_per_client) - .map(|document_index| EmbeddingInputDocument { - content: format!( - "Content from client {client_index} document {document_index}." - ), - id: format!("client-{client_index}-doc-{document_index}"), - }) - .collect(); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }) - .await?; + let input_batch: Vec = (0..docs_per_client) + .map(|document_index| EmbeddingInputDocument { + content: format!("Content from client {client_index} document {document_index}."), + id: format!("client-{client_index}-doc-{document_index}"), + }) + .collect(); - collect_embedding_results(stream).await - } + cluster.generate_embedding_batch(&GenerateEmbeddingBatchParams { + input_batch, + normalization_method: EmbeddingNormalizationMethod::None, + }) }); let per_client_results = futures_util::future::join_all(client_tasks).await; diff --git a/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs b/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs index 153ee20d..ccf2ce93 100644 --- a/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs +++ b/paddler_tests/tests/agent_l2_normalized_embeddings_have_unit_norm.rs @@ -1,33 +1,29 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_l2_normalized_embeddings_have_unit_norm() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![EmbeddingInputDocument { content: "Testing L2 normalization on embeddings".to_owned(), id: "doc-l2".to_owned(), @@ -36,8 +32,6 @@ async fn agent_l2_normalized_embeddings_have_unit_norm() -> Result<()> { }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); diff --git a/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs b/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs index ece35f61..7abae77a 100644 --- a/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs +++ b/paddler_tests/tests/agent_openai_chat_completions_non_streaming_returns_text.rs @@ -1,20 +1,18 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_openai_chat_completions_non_streaming_returns_text() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let openai_url = cluster + .balancer .addresses .compat_openai_base_url()? .join("v1/chat/completions")?; @@ -26,7 +24,6 @@ async fn agent_openai_chat_completions_non_streaming_returns_text() -> Result<() "messages": [{"role": "user", "content": "Say hello"}], "max_completion_tokens": 200, "stream": false, - "chat_template_kwargs": {"enable_thinking": false}, })) .send() .await diff --git a/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs b/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs index a912ea54..a39bcf43 100644 --- a/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs +++ b/paddler_tests/tests/agent_openai_chat_completions_streaming_returns_chunks.rs @@ -1,20 +1,18 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_openai_chat_completions_streaming_returns_chunks() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let openai_url = cluster + .balancer .addresses .compat_openai_base_url()? .join("v1/chat/completions")?; @@ -37,16 +35,17 @@ async fn agent_openai_chat_completions_streaming_returns_chunks() -> Result<()> let chunks: Vec = body .lines() - .filter(|line| !line.trim().is_empty()) - .map(|line| { - let stripped = line.strip_prefix("data: ").unwrap_or(line); - - serde_json::from_str(stripped).context("each chunk should be valid JSON") - }) + .filter_map(|line| line.strip_prefix("data: ")) + .filter(|payload| *payload != "[DONE]") + .map(|payload| serde_json::from_str(payload).context("each chunk should be valid JSON")) .collect::>()?; assert!(!chunks.is_empty(), "should have received streaming chunks"); assert_eq!(chunks[0]["object"], "chat.completion.chunk"); + assert!( + body.trim_end().ends_with("data: [DONE]"), + "the OpenAI streaming response must terminate with the [DONE] sentinel: {body:?}" + ); cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs b/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs index 714d81c3..7f6f72aa 100644 --- a/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs +++ b/paddler_tests/tests/agent_pipeline_recognizes_duck_typed_tool_call_format_when_template_is_not_registered.rs @@ -8,17 +8,17 @@ use llama_cpp_bindings::ToolCallArguments; use llama_cpp_bindings::llama_backend::LlamaBackend; use llama_cpp_bindings::model::LlamaModel; use llama_cpp_bindings::model::params::LlamaModelParams; -use paddler::tool_call_event::ToolCallEvent; -use paddler::tool_call_pipeline::ToolCallPipeline; -use paddler::tool_call_validator::ToolCallValidator; +use paddler_agent::tool_call_event::ToolCallEvent; +use paddler_agent::tool_call_pipeline::ToolCallPipeline; +use paddler_agent::tool_call_validator::ToolCallValidator; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::deepseek_r1_distill_llama_8b::deepseek_r1_distill_llama_8b; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; const QWEN_XML_PAYLOAD: &str = "\n\ diff --git a/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs b/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs index 288ccaa8..2c76b97f 100644 --- a/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs +++ b/paddler_tests/tests/agent_raw_prompt_respects_max_tokens.rs @@ -1,34 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_respects_max_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 20, raw_prompt: "The capital of France is".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs b/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs index 696a0db0..b774ef09 100644 --- a/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs +++ b/paddler_tests/tests/agent_raw_prompt_with_gbnf_grammar_constrains_output.rs @@ -1,27 +1,18 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_with_gbnf_grammar_constrains_output() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: Some(GrammarConstraint::Gbnf { grammar: r#"root ::= "yes" | "no""#.to_owned(), root: "root".to_owned(), @@ -33,8 +24,6 @@ async fn agent_raw_prompt_with_gbnf_grammar_constrains_output() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - assert!( collected.text == "yes" || collected.text == "no", "GBNF grammar should constrain output to yes/no; got {:?}", diff --git a/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs b/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs index 44611f21..4901f4d2 100644 --- a/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs +++ b/paddler_tests/tests/agent_raw_prompt_without_grammar_field_succeeds.rs @@ -1,20 +1,18 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_raw_prompt_without_grammar_field_succeeds() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let inference_url = cluster + .balancer .addresses .inference_base_url()? .join("api/v1/continue_from_raw_prompt")?; diff --git a/paddler_tests/tests/agent_rejects_structurally_invalid_tool_schema.rs b/paddler_tests/tests/agent_rejects_structurally_invalid_tool_schema.rs new file mode 100644 index 00000000..9d09176b --- /dev/null +++ b/paddler_tests/tests/agent_rejects_structurally_invalid_tool_schema.rs @@ -0,0 +1,111 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::model_card::ModelCard; +use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; +use paddler_tests::start_cluster::start_cluster; +use serde_json::Map; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_rejects_structurally_invalid_tool_schema() -> Result<()> { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }], + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + temperature: 0.0, + ..InferenceParameters::default() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await?; + + // `{"type": 123}` is a structurally well-formed JSON object (so it survives + // request-parameter validation) but is not a valid JSON Schema: the `type` + // keyword must be a string or an array of strings. `jsonschema::validator_for` + // rejects it, so the agent's tool-call pipeline build reports the tool's schema + // as invalid and the scheduler emits `ToolSchemaInvalid` before any generation. + let mut invalid_properties = Map::new(); + invalid_properties.insert("location".to_owned(), serde_json::json!({ "type": 123 })); + + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris?".to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 64, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(invalid_properties), + required: None, + additional_properties: None, + }), + }, + })], + }) + .await?; + + let schema_invalid_message = collected + .token_results + .iter() + .find_map(|event| match &event.token_result { + GeneratedTokenResult::ToolSchemaInvalid(message) => Some(message.clone()), + _ => None, + }) + .ok_or_else(|| { + anyhow::anyhow!( + "expected a ToolSchemaInvalid event when a tool's JSON Schema is invalid; got:\n{}", + collected.text + ) + })?; + + assert!( + schema_invalid_message.contains("get_weather"), + "the schema-invalid message should name the offending tool; got: {schema_invalid_message}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs b/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs index 7b4b6835..4e0175f3 100644 --- a/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs +++ b/paddler_tests/tests/agent_rejects_tool_with_invalid_required_field_in_schema.rs @@ -1,38 +1,30 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_rejects_tool_with_invalid_required_field_in_schema() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let mut name_properties = Map::new(); name_properties.insert("name".to_owned(), serde_json::json!({"type": "string"})); - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello".to_owned()), diff --git a/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs b/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs index 5cb088f0..77046ec3 100644 --- a/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs +++ b/paddler_tests/tests/agent_releases_slot_when_websocket_client_disconnects.rs @@ -1,22 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_releases_slot_when_websocket_client_disconnects() -> Result<()> { - let mut cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; + let mut cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; let agent_id = cluster .agent_ids @@ -24,11 +18,8 @@ async fn agent_releases_slot_when_websocket_client_disconnects() -> Result<()> { .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 200, raw_prompt: "Write a long story about an explorer".to_owned(), @@ -41,16 +32,14 @@ async fn agent_releases_slot_when_websocket_client_disconnects() -> Result<()> { .context("stream must yield at least one message")?; cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) + .wait_for_slots_processing(&agent_id, 1) .await .context("agent should report slot in use")?; drop(stream); cluster - .agents - .until(assert_slots_processing(&agent_id, 0)) + .wait_for_slots_processing(&agent_id, 0) .await .context("agent should release slot after the client disconnects")?; diff --git a/paddler_tests/tests/agent_reports_grammar_initialization_failure_for_invalid_gbnf.rs b/paddler_tests/tests/agent_reports_grammar_initialization_failure_for_invalid_gbnf.rs new file mode 100644 index 00000000..2807eb38 --- /dev/null +++ b/paddler_tests/tests/agent_reports_grammar_initialization_failure_for_invalid_gbnf.rs @@ -0,0 +1,55 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_reports_grammar_initialization_failure_for_invalid_gbnf() -> Result<()> { + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + + // `root ::= "unterminated` is syntactically broken GBNF (the string literal is + // never closed). The `Gbnf` constraint is passed through verbatim, so the + // malformed grammar only fails when `llama.cpp` compiles it inside + // `GrammarSampler::into_llama_sampler`, exercising the agent's + // grammar-initialization-failure path. + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: Some(GrammarConstraint::Gbnf { + grammar: r#"root ::= "unterminated"#.to_owned(), + root: "root".to_owned(), + }), + max_tokens: 10, + raw_prompt: + "<|im_start|>user\nSay hi.<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + .to_owned(), + }) + .await?; + + let failure_message = collected + .token_results + .iter() + .find_map(|event| match &event.token_result { + GeneratedTokenResult::GrammarInitializationFailed(message) => Some(message.clone()), + _ => None, + }) + .ok_or_else(|| { + anyhow::anyhow!( + "expected a GrammarInitializationFailed event for malformed GBNF; got:\n{}", + collected.text + ) + })?; + + assert!( + failure_message.contains("grammar"), + "the failure message should mention the grammar; got: {failure_message}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots.rs similarity index 61% rename from paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs rename to paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots.rs index 5e5cd051..0a36b826 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_subprocess.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots.rs @@ -1,39 +1,37 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] -async fn agent_reports_slot_cannot_start_for_excessive_slots_subprocess() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - +async fn agent_reports_slot_cannot_start_for_excessive_slots() -> Result<()> { let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); + let inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }; - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: AgentConfig::uniform(1, 257), - wait_for_slots_ready: false, + let mut cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { + name: "test-agent".to_owned(), + slot_count: 257, + }], desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters, @@ -41,13 +39,14 @@ async fn agent_reports_slot_cannot_start_for_excessive_slots_subprocess() -> Res multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + wait_for_slots_ready: false, + ..ClusterParams::default() }) .await?; let snapshot = tokio::time::timeout( Duration::from_secs(10), - cluster.agents.until(|snapshot| { + cluster.agents_watcher.until(|snapshot| { snapshot.agents.iter().any(|agent| { agent .issues diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv.rs similarity index 57% rename from paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs rename to paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv.rs index 6ea643af..ef5fbfe3 100644 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_excessive_slots_in_process.rs +++ b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv.rs @@ -1,56 +1,56 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(all(feature = "tests_that_use_llms", feature = "metal"))] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::kv_cache_dtype::KvCacheDtype; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] -async fn agent_reports_slot_cannot_start_for_excessive_slots_in_process() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - +async fn agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv() -> Result<()> { let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); + let mut inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }; + + inference_parameters.k_cache_dtype = KvCacheDtype::Q80; + inference_parameters.v_cache_dtype = KvCacheDtype::Q40; - let mut cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let mut cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), - slot_count: 257, - }), - desired_state: BalancerDesiredState { + slot_count: 1, + }], + desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; let snapshot = tokio::time::timeout( Duration::from_secs(10), - cluster.agents.until(|snapshot| { + cluster.agents_watcher.until(|snapshot| { snapshot.agents.iter().any(|agent| { agent .issues diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs deleted file mode 100644 index 9289c53d..00000000 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process.rs +++ /dev/null @@ -1,91 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms", - feature = "metal" -))] - -use std::time::Duration; - -use anyhow::Context as _; -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_tests::test_device::TestDevice; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::kv_cache_dtype::KvCacheDtype; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_in_process() -> Result<()> -{ - let device = current_test_device()?; - - if !matches!(device, TestDevice::Metal) { - return Ok(()); - } - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); - - inference_parameters.k_cache_dtype = KvCacheDtype::Q8_0; - inference_parameters.v_cache_dtype = KvCacheDtype::Q4_0; - - let mut cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { - name: "test-agent".to_owned(), - slot_count: 1, - }), - desired_state: BalancerDesiredState { - chat_template_override: None, - inference_parameters, - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }, - wait_for_slots_ready: false, - ..InProcessClusterParams::default() - }) - .await?; - - let snapshot = tokio::time::timeout( - Duration::from_secs(10), - cluster.agents.until(|snapshot| { - snapshot.agents.iter().any(|agent| { - agent - .issues - .iter() - .any(|issue| matches!(issue, AgentIssue::SlotCannotStart(_))) - }) - }), - ) - .await - .context("agent did not report SlotCannotStart within 10s")??; - - let slot_cannot_start_count = snapshot - .agents - .iter() - .flat_map(|agent| agent.issues.iter()) - .filter(|issue| matches!(issue, AgentIssue::SlotCannotStart(params) if !params.error.is_empty())) - .count(); - - assert!( - slot_cannot_start_count > 0, - "expected at least one SlotCannotStart issue with non-empty error" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs b/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs deleted file mode 100644 index cbe534ab..00000000 --- a/paddler_tests/tests/agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess.rs +++ /dev/null @@ -1,88 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms", - feature = "metal" -))] - -use std::time::Duration; - -use anyhow::Context as _; -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_tests::test_device::TestDevice; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::kv_cache_dtype::KvCacheDtype; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn agent_reports_slot_cannot_start_for_metal_quantized_distinct_kv_subprocess() -> Result<()> -{ - let device = current_test_device()?; - - if !matches!(device, TestDevice::Metal) { - return Ok(()); - } - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); - - inference_parameters.k_cache_dtype = KvCacheDtype::Q8_0; - inference_parameters.v_cache_dtype = KvCacheDtype::Q4_0; - - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: AgentConfig::uniform(1, 1), - wait_for_slots_ready: false, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters, - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }), - ..SubprocessClusterParams::default() - }) - .await?; - - let snapshot = tokio::time::timeout( - Duration::from_secs(10), - cluster.agents.until(|snapshot| { - snapshot.agents.iter().any(|agent| { - agent - .issues - .iter() - .any(|issue| matches!(issue, AgentIssue::SlotCannotStart(_))) - }) - }), - ) - .await - .context("agent did not report SlotCannotStart within 10s")??; - - let slot_cannot_start_count = snapshot - .agents - .iter() - .flat_map(|agent| agent.issues.iter()) - .filter(|issue| matches!(issue, AgentIssue::SlotCannotStart(params) if !params.error.is_empty())) - .count(); - - assert!( - slot_cannot_start_count > 0, - "expected at least one SlotCannotStart issue with non-empty error" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/agent_reports_tool_call_validation_failure.rs b/paddler_tests/tests/agent_reports_tool_call_validation_failure.rs new file mode 100644 index 00000000..4ec16766 --- /dev/null +++ b/paddler_tests/tests/agent_reports_tool_call_validation_failure.rs @@ -0,0 +1,119 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::model_card::ModelCard; +use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; +use paddler_tests::start_cluster::start_cluster; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; +use serde_json::Map; +use serde_json::Value; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn agent_reports_tool_call_validation_failure() -> Result<()> { + let ModelCard { + gpu_layer_count, + reference, + } = qwen3_0_6b(); + + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { + name: "test-agent".to_owned(), + slot_count: 1, + }], + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + temperature: 0.0, + ..InferenceParameters::default() + }, + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + wait_for_slots_ready: true, + ..ClusterParams::default() + }) + .await?; + + let mut location_properties = Map::new(); + location_properties.insert( + "location".to_owned(), + serde_json::json!({"type": "integer", "description": "The city name"}), + ); + + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text( + "What is the weather in Paris? Use the get_weather tool to find out." + .to_owned(), + ), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 400, + parse_tool_calls: true, + tools: vec![Tool::Function(FunctionCall { + function: Function { + name: "get_weather".to_owned(), + description: "Get the current weather for a location".to_owned(), + parameters: Parameters::Schema(ValidatedParametersSchema { + schema_type: "object".to_owned(), + properties: Some(location_properties), + required: Some(vec!["location".to_owned()]), + additional_properties: Some(Value::Bool(false)), + }), + }, + })], + }) + .await?; + + let validation_failures: Vec<&Vec> = collected + .token_results + .iter() + .filter_map(|event| match &event.token_result { + GeneratedTokenResult::ToolCallValidationFailed(messages) => Some(messages), + _ => None, + }) + .collect(); + + assert!( + !validation_failures.is_empty(), + "expected at least one ToolCallValidationFailed event when the model emits a string \ + location against an integer-typed schema; got tokens:\n{}", + collected.text + ); + + let first_failure = validation_failures + .iter() + .flat_map(|messages| messages.iter()) + .next() + .ok_or_else(|| anyhow::anyhow!("no validation-failure messages in any event"))?; + + assert!( + first_failure.contains("get_weather"), + "validation-failure message should name the offending tool; got: {first_failure}" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs b/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs index cc61d8b2..7cd475a2 100644 --- a/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs +++ b/paddler_tests/tests/agent_returns_identical_embeddings_for_identical_documents.rs @@ -2,35 +2,31 @@ use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_identical_embeddings_for_identical_documents() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let repeated_content = "Deterministic embedding output test."; - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![ EmbeddingInputDocument { content: repeated_content.to_owned(), @@ -45,8 +41,6 @@ async fn agent_returns_identical_embeddings_for_identical_documents() -> Result< }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 2); assert!(collected.saw_done); diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs index e1e14b7e..a9a04a7e 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_invalid_base64.rs @@ -1,32 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_invalid_base64() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -49,22 +40,18 @@ async fn agent_returns_image_decoding_error_for_invalid_base64() -> Result<()> { }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - let saw_decoding_error = collected.token_results.iter().any(|result| { - matches!( - result.token_result, - GeneratedTokenResult::ImageDecodingFailed(_) - ) - }); - - assert!( - saw_decoding_error, - "invalid base64 must produce ImageDecodingFailed" - ); - } + if let Ok(collected) = outcome { + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); + + assert!( + saw_decoding_error, + "invalid base64 must produce ImageDecodingFailed" + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs index 55a4a84b..7e40cef4 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_malformed_data_uri.rs @@ -1,32 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_malformed_data_uri() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -49,22 +40,18 @@ async fn agent_returns_image_decoding_error_for_malformed_data_uri() -> Result<( }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - let saw_decoding_error = collected.token_results.iter().any(|result| { - matches!( - result.token_result, - GeneratedTokenResult::ImageDecodingFailed(_) - ) - }); - - assert!( - saw_decoding_error, - "malformed data URI must produce ImageDecodingFailed" - ); - } + if let Ok(collected) = outcome { + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); + + assert!( + saw_decoding_error, + "malformed data URI must produce ImageDecodingFailed" + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs b/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs index 8ba81670..927c4e46 100644 --- a/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs +++ b/paddler_tests/tests/agent_returns_image_decoding_error_for_remote_url.rs @@ -1,32 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_image_decoding_error_for_remote_url() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -49,22 +40,18 @@ async fn agent_returns_image_decoding_error_for_remote_url() -> Result<()> { }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - let saw_decoding_error = collected.token_results.iter().any(|result| { - matches!( - result.token_result, - GeneratedTokenResult::ImageDecodingFailed(_) - ) - }); - - assert!( - saw_decoding_error, - "remote URL must produce ImageDecodingFailed (only data URIs supported)" - ); - } + if let Ok(collected) = outcome { + let saw_decoding_error = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ImageDecodingFailed(_) + ) + }); + + assert!( + saw_decoding_error, + "remote URL must produce ImageDecodingFailed (only data URIs supported)" + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs b/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs index 31d0b471..608fb3a5 100644 --- a/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs +++ b/paddler_tests/tests/agent_returns_rms_normalized_embeddings_when_requested.rs @@ -1,33 +1,29 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_rms_normalized_embeddings_when_requested() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![EmbeddingInputDocument { content: "Testing RMS normalization on embeddings".to_owned(), id: "doc-rms".to_owned(), @@ -36,8 +32,6 @@ async fn agent_returns_rms_normalized_embeddings_when_requested() -> Result<()> }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); assert!(matches!( diff --git a/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs b/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs index 45f68685..125154c2 100644 --- a/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs +++ b/paddler_tests/tests/agent_returns_unnormalized_embeddings_when_requested.rs @@ -1,33 +1,29 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_embedding_cluster::start_in_process_embedding_cluster; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; +use paddler_tests::start_embedding_cluster::start_embedding_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_returns_unnormalized_embeddings_when_requested() -> Result<()> { - let cluster = start_in_process_embedding_cluster( - InferenceParameters { + let cluster = start_embedding_cluster(Qwen3EmbeddingClusterParams { + agents: vec![AgentConfig::single(1)], + inference_parameters: InferenceParameters { enable_embeddings: true, ..InferenceParameters::default() }, - AgentConfig::single(1), - ) + ..Qwen3EmbeddingClusterParams::default() + }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let collected = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![EmbeddingInputDocument { content: "Testing no normalization on embeddings".to_owned(), id: "doc-none".to_owned(), @@ -36,8 +32,6 @@ async fn agent_returns_unnormalized_embeddings_when_requested() -> Result<()> { }) .await?; - let collected = collect_embedding_results(stream).await?; - assert_eq!(collected.embeddings.len(), 1); assert!(collected.saw_done); assert!(matches!( diff --git a/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs b/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs index 7a3a2bd9..4849db80 100644 --- a/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs +++ b/paddler_tests/tests/agent_serves_four_concurrent_clients_streaming_tokens.rs @@ -1,41 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_serves_four_concurrent_clients_streaming_tokens() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 4)).await?; - - let inference_base_url = cluster.addresses.inference_base_url()?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 4)).await?; let prompts = ["The sky is", "Roses are", "Once upon", "In the year"]; let client_tasks = prompts.into_iter().map(|prompt| { - let inference_base_url = inference_base_url.clone(); - - async move { - let inference_client = InferenceHttpClient::new(Client::new(), inference_base_url); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 8, - raw_prompt: prompt.to_owned(), - }) - .await?; - - collect_generated_tokens(stream).await - } + cluster.continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 8, + raw_prompt: prompt.to_owned(), + }) }); let collected_results = futures_util::future::try_join_all(client_tasks).await?; diff --git a/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs b/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs index fdb10902..f08d4e8d 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_conversation_history_over_http.rs @@ -1,29 +1,20 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_conversation_history_over_http() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello".to_owned()), @@ -37,8 +28,6 @@ async fn agent_streams_tokens_from_conversation_history_over_http() -> Result<() }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs b/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs index d9074177..b7802a41 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_image_data_uri.rs @@ -1,34 +1,25 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_subprocess_cluster_with_smolvlm2::start_subprocess_cluster_with_smolvlm2; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_smolvlm2::start_cluster_with_smolvlm2; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_image_data_uri() -> Result<()> { - let cluster = start_subprocess_cluster_with_smolvlm2(AgentConfig::uniform(1, 4)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_smolvlm2(AgentConfig::uniform(1, 4)).await?; let image_data_uri = load_test_image_data_uri()?; - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -51,8 +42,6 @@ async fn agent_streams_tokens_from_image_data_uri() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let received_tokens = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs b/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs index e3baa1f2..ec05fb6a 100644 --- a/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs +++ b/paddler_tests/tests/agent_streams_tokens_from_raw_prompt.rs @@ -1,34 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_streams_tokens_from_raw_prompt() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "The capital of France is".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs b/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs index 7bf0facc..adb04811 100644 --- a/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs +++ b/paddler_tests/tests/agent_text_only_model_rejects_image_input.rs @@ -1,35 +1,26 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn agent_text_only_model_rejects_image_input() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let image_data_uri = load_test_image_data_uri()?; - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Parts(vec![ @@ -52,23 +43,19 @@ async fn agent_text_only_model_rejects_image_input() -> Result<()> { }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - let saw_rejection = collected.token_results.iter().any(|result| { - matches!( - result.token_result, - GeneratedTokenResult::ChatTemplateError(_) - | GeneratedTokenResult::MultimodalNotSupported(_) - ) - }); + if let Ok(collected) = outcome { + let saw_rejection = collected.token_results.iter().any(|result| { + matches!( + result.token_result, + GeneratedTokenResult::ChatTemplateError(_) + | GeneratedTokenResult::MultimodalNotSupported(_) + ) + }); - assert!( - saw_rejection, - "text-only model must reject image input with chat template or multimodal-not-supported error" - ); - } + assert!( + saw_rejection, + "text-only model must reject image input with chat template or multimodal-not-supported error" + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs b/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs deleted file mode 100644 index 7ed99c42..00000000 --- a/paddler_tests/tests/balancer_closes_management_websocket_on_sigterm.rs +++ /dev/null @@ -1,57 +0,0 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - -use anyhow::Result; -use anyhow::anyhow; -use anyhow::bail; -use futures_util::StreamExt as _; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use tokio_tungstenite::connect_async; -use tokio_tungstenite::tungstenite::protocol::Message; - -#[tokio::test(flavor = "multi_thread")] -async fn balancer_closes_management_websocket_on_sigterm() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: Vec::new(), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - let management_addr = cluster.addresses.management; - let ws_url = format!("ws://{management_addr}/api/v1/agent_socket/test_agent_shutdown_probe"); - let (mut ws_stream, _response) = connect_async(ws_url).await?; - - let first_frame = ws_stream - .next() - .await - .ok_or_else(|| anyhow!("WebSocket closed before yielding the version notification"))??; - - match first_frame { - Message::Text(_) => {} - other => bail!("expected initial Text frame, got {other:?}"), - } - - let observe_close = tokio::spawn(async move { - while let Some(item) = ws_stream.next().await { - match item { - Ok(Message::Close(_)) => return Ok::(true), - Ok(_) => {} - Err(_) => break, - } - } - - Ok(false) - }); - - cluster.shutdown().await?; - - let saw_close_frame = observe_close.await??; - - assert!( - saw_close_frame, - "WebSocket client must observe a Close frame after the balancer is SIGTERMed" - ); - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs b/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs index 8dca6351..7c41013a 100644 --- a/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs +++ b/paddler_tests/tests/balancer_completes_buffered_request_after_agent_joins.rs @@ -1,61 +1,50 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::buffered_requests_status::assert_count::assert_count; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_completes_buffered_request_after_agent_joins() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - buffered_request_timeout: Duration::from_secs(120), + buffered_request_timeout: Duration::from_mins(2), max_buffered_requests: 1, desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), @@ -63,15 +52,13 @@ async fn balancer_completes_buffered_request_after_agent_joins() -> Result<()> { .await?; cluster - .buffered_requests - .until(assert_count(1)) + .wait_for_buffered_request_count(1) .await .context("balancer should buffer the request before any agent joins")?; - let mut agent_child = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("buffered-agent".to_owned()), - slots: 4, + cluster.spawn_additional_agent(&AgentConfig { + name: "buffered-agent".to_owned(), + slot_count: 4, })?; let message = stream @@ -90,9 +77,6 @@ async fn balancer_completes_buffered_request_after_agent_joins() -> Result<()> { } } - agent_child.start_kill()?; - agent_child.wait().await?; - cluster.shutdown().await?; Ok(()) diff --git a/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs b/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs index 3cee6e2b..6fe22f36 100644 --- a/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs +++ b/paddler_tests/tests/balancer_completes_in_flight_inference_during_model_switch.rs @@ -1,36 +1,28 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Result; use anyhow::anyhow; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::inference_client::Message as InferenceMessage; -use paddler_types::inference_client::Response as InferenceResponse; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_client::response::Response as InferenceResponse; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_completes_in_flight_inference_during_model_switch() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; let expected_output = "the quick brown fox jumps over the lazy dog"; - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: Some(GrammarConstraint::Gbnf { grammar: format!("root ::= \"{expected_output}\""), root: "root".to_owned(), diff --git a/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs b/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs deleted file mode 100644 index baccd731..00000000 --- a/paddler_tests/tests/balancer_distributes_buffered_requests_across_two_agents.rs +++ /dev/null @@ -1,99 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::time::Duration; - -use anyhow::Result; -use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_distributes_buffered_requests_across_two_agents() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: vec![ - AgentConfig { - name: "distributed-agent-0".to_owned(), - slot_count: 2, - }, - AgentConfig { - name: "distributed-agent-1".to_owned(), - slot_count: 2, - }, - ], - wait_for_slots_ready: true, - buffered_request_timeout: Duration::from_secs(120), - max_buffered_requests: 10, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }), - ..SubprocessClusterParams::default() - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut streams = Vec::new(); - - for _ in 0..5 { - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 10, - raw_prompt: "Hello".to_owned(), - }) - .await?; - - streams.push(stream); - } - - let mut successful_responses = 0; - - for mut stream in streams { - if let Some(item) = stream.next().await { - match item? { - Message::Response(_) => successful_responses += 1, - Message::Error(envelope) => { - anyhow::bail!( - "expected success, got error {}: {}", - envelope.error.code, - envelope.error.description - ); - } - } - } - } - - assert_eq!(successful_responses, 5); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs deleted file mode 100644 index 9ccf57c4..00000000 --- a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents.rs +++ /dev/null @@ -1,71 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::collections::BTreeSet; - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_distributes_embedding_batch_across_agents() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(2, 4), - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - ..Qwen3EmbeddingClusterParams::default() - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let filler = "x".repeat(380); - let input_batch: Vec = (0..12) - .map(|index| EmbeddingInputDocument { - content: format!("Document number {index:02}: {filler}"), - id: format!("doc-{index}"), - }) - .collect(); - let params = GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }; - - let stream = inference_client - .post_generate_embedding_batch(¶ms) - .await?; - let collected = collect_embedding_results(stream).await?; - - assert_eq!(collected.embeddings.len(), 12); - assert!(collected.saw_done); - assert!(collected.errors.is_empty()); - - let producers: BTreeSet<&str> = collected - .embeddings - .iter() - .filter_map(|produced| produced.generated_by.as_deref()) - .collect(); - - assert!( - producers.len() >= 2, - "expected the embedding batch to be distributed across at least two agents, but only saw producers: {producers:?}" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs b/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs deleted file mode 100644 index a00c5ae6..00000000 --- a/paddler_tests/tests/balancer_distributes_embedding_batch_across_agents_with_uneven_slots.rs +++ /dev/null @@ -1,97 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::collections::BTreeSet; - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_distributes_embedding_batch_across_agents_with_uneven_slots() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: vec![ - AgentConfig { - name: "agent-fat".to_owned(), - slot_count: 4, - }, - AgentConfig { - name: "agent-thin-a".to_owned(), - slot_count: 1, - }, - AgentConfig { - name: "agent-medium".to_owned(), - slot_count: 2, - }, - AgentConfig { - name: "agent-thin-b".to_owned(), - slot_count: 1, - }, - ], - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - ..Qwen3EmbeddingClusterParams::default() - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let input_batch: Vec = (0..8) - .map(|index| EmbeddingInputDocument { - content: format!("Uneven-slot document number {index}."), - id: format!("doc-{index}"), - }) - .collect(); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }) - .await?; - - let collected = collect_embedding_results(stream).await?; - - assert_eq!(collected.embeddings.len(), 8); - assert!(collected.saw_done); - assert!(collected.errors.is_empty()); - - let returned_document_ids: BTreeSet = collected - .embeddings - .iter() - .map(|produced| produced.embedding.source_document_id.clone()) - .collect(); - let expected_document_ids: BTreeSet = - (0..8).map(|index| format!("doc-{index}")).collect(); - assert_eq!(returned_document_ids, expected_document_ids); - - let producers: BTreeSet<&str> = collected - .embeddings - .iter() - .filter_map(|produced| produced.generated_by.as_deref()) - .collect(); - - assert_eq!( - producers.len(), - 4, - "embedding batch must fan out across all agents even when slot counts are uneven, but only saw producers: {producers:?}", - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs b/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs deleted file mode 100644 index 9e114ab5..00000000 --- a/paddler_tests/tests/balancer_distributes_embedding_burst_evenly_across_agents.rs +++ /dev/null @@ -1,91 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::collections::BTreeSet; - -use std::time::Duration; - -use anyhow::Result; -use futures_util::future; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_distributes_embedding_burst_evenly_across_agents() -> Result<()> { - const AGENT_COUNT: usize = 4; - const SLOTS_PER_AGENT: i32 = 2; - const CONCURRENT_REQUESTS: usize = 8; - - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(AGENT_COUNT, SLOTS_PER_AGENT), - buffered_request_timeout: Duration::from_secs(60), - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - max_buffered_requests: 32, - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let collection_futures = (0..CONCURRENT_REQUESTS).map(|request_index| { - let inference_client = inference_client.clone(); - async move { - let input_batch: Vec = (0..4) - .map(|document_index| EmbeddingInputDocument { - content: format!( - "Burst request {request_index}, document {document_index}: \ - provide an embedding for evaluation." - ), - id: format!("req-{request_index}-doc-{document_index}"), - }) - .collect(); - - let stream = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }) - .await?; - - collect_embedding_results(stream).await - } - }); - - let collected_streams = future::try_join_all(collection_futures).await?; - - let producers_across_streams: BTreeSet<&str> = collected_streams - .iter() - .flat_map(|collected| collected.embeddings.iter()) - .filter_map(|produced| produced.generated_by.as_deref()) - .collect(); - - assert_eq!( - producers_across_streams.len(), - AGENT_COUNT, - "burst of {CONCURRENT_REQUESTS} embedding batches across {AGENT_COUNT} agents must reach every agent, but saw producers: {producers_across_streams:?}", - ); - - for collected in &collected_streams { - assert!(collected.saw_done); - assert!(collected.errors.is_empty()); - assert_eq!(collected.embeddings.len(), 4); - } - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs b/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs deleted file mode 100644 index afe3c64f..00000000 --- a/paddler_tests/tests/balancer_fans_out_embedding_batch_to_all_agents.rs +++ /dev/null @@ -1,74 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::collections::BTreeSet; - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_fans_out_embedding_batch_to_all_agents() -> Result<()> { - let agent_count: usize = 4; - - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(agent_count, 2), - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - ..Qwen3EmbeddingClusterParams::default() - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let filler = "x".repeat(380); - let input_batch: Vec = (0..16) - .map(|index| EmbeddingInputDocument { - content: format!("Document number {index:02}: {filler}"), - id: format!("doc-{index}"), - }) - .collect(); - let params = GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }; - - let stream = inference_client - .post_generate_embedding_batch(¶ms) - .await?; - let collected = collect_embedding_results(stream).await?; - - assert_eq!(collected.embeddings.len(), 16); - assert!(collected.saw_done); - assert!(collected.errors.is_empty()); - - let producers: BTreeSet<&str> = collected - .embeddings - .iter() - .filter_map(|produced| produced.generated_by.as_deref()) - .collect(); - - assert_eq!( - producers.len(), - agent_count, - "expected the embedding batch to fan out across every agent, but only saw producers: {producers:?}" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs b/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs index 7a734880..ffb60d67 100644 --- a/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs +++ b/paddler_tests/tests/balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks.rs @@ -5,16 +5,16 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; -use paddler::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; -use paddler::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; -use paddler::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::manages_senders_controller::ManagesSendersController; -use paddler::balancer::request_from_agent::forward_responses_stream; +use paddler_balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use paddler_balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use paddler_balancer::chunk_forwarding_session_controller::transform_result::TransformResult; +use paddler_balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::manages_senders_controller::ManagesSendersController; +use paddler_balancer::request_from_agent::forward_responses_stream; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::jsonrpc::ErrorEnvelope; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -42,7 +42,7 @@ async fn balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks() let agent_controller_clone = agent_controller.clone(); let request_id_clone = request_id.clone(); - let forward_handle: tokio::task::JoinHandle> = tokio::spawn(async move { + let forward_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move { forward_responses_stream::<_, EmbeddingSenderCollection>( agent_controller_clone, connection_close, @@ -50,16 +50,16 @@ async fn balancer_forwards_504_timeout_error_when_agent_stops_emitting_chunks() receive_response_controller, request_id_clone, session_controller, + CancellationToken::new(), ) - .await + .await; }); let forward_completed_within = inference_item_timeout * 10; tokio::time::timeout(forward_completed_within, forward_handle) .await .context("forward_responses_stream did not return within the 504-timeout budget")? - .context("forward_responses_stream task panicked")? - .context("forward_responses_stream returned an error")?; + .context("forward_responses_stream task panicked")?; let chunk = chunk_rx .recv() diff --git a/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs b/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs index 38c16a9f..0227440f 100644 --- a/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs +++ b/paddler_tests/tests/balancer_forwards_error_when_response_channel_closes_before_terminator.rs @@ -5,17 +5,17 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use anyhow::anyhow; -use paddler::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; -use paddler::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; -use paddler::balancer::chunk_forwarding_session_controller::transform_result::TransformResult; -use paddler::balancer::embedding_sender_collection::EmbeddingSenderCollection; -use paddler::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; -use paddler::balancer::manages_senders::ManagesSenders as _; -use paddler::balancer::manages_senders_controller::ManagesSendersController; -use paddler::balancer::request_from_agent::forward_responses_stream; +use paddler_balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController; +use paddler_balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer; +use paddler_balancer::chunk_forwarding_session_controller::transform_result::TransformResult; +use paddler_balancer::embedding_sender_collection::EmbeddingSenderCollection; +use paddler_balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration; +use paddler_balancer::manages_senders::ManagesSenders as _; +use paddler_balancer::manages_senders_controller::ManagesSendersController; +use paddler_balancer::request_from_agent::forward_responses_stream; +use paddler_messaging::inference_client::message::Message as OutgoingMessage; +use paddler_messaging::jsonrpc::error_envelope::ErrorEnvelope; use paddler_tests::make_agent_controller_without_remote_agent::make_agent_controller_without_remote_agent; -use paddler_types::inference_client::Message as OutgoingMessage; -use paddler_types::jsonrpc::ErrorEnvelope; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -43,7 +43,7 @@ async fn forward_responses_stream_emits_error_envelope_when_response_channel_clo let agent_controller_clone = agent_controller.clone(); let request_id_clone = request_id.clone(); - let forward_handle: tokio::task::JoinHandle> = tokio::spawn(async move { + let forward_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move { forward_responses_stream::<_, EmbeddingSenderCollection>( agent_controller_clone, connection_close, @@ -51,8 +51,9 @@ async fn forward_responses_stream_emits_error_envelope_when_response_channel_clo receive_response_controller, request_id_clone, session_controller, + CancellationToken::new(), ) - .await + .await; }); tokio::time::sleep(Duration::from_millis(50)).await; @@ -64,8 +65,7 @@ async fn forward_responses_stream_emits_error_envelope_when_response_channel_clo tokio::time::timeout(Duration::from_secs(5), forward_handle) .await .context("forward_responses_stream did not complete in time")? - .context("forward_responses_stream task panicked")? - .context("forward_responses_stream returned an error")?; + .context("forward_responses_stream task panicked")?; let chunk = chunk_rx .recv() diff --git a/paddler_tests/tests/balancer_inference_health_returns_ok.rs b/paddler_tests/tests/balancer_inference_health_returns_ok.rs index f25ff347..5732de82 100644 --- a/paddler_tests/tests/balancer_inference_health_returns_ok.rs +++ b/paddler_tests/tests/balancer_inference_health_returns_ok.rs @@ -1,20 +1,22 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn balancer_inference_health_returns_ok() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_health_url = cluster.addresses.inference_base_url()?.join("health")?; + let inference_health_url = cluster + .balancer + .addresses + .inference_base_url()? + .join("health")?; let response = reqwest::get(inference_health_url) .await diff --git a/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs b/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs index 68cc3b26..2a97c999 100644 --- a/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs +++ b/paddler_tests/tests/balancer_inference_service_replies_with_configured_cors_origin.rs @@ -1,24 +1,26 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; const ALLOWED_ORIGIN: &str = "http://example.com"; #[tokio::test(flavor = "multi_thread")] async fn balancer_inference_service_replies_with_configured_cors_origin() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), inference_cors_allowed_hosts: vec![ALLOWED_ORIGIN.to_owned()], wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; let http_client = reqwest::Client::new(); - let inference_health_url = cluster.addresses.inference_base_url()?.join("health")?; + let inference_health_url = cluster + .balancer + .addresses + .inference_base_url()? + .join("health")?; let response = http_client .request(reqwest::Method::OPTIONS, inference_health_url) diff --git a/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs b/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs index 003b4993..7485356e 100644 --- a/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs +++ b/paddler_tests/tests/balancer_management_service_replies_with_configured_cors_origin.rs @@ -1,24 +1,26 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; const ALLOWED_ORIGIN: &str = "http://example.com"; #[tokio::test(flavor = "multi_thread")] async fn balancer_management_service_replies_with_configured_cors_origin() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), management_cors_allowed_hosts: vec![ALLOWED_ORIGIN.to_owned()], wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; let http_client = reqwest::Client::new(); - let management_health_url = cluster.addresses.management_base_url()?.join("health")?; + let management_health_url = cluster + .balancer + .addresses + .management_base_url()? + .join("health")?; let response = http_client .request(reqwest::Method::OPTIONS, management_health_url) diff --git a/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs b/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs index 2a99f0ec..ab2ffef0 100644 --- a/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs +++ b/paddler_tests/tests/balancer_memory_storage_persists_desired_state.rs @@ -1,17 +1,14 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -26,12 +23,12 @@ async fn balancer_memory_storage_persists_desired_state() -> Result<()> { use_chat_template_override: false, }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: "memory://".to_owned(), desired_state: Some(desired_state.clone()), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs b/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs index 0c5876c5..8ea8c50b 100644 --- a/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs +++ b/paddler_tests/tests/balancer_openai_compat_health_returns_ok.rs @@ -1,20 +1,22 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn balancer_openai_compat_health_returns_ok() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let openai_health_url = cluster.addresses.compat_openai_base_url()?.join("health")?; + let openai_health_url = cluster + .balancer + .addresses + .compat_openai_base_url()? + .join("health")?; let response = reqwest::get(openai_health_url) .await diff --git a/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs b/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs index 7452f91e..48dcd8ab 100644 --- a/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs +++ b/paddler_tests/tests/balancer_persists_chat_template_override_across_restart.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::state_database_file::StateDatabaseFile; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::state_database_file::StateDatabaseFile; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -34,23 +31,23 @@ async fn balancer_persists_chat_template_override_across_restart() -> Result<()> use_chat_template_override: true, }; - let first_cluster = start_subprocess_cluster(SubprocessClusterParams { + let first_cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(desired_state.clone()), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; first_cluster.shutdown().await?; - let second_cluster = start_subprocess_cluster(SubprocessClusterParams { + let second_cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: None, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs b/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs index 278f3f25..b32f373d 100644 --- a/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs +++ b/paddler_tests/tests/balancer_persists_desired_state_across_restart.rs @@ -1,18 +1,15 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::state_database_file::StateDatabaseFile; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::state_database_file::StateDatabaseFile; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -29,23 +26,23 @@ async fn balancer_persists_desired_state_across_restart() -> Result<()> { use_chat_template_override: false, }; - let first_cluster = start_subprocess_cluster(SubprocessClusterParams { + let first_cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(desired_state.clone()), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; first_cluster.shutdown().await?; - let second_cluster = start_subprocess_cluster(SubprocessClusterParams { + let second_cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: None, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs b/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs index 38b2fb26..7b72144a 100644 --- a/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs +++ b/paddler_tests/tests/balancer_persists_huggingface_mmproj_in_desired_state.rs @@ -1,18 +1,15 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::smolvlm2_256m::smolvlm2_256m; use paddler_tests::model_card::smolvlm2_256m_mmproj::smolvlm2_256m_mmproj; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -26,7 +23,7 @@ async fn balancer_persists_huggingface_mmproj_in_desired_state() -> Result<()> { .. } = smolvlm2_256m_mmproj(); - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -36,7 +33,7 @@ async fn balancer_persists_huggingface_mmproj_in_desired_state() -> Result<()> { multimodal_projection: AgentDesiredModel::HuggingFace(mmproj_reference.clone()), use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs b/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs index db0ba63a..0faf4202 100644 --- a/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs +++ b/paddler_tests/tests/balancer_persists_local_mmproj_path_in_desired_state.rs @@ -1,17 +1,14 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::smolvlm2_256m::smolvlm2_256m; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -20,7 +17,7 @@ async fn balancer_persists_local_mmproj_path_in_desired_state() -> Result<()> { let local_mmproj_path = "/tmp/test-mmproj.gguf".to_owned(); - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -30,7 +27,7 @@ async fn balancer_persists_local_mmproj_path_in_desired_state() -> Result<()> { multimodal_projection: AgentDesiredModel::LocalToAgent(local_mmproj_path.clone()), use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs b/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs index f69a1133..d9dbe6bc 100644 --- a/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs +++ b/paddler_tests/tests/balancer_persists_model_switch_in_storage.rs @@ -1,18 +1,15 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::state_database_file::StateDatabaseFile; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::state_database_file::StateDatabaseFile; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -29,12 +26,12 @@ async fn balancer_persists_model_switch_in_storage() -> Result<()> { use_chat_template_override: false, }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, state_database_url: database.url.clone(), desired_state: Some(initial_state.clone()), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs b/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs index e5f83175..3522bb25 100644 --- a/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs +++ b/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs @@ -1,23 +1,20 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_persists_url_model_in_desired_state() -> Result<()> { let configured_url = "https://example.invalid/persisted-model.gguf".to_owned(); - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -29,7 +26,7 @@ async fn balancer_persists_url_model_in_desired_state() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs b/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs deleted file mode 100644 index d86bc932..00000000 --- a/paddler_tests/tests/balancer_registers_multiple_agents_over_time.rs +++ /dev/null @@ -1,53 +0,0 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - -use anyhow::Context as _; -use anyhow::Result; -use paddler_tests::agents_status::assert_agent_count::assert_agent_count; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; - -#[tokio::test(flavor = "multi_thread")] -async fn balancer_registers_multiple_agents_over_time() -> Result<()> { - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: Vec::new(), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - let mut first_agent = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("test-agent-1".to_owned()), - slots: 1, - })?; - - cluster - .agents - .until(assert_agent_count(1)) - .await - .context("first agent should register")?; - - let mut second_agent = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("test-agent-2".to_owned()), - slots: 1, - })?; - - cluster - .agents - .until(assert_agent_count(2)) - .await - .context("second agent should register")?; - - first_agent.start_kill()?; - first_agent.wait().await?; - - second_agent.start_kill()?; - second_agent.wait().await?; - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs index 709af576..da1fa6c1 100644 --- a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs +++ b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_for_invalid_jinja.rs @@ -1,27 +1,24 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_chat_template_does_not_compile_for_invalid_jinja() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -33,7 +30,7 @@ async fn balancer_reports_chat_template_does_not_compile_for_invalid_jinja() -> multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -44,7 +41,7 @@ async fn balancer_reports_chat_template_does_not_compile_for_invalid_jinja() -> .clone(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs index f067a482..e6f3d7b1 100644 --- a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs +++ b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs @@ -1,27 +1,24 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] -async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced( -) -> Result<()> { +async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced() +-> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); let invalid_template = ChatTemplate { @@ -31,7 +28,7 @@ async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_ content: "{% for message in messages %}{{ message.content }}{% endfor %}".to_owned(), }; - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -41,7 +38,7 @@ async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_ multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -53,7 +50,7 @@ async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_ let predicate_agent_id = agent_id.clone(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == predicate_agent_id @@ -85,7 +82,7 @@ async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_ let predicate_agent_id_for_recovery = agent_id; tokio::time::timeout( Duration::from_secs(3), - cluster.agents.until(move |snapshot| { + cluster.agents_watcher.until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == predicate_agent_id_for_recovery && agent diff --git a/paddler_tests/tests/balancer_reports_download_server_denied_access.rs b/paddler_tests/tests/balancer_reports_download_server_denied_access.rs index 39e03e25..12e3a743 100644 --- a/paddler_tests/tests/balancer_reports_download_server_denied_access.rs +++ b/paddler_tests/tests/balancer_reports_download_server_denied_access.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::local_http_fixture::LocalHttpFixture; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -21,7 +18,7 @@ async fn balancer_reports_download_server_denied_access() -> Result<()> { let fixture = LocalHttpFixture::start("HTTP/1.1 403 Forbidden", Vec::new()).await?; let model_url = fixture.url("/private.gguf"); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -33,7 +30,7 @@ async fn balancer_reports_download_server_denied_access() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -44,7 +41,7 @@ async fn balancer_reports_download_server_denied_access() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_download_server_errored.rs b/paddler_tests/tests/balancer_reports_download_server_errored.rs index 73fed463..cbcec8da 100644 --- a/paddler_tests/tests/balancer_reports_download_server_errored.rs +++ b/paddler_tests/tests/balancer_reports_download_server_errored.rs @@ -1,28 +1,24 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::local_http_fixture::LocalHttpFixture; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_download_server_errored() -> Result<()> { - let fixture = - LocalHttpFixture::start("HTTP/1.1 500 Internal Server Error", Vec::new()).await?; + let fixture = LocalHttpFixture::start("HTTP/1.1 500 Internal Server Error", Vec::new()).await?; let model_url = fixture.url("/broken.gguf"); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -34,7 +30,7 @@ async fn balancer_reports_download_server_errored() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -45,7 +41,7 @@ async fn balancer_reports_download_server_errored() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs b/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs index 8b542da7..b2e78520 100644 --- a/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs +++ b/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs @@ -1,25 +1,22 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_download_server_is_unreachable() -> Result<()> { let model_url = "http://127.0.0.1:1/model.gguf".to_owned(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -31,7 +28,7 @@ async fn balancer_reports_download_server_is_unreachable() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -42,7 +39,7 @@ async fn balancer_reports_download_server_is_unreachable() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs b/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs index e711d191..e4dd08c8 100644 --- a/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs +++ b/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs @@ -1,25 +1,22 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_download_url_is_malformed() -> Result<()> { let malformed_url = "not a valid url".to_owned(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -31,7 +28,7 @@ async fn balancer_reports_download_url_is_malformed() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -42,7 +39,7 @@ async fn balancer_reports_download_url_is_malformed() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs b/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs index 411452b1..aadb558e 100644 --- a/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs +++ b/paddler_tests/tests/balancer_reports_huggingface_model_does_not_exist.rs @@ -1,23 +1,20 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::huggingface_model_reference::HuggingFaceModelReference; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_huggingface_model_does_not_exist() -> Result<()> { - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -31,7 +28,7 @@ async fn balancer_reports_huggingface_model_does_not_exist() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -42,7 +39,7 @@ async fn balancer_reports_huggingface_model_does_not_exist() -> Result<()> { .clone(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs b/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs index 8e75911c..3feb5109 100644 --- a/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs +++ b/paddler_tests/tests/balancer_reports_mmproj_cannot_be_loaded_for_invalid_file.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -25,7 +22,7 @@ async fn balancer_reports_mmproj_cannot_be_loaded_for_invalid_file() -> Result<( let ModelCard { reference, .. } = qwen3_0_6b(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -35,7 +32,7 @@ async fn balancer_reports_mmproj_cannot_be_loaded_for_invalid_file() -> Result<( multimodal_projection: AgentDesiredModel::LocalToAgent(invalid_mmproj_path.to_owned()), use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -49,7 +46,7 @@ async fn balancer_reports_mmproj_cannot_be_loaded_for_invalid_file() -> Result<( let expected_path = invalid_mmproj_path.to_owned(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == watch_agent_id diff --git a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs index 0c74809d..dcbc28ec 100644 --- a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs +++ b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_corrupt_file.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::io::Write as _; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; use tempfile::NamedTempFile; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] @@ -29,7 +26,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_corrupt_file() -> Result<() .context("temp file path is not valid UTF-8")? .to_owned(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -39,7 +36,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_corrupt_file() -> Result<() multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -52,7 +49,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_corrupt_file() -> Result<() let expected_path = corrupt_model_path.clone(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs index 000db231..9848afdb 100644 --- a/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs +++ b/paddler_tests/tests/balancer_reports_model_cannot_be_loaded_for_invalid_gguf.rs @@ -1,24 +1,21 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_model_cannot_be_loaded_for_invalid_gguf() -> Result<()> { let invalid_gguf_path = concat!(env!("CARGO_MANIFEST_DIR"), "/../fixtures/invalid.gguf"); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -28,7 +25,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_invalid_gguf() -> Result<() multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -42,7 +39,7 @@ async fn balancer_reports_model_cannot_be_loaded_for_invalid_gguf() -> Result<() let expected_path = invalid_gguf_path.to_owned(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == watch_agent_id diff --git a/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs b/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs index 502761fc..725a627b 100644 --- a/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs +++ b/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::url_model_reference::UrlModelReference; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::local_http_fixture::LocalHttpFixture; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::url_model_reference::UrlModelReference; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -21,7 +18,7 @@ async fn balancer_reports_model_does_not_exist_at_url() -> Result<()> { let fixture = LocalHttpFixture::start("HTTP/1.1 404 Not Found", Vec::new()).await?; let model_url = fixture.url("/missing.gguf"); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -33,7 +30,7 @@ async fn balancer_reports_model_does_not_exist_at_url() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -44,7 +41,7 @@ async fn balancer_reports_model_does_not_exist_at_url() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs b/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs index f84aad13..13ccd6be 100644 --- a/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs +++ b/paddler_tests/tests/balancer_reports_model_file_does_not_exist.rs @@ -1,22 +1,19 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_model_file_does_not_exist() -> Result<()> { - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -26,7 +23,7 @@ async fn balancer_reports_model_file_does_not_exist() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -37,7 +34,7 @@ async fn balancer_reports_model_file_does_not_exist() -> Result<()> { .clone(); let snapshot = cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs b/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs index 261a88fe..4d8bfcd3 100644 --- a/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs +++ b/paddler_tests/tests/balancer_reports_multimodal_projection_cannot_be_loaded.rs @@ -1,26 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_multimodal_projection_cannot_be_loaded() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -32,7 +29,7 @@ async fn balancer_reports_multimodal_projection_cannot_be_loaded() -> Result<()> ), use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -43,7 +40,7 @@ async fn balancer_reports_multimodal_projection_cannot_be_loaded() -> Result<()> .clone(); cluster - .agents + .agents_watcher .until(move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == agent_id diff --git a/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs b/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs index 35a5c442..d12c6e57 100644 --- a/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs +++ b/paddler_tests/tests/balancer_reports_unable_to_find_chat_template_for_embedding_model.rs @@ -1,26 +1,23 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::nomic_embed_text_v1_5::nomic_embed_text_v1_5; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_reports_unable_to_find_chat_template_for_embedding_model() -> Result<()> { let ModelCard { reference, .. } = nomic_embed_text_v1_5(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -30,7 +27,7 @@ async fn balancer_reports_unable_to_find_chat_template_for_embedding_model() -> multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -42,7 +39,7 @@ async fn balancer_reports_unable_to_find_chat_template_for_embedding_model() -> let predicate_agent_id = agent_id.clone(); cluster - .agents + .agents_watcher .until_agent(&agent_id, move |snapshot| { snapshot.agents.iter().any(|agent| { agent.id == predicate_agent_id diff --git a/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs b/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs deleted file mode 100644 index 8978d19b..00000000 --- a/paddler_tests/tests/balancer_resolves_buffered_requests_after_agent_killed.rs +++ /dev/null @@ -1,107 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::time::Duration; - -use anyhow::Context as _; -use anyhow::Result; -use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::model_card::ModelCard; -use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn balancer_resolves_buffered_requests_after_agent_killed() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - - let ModelCard { - gpu_layer_count, - reference, - } = qwen3_0_6b(); - - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: vec![AgentConfig { - name: "removal-agent-primary".to_owned(), - slot_count: 2, - }], - wait_for_slots_ready: true, - buffered_request_timeout: Duration::from_secs(120), - max_buffered_requests: 10, - desired_state: Some(BalancerDesiredState { - chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), - model: AgentDesiredModel::HuggingFace(reference), - multimodal_projection: AgentDesiredModel::None, - use_chat_template_override: false, - }), - ..SubprocessClusterParams::default() - }) - .await?; - - let mut secondary_agent = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("removal-agent-secondary".to_owned()), - slots: 2, - })?; - - cluster - .agents - .until(|snapshot| snapshot.agents.len() == 2) - .await - .context("both agents should register")?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut streams = Vec::new(); - - for _ in 0..3 { - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 10, - raw_prompt: "Hello".to_owned(), - }) - .await?; - - streams.push(stream); - } - - secondary_agent.start_kill()?; - secondary_agent.wait().await?; - - let mut total_resolved = 0; - - for mut stream in streams { - if let Some(item) = stream.next().await { - match item { - Ok(Message::Response(_) | Message::Error(_)) | Err(_) => total_resolved += 1, - } - } - } - - assert_eq!( - total_resolved, 3, - "all 3 buffered requests must resolve after one agent is killed" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs b/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs index 3d661ad0..1b59adef 100644 --- a/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs +++ b/paddler_tests/tests/balancer_returns_503_when_request_buffering_disabled.rs @@ -1,45 +1,35 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_503_when_request_buffering_disabled() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), max_buffered_requests: 0, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let mut agent_child = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("buffer-disabled-agent".to_owned()), - slots: 2, + cluster.spawn_additional_agent(&AgentConfig { + name: "buffer-disabled-agent".to_owned(), + slot_count: 2, })?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), @@ -60,9 +50,6 @@ async fn balancer_returns_503_when_request_buffering_disabled() -> Result<()> { } } - agent_child.start_kill()?; - agent_child.wait().await?; - cluster.shutdown().await?; Ok(()) diff --git a/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs b/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs index 9e7555ea..9b94cebd 100644 --- a/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs +++ b/paddler_tests/tests/balancer_returns_504_when_inference_item_timeout_is_zero.rs @@ -1,58 +1,49 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_504_when_inference_item_timeout_is_zero() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 2), inference_item_timeout: Duration::ZERO, wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), diff --git a/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs b/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs index 9a3071b5..14352284 100644 --- a/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs +++ b/paddler_tests/tests/balancer_returns_504_when_no_agents_registered.rs @@ -1,31 +1,26 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_504_when_no_agents_registered() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), @@ -37,15 +32,12 @@ async fn balancer_returns_504_when_no_agents_registered() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), diff --git a/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs b/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs index 041910ce..d04eea6e 100644 --- a/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs +++ b/paddler_tests/tests/balancer_returns_504_when_no_model_configured.rs @@ -1,29 +1,22 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn balancer_returns_504_when_no_model_configured() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), diff --git a/paddler_tests/tests/balancer_serves_inference_over_websocket.rs b/paddler_tests/tests/balancer_serves_inference_over_websocket.rs new file mode 100644 index 00000000..a271ba55 --- /dev/null +++ b/paddler_tests/tests/balancer_serves_inference_over_websocket.rs @@ -0,0 +1,55 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use futures_util::StreamExt as _; +use paddler_messaging::inference_client::message::Message as InferenceMessage; +use paddler_messaging::inference_client::response::Response; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_serves_inference_over_websocket() -> Result<()> { + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 1)).await?; + + let mut stream = cluster + .paddler_client + .inference() + .continue_from_raw_prompt(ContinueFromRawPromptParams { + grammar: None, + max_tokens: 16, + raw_prompt: "The capital of France is".to_owned(), + }) + .await + .map_err(anyhow::Error::new)?; + + let mut token_count: usize = 0; + + while let Some(message_result) = stream.next().await { + match message_result.map_err(anyhow::Error::new)? { + InferenceMessage::Response(envelope) => match envelope.response { + Response::GeneratedToken(generated_token_result) => { + if generated_token_result.is_token() { + token_count += 1; + } + } + Response::Embedding(_) | Response::Timeout | Response::TooManyBufferedRequests => { + panic!("inference over websocket produced an unexpected response variant") + } + }, + InferenceMessage::Error(envelope) => { + panic!( + "inference over websocket failed: code {}, description {:?}", + envelope.error.code, envelope.error.description + ) + } + } + } + + assert!(token_count > 0); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs b/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs index e584be7d..53f999d2 100644 --- a/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs +++ b/paddler_tests/tests/balancer_serves_request_after_agent_with_capacity_registers.rs @@ -1,60 +1,50 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use std::time::Duration; use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_client::message::Message; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::spawn_agent_subprocess::spawn_agent_subprocess; -use paddler_tests::spawn_agent_subprocess_params::SpawnAgentSubprocessParams; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::inference_client::Message; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn balancer_serves_request_after_agent_with_capacity_registers() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + let mut cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, buffered_request_timeout: Duration::from_millis(50), max_buffered_requests: 10, desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut early_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut early_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), @@ -75,22 +65,21 @@ async fn balancer_serves_request_after_agent_with_capacity_registers() -> Result } } - let mut agent_child = spawn_agent_subprocess(SpawnAgentSubprocessParams { - management_addr: cluster.addresses.management, - name: Some("capacity-agent".to_owned()), - slots: 4, + cluster.spawn_additional_agent(&AgentConfig { + name: "capacity-agent".to_owned(), + slot_count: 4, })?; cluster - .agents + .agents_watcher .until(|snapshot| { snapshot.agents.len() == 1 && snapshot.agents.iter().any(|agent| agent.slots_total >= 4) }) .await .context("agent should register with 4 slots")?; - let mut later_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut later_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), @@ -112,9 +101,6 @@ async fn balancer_serves_request_after_agent_with_capacity_registers() -> Result Message::Response(_) => {} } - agent_child.start_kill()?; - agent_child.wait().await?; - cluster.shutdown().await?; Ok(()) diff --git a/paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs b/paddler_tests/tests/balancer_shutdown_with_open_sse_subscriber_completes_within_one_second.rs similarity index 72% rename from paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs rename to paddler_tests/tests/balancer_shutdown_with_open_sse_subscriber_completes_within_one_second.rs index 0fa01396..8a11fcf6 100644 --- a/paddler_tests/tests/balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second.rs +++ b/paddler_tests/tests/balancer_shutdown_with_open_sse_subscriber_completes_within_one_second.rs @@ -3,17 +3,16 @@ use std::time::Duration; use anyhow::Result; use anyhow::anyhow; use futures_util::StreamExt as _; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; use tokio::time::timeout; #[tokio::test(flavor = "multi_thread")] -async fn balancer_in_process_shutdown_with_open_sse_subscriber_completes_within_one_second() --> Result<()> { - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: None, +async fn balancer_shutdown_with_open_sse_subscriber_completes_within_one_second() -> Result<()> { + let cluster = start_cluster(ClusterParams { + agents: Vec::new(), wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs b/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs index ab8c85fa..61efabd8 100644 --- a/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs +++ b/paddler_tests/tests/chat_template_drains_in_flight_inference_before_swap.rs @@ -1,36 +1,27 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::collect_generated_tokens::collect_generated_tokens; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, @@ -43,17 +34,20 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { content: "PREFIX:{{ messages[0].content }}".to_owned(), }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(template_a.clone()), - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference.clone()), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -63,11 +57,8 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let in_flight_stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let in_flight_stream = cluster + .continue_from_conversation_history_stream(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("The capital of France is".to_owned()), @@ -83,7 +74,10 @@ async fn chat_template_drains_in_flight_inference_before_swap() -> Result<()> { let swap_state = BalancerDesiredState { chat_template_override: Some(template_b.clone()), - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, diff --git a/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs b/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs index 3f8cf97b..f46a9ac1 100644 --- a/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs +++ b/paddler_tests/tests/chat_template_override_applied_to_embedding_model.rs @@ -1,19 +1,16 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::nomic_embed_text_v1_5::nomic_embed_text_v1_5; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::inference_parameters::InferenceParameters; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] @@ -24,7 +21,7 @@ async fn chat_template_override_applied_to_embedding_model() -> Result<()> { content: "{{ messages[0].content }}".to_owned(), }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: false, desired_state: Some(BalancerDesiredState { @@ -34,7 +31,7 @@ async fn chat_template_override_applied_to_embedding_model() -> Result<()> { multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs b/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs index f759a776..50deddfb 100644 --- a/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs +++ b/paddler_tests/tests/chat_template_override_replaces_model_builtin.rs @@ -1,34 +1,24 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn chat_template_override_replaces_model_builtin() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, @@ -38,17 +28,20 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { content: "{{ messages[0].content }}".to_owned(), }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(chat_template.clone()), - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -68,11 +61,8 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { assert_eq!(retrieved, Some(chat_template)); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("The capital of France is".to_owned()), @@ -86,8 +76,6 @@ async fn chat_template_override_replaces_model_builtin() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let received_tokens = collected .token_results .iter() diff --git a/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs b/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs index fcc7fa72..ba228789 100644 --- a/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs +++ b/paddler_tests/tests/chat_template_swaps_between_inference_calls.rs @@ -1,30 +1,29 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] + +use std::future::Future; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::chat_template::ChatTemplate; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster::Cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::chat_template::ChatTemplate; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; - -async fn run_inference_after_template_swap(inference_client: &InferenceHttpClient) -> Result { - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { +use paddler_tests::start_cluster::start_cluster; + +fn run_inference_after_template_swap( + cluster: &Cluster, +) -> impl Future> + Send + use<> { + let generation = + cluster.continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("The capital of France is".to_owned()), @@ -35,24 +34,21 @@ async fn run_inference_after_template_swap(inference_client: &InferenceHttpClien max_tokens: 10, parse_tool_calls: false, tools: vec![], - }) - .await?; + }); - let collected = collect_generated_tokens(stream).await?; + async move { + let collected = generation.await?; - Ok(collected - .token_results - .iter() - .any(|result| result.token_result.is_token())) + Ok(collected + .token_results + .iter() + .any(|result| result.token_result.is_token())) + } } #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn chat_template_swaps_between_inference_calls() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, @@ -65,17 +61,20 @@ async fn chat_template_swaps_between_inference_calls() -> Result<()> { content: "PREFIX:{{ messages[0].content }}".to_owned(), }; - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: AgentConfig::uniform(1, 1), wait_for_slots_ready: true, desired_state: Some(BalancerDesiredState { chat_template_override: Some(template_a.clone()), - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference.clone()), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, }), - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -85,17 +84,17 @@ async fn chat_template_swaps_between_inference_calls() -> Result<()> { .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - assert!( - run_inference_after_template_swap(&inference_client).await?, + run_inference_after_template_swap(&cluster).await?, "first inference with template_a must produce tokens" ); let swap_state = BalancerDesiredState { chat_template_override: Some(template_b.clone()), - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: true, @@ -109,7 +108,7 @@ async fn chat_template_swaps_between_inference_calls() -> Result<()> { .map_err(anyhow::Error::new)?; assert!( - run_inference_after_template_swap(&inference_client).await?, + run_inference_after_template_swap(&cluster).await?, "inference after swap must produce tokens with template_b" ); diff --git a/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs b/paddler_tests/tests/cluster_lifecycle_under_concurrent_load_does_not_hang.rs similarity index 75% rename from paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs rename to paddler_tests/tests/cluster_lifecycle_under_concurrent_load_does_not_hang.rs index 6a9e5639..a71bee78 100644 --- a/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs +++ b/paddler_tests/tests/cluster_lifecycle_under_concurrent_load_does_not_hang.rs @@ -1,22 +1,20 @@ -#![cfg(feature = "tests_that_use_in_process_cluster")] - use std::time::Duration; use std::time::Instant; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; use tokio::task::JoinSet; use tokio::time::timeout; const CONCURRENT_LIFECYCLES: usize = 8; const LIFECYCLES_PER_TASK: usize = 5; -const TOTAL_BUDGET: Duration = Duration::from_secs(60); +const TOTAL_BUDGET: Duration = Duration::from_mins(1); #[tokio::test(flavor = "multi_thread")] -async fn in_process_cluster_lifecycle_under_concurrent_load_does_not_hang() -> Result<()> { +async fn cluster_lifecycle_under_concurrent_load_does_not_hang() -> Result<()> { let started_at = Instant::now(); let mut join_set: JoinSet> = JoinSet::new(); @@ -24,9 +22,9 @@ async fn in_process_cluster_lifecycle_under_concurrent_load_does_not_hang() -> R for _ in 0..CONCURRENT_LIFECYCLES { join_set.spawn(async move { for _ in 0..LIFECYCLES_PER_TASK { - let cluster = start_in_process_cluster(InProcessClusterParams { + let cluster = start_cluster(ClusterParams { wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; cluster.shutdown().await?; diff --git a/paddler_tests/tests/in_process_cluster_shutdown_completes_within_five_seconds.rs b/paddler_tests/tests/cluster_shutdown_completes_within_five_seconds.rs similarity index 64% rename from paddler_tests/tests/in_process_cluster_shutdown_completes_within_five_seconds.rs rename to paddler_tests/tests/cluster_shutdown_completes_within_five_seconds.rs index 796e755c..4015e877 100644 --- a/paddler_tests/tests/in_process_cluster_shutdown_completes_within_five_seconds.rs +++ b/paddler_tests/tests/cluster_shutdown_completes_within_five_seconds.rs @@ -2,17 +2,17 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; use tokio::time::timeout; const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5); #[tokio::test(flavor = "multi_thread")] -async fn in_process_cluster_shutdown_completes_within_five_seconds() -> Result<()> { - let cluster = start_in_process_cluster(InProcessClusterParams { +async fn cluster_shutdown_completes_within_five_seconds() -> Result<()> { + let cluster = start_cluster(ClusterParams { wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/in_process_cluster_shutdown_returns_fd_count_to_baseline.rs b/paddler_tests/tests/cluster_shutdown_returns_fd_count_to_baseline.rs similarity index 58% rename from paddler_tests/tests/in_process_cluster_shutdown_returns_fd_count_to_baseline.rs rename to paddler_tests/tests/cluster_shutdown_returns_fd_count_to_baseline.rs index e4669ac9..0d56cc33 100644 --- a/paddler_tests/tests/in_process_cluster_shutdown_returns_fd_count_to_baseline.rs +++ b/paddler_tests/tests/cluster_shutdown_returns_fd_count_to_baseline.rs @@ -1,17 +1,17 @@ #![cfg(any(target_os = "macos", target_os = "linux"))] use anyhow::Result; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::resource_snapshot::ResourceSnapshot; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::resource_snapshot::ResourceSnapshot; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] -async fn in_process_cluster_shutdown_returns_fd_count_to_baseline() -> Result<()> { +async fn cluster_shutdown_returns_fd_count_to_baseline() -> Result<()> { let before = ResourceSnapshot::try_from_self()?; - let cluster = start_in_process_cluster(InProcessClusterParams { + let cluster = start_cluster(ClusterParams { wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; cluster.shutdown().await?; diff --git a/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs b/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs index 971f38e6..5303b38b 100644 --- a/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs +++ b/paddler_tests/tests/continuous_batch_concurrent_conversation_history_requests_complete.rs @@ -1,17 +1,14 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; fn user_message(text: &str) -> ConversationMessage { ConversationMessage { @@ -23,38 +20,29 @@ fn user_message(text: &str) -> ConversationMessage { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_concurrent_conversation_history_requests_complete() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream_a = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new(vec![user_message("What is 2+2?")]), - enable_thinking: false, - grammar: None, - max_tokens: 20, - parse_tool_calls: false, - tools: vec![], - }) - .await?; - - let stream_b = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new(vec![user_message("Name a color")]), - enable_thinking: false, - grammar: None, - max_tokens: 20, - parse_tool_calls: false, - tools: vec![], - }) - .await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(2)]).await?; + let params_a = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![user_message("What is 2+2?")]), + enable_thinking: false, + grammar: None, + max_tokens: 20, + parse_tool_calls: false, + tools: vec![], + }; + let params_b = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![user_message("Name a color")]), + enable_thinking: false, + grammar: None, + max_tokens: 20, + parse_tool_calls: false, + tools: vec![], + }; let (results_a, results_b) = tokio::join!( - collect_generated_tokens(stream_a), - collect_generated_tokens(stream_b), + cluster.continue_from_conversation_history(¶ms_a), + cluster.continue_from_conversation_history(¶ms_b), ); let collected_a = results_a?; diff --git a/paddler_tests/tests/continuous_batch_distinct_output.rs b/paddler_tests/tests/continuous_batch_distinct_output.rs index 9c4d43ef..e104401e 100644 --- a/paddler_tests/tests/continuous_batch_distinct_output.rs +++ b/paddler_tests/tests/continuous_batch_distinct_output.rs @@ -1,26 +1,19 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn two_concurrent_prompts_produce_distinct_outputs() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, @@ -28,45 +21,39 @@ async fn two_concurrent_prompts_produce_distinct_outputs() -> Result<()> { let desired_state = BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }; - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 2, - }), - desired_state, + }], + desired_state: Some(desired_state), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream_a = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 20, - raw_prompt: "Count from one to ten in English: one, two,".to_owned(), - }) - .await?; - - let stream_b = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 20, - raw_prompt: "The capital of France is".to_owned(), - }) - .await?; - + let params_a = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 20, + raw_prompt: "Count from one to ten in English: one, two,".to_owned(), + }; + let params_b = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 20, + raw_prompt: "The capital of France is".to_owned(), + }; let (collected_a, collected_b) = tokio::join!( - collect_generated_tokens(stream_a), - collect_generated_tokens(stream_b), + cluster.continue_from_raw_prompt(¶ms_a), + cluster.continue_from_raw_prompt(¶ms_b), ); let collected_a = collected_a?; diff --git a/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs b/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs index 1d5e3115..41d513be 100644 --- a/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs +++ b/paddler_tests/tests/continuous_batch_evicts_long_sequence_under_kv_pressure.rs @@ -1,80 +1,67 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_evicts_long_sequence_under_kv_pressure() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); + let mut inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }; inference_parameters.n_batch = 256; inference_parameters.context_size = 256; inference_parameters.temperature = 0.0; - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 2, - }), - desired_state: BalancerDesiredState { + }], + desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let long_prompt = "Describe in great detail how the process of photosynthesis works in plants. Cover the light-dependent reactions, the Calvin cycle, the role of chlorophyll, the thylakoid membrane, and the stroma. Explain how water and carbon dioxide are converted to glucose and oxygen. Discuss the evolutionary history of this process and its importance throughout the biosphere, and then give a long essay response."; - let long_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 200, - raw_prompt: long_prompt.to_owned(), - }) - .await?; - - let short_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 20, - raw_prompt: "Hi".to_owned(), - }) - .await?; - + let long_params = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 200, + raw_prompt: long_prompt.to_owned(), + }; + let short_params = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 20, + raw_prompt: "Hi".to_owned(), + }; let (long_collected, short_collected) = tokio::join!( - collect_generated_tokens(long_stream), - collect_generated_tokens(short_stream), + cluster.continue_from_raw_prompt(&long_params), + cluster.continue_from_raw_prompt(&short_params), ); let long_collected = long_collected?; diff --git a/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs b/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs index e5fc6a90..c671d712 100644 --- a/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs +++ b/paddler_tests/tests/continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes.rs @@ -1,69 +1,60 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::kv_cache_dtype::KvCacheDtype; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::kv_cache_dtype::KvCacheDtype; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_generates_tokens_with_distinct_k_and_v_cache_dtypes() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let mut inference_parameters = device.inference_parameters_for_full_offload(gpu_layer_count); + let mut inference_parameters = InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }; - inference_parameters.k_cache_dtype = KvCacheDtype::Q8_0; + inference_parameters.k_cache_dtype = KvCacheDtype::Q80; inference_parameters.v_cache_dtype = KvCacheDtype::F16; - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), - desired_state: BalancerDesiredState { + }], + desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 8, raw_prompt: "Count from 1 to 3:".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs b/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs index 20227f34..79079e4b 100644 --- a/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs +++ b/paddler_tests/tests/continuous_batch_generates_tokens_with_partial_layer_offload.rs @@ -1,65 +1,55 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; const PARTIAL_GPU_LAYER_COUNT: u32 = 14; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_generates_tokens_with_partial_layer_offload() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { reference, .. } = qwen3_0_6b(); - let inference_parameters = - device.inference_parameters_for_full_offload(PARTIAL_GPU_LAYER_COUNT); + let inference_parameters = InferenceParameters { + n_gpu_layers: PARTIAL_GPU_LAYER_COUNT, + ..InferenceParameters::default() + }; - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), - desired_state: BalancerDesiredState { + }], + desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 16, raw_prompt: "Count from 1 to 5:".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs b/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs index a6f37ab7..8f4bd15c 100644 --- a/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs +++ b/paddler_tests/tests/continuous_batch_long_and_short_prompts_complete_concurrently.rs @@ -1,44 +1,32 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_long_and_short_prompts_complete_concurrently() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(2)]).await?; let long_prompt = "Photosynthesis is the process by which green plants and certain other organisms transform light energy into chemical energy. During photosynthesis in green plants, light energy is captured and used to convert water, carbon dioxide, and minerals into oxygen and energy-rich organic compounds. Explain the process in detail:".to_owned(); - let long_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 20, - raw_prompt: long_prompt, - }) - .await?; - - let short_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 20, - raw_prompt: "Hi".to_owned(), - }) - .await?; - + let long_params = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 20, + raw_prompt: long_prompt, + }; + let short_params = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 20, + raw_prompt: "Hi".to_owned(), + }; let (long_collected, short_collected) = tokio::join!( - collect_generated_tokens(long_stream), - collect_generated_tokens(short_stream), + cluster.continue_from_raw_prompt(&long_params), + cluster.continue_from_raw_prompt(&short_params), ); let long_collected = long_collected?; diff --git a/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs b/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs index 234d7f35..d4741a88 100644 --- a/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs +++ b/paddler_tests/tests/continuous_batch_plain_and_multimodal_run_concurrently.rs @@ -1,37 +1,23 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(4), true).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let plain_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 64, - raw_prompt: "Write a long poem about the sea.".to_owned(), - }) - .await?; + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(4)], true).await?; let image_data_uri = load_test_image_data_uri()?; @@ -63,21 +49,23 @@ async fn continuous_batch_plain_and_multimodal_run_concurrently() -> Result<()> }, ]); - let multimodal_stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: multimodal_conversation, - enable_thinking: false, - grammar: None, - max_tokens: 32, - parse_tool_calls: false, - tools: vec![], - }) - .await?; - + let plain_params = ContinueFromRawPromptParams { + grammar: None, + max_tokens: 64, + raw_prompt: "Write a long poem about the sea.".to_owned(), + }; + let multimodal_params = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: multimodal_conversation, + enable_thinking: false, + grammar: None, + max_tokens: 32, + parse_tool_calls: false, + tools: vec![], + }; let (plain_collected, multimodal_collected) = tokio::join!( - collect_generated_tokens(plain_stream), - collect_generated_tokens(multimodal_stream), + cluster.continue_from_raw_prompt(&plain_params), + cluster.continue_from_conversation_history(&multimodal_params), ); let plain_collected = plain_collected?; diff --git a/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs b/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs index 099e2baa..62b36bb1 100644 --- a/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs +++ b/paddler_tests/tests/continuous_batch_rejects_embedding_during_active_generation.rs @@ -2,27 +2,21 @@ use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::request_params::ContinueFromRawPromptParams; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::collect_generated_tokens::collect_generated_tokens; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_rejects_embedding_during_active_generation() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(2)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut generation_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut generation_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 50, raw_prompt: "Tell me a long story about a cat".to_owned(), @@ -34,8 +28,8 @@ async fn continuous_batch_rejects_embedding_during_active_generation() -> Result .await .ok_or_else(|| anyhow::anyhow!("generation stream must yield at least one message"))?; - let embedding_outcome = inference_client - .post_generate_embedding_batch(&GenerateEmbeddingBatchParams { + let embedding_outcome = cluster + .generate_embedding_batch(&GenerateEmbeddingBatchParams { input_batch: vec![EmbeddingInputDocument { content: "test".to_owned(), id: "doc1".to_owned(), @@ -44,15 +38,11 @@ async fn continuous_batch_rejects_embedding_during_active_generation() -> Result }) .await; - if let Ok(embedding_stream) = embedding_outcome { - let collected = collect_embedding_results(embedding_stream).await; - - if let Ok(collected) = collected { - assert!( - !collected.errors.is_empty() || collected.embeddings.is_empty(), - "embedding request must fail when text-only model is busy generating" - ); - } + if let Ok(collected) = embedding_outcome { + assert!( + !collected.errors.is_empty() || collected.embeddings.is_empty(), + "embedding request must fail when text-only model is busy generating" + ); } let _drained = collect_generated_tokens(generation_stream).await; diff --git a/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs b/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs index 2a64475a..ebac879b 100644 --- a/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs +++ b/paddler_tests/tests/continuous_batch_rejects_second_request_when_only_slot_busy.rs @@ -3,47 +3,42 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_rejects_second_request_when_only_slot_busy() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, } = qwen3_0_6b(); - let mut cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let mut cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), + }], max_buffered_requests: 0, - desired_state: BalancerDesiredState { + desired_state: Some(BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -53,11 +48,8 @@ async fn continuous_batch_rejects_second_request_when_only_slot_busy() -> Result .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut first_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut first_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 100, raw_prompt: "Tell me a long story about an explorer".to_owned(), @@ -70,27 +62,18 @@ async fn continuous_batch_rejects_second_request_when_only_slot_busy() -> Result .context("first stream must yield at least one message")?; cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) + .wait_for_slots_processing(&agent_id, 1) .await .context("first request should occupy the only slot")?; - let second_outcome = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let second_failed = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), }) - .await; - - let second_failed = match second_outcome { - Err(_) => true, - Ok(stream) => { - let collected = collect_generated_tokens(stream).await; - - collected.is_err() - } - }; + .await + .is_err(); assert!( second_failed, diff --git a/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs b/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs index f078eaab..f759a46a 100644 --- a/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs +++ b/paddler_tests/tests/continuous_batch_releases_slot_when_client_disconnects.rs @@ -3,17 +3,14 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_releases_slot_when_client_disconnects() -> Result<()> { - let mut cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let mut cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; let agent_id = cluster .agent_ids @@ -21,11 +18,8 @@ async fn continuous_batch_releases_slot_when_client_disconnects() -> Result<()> .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 500, raw_prompt: "Write a long story about an explorer".to_owned(), @@ -40,16 +34,14 @@ async fn continuous_batch_releases_slot_when_client_disconnects() -> Result<()> drop(first_message); cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) + .wait_for_slots_processing(&agent_id, 1) .await .context("first request should occupy the only slot")?; drop(stream); cluster - .agents - .until(assert_slots_processing(&agent_id, 0)) + .wait_for_slots_processing(&agent_id, 0) .await .context("slot should be released after the HTTP client disconnects")?; diff --git a/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs b/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs index cc3f9e2c..526026b9 100644 --- a/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs +++ b/paddler_tests/tests/continuous_batch_releases_slots_on_shutdown_with_active_request.rs @@ -2,22 +2,17 @@ use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_releases_slots_on_shutdown_with_active_request() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 500, raw_prompt: "Write a long essay".to_owned(), diff --git a/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs b/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs index 0c07db46..615a70e1 100644 --- a/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs +++ b/paddler_tests/tests/continuous_batch_reuses_slot_after_request_completes.rs @@ -1,33 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_reuses_slot_after_request_completes() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let first_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let first_collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello world".to_owned(), }) .await?; - let first_collected = collect_generated_tokens(first_stream).await?; - assert!(matches!( first_collected.token_results.last(), Some(TokenResultWithProducer { @@ -36,16 +28,14 @@ async fn continuous_batch_reuses_slot_after_request_completes() -> Result<()> { }) )); - let second_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let second_collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Goodbye world".to_owned(), }) .await?; - let second_collected = collect_generated_tokens(second_stream).await?; - assert!(matches!( second_collected.token_results.last(), Some(TokenResultWithProducer { diff --git a/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs b/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs index 7dd5cfff..6c95e494 100644 --- a/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs +++ b/paddler_tests/tests/continuous_batch_serves_four_concurrent_requests.rs @@ -1,44 +1,28 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_serves_four_concurrent_requests() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(4)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(4)]).await?; let prompts = ["The sky is", "Roses are", "Once upon", "In the year"]; - let stream_results = futures_util::future::try_join_all(prompts.into_iter().map(|prompt| { - let inference_client = &inference_client; - - async move { - inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { - grammar: None, - max_tokens: 8, - raw_prompt: prompt.to_owned(), - }) - .await - } + let collected_results = futures_util::future::try_join_all(prompts.into_iter().map(|prompt| { + cluster.continue_from_raw_prompt(&ContinueFromRawPromptParams { + grammar: None, + max_tokens: 8, + raw_prompt: prompt.to_owned(), + }) })) .await?; - let collect_tasks = stream_results.into_iter().map(collect_generated_tokens); - - let collected_results = futures_util::future::try_join_all(collect_tasks).await?; - assert_eq!(collected_results.len(), 4); for collected in &collected_results { diff --git a/paddler_tests/tests/continuous_batch_smoke.rs b/paddler_tests/tests/continuous_batch_smoke.rs index f91c3f9c..c2072237 100644 --- a/paddler_tests/tests/continuous_batch_smoke.rs +++ b/paddler_tests/tests/continuous_batch_smoke.rs @@ -2,30 +2,21 @@ use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_smoke_generates_tokens() -> Result<()> { - let device = current_test_device()?; - - device - .require_available() - .context("selected device is unavailable")?; - let ModelCard { gpu_layer_count, reference, @@ -33,29 +24,29 @@ async fn continuous_batch_smoke_generates_tokens() -> Result<()> { let desired_state = BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }; - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), - desired_state, + }], + desired_state: Some(desired_state), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await .context("failed to start in-process cluster with Qwen3 0.6B")?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 16, raw_prompt: "Count from 1 to 5:".to_owned(), @@ -63,19 +54,13 @@ async fn continuous_batch_smoke_generates_tokens() -> Result<()> { .await .context("failed to POST /api/v1/continue_from_raw_prompt")?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() .filter(|result| result.token_result.is_token()) .count(); - assert!( - token_count > 0, - "smoke test on {} produced no tokens", - device.name() - ); + assert!(token_count > 0, "smoke test produced no tokens"); assert!( matches!( diff --git a/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs b/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs index 307cd479..ff352cb8 100644 --- a/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs +++ b/paddler_tests/tests/continuous_batch_stop_signal_terminates_generation_before_max_tokens.rs @@ -3,17 +3,14 @@ use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stop_signal_terminates_generation_before_max_tokens() -> Result<()> { - let mut cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let mut cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; let agent_id = cluster .agent_ids @@ -21,11 +18,8 @@ async fn continuous_batch_stop_signal_terminates_generation_before_max_tokens() .context("cluster must have one registered agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 500, raw_prompt: "Write a very long story about a dragon".to_owned(), @@ -38,16 +32,14 @@ async fn continuous_batch_stop_signal_terminates_generation_before_max_tokens() .context("inference stream must yield at least one message")?; cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) + .wait_for_slots_processing(&agent_id, 1) .await .context("slot should be occupied while the request is in flight")?; drop(stream); cluster - .agents - .until(assert_slots_processing(&agent_id, 0)) + .wait_for_slots_processing(&agent_id, 0) .await .context("dropping the stream must terminate generation before max_tokens is reached")?; diff --git a/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs b/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs index 61423f02..32e03a43 100644 --- a/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs +++ b/paddler_tests/tests/continuous_batch_stops_at_max_tokens_boundary.rs @@ -1,33 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stops_at_max_tokens_boundary() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 5, raw_prompt: "Count from one to one hundred:".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs b/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs index 445db487..06382479 100644 --- a/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs +++ b/paddler_tests/tests/continuous_batch_stops_generation_when_stop_sender_dropped.rs @@ -2,25 +2,19 @@ use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_stops_generation_when_stop_sender_dropped() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(2)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let mut first_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let mut first_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 500, raw_prompt: "Write a long essay about photosynthesis".to_owned(), @@ -34,16 +28,14 @@ async fn continuous_batch_stops_generation_when_stop_sender_dropped() -> Result< drop(first_stream); - let second_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let second_collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 10, raw_prompt: "Hello".to_owned(), }) .await?; - let second_collected = collect_generated_tokens(second_stream).await?; - assert!(matches!( second_collected.token_results.last(), Some(TokenResultWithProducer { diff --git a/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs b/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs index 7de37fd7..a5d94a8e 100644 --- a/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs +++ b/paddler_tests/tests/continuous_batch_two_concurrent_multimodal_requests_produce_tokens.rs @@ -1,20 +1,17 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; fn build_multimodal_conversation(image_data_uri: &str) -> ConversationHistory { ConversationHistory::new(vec![ @@ -49,40 +46,31 @@ fn build_multimodal_conversation(image_data_uri: &str) -> ConversationHistory { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn continuous_batch_two_concurrent_multimodal_requests_produce_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(4), true).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(4)], true).await?; let image_data_uri = load_test_image_data_uri()?; - let stream_a = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: build_multimodal_conversation(&image_data_uri), - enable_thinking: false, - grammar: None, - max_tokens: 32, - parse_tool_calls: false, - tools: vec![], - }) - .await?; - - let stream_b = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: build_multimodal_conversation(&image_data_uri), - enable_thinking: false, - grammar: None, - max_tokens: 32, - parse_tool_calls: false, - tools: vec![], - }) - .await?; - + let params_a = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: build_multimodal_conversation(&image_data_uri), + enable_thinking: false, + grammar: None, + max_tokens: 32, + parse_tool_calls: false, + tools: vec![], + }; + let params_b = ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: build_multimodal_conversation(&image_data_uri), + enable_thinking: false, + grammar: None, + max_tokens: 32, + parse_tool_calls: false, + tools: vec![], + }; let (collected_a, collected_b) = tokio::join!( - collect_generated_tokens(stream_a), - collect_generated_tokens(stream_b), + cluster.continue_from_conversation_history(¶ms_a), + cluster.continue_from_conversation_history(¶ms_b), ); let collected_a = collected_a?; diff --git a/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs index 96fe5bbe..8f30c61c 100644 --- a/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_deepseek_r1_distill_llama_8b::start_in_process_cluster_with_deepseek_r1_distill_llama_8b; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_deepseek_r1_distill_llama_8b::start_cluster_with_deepseek_r1_distill_llama_8b; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens() -> Result<()> { let cluster = - start_in_process_cluster_with_deepseek_r1_distill_llama_8b(AgentConfig::single(1)).await?; + start_cluster_with_deepseek_r1_distill_llama_8b(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -38,8 +32,6 @@ async fn deepseek_r1_distill_llama_8b_internal_endpoint_emits_reasoning_tokens() }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs b/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs index f3245126..fc2ce011 100644 --- a/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs +++ b/paddler_tests/tests/endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters.rs @@ -1,17 +1,17 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::embedding_input_document::EmbeddingInputDocument; +use paddler_messaging::embedding_normalization_method::EmbeddingNormalizationMethod; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::generate_embedding_batch_params::GenerateEmbeddingBatchParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; +use paddler_tests::start_cluster::start_cluster; use reqwest::Client; use reqwest::StatusCode; @@ -20,24 +20,24 @@ use reqwest::StatusCode; async fn endpoint_rejects_embedding_request_when_embeddings_disabled_in_parameters() -> Result<()> { let ModelCard { reference, .. } = qwen3_0_6b(); - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), - desired_state: BalancerDesiredState { + }], + desired_state: Some(BalancerDesiredState { chat_template_override: None, inference_parameters: InferenceParameters::default(), model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, - }, + }), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; - let inference_base_url = cluster.addresses.inference_base_url()?; + let inference_base_url = cluster.balancer.addresses.inference_base_url()?; let request_url = inference_base_url.join("api/v1/generate_embedding_batch")?; let response = Client::new() diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs index 67a9e0e4..b3fd2e41 100644 --- a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_gemma_4::start_in_process_cluster_with_gemma_4; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_gemma_4::start_cluster_with_gemma_4; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn gemma4_internal_endpoint_emits_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_gemma_4(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_gemma_4(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -37,8 +31,6 @@ async fn gemma4_internal_endpoint_emits_reasoning_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs index 235c4679..e8339c35 100644 --- a/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_gemma_4_and_mmproj::start_in_process_cluster_with_gemma_4_and_mmproj; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_gemma_4_and_mmproj::start_cluster_with_gemma_4_and_mmproj; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { - let cluster = start_in_process_cluster_with_gemma_4_and_mmproj(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_gemma_4_and_mmproj(vec![AgentConfig::single(1)]).await?; let image_data_uri = load_test_image_data_uri()?; @@ -39,8 +33,8 @@ async fn gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request() -> role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -51,8 +45,6 @@ async fn gemma4_internal_endpoint_emits_reasoning_tokens_for_image_request() -> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs index 4b2dccdc..7690cc86 100644 --- a/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/gemma4_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_gemma_4::start_in_process_cluster_with_gemma_4; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_gemma_4::start_cluster_with_gemma_4; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn gemma4_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_gemma_4(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_gemma_4(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn gemma4_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn gemma4_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs index 513d0156..1a8f5ebb 100644 --- a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_glm_4_7_flash::start_in_process_cluster_with_glm_4_7_flash; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_glm_4_7_flash::start_cluster_with_glm_4_7_flash; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn glm_4_7_flash_internal_endpoint_emits_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_glm_4_7_flash(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_glm_4_7_flash(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -37,8 +31,6 @@ async fn glm_4_7_flash_internal_endpoint_emits_reasoning_tokens() -> Result<()> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs index eb5d9159..5312705a 100644 --- a/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_glm_4_7_flash::start_in_process_cluster_with_glm_4_7_flash; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_glm_4_7_flash::start_cluster_with_glm_4_7_flash; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_glm_4_7_flash(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_glm_4_7_flash(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event() -> Resul serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn glm_4_7_flash_internal_endpoint_emits_tool_call_parsed_event() -> Resul }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/harness_agents_watcher.rs b/paddler_tests/tests/harness_agents_watcher.rs index e5f17292..dd473329 100644 --- a/paddler_tests/tests/harness_agents_watcher.rs +++ b/paddler_tests/tests/harness_agents_watcher.rs @@ -3,13 +3,12 @@ use std::collections::BTreeSet; use anyhow::Result; use anyhow::anyhow; use futures_util::stream; -use paddler_tests::agents_status::assert_slots_total_at_least::assert_slots_total_at_least; -use paddler_tests::agents_stream_watcher::AgentsStreamWatcher; -use paddler_types::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; -use paddler_types::agent_controller_snapshot::AgentControllerSnapshot; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::agent_state_application_status::AgentStateApplicationStatus; +use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use paddler_messaging::agent_controller_snapshot::AgentControllerSnapshot; +use paddler_messaging::agent_issue::AgentIssue; +use paddler_messaging::agent_issue_params::model_path::ModelPath; +use paddler_messaging::agent_state_application_status::AgentStateApplicationStatus; +use paddler_test_cluster_harness::agents_stream_watcher::AgentsStreamWatcher; fn make_snapshot(agent_id: &str, slots_total: i32) -> AgentControllerPoolSnapshot { AgentControllerPoolSnapshot { @@ -42,7 +41,12 @@ async fn until_returns_first_snapshot_matching_predicate() -> Result<()> { let mut watcher = AgentsStreamWatcher::from_stream(Box::pin(fixture)); let snapshot = watcher - .until(assert_slots_total_at_least("agent-a", 1)) + .until(|snapshot| { + snapshot + .agents + .iter() + .any(|agent| agent.id == "agent-a" && agent.slots_total >= 1) + }) .await?; assert_eq!(snapshot.agents.len(), 1); @@ -81,7 +85,12 @@ async fn until_errors_when_stream_closes_before_match() { let mut watcher = AgentsStreamWatcher::from_stream(Box::pin(fixture)); let outcome = watcher - .until(assert_slots_total_at_least("agent-a", 10)) + .until(|snapshot| { + snapshot + .agents + .iter() + .any(|agent| agent.id == "agent-a" && agent.slots_total >= 10) + }) .await; assert!( diff --git a/paddler_tests/tests/harness_in_process_cluster_shutdown.rs b/paddler_tests/tests/harness_cluster_shutdown.rs similarity index 57% rename from paddler_tests/tests/harness_in_process_cluster_shutdown.rs rename to paddler_tests/tests/harness_cluster_shutdown.rs index 7b619834..ef440e77 100644 --- a/paddler_tests/tests/harness_in_process_cluster_shutdown.rs +++ b/paddler_tests/tests/harness_cluster_shutdown.rs @@ -1,13 +1,13 @@ use anyhow::Result; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn empty_cluster_starts_and_shuts_down_without_timeout() -> Result<()> { - let cluster = start_in_process_cluster(InProcessClusterParams { - agent: None, + let cluster = start_cluster(ClusterParams { + agents: Vec::new(), wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -18,9 +18,9 @@ async fn empty_cluster_starts_and_shuts_down_without_timeout() -> Result<()> { #[tokio::test(flavor = "multi_thread")] async fn single_agent_registers_and_shuts_down_without_timeout() -> Result<()> { - let cluster = start_in_process_cluster(InProcessClusterParams { + let cluster = start_cluster(ClusterParams { wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/harness_load_test_image_data_uri_returns_jpeg_base64.rs b/paddler_tests/tests/harness_load_test_image_data_uri_returns_jpeg_base64.rs index 0f7777ca..fe608984 100644 --- a/paddler_tests/tests/harness_load_test_image_data_uri_returns_jpeg_base64.rs +++ b/paddler_tests/tests/harness_load_test_image_data_uri_returns_jpeg_base64.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; #[test] fn harness_load_test_image_data_uri_returns_jpeg_base64() -> Result<()> { diff --git a/paddler_tests/tests/harness_parse_test_device_value_reads_cpu.rs b/paddler_tests/tests/harness_parse_test_device_value_reads_cpu.rs deleted file mode 100644 index a1ba8d9e..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_reads_cpu.rs +++ /dev/null @@ -1,10 +0,0 @@ -use anyhow::Result; -use paddler_tests::parse_test_device_value::parse_test_device_value; -use paddler_tests::test_device::TestDevice; - -#[test] -fn harness_parse_test_device_value_reads_cpu() -> Result<()> { - assert_eq!(parse_test_device_value(Some("cpu"))?, TestDevice::Cpu); - - Ok(()) -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_reads_cuda.rs b/paddler_tests/tests/harness_parse_test_device_value_reads_cuda.rs deleted file mode 100644 index 1627194b..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_reads_cuda.rs +++ /dev/null @@ -1,12 +0,0 @@ -#![cfg(feature = "cuda")] - -use anyhow::Result; -use paddler_tests::parse_test_device_value::parse_test_device_value; -use paddler_tests::test_device::TestDevice; - -#[test] -fn harness_parse_test_device_value_reads_cuda() -> Result<()> { - assert_eq!(parse_test_device_value(Some("cuda"))?, TestDevice::Cuda); - - Ok(()) -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_reads_metal.rs b/paddler_tests/tests/harness_parse_test_device_value_reads_metal.rs deleted file mode 100644 index 75360265..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_reads_metal.rs +++ /dev/null @@ -1,12 +0,0 @@ -#![cfg(feature = "metal")] - -use anyhow::Result; -use paddler_tests::parse_test_device_value::parse_test_device_value; -use paddler_tests::test_device::TestDevice; - -#[test] -fn harness_parse_test_device_value_reads_metal() -> Result<()> { - assert_eq!(parse_test_device_value(Some("metal"))?, TestDevice::Metal); - - Ok(()) -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_rejects_cuda_when_feature_is_not_linked.rs b/paddler_tests/tests/harness_parse_test_device_value_rejects_cuda_when_feature_is_not_linked.rs deleted file mode 100644 index 51c594f0..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_rejects_cuda_when_feature_is_not_linked.rs +++ /dev/null @@ -1,13 +0,0 @@ -#![cfg(not(feature = "cuda"))] - -use paddler_tests::parse_test_device_value::parse_test_device_value; - -#[test] -fn harness_parse_test_device_value_rejects_cuda_when_feature_is_not_linked() { - let result = parse_test_device_value(Some("cuda")); - - assert!( - result.is_err(), - "PADDLER_TEST_DEVICE=cuda must be rejected when the cuda backend is not linked" - ); -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_rejects_metal_when_feature_is_not_linked.rs b/paddler_tests/tests/harness_parse_test_device_value_rejects_metal_when_feature_is_not_linked.rs deleted file mode 100644 index e67b03b5..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_rejects_metal_when_feature_is_not_linked.rs +++ /dev/null @@ -1,13 +0,0 @@ -#![cfg(not(feature = "metal"))] - -use paddler_tests::parse_test_device_value::parse_test_device_value; - -#[test] -fn harness_parse_test_device_value_rejects_metal_when_feature_is_not_linked() { - let result = parse_test_device_value(Some("metal")); - - assert!( - result.is_err(), - "PADDLER_TEST_DEVICE=metal must be rejected when the metal backend is not linked" - ); -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_returns_cpu_for_none.rs b/paddler_tests/tests/harness_parse_test_device_value_returns_cpu_for_none.rs deleted file mode 100644 index ee49bb65..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_returns_cpu_for_none.rs +++ /dev/null @@ -1,10 +0,0 @@ -use anyhow::Result; -use paddler_tests::parse_test_device_value::parse_test_device_value; -use paddler_tests::test_device::TestDevice; - -#[test] -fn harness_parse_test_device_value_returns_cpu_for_none() -> Result<()> { - assert_eq!(parse_test_device_value(None)?, TestDevice::Cpu); - - Ok(()) -} diff --git a/paddler_tests/tests/harness_parse_test_device_value_returns_error_for_unknown_value.rs b/paddler_tests/tests/harness_parse_test_device_value_returns_error_for_unknown_value.rs deleted file mode 100644 index a802597e..00000000 --- a/paddler_tests/tests/harness_parse_test_device_value_returns_error_for_unknown_value.rs +++ /dev/null @@ -1,11 +0,0 @@ -use paddler_tests::parse_test_device_value::parse_test_device_value; - -#[test] -fn harness_parse_test_device_value_returns_error_for_unknown_value() { - let result = parse_test_device_value(Some("vulkan")); - - assert!( - result.is_err(), - "parse_test_device_value should reject unknown values" - ); -} diff --git a/paddler_tests/tests/harness_port_allocation_uniqueness.rs b/paddler_tests/tests/harness_port_allocation_uniqueness.rs index baafbf19..5fb281a2 100644 --- a/paddler_tests/tests/harness_port_allocation_uniqueness.rs +++ b/paddler_tests/tests/harness_port_allocation_uniqueness.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use anyhow::Result; -use paddler_tests::balancer_addresses::BalancerAddresses; +use paddler_test_cluster_harness::balancer_addresses::BalancerAddresses; #[tokio::test(flavor = "multi_thread")] async fn picks_three_distinct_ports_per_invocation() -> Result<()> { diff --git a/paddler_tests/tests/harness_state_database_file_builds_file_url.rs b/paddler_tests/tests/harness_state_database_file_builds_file_url.rs index c572be29..48cebb68 100644 --- a/paddler_tests/tests/harness_state_database_file_builds_file_url.rs +++ b/paddler_tests/tests/harness_state_database_file_builds_file_url.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use paddler_tests::state_database_file::StateDatabaseFile; +use paddler_test_cluster_harness::state_database_file::StateDatabaseFile; #[test] fn harness_state_database_file_builds_file_url() -> Result<()> { diff --git a/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs b/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs deleted file mode 100644 index c8a4e6a2..00000000 --- a/paddler_tests/tests/harness_subprocess_cluster_shutdown.rs +++ /dev/null @@ -1,36 +0,0 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; - -#[tokio::test(flavor = "multi_thread")] -async fn empty_subprocess_cluster_starts_and_exits_after_sigterm() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: Vec::new(), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - cluster.shutdown().await?; - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread")] -async fn single_subprocess_agent_registers_and_exits_after_sigterm() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: AgentConfig::uniform(1, 4), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - assert_eq!(cluster.agent_ids.len(), 1); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/harness_terminate_child_returns_clean_exit_status.rs b/paddler_tests/tests/harness_terminate_child_returns_clean_exit_status.rs deleted file mode 100644 index 29ca74a6..00000000 --- a/paddler_tests/tests/harness_terminate_child_returns_clean_exit_status.rs +++ /dev/null @@ -1,39 +0,0 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - -use std::process::Stdio; - -use anyhow::Result; -use paddler_tests::paddler_command::paddler_command; -use paddler_tests::terminate_child::terminate_child; - -#[tokio::test(flavor = "multi_thread")] -async fn harness_terminate_child_returns_clean_exit_status() -> Result<()> { - let mut child = paddler_command() - .arg("agent") - .arg("--management-addr") - .arg("127.0.0.1:1") - .arg("--name") - .arg("harness-terminate-child-test") - .arg("--slots") - .arg("1") - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn()?; - - terminate_child(&mut child)?; - - let exit_status = child.wait().await?; - - assert!( - !exit_status.success(), - "terminated process must not report success exit; got {exit_status:?}" - ); - - #[cfg(unix)] - assert!( - exit_status.code().is_none(), - "SIGTERM-terminated process must have no normal exit code on Unix; got {exit_status:?}" - ); - - Ok(()) -} diff --git a/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs b/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs index 3593057f..7a08c41e 100644 --- a/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs +++ b/paddler_tests/tests/management_agents_stream_yields_initial_snapshot.rs @@ -1,18 +1,15 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_agents_stream_yields_initial_snapshot() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let mut stream = cluster .paddler_client diff --git a/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs b/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs index 43028a6c..55df39d1 100644 --- a/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs +++ b/paddler_tests/tests/management_buffered_requests_stream_yields_initial_snapshot.rs @@ -1,17 +1,15 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; use futures_util::StreamExt as _; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn management_buffered_requests_stream_yields_initial_snapshot() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/management_health_endpoint_returns_ok.rs b/paddler_tests/tests/management_health_endpoint_returns_ok.rs index 367258f8..36882a41 100644 --- a/paddler_tests/tests/management_health_endpoint_returns_ok.rs +++ b/paddler_tests/tests/management_health_endpoint_returns_ok.rs @@ -1,16 +1,14 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn management_health_endpoint_returns_ok() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await .context("failed to start subprocess cluster")?; diff --git a/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs b/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs index 98ecc865..2e5e0e81 100644 --- a/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs +++ b/paddler_tests/tests/management_metrics_endpoint_exposes_prometheus_gauges.rs @@ -1,16 +1,14 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - use anyhow::Context as _; use anyhow::Result; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; #[tokio::test(flavor = "multi_thread")] async fn management_metrics_endpoint_exposes_prometheus_gauges() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { + let cluster = start_cluster(ClusterParams { agents: Vec::new(), wait_for_slots_ready: false, - ..SubprocessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs b/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs index 22de1ecd..c2a5dbad 100644 --- a/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs +++ b/paddler_tests/tests/management_reports_zero_download_progress_after_load_complete.rs @@ -1,17 +1,14 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_reports_zero_download_progress_after_load_complete() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let snapshot = cluster .paddler_client diff --git a/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs b/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs index bdd49fc7..75414e5d 100644 --- a/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs +++ b/paddler_tests/tests/management_returns_model_metadata_for_loaded_agent.rs @@ -1,17 +1,14 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] +#![cfg(feature = "tests_that_use_llms")] use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster_with_qwen3::start_subprocess_cluster_with_qwen3; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_returns_model_metadata_for_loaded_agent() -> Result<()> { - let cluster = start_subprocess_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; + let cluster = start_cluster_with_qwen3(AgentConfig::uniform(1, 2)).await?; let agent_id = cluster .agent_ids diff --git a/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_changes.rs b/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_changes.rs index 4c02687a..e51a5f79 100644 --- a/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_changes.rs +++ b/paddler_tests/tests/management_two_agents_stream_subscribers_receive_slot_usage_changes.rs @@ -2,27 +2,20 @@ use anyhow::Context as _; use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::agents_status::assert_slots_processing::assert_slots_processing; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::current_test_device::current_test_device; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::inference_http_client::InferenceHttpClient; +use paddler_messaging::agent_desired_model::AgentDesiredModel; +use paddler_messaging::balancer_desired_state::BalancerDesiredState; +use paddler_messaging::inference_parameters::InferenceParameters; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_test_cluster_harness::collect_generated_tokens::collect_generated_tokens; use paddler_tests::model_card::ModelCard; use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; -use paddler_types::agent_desired_model::AgentDesiredModel; -use paddler_types::balancer_desired_state::BalancerDesiredState; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_tests::start_cluster::start_cluster; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn management_two_agents_stream_subscribers_receive_slot_usage_changes() -> Result<()> { - let device = current_test_device()?; - - device.require_available()?; - let ModelCard { gpu_layer_count, reference, @@ -30,20 +23,23 @@ async fn management_two_agents_stream_subscribers_receive_slot_usage_changes() - let desired_state = BalancerDesiredState { chat_template_override: None, - inference_parameters: device.inference_parameters_for_full_offload(gpu_layer_count), + inference_parameters: InferenceParameters { + n_gpu_layers: gpu_layer_count, + ..InferenceParameters::default() + }, model: AgentDesiredModel::HuggingFace(reference), multimodal_projection: AgentDesiredModel::None, use_chat_template_override: false, }; - let mut cluster = start_in_process_cluster(InProcessClusterParams { - agent: Some(AgentConfig { + let mut cluster = start_cluster(ClusterParams { + agents: vec![AgentConfig { name: "test-agent".to_owned(), slot_count: 1, - }), - desired_state, + }], + desired_state: Some(desired_state), wait_for_slots_ready: true, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; @@ -53,11 +49,8 @@ async fn management_two_agents_stream_subscribers_receive_slot_usage_changes() - .context("cluster must have registered one agent")? .clone(); - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let token_stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let token_stream = cluster + .continue_from_raw_prompt_stream(&ContinueFromRawPromptParams { grammar: None, max_tokens: 8, raw_prompt: "Count to three".to_owned(), @@ -65,16 +58,14 @@ async fn management_two_agents_stream_subscribers_receive_slot_usage_changes() - .await?; cluster - .agents - .until(assert_slots_processing(&agent_id, 1)) + .wait_for_slots_processing(&agent_id, 1) .await .context("agents_stream must emit a snapshot showing slot usage")?; collect_generated_tokens(token_stream).await?; cluster - .agents - .until(assert_slots_processing(&agent_id, 0)) + .wait_for_slots_processing(&agent_id, 0) .await .context("agents_stream must emit a snapshot showing the slot was released")?; diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs index 2a62230e..c1e0155f 100644 --- a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens.rs @@ -1,29 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::ministral_3_in_process_cluster_params::Ministral3InProcessClusterParams; -use paddler_tests::start_in_process_cluster_with_ministral_3::start_in_process_cluster_with_ministral_3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_tests::ministral_3_cluster_params::Ministral3ClusterParams; +use paddler_tests::start_cluster_with_ministral_3::start_cluster_with_ministral_3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn mistral3_internal_endpoint_emits_reasoning_tokens() -> Result<()> { - let cluster = - start_in_process_cluster_with_ministral_3(Ministral3InProcessClusterParams::default()) - .await?; + let cluster = start_cluster_with_ministral_3(Ministral3ClusterParams::default()).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -39,8 +31,6 @@ async fn mistral3_internal_endpoint_emits_reasoning_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs index f04fc4de..96af7eaa 100644 --- a/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -1,28 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_ministral_3_and_mmproj::start_in_process_cluster_with_ministral_3_and_mmproj; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_ministral_3_and_mmproj::start_cluster_with_ministral_3_and_mmproj; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { - let cluster = - start_in_process_cluster_with_ministral_3_and_mmproj(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_ministral_3_and_mmproj(vec![AgentConfig::single(1)]).await?; let image_data_uri = load_test_image_data_uri()?; @@ -40,8 +33,8 @@ async fn mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request() - role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -52,8 +45,6 @@ async fn mistral3_internal_endpoint_emits_reasoning_tokens_for_image_request() - }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs index 25314bb5..172006c5 100644 --- a/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/mistral3_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,44 +1,38 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::ministral_3_in_process_cluster_params::Ministral3InProcessClusterParams; -use paddler_tests::start_in_process_cluster_with_ministral_3::start_in_process_cluster_with_ministral_3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_tests::ministral_3_cluster_params::Ministral3ClusterParams; +use paddler_tests::start_cluster_with_ministral_3::start_cluster_with_ministral_3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn mistral3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_ministral_3(Ministral3InProcessClusterParams { + let cluster = start_cluster_with_ministral_3(Ministral3ClusterParams { deterministic_sampling: true, - ..Ministral3InProcessClusterParams::default() + ..Ministral3ClusterParams::default() }) .await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - let mut location_properties = Map::new(); location_properties.insert( "location".to_owned(), serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -66,8 +60,6 @@ async fn mistral3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/openai_chat_completion_non_streaming_conforms_to_official_schema.rs b/paddler_tests/tests/openai_chat_completion_non_streaming_conforms_to_official_schema.rs new file mode 100644 index 00000000..85869c60 --- /dev/null +++ b/paddler_tests/tests/openai_chat_completion_non_streaming_conforms_to_official_schema.rs @@ -0,0 +1,33 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn openai_chat_completion_non_streaming_conforms_to_official_schema() -> Result<()> { + let validator = OpenAIValidator::new()?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let request = json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "Say hello."}], + "max_completion_tokens": 200, + "stream": false + }); + + validator.validate_chat_completion_request(&request)?; + + let response = cluster + .openai_chat_completion_non_streaming(&request) + .await?; + + validator.validate_chat_completion_response(&response)?; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/openai_chat_completion_streaming_conforms_to_official_schema.rs b/paddler_tests/tests/openai_chat_completion_streaming_conforms_to_official_schema.rs new file mode 100644 index 00000000..bf783c92 --- /dev/null +++ b/paddler_tests/tests/openai_chat_completion_streaming_conforms_to_official_schema.rs @@ -0,0 +1,36 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn openai_chat_completion_streaming_conforms_to_official_schema() -> Result<()> { + let validator = OpenAIValidator::new()?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let request = json!({ + "model": "qwen3-test", + "messages": [{"role": "user", "content": "Say hello."}], + "max_completion_tokens": 200, + "stream": true, + "stream_options": {"include_usage": true} + }); + + validator.validate_chat_completion_request(&request)?; + + let chunks = cluster.openai_chat_completion_streaming(&request).await?; + + assert!(!chunks.is_empty(), "expected at least one streaming chunk"); + + for chunk in &chunks { + validator.validate_chat_completion_stream_chunk(chunk)?; + } + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/openai_responses_non_streaming_conforms_to_official_schema.rs b/paddler_tests/tests/openai_responses_non_streaming_conforms_to_official_schema.rs new file mode 100644 index 00000000..a5acc3a1 --- /dev/null +++ b/paddler_tests/tests/openai_responses_non_streaming_conforms_to_official_schema.rs @@ -0,0 +1,31 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn openai_responses_non_streaming_conforms_to_official_schema() -> Result<()> { + let validator = OpenAIValidator::new()?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let request = json!({ + "model": "qwen3-test", + "input": "Say hello.", + "max_output_tokens": 200, + "stream": false + }); + + validator.validate_responses_request(&request)?; + + let response = cluster.openai_responses_non_streaming(&request).await?; + + validator.validate_responses_response(&response)?; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/openai_responses_streaming_conforms_to_official_schema.rs b/paddler_tests/tests/openai_responses_streaming_conforms_to_official_schema.rs new file mode 100644 index 00000000..08927d3c --- /dev/null +++ b/paddler_tests/tests/openai_responses_streaming_conforms_to_official_schema.rs @@ -0,0 +1,57 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_openai_response_format_validator::openai_validator::OpenAIValidator; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn openai_responses_streaming_conforms_to_official_schema() -> Result<()> { + let validator = OpenAIValidator::new()?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let request = json!({ + "model": "qwen3-test", + "input": "Say hello.", + "max_output_tokens": 200, + "stream": true + }); + + validator.validate_responses_request(&request)?; + + let events = cluster.openai_responses_streaming(&request).await?; + + assert!(!events.is_empty(), "expected at least one streaming event"); + + for event in &events { + validator.validate_responses_stream_event(event)?; + } + + assert_eq!( + events.first().and_then(|event| event["type"].as_str()), + Some("response.created"), + "the responses stream must begin with response.created" + ); + assert_eq!( + events.last().and_then(|event| event["type"].as_str()), + Some("response.completed"), + "the responses stream must terminate with response.completed" + ); + + let sequence_numbers: Vec = events + .iter() + .filter_map(|event| event["sequence_number"].as_u64()) + .collect(); + + assert_eq!( + sequence_numbers, + (0..sequence_numbers.len() as u64).collect::>(), + "sequence numbers must be a gapless run starting at 0" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs b/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs deleted file mode 100644 index 95c0effe..00000000 --- a/paddler_tests/tests/paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process.rs +++ /dev/null @@ -1,102 +0,0 @@ -#![cfg(all( - unix, - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::sync::Arc; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; - -use anyhow::Context as _; -use anyhow::Result; -use nix::sys::signal::Signal; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_embedding_results::collect_embedding_results; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::embedding_input_document::EmbeddingInputDocument; -use paddler_types::embedding_normalization_method::EmbeddingNormalizationMethod; -use paddler_types::inference_parameters::InferenceParameters; -use paddler_types::request_params::GenerateEmbeddingBatchParams; -use reqwest::Client; -use tokio::signal::unix::SignalKind; -use tokio::signal::unix::signal; -use tokio_util::sync::CancellationToken; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn paddler_subprocess_cluster_does_not_leak_sigurg_to_parent_process() -> Result<()> { - let observed_sigurg_count = Arc::new(AtomicUsize::new(0)); - let observer_shutdown = CancellationToken::new(); - - let observer_count = observed_sigurg_count.clone(); - let observer_token = observer_shutdown.clone(); - let mut sigurg_stream = signal(SignalKind::from_raw(Signal::SIGURG as i32)) - .context("failed to install SIGURG observer on the test process")?; - - let observer_handle = tokio::spawn(async move { - loop { - tokio::select! { - () = observer_token.cancelled() => break, - signal_event = sigurg_stream.recv() => match signal_event { - Some(()) => { - observer_count.fetch_add(1, Ordering::SeqCst); - } - None => break, - }, - } - } - }); - - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(2, 2), - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - ..Qwen3EmbeddingClusterParams::default() - }) - .await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let input_batch: Vec = (0..4) - .map(|document_index| EmbeddingInputDocument { - content: format!("SIGURG regression document number {document_index:02}"), - id: format!("doc-{document_index}"), - }) - .collect(); - let params = GenerateEmbeddingBatchParams { - input_batch, - normalization_method: EmbeddingNormalizationMethod::None, - }; - - let stream = inference_client - .post_generate_embedding_batch(¶ms) - .await?; - let collected = collect_embedding_results(stream).await?; - - assert_eq!(collected.embeddings.len(), 4); - assert!(collected.errors.is_empty()); - - cluster.shutdown().await?; - - observer_shutdown.cancel(); - observer_handle - .await - .context("SIGURG observer task panicked")?; - - let final_sigurg_count = observed_sigurg_count.load(Ordering::SeqCst); - - assert_eq!( - final_sigurg_count, 0, - "paddler subprocesses leaked {final_sigurg_count} SIGURG signals to the parent process; \ - this would kill bash test harness loops that rely on SIGURG's default ignore action being honored. \ - The observer ran throughout cluster startup, an embedding inference, and cluster shutdown." - ); - - Ok(()) -} diff --git a/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs b/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs index 0d2919f7..828fd7cb 100644 --- a/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs +++ b/paddler_tests/tests/qwen25vl_generates_tokens_from_image_input.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen2_5_vl::start_in_process_cluster_with_qwen2_5_vl; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen2_5_vl::start_cluster_with_qwen2_5_vl; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen2_5_vl(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen2_5_vl(vec![AgentConfig::single(1)]).await?; let image_data_uri = load_test_image_data_uri()?; @@ -40,8 +34,8 @@ async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -52,8 +46,6 @@ async fn qwen25vl_generates_tokens_from_image_input() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs b/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs index aee344cf..e4706913 100644 --- a/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs +++ b/paddler_tests/tests/qwen35_generates_tokens_for_long_system_and_user_prompt.rs @@ -1,17 +1,14 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; fn build_long_link_list() -> String { let mut lines: Vec = Vec::new(); @@ -53,10 +50,7 @@ fn build_long_link_list() -> String { #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let system_prompt = "You are a focused web crawler assistant. All elements on each page are collected automatically. Your only job is to decide which links to FOLLOW to discover more relevant pages.\n\nGiven a user's goal and the followable links extracted from a web page, decide which links are worth following to find more content matching the goal.\n\nRespond with JSON only:\n{\"follow\": [1, 3]}\n\nRules:\n- \"follow\": original indices of link elements worth following\n- Reject links that are clearly irrelevant to the goal\n- Prefer following PrimaryListing links on index/listing pages\n- Follow pagination links if more matching content is likely on subsequent pages\n- If no links are worth following, return {\"follow\": []}"; @@ -76,8 +70,8 @@ async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> }, ]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -88,8 +82,6 @@ async fn qwen35_generates_tokens_for_long_system_and_user_prompt() -> Result<()> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs b/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs index d9b53524..244c662a 100644 --- a/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs +++ b/paddler_tests/tests/qwen35_generation_stops_at_eog_before_max_tokens.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("hi".to_owned()), @@ -36,8 +30,6 @@ async fn qwen35_generation_stops_at_eog_before_max_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs index 2dbfe3e3..e90b7b9d 100644 --- a/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs +++ b/paddler_tests/tests/qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), true).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], true).await?; let image_data_uri = load_test_image_data_uri()?; @@ -39,8 +33,8 @@ async fn qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request() -> role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -51,8 +45,6 @@ async fn qwen35_internal_endpoint_emits_reasoning_tokens_for_image_request() -> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs index aa06133c..6a126f89 100644 --- a/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/qwen35_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn qwen35_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn qwen35_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs index e28c8746..f1307a84 100644 --- a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs +++ b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is two plus two?".to_owned()), @@ -35,8 +29,6 @@ async fn qwen35_internal_endpoint_with_thinking_disabled_emits_only_content_toke }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs index 4089d30f..8e89e0e8 100644 --- a/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -37,8 +31,6 @@ async fn qwen35_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs b/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs index fd9e2a62..ee21d4c9 100644 --- a/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs +++ b/paddler_tests/tests/qwen35_thinking_mode_stops_cleanly_before_max_tokens.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_thinking_mode_stops_cleanly_before_max_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is 2+2?".to_owned()), @@ -36,8 +30,6 @@ async fn qwen35_thinking_mode_stops_cleanly_before_max_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs b/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs index 6fb55f20..fea8a751 100644 --- a/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs +++ b/paddler_tests/tests/qwen35_thinking_multi_turn_conversation_stops_cleanly.rs @@ -1,25 +1,19 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let conversation_history = ConversationHistory::new(vec![ ConversationMessage { @@ -40,8 +34,8 @@ async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { }, ]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -52,8 +46,6 @@ async fn qwen35_thinking_multi_turn_conversation_stops_cleanly() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs b/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs index 54292ff7..03075c4d 100644 --- a/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs +++ b/paddler_tests/tests/qwen35_with_mmproj_generates_tokens_from_image.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), true).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], true).await?; let image_data_uri = load_test_image_data_uri()?; @@ -40,8 +34,8 @@ async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -52,8 +46,6 @@ async fn qwen35_with_mmproj_generates_tokens_from_image() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs b/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs index a92205eb..9c655033 100644 --- a/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs +++ b/paddler_tests/tests/qwen35_with_system_message_completes_with_thinking.rs @@ -1,25 +1,19 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let conversation_history = ConversationHistory::new(vec![ ConversationMessage { @@ -36,8 +30,8 @@ async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { }, ]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -48,8 +42,6 @@ async fn qwen35_with_system_message_completes_with_thinking() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs b/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs index 6c2d0560..8073b3a6 100644 --- a/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs +++ b/paddler_tests/tests/qwen35_with_system_message_completes_without_thinking.rs @@ -1,25 +1,19 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let conversation_history = ConversationHistory::new(vec![ ConversationMessage { @@ -36,8 +30,8 @@ async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { }, ]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -48,8 +42,6 @@ async fn qwen35_with_system_message_completes_without_thinking() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs b/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs index 7f2b6c28..dcebb6ad 100644 --- a/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs +++ b/paddler_tests/tests/qwen35_without_mmproj_rejects_image_with_multimodal_not_supported.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_5::start_in_process_cluster_with_qwen3_5; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_qwen3_5::start_cluster_with_qwen3_5; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_5(AgentConfig::single(1), false).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_5(vec![AgentConfig::single(1)], false).await?; let image_data_uri = load_test_image_data_uri()?; @@ -39,8 +33,8 @@ async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> role: "user".to_owned(), }]); - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -51,19 +45,15 @@ async fn qwen35_without_mmproj_rejects_image_with_multimodal_not_supported() -> }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - - if let Ok(collected) = collected { - assert!( - collected.token_results.iter().any(|result| matches!( - result.token_result, - GeneratedTokenResult::MultimodalNotSupported(_) - )), - "expected MultimodalNotSupported, got: {:?}", - collected.token_results - ); - } + if let Ok(collected) = collected { + assert!( + collected.token_results.iter().any(|result| matches!( + result.token_result, + GeneratedTokenResult::MultimodalNotSupported(_) + )), + "expected MultimodalNotSupported, got: {:?}", + collected.token_results + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs b/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs index c691ac7b..55d80ef4 100644 --- a/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs +++ b/paddler_tests/tests/qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_qwen3_6_and_mmproj::start_in_process_cluster_with_qwen3_6_and_mmproj; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_tests::start_cluster_with_qwen3_6_and_mmproj::start_cluster_with_qwen3_6_and_mmproj; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_6_and_mmproj(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_6_and_mmproj(vec![AgentConfig::single(1)]).await?; let image_data_uri = load_test_image_data_uri()?; @@ -39,8 +33,8 @@ async fn qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request() -> role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: true, @@ -51,8 +45,6 @@ async fn qwen36_internal_endpoint_emits_reasoning_tokens_for_image_request() -> }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs index 2e530298..b92a9ec2 100644 --- a/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/qwen36_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_6::start_cluster_with_qwen3_6; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen36_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3_6(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn qwen36_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn qwen36_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs index 65006fc7..578175a1 100644 --- a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs +++ b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_6::start_cluster_with_qwen3_6; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3_6(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is two plus two?".to_owned()), @@ -35,8 +29,6 @@ async fn qwen36_internal_endpoint_with_thinking_disabled_emits_only_content_toke }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs index b0161740..c69e7de0 100644 --- a/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3_6::start_in_process_cluster_with_qwen3_6; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3_6::start_cluster_with_qwen3_6; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3_6(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3_6(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -37,8 +31,6 @@ async fn qwen36_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs b/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs index 2f77f78d..cc866af8 100644 --- a/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs +++ b/paddler_tests/tests/qwen3_gbnf_grammar_constrains_output_to_yes_or_no.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_gbnf_grammar_constrains_output_to_yes_or_no() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: Some(GrammarConstraint::Gbnf { grammar: r#"root ::= "yes" | "no""#.to_owned(), root: "root".to_owned(), @@ -28,8 +22,6 @@ async fn qwen3_gbnf_grammar_constrains_output_to_yes_or_no() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - assert!( collected.text == "yes" || collected.text == "no", "expected 'yes' or 'no', got: {:?}", diff --git a/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs b/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs index afb3c816..be3a18a4 100644 --- a/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs +++ b/paddler_tests/tests/qwen3_generates_tokens_from_conversation_history.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_generates_tokens_from_conversation_history() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("hi".to_owned()), @@ -36,8 +30,6 @@ async fn qwen3_generates_tokens_from_conversation_history() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs b/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs index 994d4d04..a67d0900 100644 --- a/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs +++ b/paddler_tests/tests/qwen3_generates_tokens_from_raw_prompt.rs @@ -1,25 +1,19 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_generates_tokens_from_raw_prompt() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 30, raw_prompt: @@ -28,8 +22,6 @@ async fn qwen3_generates_tokens_from_raw_prompt() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs b/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs index 1d1ce2c0..90cc766f 100644 --- a/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs +++ b/paddler_tests/tests/qwen3_grammar_with_thinking_returns_incompatible_error.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_grammar_with_thinking_returns_incompatible_error() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let outcome = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let outcome = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("What is 2+2?".to_owned()), @@ -38,18 +32,15 @@ async fn qwen3_grammar_with_thinking_returns_incompatible_error() -> Result<()> }) .await; - if let Ok(stream) = outcome { - let collected = collect_generated_tokens(stream).await; - if let Ok(collected) = collected { - assert!( - collected.token_results.iter().any(|result| matches!( - result.token_result, - GeneratedTokenResult::GrammarIncompatibleWithThinking(_) - )), - "expected GrammarIncompatibleWithThinking, got: {:?}", - collected.token_results - ); - } + if let Ok(collected) = outcome { + assert!( + collected.token_results.iter().any(|result| matches!( + result.token_result, + GeneratedTokenResult::GrammarIncompatibleWithThinking(_) + )), + "expected GrammarIncompatibleWithThinking, got: {:?}", + collected.token_results + ); } cluster.shutdown().await?; diff --git a/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs b/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs index eac8d072..6952a49b 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_concurrent_requests_independent_usage.rs @@ -2,48 +2,40 @@ use anyhow::Result; use futures_util::future; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::generation_summary::GenerationSummary; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::generation_summary::GenerationSummary; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_concurrent_requests_keep_independent_usage() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(2)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(2)]).await?; let prompts = ["Say hi.", "Count to three."]; let futures = prompts.iter().map(|prompt| { - let client = inference_client.clone(); let prompt = (*prompt).to_owned(); - async move { - let stream = client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { - add_generation_prompt: true, - conversation_history: ConversationHistory::new(vec![ConversationMessage { - content: ConversationMessageContent::Text(prompt), - role: "user".to_owned(), - }]), - enable_thinking: false, - grammar: None, - max_tokens: 30, - parse_tool_calls: false, - tools: vec![], - }) - .await?; + let generation = + cluster.continue_from_conversation_history(&ContinueFromConversationHistoryParams { + add_generation_prompt: true, + conversation_history: ConversationHistory::new(vec![ConversationMessage { + content: ConversationMessageContent::Text(prompt), + role: "user".to_owned(), + }]), + enable_thinking: false, + grammar: None, + max_tokens: 30, + parse_tool_calls: false, + tools: vec![], + }); - let collected = collect_generated_tokens(stream).await?; + async move { + let collected = generation.await?; let last = collected .token_results diff --git a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs index d36384e0..bf6bc467 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_parsed_event.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn qwen3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn qwen3_internal_endpoint_emits_tool_call_parsed_event() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed_events: Vec<&Vec> = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs index 683e970a..18907b77 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_emits_tool_call_tokens.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_emits_tool_call_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn qwen3_internal_endpoint_emits_tool_call_tokens() -> Result<()> { serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn qwen3_internal_endpoint_emits_tool_call_tokens() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let tool_call_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs b/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs index 66119a7c..a5c7774c 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_max_tokens_usage_matches.rs @@ -1,29 +1,23 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; const MAX_TOKENS: i32 = 20; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_max_tokens_usage_matches_streamed_count() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Tell me a long story.".to_owned()), @@ -37,8 +31,6 @@ async fn qwen3_internal_endpoint_max_tokens_usage_matches_streamed_count() -> Re }) .await?; - let collected = collect_generated_tokens(stream).await?; - let streamed_token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs b/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs index 52254469..cc8be358 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_pure_content_usage.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_pure_content_usage_breakdown() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello.".to_owned()), @@ -35,8 +29,6 @@ async fn qwen3_internal_endpoint_pure_content_usage_breakdown() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let last = collected .token_results .last() diff --git a/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs index 2280be84..7b8fd6d1 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens.rs @@ -1,31 +1,25 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use paddler_types::request_params::continue_from_conversation_history_params::tool::Tool; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; -use paddler_types::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::Tool; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::FunctionCall; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::function::Function; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters::Parameters; +use paddler_messaging::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; use serde_json::Map; use serde_json::Value; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; let mut location_properties = Map::new(); location_properties.insert( @@ -33,8 +27,8 @@ async fn qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens() serde_json::json!({"type": "string", "description": "The city name"}), ); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -62,8 +56,6 @@ async fn qwen3_internal_endpoint_tools_without_parse_flag_emit_only_raw_tokens() }) .await?; - let collected = collect_generated_tokens(stream).await?; - for event in &collected.token_results { match &event.token_result { GeneratedTokenResult::ToolCallParsed(_) diff --git a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs index 270611b4..1bf12a2d 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text("Say hello.".to_owned()), @@ -35,8 +29,6 @@ async fn qwen3_internal_endpoint_with_thinking_disabled_emits_no_reasoning_token }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs index 23bb7f11..00aa6b61 100644 --- a/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs +++ b/paddler_tests/tests/qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens.rs @@ -1,27 +1,21 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history: ConversationHistory::new(vec![ConversationMessage { content: ConversationMessageContent::Text( @@ -37,8 +31,6 @@ async fn qwen3_internal_endpoint_with_thinking_enabled_emits_reasoning_tokens() }) .await?; - let collected = collect_generated_tokens(stream).await?; - let reasoning_count = collected .token_results .iter() diff --git a/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs b/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs index ef71c2b7..9c34f9f6 100644 --- a/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs +++ b/paddler_tests/tests/qwen3_json_schema_grammar_returns_valid_json.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::grammar_constraint::GrammarConstraint; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::grammar_constraint::GrammarConstraint; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_json_schema_grammar_returns_valid_json() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: Some(GrammarConstraint::JsonSchema { schema: r#"{"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]}"#.to_owned(), }), @@ -27,8 +21,6 @@ async fn qwen3_json_schema_grammar_returns_valid_json() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let parsed: serde_json::Value = serde_json::from_str(&collected.text)?; assert!( diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs b/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs index 9117d69f..7ec9d3dd 100644 --- a/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs +++ b/paddler_tests/tests/qwen3_openai_non_streaming_emits_tool_calls_for_function_tool.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_non_streaming_emits_tool_calls_for_function_tool() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; - - let response = openai_client - .post_non_streaming(&json!({ + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let response = cluster + .openai_chat_completion_non_streaming(&json!({ "model": "qwen3-test", "messages": [{ "role": "user", diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs b/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs index 09793da3..096caf69 100644 --- a/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs +++ b/paddler_tests/tests/qwen3_openai_non_streaming_returns_usage.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_non_streaming_returns_usage() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; - - let response = openai_client - .post_non_streaming(&json!({ + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let response = cluster + .openai_chat_completion_non_streaming(&json!({ "model": "qwen3-test", "messages": [{"role": "user", "content": "Say hi briefly."}], "max_completion_tokens": 600 diff --git a/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs b/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs index 366064bb..bd9e39a5 100644 --- a/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs +++ b/paddler_tests/tests/qwen3_openai_non_streaming_usage_with_tool_calls.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_non_streaming_usage_with_tool_calls() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let response = openai_client - .post_non_streaming(&json!({ + let response = cluster + .openai_chat_completion_non_streaming(&json!({ "model": "qwen3-test", "messages": [{ "role": "user", diff --git a/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs b/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs index dd8dc261..8d33aa89 100644 --- a/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs +++ b/paddler_tests/tests/qwen3_openai_streaming_emits_tool_calls_for_function_tool.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_streaming_emits_tool_calls_for_function_tool() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let chunks = openai_client - .post_streaming(&json!({ + let chunks = cluster + .openai_chat_completion_streaming(&json!({ "model": "qwen3-test", "messages": [{ "role": "user", diff --git a/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs b/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs index fd2afac4..882554bf 100644 --- a/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs +++ b/paddler_tests/tests/qwen3_openai_streaming_emits_usage_when_requested.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_streaming_emits_usage_when_requested() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let chunks = openai_client - .post_streaming(&json!({ + let chunks = cluster + .openai_chat_completion_streaming(&json!({ "model": "qwen3-test", "messages": [{"role": "user", "content": "Say hi briefly."}], "stream": true, diff --git a/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs b/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs index 1e7e7d4c..30a53f8a 100644 --- a/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs +++ b/paddler_tests/tests/qwen3_openai_streaming_omits_usage_when_not_requested.rs @@ -1,23 +1,17 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_streaming_omits_usage_when_not_requested() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; - - let chunks = openai_client - .post_streaming(&json!({ + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let chunks = cluster + .openai_chat_completion_streaming(&json!({ "model": "qwen3-test", "messages": [{"role": "user", "content": "Say hi briefly."}], "stream": true, diff --git a/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs b/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs deleted file mode 100644 index 2fd50824..00000000 --- a/paddler_tests/tests/qwen3_openai_streaming_routes_reasoning_to_reasoning_content.rs +++ /dev/null @@ -1,47 +0,0 @@ -#![cfg(feature = "tests_that_use_llms")] - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; -use serde_json::Value; -use serde_json::json; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn qwen3_openai_streaming_routes_reasoning_to_reasoning_content() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; - - let chunks = openai_client - .post_streaming(&json!({ - "model": "qwen3-test", - "messages": [{"role": "user", "content": "What is two plus two? Think step by step."}], - "stream": true, - "max_completion_tokens": 600 - })) - .await?; - - let reasoning_chunks = chunks - .iter() - .filter(|chunk| { - chunk - .pointer("/choices/0/delta/reasoning_content") - .and_then(Value::as_str) - .is_some() - }) - .count(); - - assert!( - reasoning_chunks > 0, - "expected at least one delta.reasoning_content chunk; got {reasoning_chunks}" - ); - - cluster.shutdown().await?; - - Ok(()) -} diff --git a/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs b/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs index 227586f8..692a9071 100644 --- a/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs +++ b/paddler_tests/tests/qwen3_openai_streaming_usage_breakdown_with_thinking.rs @@ -1,24 +1,18 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::openai_chat_completions_client::OpenAIChatCompletionsClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use reqwest::Client; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; use serde_json::Value; use serde_json::json; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_openai_streaming_usage_breakdown_with_thinking() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; - let openai_client = OpenAIChatCompletionsClient::new( - Client::new(), - &cluster.addresses.compat_openai_base_url()?, - )?; - - let chunks = openai_client - .post_streaming(&json!({ + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let chunks = cluster + .openai_chat_completion_streaming(&json!({ "model": "qwen3-test", "messages": [{ "role": "user", @@ -26,8 +20,7 @@ async fn qwen3_openai_streaming_usage_breakdown_with_thinking() -> Result<()> { }], "stream": true, "stream_options": {"include_usage": true}, - "max_completion_tokens": 200, - "chat_template_kwargs": {"enable_thinking": true} + "max_completion_tokens": 200 })) .await?; diff --git a/paddler_tests/tests/qwen3_responses_non_streaming_returns_text_and_usage.rs b/paddler_tests/tests/qwen3_responses_non_streaming_returns_text_and_usage.rs new file mode 100644 index 00000000..bc33f203 --- /dev/null +++ b/paddler_tests/tests/qwen3_responses_non_streaming_returns_text_and_usage.rs @@ -0,0 +1,67 @@ +#![cfg(feature = "tests_that_use_llms")] + +use anyhow::Result; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; +use serde_json::Value; +use serde_json::json; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn qwen3_responses_non_streaming_returns_text_and_usage() -> Result<()> { + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; + + let response = cluster + .openai_responses_non_streaming(&json!({ + "model": "qwen3-test", + "input": "Say hi briefly.", + "max_output_tokens": 600 + })) + .await?; + + assert_eq!( + response.get("status").and_then(Value::as_str), + Some("completed"), + "responses status must be completed: {response}" + ); + + let usage = response + .get("usage") + .ok_or_else(|| anyhow::anyhow!("responses response missing usage: {response}"))?; + + let input_tokens = usage + .get("input_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.input_tokens missing"))?; + let output_tokens = usage + .get("output_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.output_tokens missing"))?; + let total_tokens = usage + .get("total_tokens") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("usage.total_tokens missing"))?; + + assert!(input_tokens > 0); + assert!(output_tokens > 0); + assert_eq!(total_tokens, input_tokens + output_tokens); + + let message_text = response + .get("output") + .and_then(Value::as_array) + .ok_or_else(|| anyhow::anyhow!("responses response missing output array"))? + .iter() + .find(|item| item.get("type").and_then(Value::as_str) == Some("message")) + .and_then(|message| message.pointer("/content/0/text")) + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("responses output has no message text: {response}"))?; + + assert!( + !message_text.is_empty(), + "responses message text must not be empty" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs b/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs index 783ffcdc..50c5249e 100644 --- a/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs +++ b/paddler_tests/tests/qwen3_without_grammar_generates_unconstrained_output.rs @@ -1,31 +1,23 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::start_in_process_cluster_with_qwen3::start_in_process_cluster_with_qwen3; -use paddler_types::request_params::ContinueFromRawPromptParams; -use reqwest::Client; +use paddler_messaging::request_params::continue_from_raw_prompt_params::ContinueFromRawPromptParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_tests::start_cluster_with_qwen3::start_cluster_with_qwen3; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn qwen3_without_grammar_generates_unconstrained_output() -> Result<()> { - let cluster = start_in_process_cluster_with_qwen3(AgentConfig::single(1)).await?; + let cluster = start_cluster_with_qwen3(vec![AgentConfig::single(1)]).await?; - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); - - let stream = inference_client - .post_continue_from_raw_prompt(&ContinueFromRawPromptParams { + let collected = cluster + .continue_from_raw_prompt(&ContinueFromRawPromptParams { grammar: None, max_tokens: 20, raw_prompt: "<|im_start|>user\nSay hello<|im_end|>\n<|im_start|>assistant\n\n\n\n\n".to_owned(), }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/single_agent_in_process_cluster_shutdown_completes_within_five_seconds.rs b/paddler_tests/tests/single_agent_cluster_shutdown_completes_within_five_seconds.rs similarity index 66% rename from paddler_tests/tests/single_agent_in_process_cluster_shutdown_completes_within_five_seconds.rs rename to paddler_tests/tests/single_agent_cluster_shutdown_completes_within_five_seconds.rs index 5c6e7779..b72d9580 100644 --- a/paddler_tests/tests/single_agent_in_process_cluster_shutdown_completes_within_five_seconds.rs +++ b/paddler_tests/tests/single_agent_cluster_shutdown_completes_within_five_seconds.rs @@ -2,17 +2,17 @@ use std::time::Duration; use anyhow::Context as _; use anyhow::Result; -use paddler_tests::in_process_cluster_params::InProcessClusterParams; -use paddler_tests::start_in_process_cluster::start_in_process_cluster; +use paddler_test_cluster_harness::cluster_params::ClusterParams; +use paddler_tests::start_cluster::start_cluster; use tokio::time::timeout; const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5); #[tokio::test(flavor = "multi_thread")] -async fn single_agent_in_process_cluster_shutdown_completes_within_five_seconds() -> Result<()> { - let cluster = start_in_process_cluster(InProcessClusterParams { +async fn single_agent_cluster_shutdown_completes_within_five_seconds() -> Result<()> { + let cluster = start_cluster(ClusterParams { wait_for_slots_ready: false, - ..InProcessClusterParams::default() + ..ClusterParams::default() }) .await?; diff --git a/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs b/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs index bbfe0466..c56fc593 100644 --- a/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs +++ b/paddler_tests/tests/smolvlm2_generates_tokens_from_image_input.rs @@ -1,28 +1,22 @@ #![cfg(feature = "tests_that_use_llms")] use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::collect_generated_tokens::collect_generated_tokens; -use paddler_tests::inference_http_client::InferenceHttpClient; -use paddler_tests::load_test_image_data_uri::load_test_image_data_uri; -use paddler_tests::start_in_process_cluster_with_smolvlm2::start_in_process_cluster_with_smolvlm2; -use paddler_tests::token_result_with_producer::TokenResultWithProducer; -use paddler_types::conversation_history::ConversationHistory; -use paddler_types::conversation_message::ConversationMessage; -use paddler_types::conversation_message_content::ConversationMessageContent; -use paddler_types::conversation_message_content_part::ConversationMessageContentPart; -use paddler_types::generated_token_result::GeneratedTokenResult; -use paddler_types::image_url::ImageUrl; -use paddler_types::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; -use reqwest::Client; +use paddler_messaging::conversation_history::ConversationHistory; +use paddler_messaging::conversation_message::ConversationMessage; +use paddler_messaging::conversation_message_content::ConversationMessageContent; +use paddler_messaging::conversation_message_content_part::ConversationMessageContentPart; +use paddler_messaging::generated_token_result::GeneratedTokenResult; +use paddler_messaging::image_url::ImageUrl; +use paddler_messaging::request_params::continue_from_conversation_history_params::ContinueFromConversationHistoryParams; +use paddler_test_cluster_harness::agent_config::AgentConfig; +use paddler_test_cluster_harness::load_test_image_data_uri::load_test_image_data_uri; +use paddler_test_cluster_harness::token_result_with_producer::TokenResultWithProducer; +use paddler_tests::start_cluster_with_smolvlm2::start_cluster_with_smolvlm2; #[serial_test::file_serial(model_load, path => "../target/model_load.lock")] #[tokio::test(flavor = "multi_thread")] async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { - let cluster = start_in_process_cluster_with_smolvlm2(AgentConfig::single(1)).await?; - - let inference_client = - InferenceHttpClient::new(Client::new(), cluster.addresses.inference_base_url()?); + let cluster = start_cluster_with_smolvlm2(vec![AgentConfig::single(1)]).await?; let image_data_uri = load_test_image_data_uri()?; @@ -40,8 +34,8 @@ async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { role: "user".to_owned(), }]); - let stream = inference_client - .post_continue_from_conversation_history(&ContinueFromConversationHistoryParams { + let collected = cluster + .continue_from_conversation_history(&ContinueFromConversationHistoryParams { add_generation_prompt: true, conversation_history, enable_thinking: false, @@ -52,8 +46,6 @@ async fn smolvlm2_generates_tokens_from_image_input() -> Result<()> { }) .await?; - let collected = collect_generated_tokens(stream).await?; - let token_count = collected .token_results .iter() diff --git a/paddler_tests/tests/subprocess_cluster_shutdown_completes_within_five_seconds.rs b/paddler_tests/tests/subprocess_cluster_shutdown_completes_within_five_seconds.rs deleted file mode 100644 index 22206941..00000000 --- a/paddler_tests/tests/subprocess_cluster_shutdown_completes_within_five_seconds.rs +++ /dev/null @@ -1,36 +0,0 @@ -#![cfg(feature = "tests_that_use_compiled_paddler")] - -use std::time::Duration; - -use anyhow::Context as _; -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; -use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; -use tokio::time::timeout; - -const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5); - -#[tokio::test(flavor = "multi_thread")] -async fn subprocess_cluster_shutdown_completes_within_five_seconds() -> Result<()> { - let cluster = start_subprocess_cluster(SubprocessClusterParams { - agents: AgentConfig::uniform(1, 4), - wait_for_slots_ready: false, - ..SubprocessClusterParams::default() - }) - .await?; - - assert_eq!(cluster.agent_ids.len(), 1); - - timeout(SHUTDOWN_BUDGET, cluster.shutdown()) - .await - .with_context(|| { - format!( - "subprocess cluster shutdown did not complete within {SHUTDOWN_BUDGET:?}; \ - SIGTERM was sent but at least one child paddler process did not exit in time, \ - or balancer service drain did not return promptly" - ) - })??; - - Ok(()) -} diff --git a/paddler_tests/tests/subprocess_cluster_shutdown_returns_fd_count_to_baseline.rs b/paddler_tests/tests/subprocess_cluster_shutdown_returns_fd_count_to_baseline.rs deleted file mode 100644 index f7f4e759..00000000 --- a/paddler_tests/tests/subprocess_cluster_shutdown_returns_fd_count_to_baseline.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - any(target_os = "macos", target_os = "linux") -))] - -use anyhow::Result; -use paddler_tests::resource_snapshot::ResourceSnapshot; -use paddler_tests::subprocess_cluster_lifecycle_in_dedicated_runtime::subprocess_cluster_lifecycle_in_dedicated_runtime; - -#[test] -fn subprocess_cluster_shutdown_returns_fd_count_to_baseline() -> Result<()> { - subprocess_cluster_lifecycle_in_dedicated_runtime()?; - - let before = ResourceSnapshot::try_from_self()?; - - subprocess_cluster_lifecycle_in_dedicated_runtime()?; - - let after = ResourceSnapshot::try_from_self()?; - let diff = after.diff(&before); - - assert_eq!( - diff.open_file_descriptors_grew_by, - 0, - "subprocess cluster lifecycle leaked file descriptors across a complete tokio runtime lifecycle: {summary}", - summary = diff.pretty_summary(), - ); - - Ok(()) -} diff --git a/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs b/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs deleted file mode 100644 index 6057a314..00000000 --- a/paddler_tests/tests/subprocess_cluster_starts_four_agents_within_sequential_spawn_budget.rs +++ /dev/null @@ -1,60 +0,0 @@ -#![cfg(all( - feature = "tests_that_use_compiled_paddler", - feature = "tests_that_use_llms" -))] - -use std::time::Duration; -use std::time::Instant; - -use anyhow::Result; -use paddler_tests::agent_config::AgentConfig; -use paddler_tests::qwen3_embedding_cluster_params::Qwen3EmbeddingClusterParams; -use paddler_tests::start_subprocess_cluster_with_qwen3_embedding::start_subprocess_cluster_with_qwen3_embedding; -use paddler_types::inference_parameters::InferenceParameters; - -#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] -#[tokio::test(flavor = "multi_thread")] -async fn subprocess_cluster_starts_four_agents_within_sequential_spawn_budget() -> Result<()> { - let agent_count: usize = 4; - let single_agent_init_budget = Duration::from_secs(8); - let cluster_overhead_budget = Duration::from_secs(8); - #[expect( - clippy::cast_possible_truncation, - reason = "agent_count is a fixed test constant that fits in u32" - )] - let cluster_startup_budget = - single_agent_init_budget * (agent_count as u32) + cluster_overhead_budget; - - let cluster_startup_started_at = Instant::now(); - - let cluster = start_subprocess_cluster_with_qwen3_embedding(Qwen3EmbeddingClusterParams { - agents: AgentConfig::uniform(agent_count, 2), - inference_parameters: InferenceParameters { - enable_embeddings: true, - ..InferenceParameters::default() - }, - ..Qwen3EmbeddingClusterParams::default() - }) - .await?; - - let cluster_startup_elapsed = cluster_startup_started_at.elapsed(); - - assert_eq!( - cluster.agent_ids.len(), - agent_count, - "expected {agent_count} agents to register; got {actual}", - actual = cluster.agent_ids.len(), - ); - - cluster.shutdown().await?; - - assert!( - cluster_startup_elapsed <= cluster_startup_budget, - "cluster startup took {cluster_startup_elapsed:?}, expected within {cluster_startup_budget:?}. \ - Under concurrent agent spawn on Metal, kernel-compile contention can starve a single agent \ - for 60-120s. Sequential spawn isolates each agent's Metal init and keeps total startup \ - within {single_agent_init_budget:?} per agent plus {cluster_overhead_budget:?} of overhead." - ); - - Ok(()) -} diff --git a/paddler_types/src/agent_issue_params/mod.rs b/paddler_types/src/agent_issue_params/mod.rs deleted file mode 100644 index e6f6b7b0..00000000 --- a/paddler_types/src/agent_issue_params/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod chat_template_does_not_compile_params; -mod hugging_face_download_lock; -mod model_path; -mod slot_cannot_start_params; - -pub use self::chat_template_does_not_compile_params::ChatTemplateDoesNotCompileParams; -pub use self::hugging_face_download_lock::HuggingFaceDownloadLock; -pub use self::model_path::ModelPath; -pub use self::slot_cannot_start_params::SlotCannotStartParams; diff --git a/paddler_types/src/balancer_desired_state.rs b/paddler_types/src/balancer_desired_state.rs deleted file mode 100644 index 4fe79b88..00000000 --- a/paddler_types/src/balancer_desired_state.rs +++ /dev/null @@ -1,33 +0,0 @@ -use serde::Deserialize; -use serde::Serialize; - -use crate::agent_desired_model::AgentDesiredModel; -use crate::agent_desired_state::AgentDesiredState; -use crate::chat_template::ChatTemplate; -use crate::inference_parameters::InferenceParameters; - -#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] -#[serde(deny_unknown_fields)] -pub struct BalancerDesiredState { - pub chat_template_override: Option, - pub inference_parameters: InferenceParameters, - pub model: AgentDesiredModel, - pub multimodal_projection: AgentDesiredModel, - pub use_chat_template_override: bool, -} - -impl BalancerDesiredState { - #[must_use] - pub fn to_agent_desired_state(&self) -> AgentDesiredState { - AgentDesiredState { - chat_template_override: if self.use_chat_template_override { - self.chat_template_override.clone() - } else { - None - }, - inference_parameters: self.inference_parameters.clone(), - model: self.model.clone(), - multimodal_projection: self.multimodal_projection.clone(), - } - } -} diff --git a/paddler_types/src/embedding.rs b/paddler_types/src/embedding.rs deleted file mode 100644 index 4aa6b685..00000000 --- a/paddler_types/src/embedding.rs +++ /dev/null @@ -1,147 +0,0 @@ -use anyhow::Result; -use anyhow::anyhow; -use serde::Deserialize; -use serde::Serialize; - -use crate::embedding_normalization_method::EmbeddingNormalizationMethod; -use crate::normalization::l2; -use crate::normalization::rms_norm; -use crate::pooling_type::PoolingType; - -#[derive(Debug, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -pub struct Embedding { - pub embedding: Vec, - pub normalization_method: EmbeddingNormalizationMethod, - pub pooling_type: PoolingType, - pub source_document_id: String, -} - -impl Embedding { - pub fn normalize(self, normalization_method: &EmbeddingNormalizationMethod) -> Result { - if !self - .normalization_method - .can_transform_to(normalization_method) - { - return Err(anyhow!( - "Cannot transform from {:?} to {normalization_method:?}", - self.normalization_method - )); - } - - if !self - .normalization_method - .needs_transformation_to(normalization_method) - { - return Ok(self); - } - - Ok(Self { - embedding: match normalization_method { - EmbeddingNormalizationMethod::None => self.embedding, - EmbeddingNormalizationMethod::L2 => l2(&self.embedding), - EmbeddingNormalizationMethod::RmsNorm { epsilon } => { - rms_norm(&self.embedding, *epsilon) - } - }, - normalization_method: normalization_method.clone(), - pooling_type: self.pooling_type.clone(), - source_document_id: self.source_document_id, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn make_embedding(values: Vec, method: EmbeddingNormalizationMethod) -> Embedding { - Embedding { - embedding: values, - normalization_method: method, - pooling_type: PoolingType::Mean, - source_document_id: "test".to_owned(), - } - } - - #[test] - fn test_normalize_from_none_to_l2() -> Result<()> { - let embedding = make_embedding(vec![3.0, 4.0], EmbeddingNormalizationMethod::None); - let result = embedding.normalize(&EmbeddingNormalizationMethod::L2)?; - - assert_eq!(result.embedding, vec![0.6, 0.8]); - assert!(matches!( - result.normalization_method, - EmbeddingNormalizationMethod::L2 - )); - - Ok(()) - } - - #[test] - fn test_normalize_from_none_to_rms_norm() -> Result<()> { - let embedding = - make_embedding(vec![2.0, 2.0, 2.0, 2.0], EmbeddingNormalizationMethod::None); - let result = - embedding.normalize(&EmbeddingNormalizationMethod::RmsNorm { epsilon: 0.0 })?; - - for val in &result.embedding { - assert!((val - 1.0).abs() < 1e-6); - } - - Ok(()) - } - - #[test] - fn test_normalize_none_to_none_is_noop() -> Result<()> { - let embedding = make_embedding(vec![1.0, 2.0, 3.0], EmbeddingNormalizationMethod::None); - let result = embedding.normalize(&EmbeddingNormalizationMethod::None)?; - - assert_eq!(result.embedding, vec![1.0, 2.0, 3.0]); - - Ok(()) - } - - #[test] - fn test_normalize_rejects_l2_to_rms_norm() { - let embedding = make_embedding(vec![0.6, 0.8], EmbeddingNormalizationMethod::L2); - let result = embedding.normalize(&EmbeddingNormalizationMethod::RmsNorm { epsilon: 1e-6 }); - - assert!(result.is_err()); - } - - #[test] - fn test_normalize_rejects_l2_to_none() { - let embedding = make_embedding(vec![0.6, 0.8], EmbeddingNormalizationMethod::L2); - let result = embedding.normalize(&EmbeddingNormalizationMethod::None); - - assert!(result.is_err()); - } - - #[test] - fn test_normalize_rejects_rms_norm_to_l2() { - let embedding = make_embedding( - vec![1.0, 1.0], - EmbeddingNormalizationMethod::RmsNorm { epsilon: 1e-6 }, - ); - let result = embedding.normalize(&EmbeddingNormalizationMethod::L2); - - assert!(result.is_err()); - } - - #[test] - fn test_normalize_preserves_metadata() -> Result<()> { - let embedding = Embedding { - embedding: vec![3.0, 4.0], - normalization_method: EmbeddingNormalizationMethod::None, - pooling_type: PoolingType::Cls, - source_document_id: "doc-42".to_owned(), - }; - let result = embedding.normalize(&EmbeddingNormalizationMethod::L2)?; - - assert!(matches!(result.pooling_type, PoolingType::Cls)); - assert_eq!(result.source_document_id, "doc-42"); - - Ok(()) - } -} diff --git a/paddler_types/src/inference_client/mod.rs b/paddler_types/src/inference_client/mod.rs deleted file mode 100644 index 5ad2cd7f..00000000 --- a/paddler_types/src/inference_client/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod message; -mod response; - -pub use self::message::Message; -pub use self::response::Response; diff --git a/paddler_types/src/inference_server/mod.rs b/paddler_types/src/inference_server/mod.rs deleted file mode 100644 index 07a444d3..00000000 --- a/paddler_types/src/inference_server/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod message; -mod request; - -pub use self::message::Message; -pub use self::request::Request; diff --git a/paddler_types/src/jsonrpc/mod.rs b/paddler_types/src/jsonrpc/mod.rs deleted file mode 100644 index 844a1882..00000000 --- a/paddler_types/src/jsonrpc/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod error; -mod error_envelope; -mod request_envelope; -mod response_envelope; - -pub use self::error::Error; -pub use self::error_envelope::ErrorEnvelope; -pub use self::request_envelope::RequestEnvelope; -pub use self::response_envelope::ResponseEnvelope; diff --git a/paddler_types/src/kv_cache_dtype.rs b/paddler_types/src/kv_cache_dtype.rs deleted file mode 100644 index cead40ff..00000000 --- a/paddler_types/src/kv_cache_dtype.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::Deserialize; -use serde::Serialize; - -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] -#[expect( - non_camel_case_types, - reason = "variant names mirror ggml type identifiers (e.g. GGML_TYPE_IQ4_NL) for parity with llama.cpp's --cache-type-k/-v" -)] -pub enum KvCacheDtype { - F32, - F16, - BF16, - Q8_0, - Q4_0, - Q4_1, - IQ4_NL, - Q5_0, - Q5_1, -} diff --git a/paddler_types/src/normalization/mod.rs b/paddler_types/src/normalization/mod.rs deleted file mode 100644 index c36cb72e..00000000 --- a/paddler_types/src/normalization/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod l2; -mod rms_norm; - -pub use l2::l2; -pub use rms_norm::rms_norm; diff --git a/paddler_types/src/raw_tool_call_tokens.rs b/paddler_types/src/raw_tool_call_tokens.rs deleted file mode 100644 index 88f39ebf..00000000 --- a/paddler_types/src/raw_tool_call_tokens.rs +++ /dev/null @@ -1,25 +0,0 @@ -use serde::Deserialize; -use serde::Serialize; - -#[derive(Debug, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -pub struct RawToolCallTokens { - pub text: String, - pub ffi_error_message: String, -} - -#[cfg(test)] -mod tests { - use super::RawToolCallTokens; - - #[test] - fn carries_text_and_ffi_error_message() { - let tokens = RawToolCallTokens { - text: "raw payload".to_owned(), - ffi_error_message: "parser bailed".to_owned(), - }; - - assert_eq!(tokens.text, "raw payload"); - assert_eq!(tokens.ffi_error_message, "parser bailed"); - } -} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs deleted file mode 100644 index 546e87ca..00000000 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters.rs +++ /dev/null @@ -1,30 +0,0 @@ -use anyhow::Result; -use serde::Deserialize; -use serde::Serialize; - -use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema; -use crate::validates::Validates; -use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::validated_parameters_schema::ValidatedParametersSchema; - -#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] -#[serde(untagged)] -pub enum Parameters { - #[default] - Empty, - Schema(TParametersSchema), -} - -impl Parameters { - pub const fn is_empty(&self) -> bool { - matches!(self, Self::Empty) - } -} - -impl Validates> for Parameters { - fn validate(self) -> Result> { - match self { - Self::Empty => Ok(Parameters::Empty), - Self::Schema(schema) => Ok(Parameters::Schema(schema.validate()?)), - } - } -} diff --git a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs b/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs deleted file mode 100644 index 023a6680..00000000 --- a/paddler_types/src/request_params/continue_from_conversation_history_params/tool/tool_params/function_call/parameters_schema/raw_parameters_schema.rs +++ /dev/null @@ -1,103 +0,0 @@ -use anyhow::Result; -use anyhow::anyhow; -use serde::Deserialize; -use serde::Serialize; -use serde_json::Map; -use serde_json::Value; - -use super::validated_parameters_schema::ValidatedParametersSchema; -use crate::validates::Validates; - -#[derive(Default, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -pub struct RawParametersSchema { - #[serde(rename = "type")] - pub schema_type: String, - pub properties: Option>, - pub required: Option>, - #[serde(rename = "additionalProperties")] - pub additional_properties: Option, -} - -impl Validates for RawParametersSchema { - fn validate(self) -> Result { - if let (Some(required), Some(properties)) = (&self.required, &self.properties) { - for field in required { - if !properties.contains_key(field) { - return Err(anyhow!("Required field '{field}' not found in properties")); - } - } - } - - Ok(ValidatedParametersSchema { - schema_type: self.schema_type, - properties: self.properties, - required: self.required, - additional_properties: self.additional_properties, - }) - } -} - -#[cfg(test)] -mod tests { - use serde_json::from_value; - use serde_json::json; - - use super::*; - - #[test] - fn test_deserialize_with_valid_properties() -> Result<()> { - let input = json!({ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer", "minimum": 0} - }, - "required": ["name"], - "additionalProperties": false - }); - - let raw_schema: RawParametersSchema = from_value(input)?; - let schema: ValidatedParametersSchema = raw_schema.validate()?; - - assert_eq!(schema.schema_type, "object"); - assert!(schema.properties.is_some()); - - let properties = schema - .properties - .as_ref() - .ok_or_else(|| anyhow!("expected properties"))?; - - assert_eq!(properties.len(), 2); - assert_eq!(schema.required, Some(vec!["name".to_owned()])); - assert_eq!(schema.additional_properties, Some(json!(false))); - - Ok(()) - } - - #[test] - fn test_deserialize_required_field_not_in_properties() -> Result<()> { - let input = json!({ - "type": "object", - "properties": { - "name": {"type": "string"} - }, - "required": ["name", "missing_field"] - }); - - let raw_schema: RawParametersSchema = from_value(input)?; - let result: Result = raw_schema.validate(); - - assert!(result.is_err()); - - if let Err(error) = &result { - assert!( - error - .to_string() - .contains("Required field 'missing_field' not found in properties") - ); - } - - Ok(()) - } -} diff --git a/paddler_types/src/request_params/mod.rs b/paddler_types/src/request_params/mod.rs deleted file mode 100644 index dda6195e..00000000 --- a/paddler_types/src/request_params/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod continue_from_conversation_history_params; -mod continue_from_raw_prompt_params; -mod generate_embedding_batch_params; - -pub use continue_from_raw_prompt_params::ContinueFromRawPromptParams; -pub use generate_embedding_batch_params::ChunkEvenlyWithCapError; -pub use generate_embedding_batch_params::GenerateEmbeddingBatchParams; diff --git a/resources/ts/components/AgentIssues.tsx b/resources/ts/components/AgentIssues.tsx index c9c5a162..e214f694 100644 --- a/resources/ts/components/AgentIssues.tsx +++ b/resources/ts/components/AgentIssues.tsx @@ -233,8 +233,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

- Paddler will keep re-checking, but the same malformed URL - will keep failing the same way. + Paddler will keep re-checking, but the same malformed URL will + keep failing the same way.

What can you do?{" "}

@@ -256,8 +256,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

- Paddler will keep re-checking, but the same 404 will keep - firing until the remote server publishes the file at that URL. + Paddler will keep re-checking, but the same 404 will keep firing + until the remote server publishes the file at that URL.

What can you do?{" "}

@@ -278,14 +278,14 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

- Paddler will keep re-checking; if the server starts - accepting the request, the next attempt will succeed. + Paddler will keep re-checking; if the server starts accepting + the request, the next attempt will succeed.

What can you do?{" "}

- Confirm the URL is correct and reachable without auth. If it's - a private model, switch to a URL that doesn't require - credentials, or use the HuggingFace integration instead. + Confirm the URL is correct and reachable without auth. If it's a + private model, switch to a URL that doesn't require credentials, + or use the HuggingFace integration instead.

); @@ -300,8 +300,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

- Paddler will keep re-checking. If the server starts - answering normally, the next attempt will succeed. + Paddler will keep re-checking. If the server starts answering + normally, the next attempt will succeed.

What can you do?{" "}

@@ -318,20 +318,19 @@ export function AgentIssues({ issues }: { issues: Array }) { return (

  • - Download was interrupted:{" "} - {issue.DownloadInterrupted.model_path} + Download was interrupted: {issue.DownloadInterrupted.model_path} What will Paddler do?{" "}

    - Paddler will keep re-checking. The next attempt resumes - from the bytes already on disk if the server supports Range - requests; otherwise it starts fresh. + Paddler will keep re-checking. The next attempt resumes from the + bytes already on disk if the server supports Range requests; + otherwise it starts fresh.

    What can you do?{" "}

    - Often transient — check network stability and whether the - remote server is being restarted or rate-limiting. No action - needed if it clears on its own. + Often transient — check network stability and whether the remote + server is being restarted or rate-limiting. No action needed if + it clears on its own.

  • ); @@ -346,8 +345,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

    - Paddler will keep re-checking; if the network comes back, - the next attempt will succeed. + Paddler will keep re-checking; if the network comes back, the + next attempt will succeed.

    What can you do?{" "}

    @@ -374,8 +373,8 @@ export function AgentIssues({ issues }: { issues: Array }) {

    The server responded with a 4xx status, meaning the request was rejected (for example bad URL, throttling, or unsupported - method). Verify the model URL is correct and that the host - isn't rate-limiting the agent. + method). Verify the model URL is correct and that the host isn't + rate-limiting the agent.

    ); @@ -412,8 +411,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

    - Paddler will keep re-checking; the moment write permission - is restored, the next attempt will succeed. + Paddler will keep re-checking; the moment write permission is + restored, the next attempt will succeed.

    What can you do?{" "}

    @@ -435,8 +434,8 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

    - Paddler will keep re-checking; the moment space is - available, the next attempt will succeed. + Paddler will keep re-checking; the moment space is available, + the next attempt will succeed.

    What can you do?{" "}

    Free space on the disk that hosts the cache directory.

    @@ -453,14 +452,14 @@ export function AgentIssues({ issues }: { issues: Array }) { What will Paddler do?{" "}

    - Paddler will keep re-checking; the cache will be rebuilt - on the next attempt. + Paddler will keep re-checking; the cache will be rebuilt on the + next attempt.

    What can you do?{" "}

    If the issue persists, manually clear the{" "} - downloaded-models subdirectory of the cache and - let Paddler rebuild it. + downloaded-models subdirectory of the cache and let + Paddler rebuild it.

    ); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index c0b9b8ec..1c638ee2 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.93.0" +channel = "1.95.0" components = ["clippy", "rust-analyzer", "rustfmt"] diff --git a/tsconfig.json b/tsconfig.json index 3c5e36d8..61732104 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -40,6 +40,6 @@ }, "include": [ "jarmuz/**/*", - "resources/ts/**/*", + "resources/ts/**/*" ] } diff --git a/vendor/openai/openai-openapi b/vendor/openai/openai-openapi new file mode 160000 index 00000000..5162af98 --- /dev/null +++ b/vendor/openai/openai-openapi @@ -0,0 +1 @@ +Subproject commit 5162af98d3147432c14680df789e8e12d4891e6b