From e013ec80dd58491d66fe17418d5cbd5608756473 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Tue, 23 Dec 2025 11:47:43 -0800 Subject: [PATCH 1/5] Adding ModelSource type and LoadModel message type to protocol with tests, inference node changes for model and checkpoint reloading, gateway node changes to handle LoadModel messages, adding LoadModel request broadcasting to gateway node, updating test network file with the protocol changes, adding justfile command for testing, and fixing model routing for idle notes and add delay for memory free for reload --- .../inference-node/src/bin/gateway-node.rs | 29 ++++++++++++++++++- .../inference-only/inference-node/src/main.rs | 1 + shared/inference/src/protocol.rs | 1 - 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 201df83e0..27ef1c3cd 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -101,6 +101,24 @@ fn default_temperature() -> Option { fn default_top_p() -> Option { Some(1.0) } +#[derive(serde::Deserialize)] +struct LoadModelRequest { + model_name: String, + #[serde(default = "default_model_source_type")] + source_type: String, // "huggingface" or "local" + #[serde(default)] + source_path: Option, +} + +fn default_model_source_type() -> String { + "huggingface".to_string() +} + +#[derive(serde::Serialize)] +struct LoadModelResponse { + success: bool, + message: String, +} #[derive(serde::Deserialize, Debug, Clone)] #[serde(tag = "source_type", rename_all = "lowercase")] @@ -165,7 +183,7 @@ async fn handle_inference( info!( "Routing request to node: {} (model: {})", target_peer_id.fmt_short(), - node_model_name + node.model_name.as_ref().unwrap_or(&"unknown".to_string()) ); drop(nodes); @@ -551,6 +569,15 @@ async fn run_gateway() -> Result<()> { } } + Some(gossip_msg) = gossip_rx.recv() => { + info!("Broadcasting gossip message: {:?}", gossip_msg); + if let Err(e) = network.broadcast(&gossip_msg) { + error!("Failed to broadcast gossip message: {:#}", e); + } else { + info!("Successfully broadcasted gossip message"); + } + } + event = network.poll_next() => { match event { Ok(Some(NetworkEvent::MessageReceived((peer_id, msg)))) => { diff --git a/architectures/inference-only/inference-node/src/main.rs b/architectures/inference-only/inference-node/src/main.rs index fe286cfe4..385fec2db 100644 --- a/architectures/inference-only/inference-node/src/main.rs +++ b/architectures/inference-only/inference-node/src/main.rs @@ -244,6 +244,7 @@ async fn main() -> Result<()> { ModelLoadState::Loaded(name) => Some(name.clone()), _ => None, }; + let availability_msg = InferenceGossipMessage::NodeAvailable { model_name: model_name_for_broadcast.clone(), checkpoint_id: None, diff --git a/shared/inference/src/protocol.rs b/shared/inference/src/protocol.rs index 97ac7a225..8fbd7c401 100644 --- a/shared/inference/src/protocol.rs +++ b/shared/inference/src/protocol.rs @@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize}; pub enum ModelSource { HuggingFace(String), Local(String), - // See test case below for additional future source types } #[derive(Debug, Clone, Serialize, Deserialize)] From e4b679705fd7e389bd4370d7fd69407aa015b282 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Tue, 17 Feb 2026 11:18:54 -0800 Subject: [PATCH 2/5] Addressing PR feedback: adding enums for source type, removing unecessary async calls, and removing manual drops, gateway node changes to allow for model assignment by node and tracking, Justfile updates for model assignment testing --- .../inference-node/src/bin/gateway-node.rs | 290 ++++++++++++++++-- .../inference-only/inference-node/src/main.rs | 14 +- justfile | 77 +++-- 3 files changed, 324 insertions(+), 57 deletions(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 27ef1c3cd..458caf88b 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -33,6 +33,42 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; +/// Default path for storing model assignments +const ASSIGNMENTS_FILE: &str = "/tmp/psyche-gateway-assignments.json"; + +/// Load model assignments from disk +fn load_assignments(path: &str) -> HashMap { + match fs::read_to_string(path) { + Ok(contents) => match serde_json::from_str::>(&contents) { + Ok(assignments) => { + info!( + "Loaded {} model assignments from {}", + assignments.len(), + path + ); + assignments + } + Err(e) => { + warn!("Failed to parse assignments file: {:#}", e); + HashMap::new() + } + }, + Err(_) => { + info!("No assignments file found at {}, starting fresh", path); + HashMap::new() + } + } +} + +/// Save model assignments to disk +fn save_assignments(path: &str, assignments: &HashMap) -> Result<()> { + let json = + serde_json::to_string_pretty(assignments).context("Failed to serialize assignments")?; + fs::write(path, json).context("Failed to write assignments file")?; + debug!("Saved {} model assignments to {}", assignments.len(), path); + Ok(()) +} + #[derive(Parser, Debug)] struct Args { #[arg(long, default_value = "0.0.0.0:8000")] @@ -67,6 +103,7 @@ struct InferenceNodeInfo { struct GatewayState { available_nodes: RwLock>, pending_requests: RwLock>>, + model_assignments: RwLock>, // node_id -> assigned model name network_tx: mpsc::Sender<(EndpointId, InferenceMessage)>, gossip_tx: mpsc::Sender, endpoint_addr: EndpointAddr, @@ -101,13 +138,35 @@ fn default_temperature() -> Option { fn default_top_p() -> Option { Some(1.0) } + +#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +enum ModelSourceType { + #[default] + HuggingFace, + Local, +} + #[derive(serde::Deserialize)] -struct LoadModelRequest { +struct AssignModelsRequest { + assignments: Vec, +} + +#[derive(serde::Deserialize)] +struct ModelAssignmentSpec { model_name: String, - #[serde(default = "default_model_source_type")] - source_type: String, // "huggingface" or "local" + #[serde(default)] + source_type: ModelSourceType, #[serde(default)] source_path: Option, + num_nodes: usize, +} + +#[derive(serde::Serialize)] +struct AssignmentInfo { + node_id: String, + model_name: String, + status: String, // "loading", "loaded", "idle", "offline" } fn default_model_source_type() -> String { @@ -162,6 +221,40 @@ async fn handle_inference( Json(req): Json, ) -> Result, AppError> { let nodes = state.available_nodes.read().await; + let assignments = state.model_assignments.read().await; + + // Determine requested model + let requested_model = req.model.as_deref(); + + // Find suitable nodes: + // 1. If model specified: prefer nodes assigned to that model with it loaded + // 2. If no model specified: use any node with a model loaded + let suitable_nodes: Vec<_> = if let Some(model) = requested_model { + // Prefer nodes assigned to the requested model that have it loaded + let assigned_and_loaded: Vec<_> = nodes + .values() + .filter(|n| { + assignments + .get(&n.peer_id) + .map(|assigned| assigned == model) + .unwrap_or(false) + && n.model_name.as_deref() == Some(model) + }) + .collect(); + + if !assigned_and_loaded.is_empty() { + assigned_and_loaded + } else { + // Fallback: any node with the requested model loaded + nodes + .values() + .filter(|n| n.model_name.as_deref() == Some(model)) + .collect() + } + } else { + // No model specified - use any node with a model loaded + nodes.values().filter(|n| n.model_name.is_some()).collect() + }; let nodes_with_model: Vec<(EndpointId, String)> = nodes .values() @@ -181,11 +274,16 @@ async fn handle_inference( let model_name = req.model.clone().unwrap_or_else(|| node_model_name.clone()); info!( - "Routing request to node: {} (model: {})", + "Routing request to node: {} (model: {}, assigned: {})", target_peer_id.fmt_short(), - node.model_name.as_ref().unwrap_or(&"unknown".to_string()) + node.model_name.as_deref().unwrap_or("unknown"), + assignments + .get(&target_peer_id) + .map(|s| s.as_str()) + .unwrap_or("none") ); drop(nodes); + drop(assignments); let messages: Vec = req .messages @@ -252,42 +350,112 @@ async fn handle_inference( } #[axum::debug_handler] -async fn handle_load_model( +async fn handle_assign_models( State(state): State>, - Json(req): Json, + Json(req): Json, ) -> Result { use psyche_inference::ModelSource; info!( - "Admin API: Received LoadModel request for model: {} (source: {:?})", - req.model_name, req.source + "Admin API: Received assign-models request with {} specs", + req.assignments.len() ); - let model_source = match req.source { - LoadModelSource::HuggingFace { source_path } => { - let path = source_path.unwrap_or_else(|| req.model_name.clone()); - ModelSource::HuggingFace(path) + let mut assigned_count = 0; + let mut total_requested = 0; + + for spec in req.assignments { + total_requested += spec.num_nodes; + + info!( + "Assigning {} nodes to model: {}", + spec.num_nodes, spec.model_name + ); + + // Get available nodes + let nodes = state.available_nodes.read().await; + let assignments = state.model_assignments.read().await; + + // Find idle nodes (not currently assigned) + let idle_nodes: Vec = nodes + .keys() + .filter(|node_id| !assignments.contains_key(*node_id)) + .copied() + .take(spec.num_nodes) + .collect(); + + if idle_nodes.len() < spec.num_nodes { + warn!( + "Only {} idle nodes available, requested {}", + idle_nodes.len(), + spec.num_nodes + ); } - LoadModelSource::Local { source_path } => ModelSource::Local(source_path), - }; - let load_msg = InferenceGossipMessage::LoadModel { - model_name: req.model_name.clone(), - model_source, - }; + drop(nodes); + drop(assignments); - state.gossip_tx.send(load_msg).await.map_err(|e| { - error!("Failed to broadcast LoadModel message: {:#}", e); - AppError::InternalError - })?; + // Build model source + let model_source = match spec.source_type { + ModelSourceType::HuggingFace => { + let path = spec.source_path.unwrap_or_else(|| spec.model_name.clone()); + ModelSource::HuggingFace(path) + } + ModelSourceType::Local => { + let path = spec.source_path.ok_or_else(|| { + AppError::BadRequest("source_path is required for local models".to_string()) + })?; + ModelSource::Local(path) + } + }; + + // Assign and send LoadModel to each selected node + for node_id in idle_nodes { + // Update assignments map + state + .model_assignments + .write() + .await + .insert(node_id, spec.model_name.clone()); + + // Broadcast LoadModel to the specific node + let load_msg = InferenceGossipMessage::LoadModel { + model_name: spec.model_name.clone(), + model_source: model_source.clone(), + }; + + if let Err(e) = state.gossip_tx.send(load_msg).await { + error!( + "Failed to send LoadModel to node {}: {:#}", + node_id.fmt_short(), + e + ); + } else { + info!( + "Sent LoadModel to node {} for model {}", + node_id.fmt_short(), + spec.model_name + ); + assigned_count += 1; + } + } + } + + // Persist assignments to disk + let assignments = state.model_assignments.read().await; + if let Err(e) = save_assignments(ASSIGNMENTS_FILE, &assignments) { + error!("Failed to save assignments: {:#}", e); + } + drop(assignments); info!( - "Successfully broadcasted LoadModel message for: {}", - req.model_name + "Assignment complete: {} nodes assigned out of {} requested", + assigned_count, total_requested ); + Ok(format!( - "LoadModel broadcast sent for model: {}", - req.model_name + "Assigned {} nodes out of {} requested", + assigned_count, total_requested )) } @@ -298,6 +466,63 @@ async fn handle_bootstrap(State(state): State>) -> Json>, +) -> Json> { + let assignments = state.model_assignments.read().await; + let nodes = state.available_nodes.read().await; + + let mut result = Vec::new(); + + for (node_id, assigned_model) in assignments.iter() { + let status = match nodes.get(node_id) { + None => { + info!( + "Node {} not in available_nodes (offline)", + node_id.fmt_short() + ); + "offline".to_string() + } + Some(node_info) => match &node_info.model_name { + None => { + info!( + "Node {} has no model loaded (assigned: {})", + node_id.fmt_short(), + assigned_model + ); + "idle".to_string() + } + Some(current_model) if current_model == assigned_model => { + info!( + "Node {} loaded correct model: {}", + node_id.fmt_short(), + current_model + ); + "loaded".to_string() + } + Some(current_model) => { + info!( + "Node {} has model '{}' but assigned model is '{}'", + node_id.fmt_short(), + current_model, + assigned_model + ); + "loading".to_string() // Has different model, probably loading + } + }, + }; + + result.push(AssignmentInfo { + node_id: node_id.to_string(), + model_name: assigned_model.clone(), + status, + }); + } + + Json(result) } #[derive(Debug)] @@ -305,6 +530,7 @@ enum AppError { NoNodesAvailable, Timeout, InternalError, + BadRequest(String), } impl IntoResponse for AppError { @@ -322,6 +548,7 @@ impl IntoResponse for AppError { StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ), + AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), }; (status, message).into_response() } @@ -466,10 +693,13 @@ async fn run_gateway() -> Result<()> { let (gossip_tx, mut gossip_rx) = mpsc::channel::(100); let endpoint_addr = network.router().endpoint().addr(); + // Load persisted model assignments + let model_assignments = load_assignments(ASSIGNMENTS_FILE); let state = Arc::new(GatewayState { available_nodes: RwLock::new(HashMap::new()), pending_requests: RwLock::new(HashMap::new()), + model_assignments: RwLock::new(model_assignments), network_tx, gossip_tx, endpoint_addr, @@ -635,8 +865,12 @@ async fn run_gateway() -> Result<()> { let app = Router::new() .route("/v1/chat/completions", post(handle_inference)) - .route("/admin/load-model", post(handle_load_model)) .route("/bootstrap", get(handle_bootstrap)) + .route("/admin/assign-models", post(handle_assign_models)) + .route( + "/admin/assignments", + axum::routing::get(handle_get_assignments), + ) .with_state(state.clone()); let listener = tokio::net::TcpListener::bind(&args.listen_addr) diff --git a/architectures/inference-only/inference-node/src/main.rs b/architectures/inference-only/inference-node/src/main.rs index 385fec2db..f3486f2a8 100644 --- a/architectures/inference-only/inference-node/src/main.rs +++ b/architectures/inference-only/inference-node/src/main.rs @@ -357,14 +357,19 @@ async fn main() -> Result<()> { _ => true, }; - if should_load { - *model_state.write().await = ModelLoadState::Loading(requested_model.clone()); + let model_already_loaded = { + let current = current_model_name.read().await; + current.as_ref() == Some(&requested_model) + }; + if model_already_loaded { + info!("Model {} already loaded, skipping", requested_model); + } else { info!("Loading new model: {} (background task)", requested_model); // Spawn background task to avoid blocking the event loop // Model loading can take 10-60+ seconds, so we don't want to block heartbeats let inference_node_shared_clone = inference_node_shared.clone(); - let model_state_clone = model_state.clone(); + let current_model_name_clone = current_model_name.clone(); let requested_model_clone = requested_model.clone(); tokio::spawn(async move { @@ -399,8 +404,9 @@ async fn main() -> Result<()> { match load_result { Ok(new_node) => { + // update model name first, then node, to maintain consistency + *current_model_name_clone.write().await = Some(requested_model_clone.clone()); *inference_node_shared_clone.write().await = Some(new_node); - *model_state_clone.write().await = ModelLoadState::Loaded(requested_model_clone.clone()); info!("Successfully loaded model: {}", requested_model_clone); // Note: NodeAvailable will be broadcast on next heartbeat (every 30s) diff --git a/justfile b/justfile index 476805662..524c59f59 100644 --- a/justfile +++ b/justfile @@ -272,8 +272,8 @@ test-inference prompt="Hello, world!" max_tokens="50": test-inference-e2e model="gpt2" prompt="Hello, world!": ./scripts/test-inference-e2e.sh "{{ model }}" "{{ prompt }}" -# Test dynamic model loading with multiple nodes (gateway + 2 inference nodes) -test-model-loading initial_model="gpt2": +# Test model assignment system with multiple nodes and models (gateway + 3 inference nodes) +test-model-assignment: #!/usr/bin/env bash set -euo pipefail @@ -283,7 +283,7 @@ test-model-loading initial_model="gpt2": exit 1 fi - SESSION="psyche-model-loading" + SESSION="psyche-model-assignment" GATEWAY_PEER_FILE="/tmp/psyche-gateway-peer.json" # Clean up old peer file @@ -319,16 +319,21 @@ test-model-loading initial_model="gpt2": sleep 2 echo "Gateway ready" - # Start inference node 1 with initial model - echo "Starting inference node 1 (with model: {{ initial_model }})..." + # Start inference node 1 in idle mode + echo "Starting inference node 1 (idle mode)..." tmux new-window -t $SESSION -n node1 - tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --model-name {{ initial_model }} --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.35" C-m + tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.5" C-m - # Start inference node 2 without model (idle mode) - echo "Starting inference node 2 (idle mode - no initial model)..." + # Start inference node 2 in idle mode + echo "Starting inference node 2 (idle mode)..." tmux new-window -t $SESSION -n node2 tmux send-keys -t $SESSION:node2 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.35" C-m + # Start inference node 3 in idle mode + echo "Starting inference node 3 (idle mode)..." + tmux new-window -t $SESSION -n node3 + tmux send-keys -t $SESSION:node3 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.5" C-m + sleep 5 echo "" echo "All nodes started" @@ -338,43 +343,65 @@ test-model-loading initial_model="gpt2": tmux new-window -t $SESSION -n test tmux send-keys -t $SESSION:test "cat << 'EOF'" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m - tmux send-keys -t $SESSION:test " Dynamic Model Loading Test" C-m + tmux send-keys -t $SESSION:test " Model Assignment System Test" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m tmux send-keys -t $SESSION:test "" C-m tmux send-keys -t $SESSION:test "Status:" C-m tmux send-keys -t $SESSION:test " • Gateway: running on http://127.0.0.1:8000" C-m - tmux send-keys -t $SESSION:test " • Node 1: {{ initial_model }}" C-m + tmux send-keys -t $SESSION:test " • Node 1: idle (no model)" C-m tmux send-keys -t $SESSION:test " • Node 2: idle (no model)" C-m + tmux send-keys -t $SESSION:test " • Node 3: idle (no model)" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 1: Send inference request with current model" C-m + tmux send-keys -t $SESSION:test "Test 1: View current assignments" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m - tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m - tmux send-keys -t $SESSION:test " -d '{\"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Expected: Empty array (no assignments yet)" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 2: Load new model on all nodes" C-m + tmux send-keys -t $SESSION:test "Test 2: Assign models to nodes (2 nodes to gpt2, 1 to llama)" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/load-model \\\\" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/assign-models \\\\" C-m tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m - tmux send-keys -t $SESSION:test " -d '{\"model_name\": \"gpt2\", \"source_type\": \"huggingface\"}'" C-m + tmux send-keys -t $SESSION:test " -d '{" C-m + tmux send-keys -t $SESSION:test " \"assignments\": [" C-m + tmux send-keys -t $SESSION:test " {\"model_name\": \"gpt2\", \"num_nodes\": 2, \"source_type\": \"huggingface\"}," C-m + tmux send-keys -t $SESSION:test " {\"model_name\": \"meta-llama/Llama-3.2-1B-Instruct\", \"num_nodes\": 1, \"source_type\": \"huggingface\"}" C-m + tmux send-keys -t $SESSION:test " ]" C-m + tmux send-keys -t $SESSION:test " }'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Expected: Gateway assigns 2 nodes to gpt2, 1 to llama" C-m + tmux send-keys -t $SESSION:test "Watch node windows for LoadModel messages and loading progress" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 3: View assignments (wait ~10s for models to load)" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Expected: Both nodes reload with new model" C-m + tmux send-keys -t $SESSION:test "Expected: Shows node_id, model_name, status for each assignment" C-m + tmux send-keys -t $SESSION:test "Status values: 'loading', 'loaded', 'idle', 'offline'" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 3: Send inference with new model" C-m + tmux send-keys -t $SESSION:test "Test 4: Send inference request to gpt2 nodes" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m - tmux send-keys -t $SESSION:test "(Use same command as Test 1)" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m + tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m + tmux send-keys -t $SESSION:test " -d '{\"model\": \"gpt2\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 5: Send inference request to llama node" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/v1/chat/completions \\\\" C-m + tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m + tmux send-keys -t $SESSION:test " -d '{\"model\": \"meta-llama/Llama-3.2-1B-Instruct\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}], \"max_tokens\": 50}'" C-m tmux send-keys -t $SESSION:test "" C-m tmux send-keys -t $SESSION:test "Navigation:" C-m - tmux send-keys -t $SESSION:test " • Switch windows: Ctrl-b then 0/1/2/3" C-m - tmux send-keys -t $SESSION:test " 0=gateway, 1=node1, 2=node2, 3=test" C-m + tmux send-keys -t $SESSION:test " • Switch windows: Ctrl-b then 0/1/2/3/4" C-m + tmux send-keys -t $SESSION:test " 0=gateway, 1=node1, 2=node2, 3=node3, 4=test" C-m tmux send-keys -t $SESSION:test " • Exit tmux: Ctrl-b then d" C-m - tmux send-keys -t $SESSION:test " • Kill session: tmux kill-session -t psyche-model-loading" C-m + tmux send-keys -t $SESSION:test " • Kill session: tmux kill-session -t psyche-model-assignment" C-m tmux send-keys -t $SESSION:test "═══════════════════════════════════════════════════════════════" C-m tmux send-keys -t $SESSION:test "EOF" C-m # Attach to session - echo "Starting multi-node test in tmux session '$SESSION'" - echo "Windows: gateway, node1, node2, test" + echo "Starting model assignment test in tmux session '$SESSION'" + echo "Windows: gateway, node1, node2, node3, test" echo "" echo "To attach: tmux attach -t $SESSION" echo "To kill: tmux kill-session -t $SESSION" From ccd94ef5e463f40e74827e3d84c8f9413497cc40 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Thu, 19 Feb 2026 11:47:30 -0800 Subject: [PATCH 3/5] Adding target by node id for LoadModel messages, updating endpoint to display full node status and updating justfile --- .../inference-node/src/bin/gateway-node.rs | 187 ++++++++---------- .../inference-node/src/bin/test-network.rs | 7 +- .../inference-only/inference-node/src/main.rs | 29 +-- justfile | 36 ++-- shared/inference/src/protocol.rs | 5 + 5 files changed, 128 insertions(+), 136 deletions(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 458caf88b..4a0f12c02 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -33,10 +33,8 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; -/// Default path for storing model assignments const ASSIGNMENTS_FILE: &str = "/tmp/psyche-gateway-assignments.json"; -/// Load model assignments from disk fn load_assignments(path: &str) -> HashMap { match fs::read_to_string(path) { Ok(contents) => match serde_json::from_str::>(&contents) { @@ -60,7 +58,6 @@ fn load_assignments(path: &str) -> HashMap { } } -/// Save model assignments to disk fn save_assignments(path: &str, assignments: &HashMap) -> Result<()> { let json = serde_json::to_string_pretty(assignments).context("Failed to serialize assignments")?; @@ -169,35 +166,6 @@ struct AssignmentInfo { status: String, // "loading", "loaded", "idle", "offline" } -fn default_model_source_type() -> String { - "huggingface".to_string() -} - -#[derive(serde::Serialize)] -struct LoadModelResponse { - success: bool, - message: String, -} - -#[derive(serde::Deserialize, Debug, Clone)] -#[serde(tag = "source_type", rename_all = "lowercase")] -enum LoadModelSource { - #[serde(rename = "huggingface")] - HuggingFace { - source_path: Option, - }, - Local { - source_path: String, - }, -} - -#[derive(serde::Deserialize)] -struct LoadModelRequest { - model_name: String, - #[serde(flatten)] - source: LoadModelSource, -} - #[derive(serde::Serialize)] struct ChatCompletionChoice { index: usize, @@ -223,52 +191,53 @@ async fn handle_inference( let nodes = state.available_nodes.read().await; let assignments = state.model_assignments.read().await; - // Determine requested model let requested_model = req.model.as_deref(); - // Find suitable nodes: - // 1. If model specified: prefer nodes assigned to that model with it loaded - // 2. If no model specified: use any node with a model loaded - let suitable_nodes: Vec<_> = if let Some(model) = requested_model { - // Prefer nodes assigned to the requested model that have it loaded + let suitable_nodes: Vec<(EndpointId, String)> = if let Some(model) = requested_model { let assigned_and_loaded: Vec<_> = nodes .values() - .filter(|n| { - assignments + .filter_map(|n| { + if assignments .get(&n.peer_id) .map(|assigned| assigned == model) .unwrap_or(false) && n.model_name.as_deref() == Some(model) + { + Some((n.peer_id, n.model_name.clone()?)) + } else { + None + } }) .collect(); if !assigned_and_loaded.is_empty() { assigned_and_loaded } else { - // Fallback: any node with the requested model loaded nodes .values() - .filter(|n| n.model_name.as_deref() == Some(model)) + .filter_map(|n| { + if n.model_name.as_deref() == Some(model) { + Some((n.peer_id, n.model_name.clone()?)) + } else { + None + } + }) .collect() } } else { - // No model specified - use any node with a model loaded - nodes.values().filter(|n| n.model_name.is_some()).collect() + nodes + .values() + .filter_map(|n| Some((n.peer_id, n.model_name.clone()?))) + .collect() }; - let nodes_with_model: Vec<(EndpointId, String)> = nodes - .values() - .filter_map(|n| Some((n.peer_id, n.model_name.clone()?))) - .collect(); - - if nodes_with_model.is_empty() { - // No nodes have models loaded yet + if suitable_nodes.is_empty() { return Err(AppError::NoNodesAvailable); } // Select first available node with a model // TODO: Add load balancing and model-specific routing in the future - let (target_peer_id, node_model_name) = &nodes_with_model[0]; + let (target_peer_id, node_model_name) = &suitable_nodes[0]; let target_peer_id = *target_peer_id; let model_name = req.model.clone().unwrap_or_else(|| node_model_name.clone()); @@ -276,7 +245,7 @@ async fn handle_inference( info!( "Routing request to node: {} (model: {}, assigned: {})", target_peer_id.fmt_short(), - node.model_name.as_deref().unwrap_or("unknown"), + node_model_name, assignments .get(&target_peer_id) .map(|s| s.as_str()) @@ -372,11 +341,9 @@ async fn handle_assign_models( spec.num_nodes, spec.model_name ); - // Get available nodes let nodes = state.available_nodes.read().await; let assignments = state.model_assignments.read().await; - // Find idle nodes (not currently assigned) let idle_nodes: Vec = nodes .keys() .filter(|node_id| !assignments.contains_key(*node_id)) @@ -395,7 +362,6 @@ async fn handle_assign_models( drop(nodes); drop(assignments); - // Build model source let model_source = match spec.source_type { ModelSourceType::HuggingFace => { let path = spec.source_path.unwrap_or_else(|| spec.model_name.clone()); @@ -409,17 +375,15 @@ async fn handle_assign_models( } }; - // Assign and send LoadModel to each selected node for node_id in idle_nodes { - // Update assignments map state .model_assignments .write() .await .insert(node_id, spec.model_name.clone()); - // Broadcast LoadModel to the specific node let load_msg = InferenceGossipMessage::LoadModel { + target_node_id: Some(node_id), model_name: spec.model_name.clone(), model_source: model_source.clone(), }; @@ -441,7 +405,6 @@ async fn handle_assign_models( } } - // Persist assignments to disk let assignments = state.model_assignments.read().await; if let Err(e) = save_assignments(ASSIGNMENTS_FILE, &assignments) { error!("Failed to save assignments: {:#}", e); @@ -477,51 +440,69 @@ async fn handle_get_assignments( let mut result = Vec::new(); - for (node_id, assigned_model) in assignments.iter() { - let status = match nodes.get(node_id) { + for (node_id, node_info) in nodes.iter() { + let (assigned_model, status) = match assignments.get(node_id) { None => { - info!( - "Node {} not in available_nodes (offline)", - node_id.fmt_short() - ); - "offline".to_string() + let status = if node_info.model_name.is_some() { + "unassigned_with_model".to_string() + } else { + "unassigned".to_string() + }; + (None, status) + } + Some(assigned_model) => { + let status = match &node_info.model_name { + None => { + info!( + "Node {} has no model loaded (assigned: {})", + node_id.fmt_short(), + assigned_model + ); + "idle".to_string() + } + Some(current_model) if current_model == assigned_model => { + info!( + "Node {} loaded correct model: {}", + node_id.fmt_short(), + current_model + ); + "loaded".to_string() + } + Some(current_model) => { + info!( + "Node {} has model '{}' but assigned model is '{}'", + node_id.fmt_short(), + current_model, + assigned_model + ); + "loading".to_string() + } + }; + (Some(assigned_model.clone()), status) } - Some(node_info) => match &node_info.model_name { - None => { - info!( - "Node {} has no model loaded (assigned: {})", - node_id.fmt_short(), - assigned_model - ); - "idle".to_string() - } - Some(current_model) if current_model == assigned_model => { - info!( - "Node {} loaded correct model: {}", - node_id.fmt_short(), - current_model - ); - "loaded".to_string() - } - Some(current_model) => { - info!( - "Node {} has model '{}' but assigned model is '{}'", - node_id.fmt_short(), - current_model, - assigned_model - ); - "loading".to_string() // Has different model, probably loading - } - }, }; result.push(AssignmentInfo { node_id: node_id.to_string(), - model_name: assigned_model.clone(), + model_name: assigned_model.unwrap_or_else(|| "".to_string()), status, }); } + for (node_id, assigned_model) in assignments.iter() { + if !nodes.contains_key(node_id) { + info!( + "Node {} not in available_nodes (offline)", + node_id.fmt_short() + ); + result.push(AssignmentInfo { + node_id: node_id.to_string(), + model_name: assigned_model.clone(), + status: "offline".to_string(), + }); + } + } + Json(result) } @@ -692,8 +673,11 @@ async fn run_gateway() -> Result<()> { let (network_tx, mut network_rx) = mpsc::channel::<(EndpointId, InferenceMessage)>(100); let (gossip_tx, mut gossip_rx) = mpsc::channel::(100); +<<<<<<< HEAD let endpoint_addr = network.router().endpoint().addr(); // Load persisted model assignments +======= +>>>>>>> 78e40a8a2 (Adding target by node id for LoadModel messages, updating endpoint to display full node status and updating justfile) let model_assignments = load_assignments(ASSIGNMENTS_FILE); let state = Arc::new(GatewayState { @@ -713,7 +697,11 @@ async fn run_gateway() -> Result<()> { let cancel = cancel.clone(); tokio::spawn(async move { let mut task_set = tokio::task::JoinSet::new(); +<<<<<<< HEAD +======= + +>>>>>>> 78e40a8a2 (Adding target by node id for LoadModel messages, updating endpoint to display full node status and updating justfile) let mut cleanup_interval = tokio::time::interval(Duration::from_secs(15)); cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -799,15 +787,6 @@ async fn run_gateway() -> Result<()> { } } - Some(gossip_msg) = gossip_rx.recv() => { - info!("Broadcasting gossip message: {:?}", gossip_msg); - if let Err(e) = network.broadcast(&gossip_msg) { - error!("Failed to broadcast gossip message: {:#}", e); - } else { - info!("Successfully broadcasted gossip message"); - } - } - event = network.poll_next() => { match event { Ok(Some(NetworkEvent::MessageReceived((peer_id, msg)))) => { diff --git a/architectures/inference-only/inference-node/src/bin/test-network.rs b/architectures/inference-only/inference-node/src/bin/test-network.rs index f92786fe2..d01a31a4a 100644 --- a/architectures/inference-only/inference-node/src/bin/test-network.rs +++ b/architectures/inference-only/inference-node/src/bin/test-network.rs @@ -159,9 +159,10 @@ async fn main() -> Result<()> { InferenceGossipMessage::NodeUnavailable => { info!("Peer {} left the network", peer_id.fmt_short()); } - InferenceGossipMessage::LoadModel { model_name, model_source } => { - info!("LoadModel request from {}: {} ({:?})", - peer_id.fmt_short(), model_name, model_source); + InferenceGossipMessage::LoadModel { target_node_id, model_name, model_source } => { + let target_str = target_node_id.map(|id| id.fmt_short().to_string()).unwrap_or_else(|| "all".to_string()); + info!("LoadModel request from {} (target: {}): {} ({:?})", + peer_id.fmt_short(), target_str, model_name, model_source); } InferenceGossipMessage::ReloadCheckpoint { checkpoint_id, checkpoint_source } => { info!("Checkpoint reload request from {}: {} ({})", diff --git a/architectures/inference-only/inference-node/src/main.rs b/architectures/inference-only/inference-node/src/main.rs index f3486f2a8..a18b3b69a 100644 --- a/architectures/inference-only/inference-node/src/main.rs +++ b/architectures/inference-only/inference-node/src/main.rs @@ -336,7 +336,16 @@ async fn main() -> Result<()> { InferenceGossipMessage::NodeUnavailable => { info!("Peer {} is no longer available", peer_id.fmt_short()); } - InferenceGossipMessage::LoadModel { model_name: requested_model, model_source } => { + InferenceGossipMessage::LoadModel { target_node_id, model_name: requested_model, model_source } => { + let my_node_id = network.endpoint_id(); + if let Some(target) = target_node_id { + if target != my_node_id { + debug!("LoadModel not for us (target: {}, me: {}), ignoring", + target.fmt_short(), my_node_id.fmt_short()); + continue; + } + } + info!("Received LoadModel request from {}: model={}, source={:?}", peer_id.fmt_short(), requested_model, model_source); @@ -357,19 +366,13 @@ async fn main() -> Result<()> { _ => true, }; - let model_already_loaded = { - let current = current_model_name.read().await; - current.as_ref() == Some(&requested_model) - }; - if model_already_loaded { - info!("Model {} already loaded, skipping", requested_model); - } else { - info!("Loading new model: {} (background task)", requested_model); + if should_load { + *model_state.write().await = ModelLoadState::Loading(requested_model.clone()); // Spawn background task to avoid blocking the event loop // Model loading can take 10-60+ seconds, so we don't want to block heartbeats let inference_node_shared_clone = inference_node_shared.clone(); - let current_model_name_clone = current_model_name.clone(); + let model_state_clone = model_state.clone(); let requested_model_clone = requested_model.clone(); tokio::spawn(async move { @@ -405,16 +408,14 @@ async fn main() -> Result<()> { match load_result { Ok(new_node) => { // update model name first, then node, to maintain consistency - *current_model_name_clone.write().await = Some(requested_model_clone.clone()); + *model_state_clone.write().await = ModelLoadState::Loaded(requested_model_clone.clone()); *inference_node_shared_clone.write().await = Some(new_node); info!("Successfully loaded model: {}", requested_model_clone); - // Note: NodeAvailable will be broadcast on next heartbeat (every 30s) - // or the node can be manually queried to verify the model is loaded } Err(e) => { error!("Failed to load model {}: {:#}", requested_model_clone, e); - // Set back to Idle on failure + // set back to Idle on failure *model_state_clone.write().await = ModelLoadState::Idle; } } diff --git a/justfile b/justfile index 524c59f59..0029cf33f 100644 --- a/justfile +++ b/justfile @@ -285,9 +285,11 @@ test-model-assignment: SESSION="psyche-model-assignment" GATEWAY_PEER_FILE="/tmp/psyche-gateway-peer.json" + ASSIGNMENTS_FILE="/tmp/psyche-gateway-assignments.json" - # Clean up old peer file + # Clean up old files rm -f "$GATEWAY_PEER_FILE" + rm -f "$ASSIGNMENTS_FILE" # Kill existing session if it exists tmux kill-session -t $SESSION 2>/dev/null || true @@ -322,17 +324,17 @@ test-model-assignment: # Start inference node 1 in idle mode echo "Starting inference node 1 (idle mode)..." tmux new-window -t $SESSION -n node1 - tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.5" C-m + tmux send-keys -t $SESSION:node1 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m # Start inference node 2 in idle mode echo "Starting inference node 2 (idle mode)..." tmux new-window -t $SESSION -n node2 - tmux send-keys -t $SESSION:node2 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.35" C-m + tmux send-keys -t $SESSION:node2 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m # Start inference node 3 in idle mode echo "Starting inference node 3 (idle mode)..." tmux new-window -t $SESSION -n node3 - tmux send-keys -t $SESSION:node3 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.5" C-m + tmux send-keys -t $SESSION:node3 "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --discovery-mode local --relay-kind n0 --tensor-parallel-size 1 --gpu-memory-utilization 0.25" C-m sleep 5 echo "" @@ -356,23 +358,27 @@ test-model-assignment: tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Expected: Empty array (no assignments yet)" C-m + tmux send-keys -t $SESSION:test "Expected: Shows all 3 nodes with status 'unassigned'" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 2: Assign models to nodes (2 nodes to gpt2, 1 to llama)" C-m + tmux send-keys -t $SESSION:test "Test 2a: Assign 2 nodes to gpt2 (batch 1)" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/assign-models \\\\" C-m tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m - tmux send-keys -t $SESSION:test " -d '{" C-m - tmux send-keys -t $SESSION:test " \"assignments\": [" C-m - tmux send-keys -t $SESSION:test " {\"model_name\": \"gpt2\", \"num_nodes\": 2, \"source_type\": \"huggingface\"}," C-m - tmux send-keys -t $SESSION:test " {\"model_name\": \"meta-llama/Llama-3.2-1B-Instruct\", \"num_nodes\": 1, \"source_type\": \"huggingface\"}" C-m - tmux send-keys -t $SESSION:test " ]" C-m - tmux send-keys -t $SESSION:test " }'" C-m + tmux send-keys -t $SESSION:test " -d '{\"assignments\": [{\"model_name\": \"gpt2\", \"num_nodes\": 2, \"source_type\": \"huggingface\"}]}'" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Expected: Gateway assigns 2 nodes to gpt2, 1 to llama" C-m - tmux send-keys -t $SESSION:test "Watch node windows for LoadModel messages and loading progress" C-m + tmux send-keys -t $SESSION:test "Wait ~15s for gpt2 models to load, then check status:" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Test 2b: Assign 1 node to llama (batch 2)" C-m + tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m + tmux send-keys -t $SESSION:test "curl -X POST http://127.0.0.1:8000/admin/assign-models \\\\" C-m + tmux send-keys -t $SESSION:test " -H 'Content-Type: application/json' \\\\" C-m + tmux send-keys -t $SESSION:test " -d '{\"assignments\": [{\"model_name\": \"meta-llama/Llama-3.2-1B-Instruct\", \"num_nodes\": 1, \"source_type\": \"huggingface\"}]}'" C-m + tmux send-keys -t $SESSION:test "" C-m + tmux send-keys -t $SESSION:test "Wait ~15s for llama to load, then check status:" C-m + tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m - tmux send-keys -t $SESSION:test "Test 3: View assignments (wait ~10s for models to load)" C-m + tmux send-keys -t $SESSION:test "Test 3: View final assignments" C-m tmux send-keys -t $SESSION:test "────────────────────────────────────────────────────────────────" C-m tmux send-keys -t $SESSION:test "curl http://127.0.0.1:8000/admin/assignments | jq" C-m tmux send-keys -t $SESSION:test "" C-m diff --git a/shared/inference/src/protocol.rs b/shared/inference/src/protocol.rs index 8fbd7c401..cd6bb95f8 100644 --- a/shared/inference/src/protocol.rs +++ b/shared/inference/src/protocol.rs @@ -1,5 +1,6 @@ //! Protocol types for inference requests and responses +use iroh::PublicKey; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -18,6 +19,7 @@ pub enum InferenceGossipMessage { }, NodeUnavailable, LoadModel { + target_node_id: Option, // None = all nodes, Some(id) = specific node model_name: String, model_source: ModelSource, }, @@ -226,6 +228,7 @@ mod tests { #[test] fn test_load_model_message_serialization() { let msg = InferenceGossipMessage::LoadModel { + target_node_id: None, model_name: "gpt2".to_string(), model_source: ModelSource::HuggingFace("gpt2".to_string()), }; @@ -235,9 +238,11 @@ mod tests { match parsed { InferenceGossipMessage::LoadModel { + target_node_id, model_name, model_source, } => { + assert_eq!(target_node_id, None); assert_eq!(model_name, "gpt2"); assert_eq!(model_source, ModelSource::HuggingFace("gpt2".to_string())); } From 07577830c8523f8eff2480dbd11775fbb268cb40 Mon Sep 17 00:00:00 2001 From: nightwing Date: Thu, 5 Mar 2026 15:57:16 +0000 Subject: [PATCH 4/5] Formatting --- .../inference-only/inference-node/src/bin/gateway-node.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 4a0f12c02..86969302e 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -673,11 +673,8 @@ async fn run_gateway() -> Result<()> { let (network_tx, mut network_rx) = mpsc::channel::<(EndpointId, InferenceMessage)>(100); let (gossip_tx, mut gossip_rx) = mpsc::channel::(100); -<<<<<<< HEAD let endpoint_addr = network.router().endpoint().addr(); // Load persisted model assignments -======= ->>>>>>> 78e40a8a2 (Adding target by node id for LoadModel messages, updating endpoint to display full node status and updating justfile) let model_assignments = load_assignments(ASSIGNMENTS_FILE); let state = Arc::new(GatewayState { @@ -697,11 +694,6 @@ async fn run_gateway() -> Result<()> { let cancel = cancel.clone(); tokio::spawn(async move { let mut task_set = tokio::task::JoinSet::new(); -<<<<<<< HEAD - -======= - ->>>>>>> 78e40a8a2 (Adding target by node id for LoadModel messages, updating endpoint to display full node status and updating justfile) let mut cleanup_interval = tokio::time::interval(Duration::from_secs(15)); cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); From cb710ed52bd18fa53de601fccaaf62e3a5835c78 Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Mon, 9 Mar 2026 17:54:11 -0700 Subject: [PATCH 5/5] Changing AssignmentInfo status to AssignmentStatus enum and removing node id from assignments when we have a stale node --- .../inference-node/src/bin/gateway-node.rs | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 86969302e..0a3418b29 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -159,11 +159,22 @@ struct ModelAssignmentSpec { num_nodes: usize, } +#[derive(serde::Serialize)] +#[serde(rename_all = "lowercase")] +enum AssignmentStatus { + Unassigned, + UnassignedWithModel, + Loading, + Loaded, + Idle, + Offline, +} + #[derive(serde::Serialize)] struct AssignmentInfo { node_id: String, model_name: String, - status: String, // "loading", "loaded", "idle", "offline" + status: AssignmentStatus, } #[derive(serde::Serialize)] @@ -429,7 +440,6 @@ async fn handle_bootstrap(State(state): State>) -> Json { let status = if node_info.model_name.is_some() { - "unassigned_with_model".to_string() + AssignmentStatus::UnassignedWithModel } else { - "unassigned".to_string() + AssignmentStatus::Unassigned }; (None, status) } @@ -458,7 +468,7 @@ async fn handle_get_assignments( node_id.fmt_short(), assigned_model ); - "idle".to_string() + AssignmentStatus::Idle } Some(current_model) if current_model == assigned_model => { info!( @@ -466,7 +476,7 @@ async fn handle_get_assignments( node_id.fmt_short(), current_model ); - "loaded".to_string() + AssignmentStatus::Loaded } Some(current_model) => { info!( @@ -475,7 +485,7 @@ async fn handle_get_assignments( current_model, assigned_model ); - "loading".to_string() + AssignmentStatus::Loading } }; (Some(assigned_model.clone()), status) @@ -498,7 +508,7 @@ async fn handle_get_assignments( result.push(AssignmentInfo { node_id: node_id.to_string(), model_name: assigned_model.clone(), - status: "offline".to_string(), + status: AssignmentStatus::Offline, }); } } @@ -726,6 +736,7 @@ async fn run_gateway() -> Result<()> { for (node_id, age) in stale_nodes { warn!("Removing stale node {} (no heartbeat for {:?})", node_id.fmt_short(), age); nodes.remove(&node_id); + state.model_assignments.write().await.remove(&node_id); } }