diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 201df83e0..0a3418b29 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -33,6 +33,39 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; +const ASSIGNMENTS_FILE: &str = "/tmp/psyche-gateway-assignments.json"; + +fn load_assignments(path: &str) -> HashMap { + match fs::read_to_string(path) { + Ok(contents) => match serde_json::from_str::>(&contents) { + Ok(assignments) => { + info!( + "Loaded {} model assignments from {}", + assignments.len(), + path + ); + assignments + } + Err(e) => { + warn!("Failed to parse assignments file: {:#}", e); + HashMap::new() + } + }, + Err(_) => { + info!("No assignments file found at {}, starting fresh", path); + HashMap::new() + } + } +} + +fn save_assignments(path: &str, assignments: &HashMap) -> Result<()> { + let json = + serde_json::to_string_pretty(assignments).context("Failed to serialize assignments")?; + fs::write(path, json).context("Failed to write assignments file")?; + debug!("Saved {} model assignments to {}", assignments.len(), path); + Ok(()) +} + #[derive(Parser, Debug)] struct Args { #[arg(long, default_value = "0.0.0.0:8000")] @@ -67,6 +100,7 @@ struct InferenceNodeInfo { struct GatewayState { available_nodes: RwLock>, pending_requests: RwLock>>, + model_assignments: RwLock>, // node_id -> assigned model name network_tx: mpsc::Sender<(EndpointId, InferenceMessage)>, gossip_tx: mpsc::Sender, endpoint_addr: EndpointAddr, @@ -102,23 +136,45 @@ fn default_top_p() -> Option { Some(1.0) } -#[derive(serde::Deserialize, Debug, Clone)] -#[serde(tag = "source_type", rename_all = "lowercase")] -enum LoadModelSource { - #[serde(rename = "huggingface")] - HuggingFace { - source_path: Option, - }, - Local { - source_path: String, - }, +#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +enum ModelSourceType { + #[default] + HuggingFace, + Local, +} + +#[derive(serde::Deserialize)] +struct AssignModelsRequest { + assignments: Vec, } #[derive(serde::Deserialize)] -struct LoadModelRequest { +struct ModelAssignmentSpec { + model_name: String, + #[serde(default)] + source_type: ModelSourceType, + #[serde(default)] + source_path: Option, + num_nodes: usize, +} + +#[derive(serde::Serialize)] +#[serde(rename_all = "lowercase")] +enum AssignmentStatus { + Unassigned, + UnassignedWithModel, + Loading, + Loaded, + Idle, + Offline, +} + +#[derive(serde::Serialize)] +struct AssignmentInfo { + node_id: String, model_name: String, - #[serde(flatten)] - source: LoadModelSource, + status: AssignmentStatus, } #[derive(serde::Serialize)] @@ -144,30 +200,70 @@ async fn handle_inference( Json(req): Json, ) -> Result, AppError> { let nodes = state.available_nodes.read().await; + let assignments = state.model_assignments.read().await; + + let requested_model = req.model.as_deref(); + + let suitable_nodes: Vec<(EndpointId, String)> = if let Some(model) = requested_model { + let assigned_and_loaded: Vec<_> = nodes + .values() + .filter_map(|n| { + if assignments + .get(&n.peer_id) + .map(|assigned| assigned == model) + .unwrap_or(false) + && n.model_name.as_deref() == Some(model) + { + Some((n.peer_id, n.model_name.clone()?)) + } else { + None + } + }) + .collect(); + + if !assigned_and_loaded.is_empty() { + assigned_and_loaded + } else { + nodes + .values() + .filter_map(|n| { + if n.model_name.as_deref() == Some(model) { + Some((n.peer_id, n.model_name.clone()?)) + } else { + None + } + }) + .collect() + } + } else { + nodes + .values() + .filter_map(|n| Some((n.peer_id, n.model_name.clone()?))) + .collect() + }; - let nodes_with_model: Vec<(EndpointId, String)> = nodes - .values() - .filter_map(|n| Some((n.peer_id, n.model_name.clone()?))) - .collect(); - - if nodes_with_model.is_empty() { - // No nodes have models loaded yet + if suitable_nodes.is_empty() { return Err(AppError::NoNodesAvailable); } // Select first available node with a model // TODO: Add load balancing and model-specific routing in the future - let (target_peer_id, node_model_name) = &nodes_with_model[0]; + let (target_peer_id, node_model_name) = &suitable_nodes[0]; let target_peer_id = *target_peer_id; let model_name = req.model.clone().unwrap_or_else(|| node_model_name.clone()); info!( - "Routing request to node: {} (model: {})", + "Routing request to node: {} (model: {}, assigned: {})", target_peer_id.fmt_short(), - node_model_name + node_model_name, + assignments + .get(&target_peer_id) + .map(|s| s.as_str()) + .unwrap_or("none") ); drop(nodes); + drop(assignments); let messages: Vec = req .messages @@ -234,42 +330,106 @@ async fn handle_inference( } #[axum::debug_handler] -async fn handle_load_model( +async fn handle_assign_models( State(state): State>, - Json(req): Json, + Json(req): Json, ) -> Result { use psyche_inference::ModelSource; info!( - "Admin API: Received LoadModel request for model: {} (source: {:?})", - req.model_name, req.source + "Admin API: Received assign-models request with {} specs", + req.assignments.len() ); - let model_source = match req.source { - LoadModelSource::HuggingFace { source_path } => { - let path = source_path.unwrap_or_else(|| req.model_name.clone()); - ModelSource::HuggingFace(path) + let mut assigned_count = 0; + let mut total_requested = 0; + + for spec in req.assignments { + total_requested += spec.num_nodes; + + info!( + "Assigning {} nodes to model: {}", + spec.num_nodes, spec.model_name + ); + + let nodes = state.available_nodes.read().await; + let assignments = state.model_assignments.read().await; + + let idle_nodes: Vec = nodes + .keys() + .filter(|node_id| !assignments.contains_key(*node_id)) + .copied() + .take(spec.num_nodes) + .collect(); + + if idle_nodes.len() < spec.num_nodes { + warn!( + "Only {} idle nodes available, requested {}", + idle_nodes.len(), + spec.num_nodes + ); } - LoadModelSource::Local { source_path } => ModelSource::Local(source_path), - }; - let load_msg = InferenceGossipMessage::LoadModel { - model_name: req.model_name.clone(), - model_source, - }; + drop(nodes); + drop(assignments); - state.gossip_tx.send(load_msg).await.map_err(|e| { - error!("Failed to broadcast LoadModel message: {:#}", e); - AppError::InternalError - })?; + let model_source = match spec.source_type { + ModelSourceType::HuggingFace => { + let path = spec.source_path.unwrap_or_else(|| spec.model_name.clone()); + ModelSource::HuggingFace(path) + } + ModelSourceType::Local => { + let path = spec.source_path.ok_or_else(|| { + AppError::BadRequest("source_path is required for local models".to_string()) + })?; + ModelSource::Local(path) + } + }; + + for node_id in idle_nodes { + state + .model_assignments + .write() + .await + .insert(node_id, spec.model_name.clone()); + + let load_msg = InferenceGossipMessage::LoadModel { + target_node_id: Some(node_id), + model_name: spec.model_name.clone(), + model_source: model_source.clone(), + }; + + if let Err(e) = state.gossip_tx.send(load_msg).await { + error!( + "Failed to send LoadModel to node {}: {:#}", + node_id.fmt_short(), + e + ); + } else { + info!( + "Sent LoadModel to node {} for model {}", + node_id.fmt_short(), + spec.model_name + ); + assigned_count += 1; + } + } + } + + let assignments = state.model_assignments.read().await; + if let Err(e) = save_assignments(ASSIGNMENTS_FILE, &assignments) { + error!("Failed to save assignments: {:#}", e); + } + drop(assignments); info!( - "Successfully broadcasted LoadModel message for: {}", - req.model_name + "Assignment complete: {} nodes assigned out of {} requested", + assigned_count, total_requested ); + Ok(format!( - "LoadModel broadcast sent for model: {}", - req.model_name + "Assigned {} nodes out of {} requested", + assigned_count, total_requested )) } @@ -282,11 +442,86 @@ async fn handle_bootstrap(State(state): State>) -> Json>, +) -> Json> { + let assignments = state.model_assignments.read().await; + let nodes = state.available_nodes.read().await; + + let mut result = Vec::new(); + + for (node_id, node_info) in nodes.iter() { + let (assigned_model, status) = match assignments.get(node_id) { + None => { + let status = if node_info.model_name.is_some() { + AssignmentStatus::UnassignedWithModel + } else { + AssignmentStatus::Unassigned + }; + (None, status) + } + Some(assigned_model) => { + let status = match &node_info.model_name { + None => { + info!( + "Node {} has no model loaded (assigned: {})", + node_id.fmt_short(), + assigned_model + ); + AssignmentStatus::Idle + } + Some(current_model) if current_model == assigned_model => { + info!( + "Node {} loaded correct model: {}", + node_id.fmt_short(), + current_model + ); + AssignmentStatus::Loaded + } + Some(current_model) => { + info!( + "Node {} has model '{}' but assigned model is '{}'", + node_id.fmt_short(), + current_model, + assigned_model + ); + AssignmentStatus::Loading + } + }; + (Some(assigned_model.clone()), status) + } + }; + + result.push(AssignmentInfo { + node_id: node_id.to_string(), + model_name: assigned_model.unwrap_or_else(|| "".to_string()), + status, + }); + } + + for (node_id, assigned_model) in assignments.iter() { + if !nodes.contains_key(node_id) { + info!( + "Node {} not in available_nodes (offline)", + node_id.fmt_short() + ); + result.push(AssignmentInfo { + node_id: node_id.to_string(), + model_name: assigned_model.clone(), + status: AssignmentStatus::Offline, + }); + } + } + + Json(result) +} + #[derive(Debug)] enum AppError { NoNodesAvailable, Timeout, InternalError, + BadRequest(String), } impl IntoResponse for AppError { @@ -304,6 +539,7 @@ impl IntoResponse for AppError { StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ), + AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), }; (status, message).into_response() } @@ -448,10 +684,13 @@ async fn run_gateway() -> Result<()> { let (gossip_tx, mut gossip_rx) = mpsc::channel::(100); let endpoint_addr = network.router().endpoint().addr(); + // Load persisted model assignments + let model_assignments = load_assignments(ASSIGNMENTS_FILE); let state = Arc::new(GatewayState { available_nodes: RwLock::new(HashMap::new()), pending_requests: RwLock::new(HashMap::new()), + model_assignments: RwLock::new(model_assignments), network_tx, gossip_tx, endpoint_addr, @@ -465,7 +704,6 @@ async fn run_gateway() -> Result<()> { let cancel = cancel.clone(); tokio::spawn(async move { let mut task_set = tokio::task::JoinSet::new(); - let mut cleanup_interval = tokio::time::interval(Duration::from_secs(15)); cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -498,6 +736,7 @@ async fn run_gateway() -> Result<()> { for (node_id, age) in stale_nodes { warn!("Removing stale node {} (no heartbeat for {:?})", node_id.fmt_short(), age); nodes.remove(&node_id); + state.model_assignments.write().await.remove(&node_id); } } @@ -608,8 +847,12 @@ async fn run_gateway() -> Result<()> { let app = Router::new() .route("/v1/chat/completions", post(handle_inference)) - .route("/admin/load-model", post(handle_load_model)) .route("/bootstrap", get(handle_bootstrap)) + .route("/admin/assign-models", post(handle_assign_models)) + .route( + "/admin/assignments", + axum::routing::get(handle_get_assignments), + ) .with_state(state.clone()); let listener = tokio::net::TcpListener::bind(&args.listen_addr) diff --git a/architectures/inference-only/inference-node/src/bin/test-network.rs b/architectures/inference-only/inference-node/src/bin/test-network.rs index f92786fe2..d01a31a4a 100644 --- a/architectures/inference-only/inference-node/src/bin/test-network.rs +++ b/architectures/inference-only/inference-node/src/bin/test-network.rs @@ -159,9 +159,10 @@ async fn main() -> Result<()> { InferenceGossipMessage::NodeUnavailable => { info!("Peer {} left the network", peer_id.fmt_short()); } - InferenceGossipMessage::LoadModel { model_name, model_source } => { - info!("LoadModel request from {}: {} ({:?})", - peer_id.fmt_short(), model_name, model_source); + InferenceGossipMessage::LoadModel { target_node_id, model_name, model_source } => { + let target_str = target_node_id.map(|id| id.fmt_short().to_string()).unwrap_or_else(|| "all".to_string()); + info!("LoadModel request from {} (target: {}): {} ({:?})", + peer_id.fmt_short(), target_str, model_name, model_source); } InferenceGossipMessage::ReloadCheckpoint { checkpoint_id, checkpoint_source } => { info!("Checkpoint reload request from {}: {} ({})", diff --git a/architectures/inference-only/inference-node/src/main.rs b/architectures/inference-only/inference-node/src/main.rs index fe286cfe4..a18b3b69a 100644 --- a/architectures/inference-only/inference-node/src/main.rs +++ b/architectures/inference-only/inference-node/src/main.rs @@ -244,6 +244,7 @@ async fn main() -> Result<()> { ModelLoadState::Loaded(name) => Some(name.clone()), _ => None, }; + let availability_msg = InferenceGossipMessage::NodeAvailable { model_name: model_name_for_broadcast.clone(), checkpoint_id: None, @@ -335,7 +336,16 @@ async fn main() -> Result<()> { InferenceGossipMessage::NodeUnavailable => { info!("Peer {} is no longer available", peer_id.fmt_short()); } - InferenceGossipMessage::LoadModel { model_name: requested_model, model_source } => { + InferenceGossipMessage::LoadModel { target_node_id, model_name: requested_model, model_source } => { + let my_node_id = network.endpoint_id(); + if let Some(target) = target_node_id { + if target != my_node_id { + debug!("LoadModel not for us (target: {}, me: {}), ignoring", + target.fmt_short(), my_node_id.fmt_short()); + continue; + } + } + info!("Received LoadModel request from {}: model={}, source={:?}", peer_id.fmt_short(), requested_model, model_source); @@ -358,7 +368,6 @@ async fn main() -> Result<()> { if should_load { *model_state.write().await = ModelLoadState::Loading(requested_model.clone()); - info!("Loading new model: {} (background task)", requested_model); // Spawn background task to avoid blocking the event loop // Model loading can take 10-60+ seconds, so we don't want to block heartbeats @@ -398,16 +407,15 @@ async fn main() -> Result<()> { match load_result { Ok(new_node) => { - *inference_node_shared_clone.write().await = Some(new_node); + // update model name first, then node, to maintain consistency *model_state_clone.write().await = ModelLoadState::Loaded(requested_model_clone.clone()); + *inference_node_shared_clone.write().await = Some(new_node); info!("Successfully loaded model: {}", requested_model_clone); - // Note: NodeAvailable will be broadcast on next heartbeat (every 30s) - // or the node can be manually queried to verify the model is loaded } Err(e) => { error!("Failed to load model {}: {:#}", requested_model_clone, e); - // Set back to Idle on failure + // set back to Idle on failure *model_state_clone.write().await = ModelLoadState::Idle; } } diff --git a/justfile b/justfile index 476805662..0029cf33f 100644 --- a/justfile +++ b/justfile @@ -272,8 +272,8 @@ test-inference prompt="Hello, world!" max_tokens="50": test-inference-e2e model="gpt2" prompt="Hello, world!": ./scripts/test-inference-e2e.sh "{{ model }}" "{{ prompt }}" -# Test dynamic model loading with multiple nodes (gateway + 2 inference nodes) -test-model-loading initial_model="gpt2": +# Test model assignment system with multiple nodes and models (gateway + 3 inference nodes) +test-model-assignment: #!/usr/bin/env bash set -euo pipefail @@ -283,11 +283,13 @@ test-model-loading initial_model="gpt2": exit 1 fi - SESSION="psyche-model-loading" + SESSION="psyche-model-assignment" GATEWAY_PEER_FILE="/tmp/psyche-gateway-peer.json" + ASSIGNMENTS_FILE="/tmp/psyche-gateway-assignments.json" - # Clean up old peer file + # Clean up old files rm -f "$GATEWAY_PEER_FILE" + rm -f "$ASSIGNMENTS_FILE" # Kill existing session if it exists tmux kill-session -t $SESSION 2>/dev/null || true @@ -319,15 +321,20 @@ test-model-loading initial_model="gpt2": sleep 2 echo "Gateway ready" - # Start inference node 1 with initial model - echo "Starting inference node 1 (with model: {{ initial_model }})..." + # Start inference node 1 in idle mode + echo "Starting inference node 1 (idle mode)..." tmux new-window -t $SESSION -n node1 - tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --model-name {{ initial_model }} --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.35" C-m + tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m - # Start inference node 2 without model (idle mode) - echo "Starting inference node 2 (idle mode - no initial model)..." + # Start inference node 2 in idle mode + echo "Starting inference node 2 (idle mode)..." tmux new-window -t $SESSION -n node2 - tmux send-keys -t $SESSION:node2 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.35" C-m + tmux send-keys -t $SESSION:node2 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m + + # Start inference node 3 in idle mode + echo "Starting inference node 3 (idle mode)..." + tmux new-window -t $SESSION -n node3 + tmux send-keys -t $SESSION:node3 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m sleep 5 echo "" @@ -338,43 +345,69 @@ test-model-loading initial_model="gpt2": tmux new-window -t $SESSION -n test tmux send-keys -t $SESSION:test "cat << 'EOF'" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m - tmux send-keys -t $SESSION:test " Dynamic Model Loading Test" C-m + tmux send-keys -t $SESSION:test " Model Assignment System Test" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m tmux send-keys -t $SESSION:test "" C-m tmux send-keys -t $SESSION:test "Status:" C-m tmux send-keys -t $SESSION:test " • Gateway: running on http://127.0.0.1:8000" C-m - tmux send-keys -t $SESSION:test " • Node 1: {{ initial_model }}" C-m + tmux send-keys -t $SESSION:test " • Node 1: idle (no model)" C-m tmux send-keys -t $SESSION:test " • Node 2: idle (no model)" C-m + tmux send-keys -t $SESSION:test " • Node 3: idle (no model)" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 1: Send inference request with current model" C-m + tmux send-keys -t $SESSION:test "Test 1: View current assignments" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Expected: Shows all 3 nodes with status 'unassigned'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 2a: Assign 2 nodes to gpt2 (batch 1)" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/assign-models \\\\" C-m tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m - tmux send-keys -t $SESSION:test " -d '{\"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m + tmux send-keys -t $SESSION:test " -d '{\"assignments\": [{\"model_name\": \"gpt2\", \"num_nodes\": 2, \"source_type\": \"huggingface\"}]}'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Wait ~15s for gpt2 models to load, then check status:" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 2: Load new model on all nodes" C-m + tmux send-keys -t $SESSION:test "Test 2b: Assign 1 node to llama (batch 2)" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/load-model \\\\" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/assign-models \\\\" C-m tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m - tmux send-keys -t $SESSION:test " -d '{\"model_name\": \"gpt2\", \"source_type\": \"huggingface\"}'" C-m + tmux send-keys -t $SESSION:test " -d '{\"assignments\": [{\"model_name\": \"meta-llama/Llama-3.2-1B-Instruct\", \"num_nodes\": 1, \"source_type\": \"huggingface\"}]}'" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Expected: Both nodes reload with new model" C-m + tmux send-keys -t $SESSION:test "Wait ~15s for llama to load, then check status:" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 3: Send inference with new model" C-m + tmux send-keys -t $SESSION:test "Test 3: View final assignments" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "(Use same command as Test 1)" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Expected: Shows node_id, model_name, status for each assignment" C-m + tmux send-keys -t $SESSION:test "Status values: 'loading', 'loaded', 'idle', 'offline'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 4: Send inference request to gpt2 nodes" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m + tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m + tmux send-keys -t $SESSION:test " -d '{\"model\": \"gpt2\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 5: Send inference request to llama node" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m + tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m + tmux send-keys -t $SESSION:test " -d '{\"model\": \"meta-llama/Llama-3.2-1B-Instruct\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m tmux send-keys -t $SESSION:test "" C-m tmux send-keys -t $SESSION:test "Navigation:" C-m - tmux send-keys -t $SESSION:test " • Switch windows: Ctrl-b then 0/1/2/3" C-m - tmux send-keys -t $SESSION:test " 0=gateway, 1=node1, 2=node2, 3=test" C-m + tmux send-keys -t $SESSION:test " • Switch windows: Ctrl-b then 0/1/2/3/4" C-m + tmux send-keys -t $SESSION:test " 0=gateway, 1=node1, 2=node2, 3=node3, 4=test" C-m tmux send-keys -t $SESSION:test " • Exit tmux: Ctrl-b then d" C-m - tmux send-keys -t $SESSION:test " • Kill session: tmux kill-session -t psyche-model-loading" C-m + tmux send-keys -t $SESSION:test " • Kill session: tmux kill-session -t psyche-model-assignment" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m tmux send-keys -t $SESSION:test "EOF" C-m # Attach to session - echo "Starting multi-node test in tmux session '$SESSION'" - echo "Windows: gateway, node1, node2, test" + echo "Starting model assignment test in tmux session '$SESSION'" + echo "Windows: gateway, node1, node2, node3, test" echo "" echo "To attach: tmux attach -t $SESSION" echo "To kill: tmux kill-session -t $SESSION" diff --git a/shared/inference/src/protocol.rs b/shared/inference/src/protocol.rs index 97ac7a225..cd6bb95f8 100644 --- a/shared/inference/src/protocol.rs +++ b/shared/inference/src/protocol.rs @@ -1,12 +1,12 @@ //! Protocol types for inference requests and responses +use iroh::PublicKey; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum ModelSource { HuggingFace(String), Local(String), - // See test case below for additional future source types } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -19,6 +19,7 @@ pub enum InferenceGossipMessage { }, NodeUnavailable, LoadModel { + target_node_id: Option, // None = all nodes, Some(id) = specific node model_name: String, model_source: ModelSource, }, @@ -227,6 +228,7 @@ mod tests { #[test] fn test_load_model_message_serialization() { let msg = InferenceGossipMessage::LoadModel { + target_node_id: None, model_name: "gpt2".to_string(), model_source: ModelSource::HuggingFace("gpt2".to_string()), }; @@ -236,9 +238,11 @@ mod tests { match parsed { InferenceGossipMessage::LoadModel { + target_node_id, model_name, model_source, } => { + assert_eq!(target_node_id, None); assert_eq!(model_name, "gpt2"); assert_eq!(model_source, ModelSource::HuggingFace("gpt2".to_string())); }