Skip to content
Open
155 changes: 153 additions & 2 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub struct DomainServices {
pub usage_service: Arc<dyn services::usage::UsageServiceTrait + Send + Sync>,
pub user_service: Arc<dyn services::user::UserServiceTrait + Send + Sync>,
pub files_service: Arc<dyn services::files::FileServiceTrait + Send + Sync>,
pub vector_store_service:
Arc<dyn services::vector_stores::VectorStoreServiceTrait + Send + Sync>,
pub metrics_service: Arc<dyn services::metrics::MetricsServiceTrait>,
}

Expand Down Expand Up @@ -418,6 +420,28 @@ pub async fn init_domain_services_with_pool(
let web_search_provider =
Arc::new(services::responses::tools::brave::BraveWebSearchProvider::new());

// Initialize RAG service client (if configured)
let rag_service: Option<Arc<dyn services::rag::RagServiceTrait>> =
if let Some(ref rag_config) = config.rag_service {
match services::rag::RagServiceClient::new(
rag_config.app_id.clone(),
rag_config.gateway_subdomain.clone(),
rag_config.auth_token_file.as_deref(),
rag_config.timeout_seconds,
) {
Ok(rag_client) => {
tracing::info!("RAG service client initialized successfully");
Some(Arc::new(rag_client))
}
Err(e) => {
tracing::warn!(error = %e, "Failed to initialize RAG service client");
None
}
}
} else {
None
};

// Create session repository for user service
let session_repo = Arc::new(database::SessionRepository::new(database.pool().clone()))
as Arc<dyn services::auth::SessionRepository>;
Expand Down Expand Up @@ -449,10 +473,32 @@ pub async fn init_domain_services_with_pool(
)) as Arc<dyn services::files::FileRepositoryTrait>;

let files_service = Arc::new(services::files::FileServiceImpl::new(
file_repository,
file_repository.clone(),
s3_storage,
)) as Arc<dyn services::files::FileServiceTrait + Send + Sync>;

// Create vector store ref repository and service (thin proxy to RAG)
let vs_ref_repo = Arc::new(database::PgVectorStoreRefRepository::new(
database.pool().clone(),
)) as Arc<dyn services::vector_stores::VectorStoreRefRepository>;

let vector_store_service: Arc<
dyn services::vector_stores::VectorStoreServiceTrait + Send + Sync,
> = if let Some(ref rag) = rag_service {
Arc::new(services::vector_stores::VectorStoreServiceImpl::new(
vs_ref_repo,
file_repository.clone(),
rag.clone(),
))
} else {
tracing::warn!("RAG service not configured — vector store endpoints will return errors");
Arc::new(services::vector_stores::VectorStoreServiceImpl::new(
vs_ref_repo,
file_repository.clone(),
Arc::new(services::rag::NotConfiguredRagService),
))
};

let response_service = Arc::new(services::ResponseService::new(
response_repo,
response_items_repo.clone(),
Expand All @@ -478,6 +524,7 @@ pub async fn init_domain_services_with_pool(
usage_service,
user_service,
files_service,
vector_store_service,
metrics_service,
}
}
Expand Down Expand Up @@ -513,14 +560,20 @@ pub async fn init_domain_services_with_mcp_factory(
let web_search_provider =
Arc::new(services::responses::tools::brave::BraveWebSearchProvider::new());

// Preserve file_search_provider from the base domain services
let file_search_provider = domain_services
.response_service
.file_search_provider
.clone();

let response_service = Arc::new(services::ResponseService::with_mcp_client_factory(
response_repo,
response_items_repo,
inference_provider_pool,
domain_services.conversation_service.clone(),
domain_services.completion_service.clone(),
Some(web_search_provider),
None,
file_search_provider,
domain_services.files_service.clone(), // Reuse files_service from base
organization_service,
mcp_client_factory,
Expand All @@ -530,6 +583,45 @@ pub async fn init_domain_services_with_mcp_factory(
domain_services
}

/// Initialize domain services with a custom RAG service (for testing vector stores)
#[allow(clippy::too_many_arguments)]
pub async fn init_domain_services_with_rag(
database: Arc<Database>,
config: &ApiConfig,
organization_service: Arc<dyn services::organization::OrganizationServiceTrait + Send + Sync>,
inference_provider_pool: Arc<services::inference_provider_pool::InferenceProviderPool>,
metrics_service: Arc<dyn services::metrics::MetricsServiceTrait>,
rag_service: Arc<dyn services::rag::RagServiceTrait>,
) -> DomainServices {
// Get the base domain services (uses NotConfiguredRagService since config.rag_service is None)
let mut domain_services = init_domain_services_with_pool(
database.clone(),
config,
organization_service,
inference_provider_pool,
metrics_service,
)
.await;

// Replace vector_store_service with one using the injected RAG service
let vs_ref_repo = Arc::new(database::PgVectorStoreRefRepository::new(
database.pool().clone(),
)) as Arc<dyn services::vector_stores::VectorStoreRefRepository>;

let file_repository = Arc::new(database::repositories::FileRepository::new(
database.pool().clone(),
)) as Arc<dyn services::files::FileRepositoryTrait>;

domain_services.vector_store_service =
Arc::new(services::vector_stores::VectorStoreServiceImpl::new(
vs_ref_repo,
file_repository,
rag_service,
));

domain_services
}

/// Initialize inference provider pool
pub async fn init_inference_providers(
config: &ApiConfig,
Expand Down Expand Up @@ -662,6 +754,7 @@ pub fn build_app_with_config(
usage_service: domain_services.usage_service.clone(),
user_service: domain_services.user_service.clone(),
files_service: domain_services.files_service.clone(),
vector_store_service: domain_services.vector_store_service.clone(),
inference_provider_pool: domain_services.inference_provider_pool.clone(),
metrics_service: domain_services.metrics_service.clone(),
config: config.clone(),
Expand Down Expand Up @@ -745,6 +838,9 @@ pub fn build_app_with_config(
let files_routes =
build_files_routes(app_state.clone(), &auth_components.auth_state_middleware);

let vector_store_routes =
build_vector_store_routes(app_state.clone(), &auth_components.auth_state_middleware);

let billing_routes = build_billing_routes(
domain_services.usage_service.clone(),
&auth_components.auth_state_middleware,
Expand Down Expand Up @@ -793,6 +889,7 @@ pub fn build_app_with_config(
.merge(invitation_routes)
.merge(auth_vpc_routes)
.merge(files_routes)
.merge(vector_store_routes)
.merge(billing_routes)
.merge(usage_recording_routes)
.merge(health_routes),
Expand Down Expand Up @@ -1119,6 +1216,58 @@ pub fn build_workspace_routes(app_state: AppState, auth_state_middleware: &AuthS
))
}

/// Build vector store routes
pub fn build_vector_store_routes(app_state: AppState, auth_state_middleware: &AuthState) -> Router {
use crate::routes::vector_stores::*;

Router::new()
.route(
"/vector_stores",
post(create_vector_store).get(list_vector_stores),
)
.route(
"/vector_stores/{vector_store_id}",
get(get_vector_store)
.post(modify_vector_store)
.delete(delete_vector_store),
)
.route(
"/vector_stores/{vector_store_id}/search",
post(search_vector_store),
)
.route(
"/vector_stores/{vector_store_id}/files",
post(create_vector_store_file).get(list_vector_store_files),
)
.route(
"/vector_stores/{vector_store_id}/files/{file_id}",
get(get_vector_store_file)
.post(update_vector_store_file)
.delete(delete_vector_store_file),
)
.route(
"/vector_stores/{vector_store_id}/file_batches",
post(create_vector_store_file_batch),
)
.route(
"/vector_stores/{vector_store_id}/file_batches/{batch_id}",
get(get_vector_store_file_batch),
)
.route(
"/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
post(cancel_vector_store_file_batch),
)
.route(
"/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
get(list_vector_store_file_batch_files),
)
.with_state(app_state)
.layer(from_fn_with_state(
auth_state_middleware.clone(),
auth_middleware_with_api_key,
))
}

/// Build file upload routes
pub fn build_files_routes(app_state: AppState, auth_state_middleware: &AuthState) -> Router {
use crate::routes::files::MAX_FILE_SIZE;
Expand Down Expand Up @@ -1442,6 +1591,7 @@ mod tests {
},
cors: config::CorsConfig::default(),
external_providers: config::ExternalProvidersConfig::default(),
rag_service: None,
};

// Initialize services
Expand Down Expand Up @@ -1545,6 +1695,7 @@ mod tests {
},
cors: config::CorsConfig::default(),
external_providers: config::ExternalProvidersConfig::default(),
rag_service: None,
};

let auth_components = init_auth_services(database.clone(), &config);
Expand Down
Loading