Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions paddler_agent/src/model_source/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ async fn resolve_url_into_cache(
}
};

if let Err(io_error) = cached.remove_invalid_cache_entry().await {
slot_aggregated_status.reset_download();
slot_aggregated_status.register_issue(classify_cache_io_error(url_string, &io_error));

return Err(anyhow::Error::new(io_error));
}

let basename = cached
.cache_file_path
.file_name()
Expand Down Expand Up @@ -775,4 +782,61 @@ mod tests {
));
assert_eq!(tokio::fs::read(&expected_path).await.unwrap(), body);
}

#[tokio::test]
async fn resolve_recovers_from_stale_directory_then_downloads() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url_string = format!("http://127.0.0.1:{port}/model.gguf");
let body = b"recovered model bytes".to_vec();
let server = tokio::spawn(serve_single_ok_response(listener, body.clone()));

let cached = CachedDownloadedModel::new(&cache_dir, &url_string).unwrap();
let expected_path = cached.cache_file_path.clone();
cached.ensure_cache_subdir_exists().await.unwrap();
tokio::fs::create_dir(&expected_path).await.unwrap();

let resolution = resolve_url_into_cache(&url_string, &cache_dir, fresh_status())
.await
.unwrap();

server.await.unwrap();

assert!(matches!(
resolution,
DesiredModelResolution::Resolved(resolved_path) if resolved_path == expected_path
));
assert_eq!(tokio::fs::read(&expected_path).await.unwrap(), body);
}

#[tokio::test]
async fn resolve_reports_unreachable_when_dead_url_with_stale_directory() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let url_string = format!("http://127.0.0.1:{port}/model.gguf");

let cached = CachedDownloadedModel::new(&cache_dir, &url_string).unwrap();
cached.ensure_cache_subdir_exists().await.unwrap();
tokio::fs::create_dir(&cached.cache_file_path)
.await
.unwrap();

let status = fresh_status();
let result = resolve_url_into_cache(&url_string, &cache_dir, status.clone()).await;

assert!(result.is_err(), "a dead URL must produce an Err");
assert!(
status.has_issue(&AgentIssue::DownloadServerIsUnreachable(ModelPath {
model_path: url_string.clone(),
})),
"recovery must clear the stale dir so the dead-URL download reports unreachable"
);
}
}
107 changes: 104 additions & 3 deletions paddler_cache_dir/src/cached_downloaded_model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::Write as _;
use std::io;
use std::path::PathBuf;

use anyhow::Result;
Expand Down Expand Up @@ -44,11 +45,33 @@ impl CachedDownloadedModel {
})
}

pub async fn is_cached(&self) -> Result<bool, std::io::Error> {
fs::try_exists(&self.cache_file_path).await
pub async fn is_cached(&self) -> Result<bool, io::Error> {
match fs::metadata(&self.cache_file_path).await {
Ok(metadata) => Ok(metadata.is_file()),
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(false),
Err(error) => Err(error),
}
}

pub async fn remove_invalid_cache_entry(&self) -> Result<(), io::Error> {
let metadata = match fs::symlink_metadata(&self.cache_file_path).await {
Ok(metadata) => metadata,
Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(()),
Err(error) => return Err(error),
};

if metadata.is_file() {
return Ok(());
}

if metadata.is_dir() {
fs::remove_dir_all(&self.cache_file_path).await
} else {
fs::remove_file(&self.cache_file_path).await
}
}

pub async fn ensure_cache_subdir_exists(&self) -> Result<(), std::io::Error> {
pub async fn ensure_cache_subdir_exists(&self) -> Result<(), io::Error> {
fs::create_dir_all(&self.cache_subdir).await
}

Expand Down Expand Up @@ -202,6 +225,84 @@ mod tests {
assert!(cached.is_cached().await.unwrap());
}

#[tokio::test]
async fn is_cached_returns_false_when_path_is_a_directory() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());
let cached =
CachedDownloadedModel::new(&cache_dir, "https://host.example/leftover.gguf").unwrap();

cached.ensure_cache_subdir_exists().await.unwrap();
tokio::fs::create_dir(&cached.cache_file_path)
.await
.unwrap();

assert!(
!cached.is_cached().await.unwrap(),
"a directory occupying the cache file path is not a cached model"
);
}

#[tokio::test]
async fn remove_invalid_cache_entry_removes_stale_directory() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());
let cached =
CachedDownloadedModel::new(&cache_dir, "https://host.example/stale.gguf").unwrap();

cached.ensure_cache_subdir_exists().await.unwrap();
tokio::fs::create_dir(&cached.cache_file_path)
.await
.unwrap();
tokio::fs::write(cached.cache_file_path.join("inner"), b"leftover")
.await
.unwrap();

cached.remove_invalid_cache_entry().await.unwrap();

assert!(
!tokio::fs::try_exists(&cached.cache_file_path)
.await
.unwrap()
);
}

#[tokio::test]
async fn remove_invalid_cache_entry_keeps_regular_file() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());
let cached =
CachedDownloadedModel::new(&cache_dir, "https://host.example/real.gguf").unwrap();

cached.ensure_cache_subdir_exists().await.unwrap();
tokio::fs::write(&cached.cache_file_path, b"real model bytes")
.await
.unwrap();

cached.remove_invalid_cache_entry().await.unwrap();

assert_eq!(
tokio::fs::read(&cached.cache_file_path).await.unwrap(),
b"real model bytes"
);
}

#[tokio::test]
async fn remove_invalid_cache_entry_is_noop_when_absent() {
let directory = TempDir::new().unwrap();
let cache_dir = cache_dir_at(directory.path());
let cached =
CachedDownloadedModel::new(&cache_dir, "https://host.example/missing.gguf").unwrap();

cached.remove_invalid_cache_entry().await.unwrap();

assert!(
!tokio::fs::try_exists(&cached.cache_file_path)
.await
.unwrap()
);
}

#[tokio::test]
async fn try_acquire_download_lock_succeeds_when_uncontested() {
let directory = TempDir::new().unwrap();
Expand Down
1 change: 1 addition & 0 deletions paddler_test_cluster_harness/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ url = { workspace = true }

[dev-dependencies]
http = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }

[lints]
workspace = true
42 changes: 36 additions & 6 deletions paddler_test_cluster_harness/src/agents_stream_watcher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::pin::Pin;
use std::time::Duration;

use anyhow::Context as _;
use anyhow::Result;
Expand All @@ -8,6 +9,9 @@ use futures_util::Stream;
use futures_util::StreamExt as _;
use paddler_client::client_management::ClientManagement;
use paddler_messaging::agent_controller_pool_snapshot::AgentControllerPoolSnapshot;
use tokio::time::timeout;

const UNTIL_TIMEOUT: Duration = Duration::from_secs(10);

pub struct AgentsStreamWatcher {
stream: Pin<Box<dyn Stream<Item = Result<AgentControllerPoolSnapshot>> + Send>>,
Expand Down Expand Up @@ -35,16 +39,24 @@ impl AgentsStreamWatcher {
Self { stream }
}

async fn next_snapshot(&mut self) -> Result<Option<AgentControllerPoolSnapshot>> {
match timeout(UNTIL_TIMEOUT, self.stream.next()).await {
Err(_elapsed) => Err(anyhow!(
"agents stream did not satisfy the predicate within {UNTIL_TIMEOUT:?}"
)),
Ok(None) => Ok(None),
Ok(Some(item)) => Ok(Some(item.context("agents stream yielded an error")?)),
}
}

pub async fn until<TPredicate>(
&mut self,
mut predicate: TPredicate,
) -> Result<AgentControllerPoolSnapshot>
where
TPredicate: FnMut(&AgentControllerPoolSnapshot) -> bool,
{
while let Some(item) = self.stream.next().await {
let snapshot = item.context("agents stream yielded an error")?;

while let Some(snapshot) = self.next_snapshot().await? {
if predicate(&snapshot) {
return Ok(snapshot);
}
Expand All @@ -63,9 +75,7 @@ impl AgentsStreamWatcher {
where
TPredicate: FnMut(&AgentControllerPoolSnapshot) -> bool,
{
while let Some(item) = self.stream.next().await {
let snapshot = item.context("agents stream yielded an error")?;

while let Some(snapshot) = self.next_snapshot().await? {
let agent_present = snapshot
.agents
.iter()
Expand Down Expand Up @@ -219,6 +229,26 @@ mod tests {
AgentsStreamWatcher::from_stream(Box::pin(stream::iter(snapshots.into_iter().map(Ok))))
}

#[tokio::test(start_paused = true)]
async fn until_times_out_when_predicate_is_never_satisfied() -> Result<()> {
let mut watcher = AgentsStreamWatcher::from_stream(Box::pin(stream::pending::<
Result<AgentControllerPoolSnapshot>,
>()));

let error = watcher
.until(|_snapshot| false)
.await
.err()
.context("until must time out when no snapshot ever arrives")?;

assert!(
format!("{error:#}").contains("within"),
"timeout error must mention the elapsed bound, got: {error:#}"
);

Ok(())
}

#[tokio::test]
async fn until_agent_returns_ok_when_predicate_matches_with_agent_present() -> Result<()> {
let agent_id = "agent-x";
Expand Down
Loading