diff --git a/paddler_agent/src/model_source/url.rs b/paddler_agent/src/model_source/url.rs index 68d234fa..e5a2e88e 100644 --- a/paddler_agent/src/model_source/url.rs +++ b/paddler_agent/src/model_source/url.rs @@ -190,6 +190,13 @@ async fn resolve_url_into_cache( } }; + if let Err(io_error) = cached.remove_invalid_cache_entry().await { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(classify_cache_io_error(url_string, &io_error)); + + return Err(anyhow::Error::new(io_error)); + } + let basename = cached .cache_file_path .file_name() @@ -775,4 +782,61 @@ mod tests { )); assert_eq!(tokio::fs::read(&expected_path).await.unwrap(), body); } + + #[tokio::test] + async fn resolve_recovers_from_stale_directory_then_downloads() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let url_string = format!("http://127.0.0.1:{port}/model.gguf"); + let body = b"recovered model bytes".to_vec(); + let server = tokio::spawn(serve_single_ok_response(listener, body.clone())); + + let cached = CachedDownloadedModel::new(&cache_dir, &url_string).unwrap(); + let expected_path = cached.cache_file_path.clone(); + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::create_dir(&expected_path).await.unwrap(); + + let resolution = resolve_url_into_cache(&url_string, &cache_dir, fresh_status()) + .await + .unwrap(); + + server.await.unwrap(); + + assert!(matches!( + resolution, + DesiredModelResolution::Resolved(resolved_path) if resolved_path == expected_path + )); + assert_eq!(tokio::fs::read(&expected_path).await.unwrap(), body); + } + + #[tokio::test] + async fn resolve_reports_unreachable_when_dead_url_with_stale_directory() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + let url_string = format!("http://127.0.0.1:{port}/model.gguf"); + + let cached = CachedDownloadedModel::new(&cache_dir, &url_string).unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::create_dir(&cached.cache_file_path) + .await + .unwrap(); + + let status = fresh_status(); + let result = resolve_url_into_cache(&url_string, &cache_dir, status.clone()).await; + + assert!(result.is_err(), "a dead URL must produce an Err"); + assert!( + status.has_issue(&AgentIssue::DownloadServerIsUnreachable(ModelPath { + model_path: url_string.clone(), + })), + "recovery must clear the stale dir so the dead-URL download reports unreachable" + ); + } } diff --git a/paddler_cache_dir/src/cached_downloaded_model.rs b/paddler_cache_dir/src/cached_downloaded_model.rs index 2fbd586f..2d565cc0 100644 --- a/paddler_cache_dir/src/cached_downloaded_model.rs +++ b/paddler_cache_dir/src/cached_downloaded_model.rs @@ -1,4 +1,5 @@ use std::fmt::Write as _; +use std::io; use std::path::PathBuf; use anyhow::Result; @@ -44,11 +45,33 @@ impl CachedDownloadedModel { }) } - pub async fn is_cached(&self) -> Result { - fs::try_exists(&self.cache_file_path).await + pub async fn is_cached(&self) -> Result { + match fs::metadata(&self.cache_file_path).await { + Ok(metadata) => Ok(metadata.is_file()), + Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(false), + Err(error) => Err(error), + } + } + + pub async fn remove_invalid_cache_entry(&self) -> Result<(), io::Error> { + let metadata = match fs::symlink_metadata(&self.cache_file_path).await { + Ok(metadata) => metadata, + Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(()), + Err(error) => return Err(error), + }; + + if metadata.is_file() { + return Ok(()); + } + + if metadata.is_dir() { + fs::remove_dir_all(&self.cache_file_path).await + } else { + fs::remove_file(&self.cache_file_path).await + } } - pub async fn ensure_cache_subdir_exists(&self) -> Result<(), std::io::Error> { + pub async fn ensure_cache_subdir_exists(&self) -> Result<(), io::Error> { fs::create_dir_all(&self.cache_subdir).await } @@ -202,6 +225,84 @@ mod tests { assert!(cached.is_cached().await.unwrap()); } + #[tokio::test] + async fn is_cached_returns_false_when_path_is_a_directory() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/leftover.gguf").unwrap(); + + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::create_dir(&cached.cache_file_path) + .await + .unwrap(); + + assert!( + !cached.is_cached().await.unwrap(), + "a directory occupying the cache file path is not a cached model" + ); + } + + #[tokio::test] + async fn remove_invalid_cache_entry_removes_stale_directory() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/stale.gguf").unwrap(); + + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::create_dir(&cached.cache_file_path) + .await + .unwrap(); + tokio::fs::write(cached.cache_file_path.join("inner"), b"leftover") + .await + .unwrap(); + + cached.remove_invalid_cache_entry().await.unwrap(); + + assert!( + !tokio::fs::try_exists(&cached.cache_file_path) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn remove_invalid_cache_entry_keeps_regular_file() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/real.gguf").unwrap(); + + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::write(&cached.cache_file_path, b"real model bytes") + .await + .unwrap(); + + cached.remove_invalid_cache_entry().await.unwrap(); + + assert_eq!( + tokio::fs::read(&cached.cache_file_path).await.unwrap(), + b"real model bytes" + ); + } + + #[tokio::test] + async fn remove_invalid_cache_entry_is_noop_when_absent() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/missing.gguf").unwrap(); + + cached.remove_invalid_cache_entry().await.unwrap(); + + assert!( + !tokio::fs::try_exists(&cached.cache_file_path) + .await + .unwrap() + ); + } + #[tokio::test] async fn try_acquire_download_lock_succeeds_when_uncontested() { let directory = TempDir::new().unwrap(); diff --git a/paddler_test_cluster_harness/Cargo.toml b/paddler_test_cluster_harness/Cargo.toml index e97271e6..0b63e5b3 100644 --- a/paddler_test_cluster_harness/Cargo.toml +++ b/paddler_test_cluster_harness/Cargo.toml @@ -26,6 +26,7 @@ url = { workspace = true } [dev-dependencies] http = { workspace = true } +tokio = { workspace = true, features = ["test-util"] } [lints] workspace = true diff --git a/paddler_test_cluster_harness/src/agents_stream_watcher.rs b/paddler_test_cluster_harness/src/agents_stream_watcher.rs index 19d0448a..30f47dcf 100644 --- a/paddler_test_cluster_harness/src/agents_stream_watcher.rs +++ b/paddler_test_cluster_harness/src/agents_stream_watcher.rs @@ -1,4 +1,5 @@ use std::pin::Pin; +use std::time::Duration; use anyhow::Context as _; use anyhow::Result; @@ -8,6 +9,9 @@ use futures_util::Stream; use futures_util::StreamExt as _; use paddler_client::client_management::ClientManagement; use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot; +use tokio::time::timeout; + +const UNTIL_TIMEOUT: Duration = Duration::from_secs(10); pub struct AgentsStreamWatcher { stream: Pin> + Send>>, @@ -35,6 +39,16 @@ impl AgentsStreamWatcher { Self { stream } } + async fn next_snapshot(&mut self) -> Result> { + match timeout(UNTIL_TIMEOUT, self.stream.next()).await { + Err(_elapsed) => Err(anyhow!( + "agents stream did not satisfy the predicate within {UNTIL_TIMEOUT:?}" + )), + Ok(None) => Ok(None), + Ok(Some(item)) => Ok(Some(item.context("agents stream yielded an error")?)), + } + } + pub async fn until( &mut self, mut predicate: TPredicate, @@ -42,9 +56,7 @@ impl AgentsStreamWatcher { where TPredicate: FnMut(&AgentControllerPoolSnapshot) -> bool, { - while let Some(item) = self.stream.next().await { - let snapshot = item.context("agents stream yielded an error")?; - + while let Some(snapshot) = self.next_snapshot().await? { if predicate(&snapshot) { return Ok(snapshot); } @@ -63,9 +75,7 @@ impl AgentsStreamWatcher { where TPredicate: FnMut(&AgentControllerPoolSnapshot) -> bool, { - while let Some(item) = self.stream.next().await { - let snapshot = item.context("agents stream yielded an error")?; - + while let Some(snapshot) = self.next_snapshot().await? { let agent_present = snapshot .agents .iter() @@ -219,6 +229,26 @@ mod tests { AgentsStreamWatcher::from_stream(Box::pin(stream::iter(snapshots.into_iter().map(Ok)))) } + #[tokio::test(start_paused = true)] + async fn until_times_out_when_predicate_is_never_satisfied() -> Result<()> { + let mut watcher = AgentsStreamWatcher::from_stream(Box::pin(stream::pending::< + Result, + >())); + + let error = watcher + .until(|_snapshot| false) + .await + .err() + .context("until must time out when no snapshot ever arrives")?; + + assert!( + format!("{error:#}").contains("within"), + "timeout error must mention the elapsed bound, got: {error:#}" + ); + + Ok(()) + } + #[tokio::test] async fn until_agent_returns_ok_when_predicate_matches_with_agent_present() -> Result<()> { let agent_id = "agent-x";