diff --git a/crates/anyedge-adapter-axum/src/request.rs b/crates/anyedge-adapter-axum/src/request.rs index 8787d87..563252a 100644 --- a/crates/anyedge-adapter-axum/src/request.rs +++ b/crates/anyedge-adapter-axum/src/request.rs @@ -6,16 +6,30 @@ use anyedge_core::proxy::ProxyHandle; use axum::body::Body as AxumBody; use axum::extract::connect_info::ConnectInfo; use axum::http::Request; +use http::header::CONTENT_TYPE; +use http::HeaderValue; use crate::context::AxumRequestContext; use crate::proxy::AxumProxyClient; /// Convert an Axum/Hyper request into an AnyEdge core request while preserving streaming bodies /// and exposing connection metadata through `AxumRequestContext`. -pub fn into_core_request(request: Request) -> CoreRequest { +pub async fn into_core_request(request: Request) -> Result { let (parts, body) = request.into_parts(); - let stream = body.into_data_stream(); - let body = Body::from_stream(stream); + + let body = match parts.headers.get(CONTENT_TYPE) { + Some(value) if is_json_content_type(value) => { + let bytes = axum::body::to_bytes(body, usize::MAX) + .await + .map_err(|e| format!("Failed to convert body into bytes: {e}"))?; + Body::from_bytes(bytes) + } + _ => { + let stream = body.into_data_stream(); + Body::from_stream(stream) + } + }; + let mut core_request = CoreRequest::from_parts(parts, body); if let Some(remote_addr) = core_request @@ -38,7 +52,29 @@ pub fn into_core_request(request: Request) -> CoreRequest { .extensions_mut() .insert(ProxyHandle::with_client(AxumProxyClient::default())); - core_request + Ok(core_request) +} + +fn is_json_content_type(value: &HeaderValue) -> bool { + let Ok(raw) = value.to_str() else { + return false; + }; + + let media_type = raw.split(';').next().map(str::trim).unwrap_or(""); + if media_type.eq_ignore_ascii_case("application/json") { + return true; + } + + let Some((ty, subtype)) = media_type.split_once('/') else { + return false; + }; + + if !ty.eq_ignore_ascii_case("application") { + return false; + } + + let subtype = subtype.trim(); + subtype.len() >= 5 && subtype[subtype.len() - 5..].eq_ignore_ascii_case("+json") } #[cfg(test)] @@ -47,8 +83,8 @@ mod tests { use anyedge_core::body::Body; use anyedge_core::http::Method; - #[test] - fn converts_request_and_records_connect_info() { + #[tokio::test] + async fn converts_request_and_records_connect_info() { let mut request = Request::builder() .method(Method::POST) .uri("/demo") @@ -59,7 +95,9 @@ mod tests { .extensions_mut() .insert(ConnectInfo::("127.0.0.1:4000".parse().unwrap())); - let core_request = into_core_request(request); + let core_request = into_core_request(request) + .await + .expect("request conversion"); assert_eq!(core_request.method(), &Method::POST); assert_eq!(core_request.uri().path(), "/demo"); assert_eq!(core_request.headers()["x-test"], "1"); @@ -76,15 +114,79 @@ mod tests { .is_none()); } - #[test] - fn missing_connect_info_is_handled_gracefully() { + #[tokio::test] + async fn missing_connect_info_is_handled_gracefully() { let request = Request::builder() .method(Method::GET) .uri("/demo") .body(AxumBody::empty()) .expect("request"); - let core_request = into_core_request(request); + let core_request = into_core_request(request) + .await + .expect("request conversion"); assert!(AxumRequestContext::get(&core_request).is_none()); } + + #[tokio::test] + async fn json_content_type_buffers_body() { + let json_payload = r#"{"name":"test"}"#; + let request = Request::builder() + .method(Method::POST) + .uri("/api/test") + .header("content-type", "application/json") + .body(AxumBody::from(json_payload)) + .expect("request"); + + let core_request = into_core_request(request) + .await + .expect("request conversion"); + assert_eq!(core_request.method(), &Method::POST); + + match core_request.body() { + Body::Once(bytes) => { + assert_eq!(bytes.as_ref(), json_payload.as_bytes()); + } + Body::Stream(_) => panic!("JSON body should be buffered, not streaming"), + } + } + + #[tokio::test] + async fn non_json_content_type_streams_body() { + let request = Request::builder() + .method(Method::POST) + .uri("/upload") + .header("content-type", "application/octet-stream") + .body(AxumBody::from("binary data")) + .expect("request"); + + let core_request = into_core_request(request) + .await + .expect("request conversion"); + + assert!(matches!(core_request.body(), Body::Stream(_))); + } + + #[test] + fn test_is_json_content_type() { + assert!(is_json_content_type(&HeaderValue::from_static( + "application/json" + ))); + assert!(is_json_content_type(&HeaderValue::from_static( + "application/json; charset=utf-8" + ))); + assert!(is_json_content_type(&HeaderValue::from_static( + "application/vnd.api+json" + ))); + assert!(is_json_content_type(&HeaderValue::from_static( + "APPLICATION/VND.CUSTOM+JSON; CHARSET=UTF-8" + ))); + + assert!(!is_json_content_type(&HeaderValue::from_static( + "text/json" + ))); + assert!(!is_json_content_type(&HeaderValue::from_static( + "application/json+xml" + ))); + } } diff --git a/crates/anyedge-adapter-axum/src/service.rs b/crates/anyedge-adapter-axum/src/service.rs index 77ea0e1..43cfa55 100644 --- a/crates/anyedge-adapter-axum/src/service.rs +++ b/crates/anyedge-adapter-axum/src/service.rs @@ -3,8 +3,10 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use axum::body::Body; use axum::body::Body as AxumBody; use axum::http::{Request, Response}; +use http::StatusCode; use tokio::{runtime::Handle, task}; use tower::Service; @@ -37,7 +39,15 @@ impl Service> for AnyEdgeAxumService { fn call(&mut self, request: Request) -> Self::Future { let router = self.router.clone(); Box::pin(async move { - let core_request = into_core_request(request); + let core_request = match into_core_request(request).await { + Ok(req) => req, + Err(e) => { + let mut err_response = Response::new(Body::from(e.to_string())); + *err_response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + + return Ok(err_response); + } + }; let core_response = task::block_in_place(move || { Handle::current().block_on(router.oneshot(core_request)) });