diff --git a/components/ads-client/src/client.rs b/components/ads-client/src/client.rs index adf0b82e44..b4ba056c33 100644 --- a/components/ads-client/src/client.rs +++ b/components/ads-client/src/client.rs @@ -4,28 +4,21 @@ */ use std::collections::HashMap; -use std::time::Duration; use crate::client::ad_response::{AdImage, AdResponse, AdResponseValue, AdSpoc, AdTile}; -use crate::client::config::AdsClientConfig; use crate::error::{RecordClickError, RecordImpressionError, ReportAdError, RequestAdsError}; -use crate::http_cache::{HttpCache, RequestCachePolicy}; +use crate::http_cache::{HttpCacheError, RequestCachePolicy}; use crate::mars::MARSClient; use crate::telemetry::Telemetry; use ad_request::{AdPlacementRequest, AdRequest}; -use context_id::{ContextIDComponent, DefaultContextIdCallback}; +use context_id::ContextIDComponent; use url::Url; -use uuid::Uuid; - -use crate::http_cache::{ByteSize, HttpCacheError}; pub mod ad_request; pub mod ad_response; +pub mod builder; pub mod config; -const DEFAULT_TTL_SECONDS: u64 = 300; -const DEFAULT_MAX_CACHE_SIZE_MIB: u64 = 10; - pub struct AdsClient where T: Clone + Telemetry, @@ -39,57 +32,17 @@ impl AdsClient where T: Clone + Telemetry, { - pub fn new(client_config: AdsClientConfig) -> Self { - let context_id = Uuid::new_v4().to_string(); - let context_id_component = ContextIDComponent::new( - &context_id, - 0, - cfg!(test), - Box::new(DefaultContextIdCallback), - ); - let telemetry = client_config.telemetry; - - // Configure the cache if a path is provided. - // Defaults for ttl and cache size are also set if unspecified. - if let Some(cache_cfg) = client_config.cache_config { - let default_cache_ttl = Duration::from_secs( - cache_cfg - .default_cache_ttl_seconds - .unwrap_or(DEFAULT_TTL_SECONDS), - ); - let max_cache_size = - ByteSize::mib(cache_cfg.max_size_mib.unwrap_or(DEFAULT_MAX_CACHE_SIZE_MIB)); - - let http_cache = match HttpCache::builder(cache_cfg.db_path) - .max_size(max_cache_size) - .default_ttl(default_cache_ttl) - .build() - { - Ok(cache) => Some(cache), - Err(e) => { - telemetry.record(&e); - None - } - }; - - let client = MARSClient::new(client_config.environment, http_cache, telemetry.clone()); - let client = Self { - context_id_component, - client, - telemetry: telemetry.clone(), - }; - telemetry.record(&ClientOperationEvent::New); - return client; - } - - let client = MARSClient::new(client_config.environment, None, telemetry.clone()); - let client = Self { - context_id_component, - client, - telemetry: telemetry.clone(), - }; + pub fn new( + client: MARSClient, + context_id_component: ContextIDComponent, + telemetry: T, + ) -> Self { telemetry.record(&ClientOperationEvent::New); - client + Self { + client, + context_id_component, + telemetry, + } } fn request_ads( @@ -221,6 +174,8 @@ pub enum ClientOperationEvent { #[cfg(test)] mod tests { + use context_id::DefaultContextIdCallback; + use crate::{ client::config::Environment, ffi::telemetry::MozAdsTelemetryWrapper, @@ -241,21 +196,20 @@ mod tests { false, Box::new(DefaultContextIdCallback), ); - AdsClient { - context_id_component, - client, - telemetry: MozAdsTelemetryWrapper::noop(), - } + AdsClient::new(client, context_id_component, MozAdsTelemetryWrapper::noop()) } #[test] fn test_get_context_id() { - let config = AdsClientConfig { - environment: Environment::Test, - cache_config: None, - telemetry: MozAdsTelemetryWrapper::noop(), - }; - let client = AdsClient::new(config); + let telemetry = MozAdsTelemetryWrapper::noop(); + let mars_client = MARSClient::new(Environment::Test, None, telemetry.clone()); + let context_id_component = ContextIDComponent::new( + &uuid::Uuid::new_v4().to_string(), + 0, + false, + Box::new(DefaultContextIdCallback), + ); + let client = AdsClient::new(mars_client, context_id_component, telemetry); let context_id = client.get_context_id().unwrap(); assert!(!context_id.is_empty()); } @@ -329,6 +283,8 @@ mod tests { #[test] #[ignore = "Cache invalidation temporarily disabled - will be re-enabled behind Nimbus experiment"] fn test_record_click_invalidates_cache() { + use crate::http_cache::HttpCache; + viaduct_dev::init_backend_dev(); let cache = HttpCache::builder("test_record_click_invalidates_cache") .build() diff --git a/components/ads-client/src/client/builder.rs b/components/ads-client/src/client/builder.rs new file mode 100644 index 0000000000..b044d949b2 --- /dev/null +++ b/components/ads-client/src/client/builder.rs @@ -0,0 +1,329 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public +* License, v. 2.0. If a copy of the MPL was not distributed with this +* file, You can obtain one at http://mozilla.org/MPL/2.0/. +*/ + +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use context_id::{ContextIDComponent, ContextIdCallback}; +use uuid::Uuid; + +use crate::error::BuildError; +use crate::http_cache::{ByteSize, HttpCache}; +use crate::mars::MARSClient; +use crate::telemetry::Telemetry; + +use super::config::Environment; +use super::AdsClient; + +const DEFAULT_TTL_SECONDS: u64 = 300; +const DEFAULT_MAX_CACHE_SIZE_MIB: u64 = 10; +const DATA_DIR_NAME: &str = "ads-client"; + +pub struct AdsClientBuilder { + data_dir: Option, + environment: Environment, + cache_ttl_seconds: Option, + cache_max_size_mib: Option, + legacy_cache_db_path: Option, + telemetry: T, +} + +impl AdsClientBuilder +where + T: Clone + Telemetry, +{ + pub fn new(telemetry: T) -> Self { + Self { + data_dir: None, + environment: Environment::default(), + cache_ttl_seconds: None, + cache_max_size_mib: None, + legacy_cache_db_path: None, + telemetry, + } + } + + pub fn data_dir(mut self, dir: String) -> Self { + self.data_dir = Some(dir); + self + } + + pub fn environment(mut self, env: Environment) -> Self { + self.environment = env; + self + } + + pub fn cache_ttl_seconds(mut self, ttl: u64) -> Self { + self.cache_ttl_seconds = Some(ttl); + self + } + + pub fn cache_max_size_mib(mut self, size: u64) -> Self { + self.cache_max_size_mib = Some(size); + self + } + + pub fn legacy_cache_db_path(mut self, path: String) -> Self { + self.legacy_cache_db_path = Some(path); + self + } + + pub fn build(self) -> Result, BuildError> { + let data_dir = resolve_data_dir( + self.data_dir.as_deref(), + self.legacy_cache_db_path.as_deref(), + ) + .ok_or(BuildError::NoDataDir)?; + + std::fs::create_dir_all(&data_dir).map_err(|e| BuildError::CreateDataDir { + path: data_dir.display().to_string(), + reason: e.to_string(), + })?; + + let data_dir_str = data_dir.to_string_lossy().to_string(); + + // Read persisted context_id or generate a fresh one + let (context_id, context_id_ts, callback): (String, i64, Box) = + match read_persisted_context_id(&data_dir_str) { + Some((id, ts)) => (id, ts, Box::new(FileContextIdCallback::new(&data_dir_str))), + None => ( + Uuid::new_v4().to_string(), + 0, + Box::new(FileContextIdCallback::new(&data_dir_str)), + ), + }; + + let context_id_component = + ContextIDComponent::new(&context_id, context_id_ts, cfg!(test), callback); + + // Build cache at {data_dir}/cache.db + let cache_db_path = data_dir.join("cache.db"); + let ttl = Duration::from_secs(self.cache_ttl_seconds.unwrap_or(DEFAULT_TTL_SECONDS)); + let max_size = ByteSize::mib( + self.cache_max_size_mib + .unwrap_or(DEFAULT_MAX_CACHE_SIZE_MIB), + ); + + let http_cache = match HttpCache::builder(cache_db_path) + .default_ttl(ttl) + .max_size(max_size) + .build() + { + Ok(cache) => Some(cache), + Err(e) => { + self.telemetry.record(&e); + None + } + }; + + let client = MARSClient::new(self.environment, http_cache, self.telemetry.clone()); + + Ok(AdsClient::new(client, context_id_component, self.telemetry)) + } +} + +/// Resolves the data directory following a fallback chain: +/// 1. Explicit `data_dir` if provided +/// 2. Parent directory of `legacy_cache_db_path` if provided +/// 3. `$HOME` env var +/// 4. `$APPDATA` or `$LOCALAPPDATA` (Windows) +/// +/// Always appends an `ads-client/` subdirectory for isolation. +fn resolve_data_dir( + explicit_data_dir: Option<&str>, + legacy_cache_db_path: Option<&str>, +) -> Option { + // 1. Explicit data_dir + if let Some(dir) = explicit_data_dir { + let p = PathBuf::from(dir); + if !p.as_os_str().is_empty() { + return Some(p.join(DATA_DIR_NAME)); + } + } + + // 2. Parent of legacy cache db_path + if let Some(db_path) = legacy_cache_db_path { + let p = Path::new(db_path); + if let Some(parent) = p.parent() { + if !parent.as_os_str().is_empty() { + return Some(parent.join(DATA_DIR_NAME)); + } + } + } + + // 3. $HOME + if let Ok(home) = std::env::var("HOME") { + let p = PathBuf::from(&home); + if !p.as_os_str().is_empty() { + return Some(p.join(DATA_DIR_NAME)); + } + } + + // 4. $APPDATA / $LOCALAPPDATA (Windows) + for var in &["APPDATA", "LOCALAPPDATA"] { + if let Ok(val) = std::env::var(var) { + let p = PathBuf::from(&val); + if !p.as_os_str().is_empty() { + return Some(p.join(DATA_DIR_NAME)); + } + } + } + + None +} + +/// A file-based implementation of [ContextIdCallback] that persists +/// the context_id and creation timestamp to `{data_dir}/context_id`. +struct FileContextIdCallback { + file_path: PathBuf, +} + +impl FileContextIdCallback { + fn new(data_dir: &str) -> Self { + Self { + file_path: PathBuf::from(data_dir).join("context_id"), + } + } +} + +impl ContextIdCallback for FileContextIdCallback { + fn persist(&self, context_id: String, creation_date: i64) { + let content = format!("{}\n{}", context_id, creation_date); + if let Err(e) = std::fs::write(&self.file_path, content) { + eprintln!( + "Failed to persist context_id to {}: {}", + self.file_path.display(), + e + ); + } + } + + fn rotated(&self, _old_context_id: String) { + // No-op: the subsequent persist() call updates the file. + } +} + +/// Reads a previously persisted context_id and creation timestamp from file. +/// Returns `None` if the file is missing, corrupt, or contains an invalid UUID. +fn read_persisted_context_id(data_dir: &str) -> Option<(String, i64)> { + let file_path = PathBuf::from(data_dir).join("context_id"); + let content = std::fs::read_to_string(&file_path).ok()?; + let mut lines = content.lines(); + let context_id = lines.next()?.trim().to_string(); + let timestamp_str = lines.next()?.trim(); + let timestamp = timestamp_str.parse::().ok()?; + + // Validate UUID format + if Uuid::parse_str(&context_id).is_err() { + return None; + } + + Some((context_id, timestamp)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_resolve_data_dir_explicit() { + let result = resolve_data_dir(Some("/tmp/myapp"), None); + assert_eq!(result, Some(PathBuf::from("/tmp/myapp/ads-client"))); + } + + #[test] + fn test_resolve_data_dir_from_legacy_db_path() { + let result = resolve_data_dir(None, Some("/data/app/cache.db")); + assert_eq!(result, Some(PathBuf::from("/data/app/ads-client"))); + } + + #[test] + fn test_resolve_data_dir_explicit_takes_priority() { + let result = resolve_data_dir(Some("/explicit"), Some("/legacy/cache.db")); + assert_eq!(result, Some(PathBuf::from("/explicit/ads-client"))); + } + + #[test] + fn test_resolve_data_dir_empty_explicit_falls_through() { + let result = resolve_data_dir(Some(""), Some("/legacy/cache.db")); + assert_eq!(result, Some(PathBuf::from("/legacy/ads-client"))); + } + + #[test] + fn test_resolve_data_dir_home_fallback() { + // $HOME is typically set in test environments + let result = resolve_data_dir(None, None); + if std::env::var("HOME").is_ok() { + assert!(result.is_some()); + let path = result.unwrap(); + assert!(path.ends_with("ads-client")); + } + } + + #[test] + fn test_read_persisted_context_id_valid() { + let dir = std::env::temp_dir().join("ads-client-test-read-valid"); + std::fs::create_dir_all(&dir).unwrap(); + let file_path = dir.join("context_id"); + let uuid = Uuid::new_v4().to_string(); + let ts = 1745859061i64; + std::fs::write(&file_path, format!("{}\n{}", uuid, ts)).unwrap(); + + let result = read_persisted_context_id(dir.to_str().unwrap()); + assert_eq!(result, Some((uuid, ts))); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_read_persisted_context_id_missing_file() { + let result = read_persisted_context_id("/nonexistent/path"); + assert_eq!(result, None); + } + + #[test] + fn test_read_persisted_context_id_invalid_uuid() { + let dir = std::env::temp_dir().join("ads-client-test-read-invalid-uuid"); + std::fs::create_dir_all(&dir).unwrap(); + let file_path = dir.join("context_id"); + std::fs::write(&file_path, "not-a-uuid\n12345").unwrap(); + + let result = read_persisted_context_id(dir.to_str().unwrap()); + assert_eq!(result, None); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_read_persisted_context_id_missing_timestamp() { + let dir = std::env::temp_dir().join("ads-client-test-read-no-ts"); + std::fs::create_dir_all(&dir).unwrap(); + let file_path = dir.join("context_id"); + let uuid = Uuid::new_v4().to_string(); + std::fs::write(&file_path, uuid).unwrap(); + + let result = read_persisted_context_id(dir.to_str().unwrap()); + assert_eq!(result, None); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_file_context_id_callback_roundtrip() { + let dir = std::env::temp_dir().join("ads-client-test-callback-roundtrip"); + std::fs::create_dir_all(&dir).unwrap(); + let dir_str = dir.to_str().unwrap(); + + let callback = FileContextIdCallback::new(dir_str); + let uuid = Uuid::new_v4().to_string(); + let ts = 1745859061i64; + callback.persist(uuid.clone(), ts); + + let result = read_persisted_context_id(dir_str); + assert_eq!(result, Some((uuid, ts))); + + std::fs::remove_dir_all(&dir).ok(); + } +} diff --git a/components/ads-client/src/client/config.rs b/components/ads-client/src/client/config.rs index 759c0b907b..f31f17e90b 100644 --- a/components/ads-client/src/client/config.rs +++ b/components/ads-client/src/client/config.rs @@ -6,23 +6,12 @@ use once_cell::sync::Lazy; use url::Url; -use crate::telemetry::Telemetry; - static MARS_API_ENDPOINT_PROD: Lazy = Lazy::new(|| Url::parse("https://ads.mozilla.org/v1/").expect("hardcoded URL must be valid")); static MARS_API_ENDPOINT_STAGING: Lazy = Lazy::new(|| Url::parse("https://ads.allizom.org/v1/").expect("hardcoded URL must be valid")); -pub struct AdsClientConfig -where - T: Telemetry, -{ - pub environment: Environment, - pub cache_config: Option, - pub telemetry: T, -} - #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub enum Environment { #[default] @@ -43,13 +32,6 @@ impl Environment { } } -#[derive(Clone, Debug)] -pub struct AdsCacheConfig { - pub db_path: String, - pub default_cache_ttl_seconds: Option, - pub max_size_mib: Option, -} - #[cfg(test)] mod tests { use url::Host; diff --git a/components/ads-client/src/error.rs b/components/ads-client/src/error.rs index a041286186..b8cf2b5b4d 100644 --- a/components/ads-client/src/error.rs +++ b/components/ads-client/src/error.rs @@ -8,6 +8,9 @@ use viaduct::Response; #[derive(Debug, thiserror::Error)] pub enum ComponentError { + #[error("Error building ads client: {0}")] + Build(#[from] BuildError), + #[error("Error requesting ads: {0}")] RequestAds(#[from] RequestAdsError), @@ -21,6 +24,17 @@ pub enum ComponentError { ReportAd(#[from] ReportAdError), } +#[derive(Debug, thiserror::Error)] +pub enum BuildError { + #[error( + "Could not resolve a data directory. Provide a data_dir or cache_config with db_path." + )] + NoDataDir, + + #[error("Failed to create data directory at {path}: {reason}")] + CreateDataDir { path: String, reason: String }, +} + #[derive(Debug, thiserror::Error)] pub enum RequestAdsError { #[error("Error building ad requests from configs: {0}")] diff --git a/components/ads-client/src/ffi.rs b/components/ads-client/src/ffi.rs index d290bda5c9..acd0309d50 100644 --- a/components/ads-client/src/ffi.rs +++ b/components/ads-client/src/ffi.rs @@ -11,13 +11,13 @@ use crate::client::ad_request::{AdContentCategory, AdPlacementRequest, IABConten use crate::client::ad_response::{ AdCallbacks, AdImage, AdSpoc, AdTile, SpocFrequencyCaps, SpocRanking, }; -use crate::client::config::{AdsCacheConfig, AdsClientConfig, Environment}; -use crate::client::AdsClient; +use crate::client::builder::AdsClientBuilder; +use crate::client::config::Environment; use crate::error::ComponentError; use crate::ffi::telemetry::MozAdsTelemetryWrapper; use crate::http_cache::{CacheMode, RequestCachePolicy}; use crate::MozAdsClient; -use error_support::{ErrorHandling, GetErrorHandling}; +use error_support::{handle_error, ErrorHandling, GetErrorHandling}; use parking_lot::Mutex; use url::Url; @@ -96,6 +96,7 @@ pub struct MozAdsClientBuilder(Mutex); #[derive(Default)] struct MozAdsClientBuilderInner { + data_dir: Option, environment: Option, cache_config: Option, telemetry: Option>, @@ -114,6 +115,11 @@ impl MozAdsClientBuilder { Self::default() } + pub fn data_dir(self: Arc, data_dir: String) -> Arc { + self.0.lock().data_dir = Some(data_dir); + self + } + pub fn environment(self: Arc, environment: MozAdsEnvironment) -> Arc { self.0.lock().environment = Some(environment); self @@ -129,21 +135,42 @@ impl MozAdsClientBuilder { self } - pub fn build(&self) -> MozAdsClient { + #[handle_error(ComponentError)] + pub fn try_build(&self) -> AdsClientApiResult { let inner = self.0.lock(); - let client_config = AdsClientConfig { - environment: inner.environment.unwrap_or_default().into(), - cache_config: inner.cache_config.clone().map(Into::into), - telemetry: inner - .telemetry - .clone() - .map(MozAdsTelemetryWrapper::new) - .unwrap_or_else(MozAdsTelemetryWrapper::noop), - }; - let client = AdsClient::new(client_config); - MozAdsClient { - inner: Mutex::new(client), + let telemetry = inner + .telemetry + .clone() + .map(MozAdsTelemetryWrapper::new) + .unwrap_or_else(MozAdsTelemetryWrapper::noop); + + let mut builder = AdsClientBuilder::new(telemetry) + .environment(inner.environment.unwrap_or_default().into()); + + if let Some(ref dir) = inner.data_dir { + builder = builder.data_dir(dir.clone()); + } + if let Some(ref cfg) = inner.cache_config { + builder = builder.legacy_cache_db_path(cfg.db_path.clone()); + if let Some(ttl) = cfg.default_cache_ttl_seconds { + builder = builder.cache_ttl_seconds(ttl); + } + if let Some(size) = cfg.max_size_mib { + builder = builder.cache_max_size_mib(size); + } } + + let client = builder.build()?; + Ok(MozAdsClient { + inner: Mutex::new(client), + }) + } + + /// Deprecated: Use `try_build()` instead, which returns a Result with + /// proper error handling. + pub fn build(&self) -> MozAdsClient { + self.try_build() + .expect("AdsClient build failed. Use try_build() for proper error handling.") } } @@ -418,16 +445,6 @@ impl From> for RequestCachePolicy { } } -impl From for AdsCacheConfig { - fn from(config: MozAdsCacheConfig) -> Self { - Self { - db_path: config.db_path, - default_cache_ttl_seconds: config.default_cache_ttl_seconds, - max_size_mib: config.max_size_mib, - } - } -} - impl From<&MozAdsPlacementRequest> for AdPlacementRequest { fn from(request: &MozAdsPlacementRequest) -> Self { Self {