Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 112 additions & 10 deletions crates/anyedge-adapter-axum/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@ use anyedge_core::proxy::ProxyHandle;
use axum::body::Body as AxumBody;
use axum::extract::connect_info::ConnectInfo;
use axum::http::Request;
use http::header::CONTENT_TYPE;
use http::HeaderValue;

use crate::context::AxumRequestContext;
use crate::proxy::AxumProxyClient;

/// Convert an Axum/Hyper request into an AnyEdge core request while preserving streaming bodies
/// and exposing connection metadata through `AxumRequestContext`.
pub fn into_core_request(request: Request<AxumBody>) -> CoreRequest {
pub async fn into_core_request(request: Request<AxumBody>) -> Result<CoreRequest, String> {
let (parts, body) = request.into_parts();
let stream = body.into_data_stream();
let body = Body::from_stream(stream);

let body = match parts.headers.get(CONTENT_TYPE) {
Some(value) if is_json_content_type(value) => {
let bytes = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|e| format!("Failed to convert body into bytes: {e}"))?;
Body::from_bytes(bytes)
}
_ => {
let stream = body.into_data_stream();
Body::from_stream(stream)
}
};

let mut core_request = CoreRequest::from_parts(parts, body);

if let Some(remote_addr) = core_request
Expand All @@ -38,7 +52,29 @@ pub fn into_core_request(request: Request<AxumBody>) -> CoreRequest {
.extensions_mut()
.insert(ProxyHandle::with_client(AxumProxyClient::default()));

core_request
Ok(core_request)
}

fn is_json_content_type(value: &HeaderValue) -> bool {
let Ok(raw) = value.to_str() else {
return false;
};

let media_type = raw.split(';').next().map(str::trim).unwrap_or("");
if media_type.eq_ignore_ascii_case("application/json") {
return true;
}

let Some((ty, subtype)) = media_type.split_once('/') else {
return false;
};

if !ty.eq_ignore_ascii_case("application") {
return false;
}

let subtype = subtype.trim();
subtype.len() >= 5 && subtype[subtype.len() - 5..].eq_ignore_ascii_case("+json")
}

#[cfg(test)]
Expand All @@ -47,8 +83,8 @@ mod tests {
use anyedge_core::body::Body;
use anyedge_core::http::Method;

#[test]
fn converts_request_and_records_connect_info() {
#[tokio::test]
async fn converts_request_and_records_connect_info() {
let mut request = Request::builder()
.method(Method::POST)
.uri("/demo")
Expand All @@ -59,7 +95,9 @@ mod tests {
.extensions_mut()
.insert(ConnectInfo::<SocketAddr>("127.0.0.1:4000".parse().unwrap()));

let core_request = into_core_request(request);
let core_request = into_core_request(request)
.await
.expect("request conversion");
assert_eq!(core_request.method(), &Method::POST);
assert_eq!(core_request.uri().path(), "/demo");
assert_eq!(core_request.headers()["x-test"], "1");
Expand All @@ -76,15 +114,79 @@ mod tests {
.is_none());
}

#[test]
fn missing_connect_info_is_handled_gracefully() {
#[tokio::test]
async fn missing_connect_info_is_handled_gracefully() {
let request = Request::builder()
.method(Method::GET)
.uri("/demo")
.body(AxumBody::empty())
.expect("request");

let core_request = into_core_request(request);
let core_request = into_core_request(request)
.await
.expect("request conversion");
assert!(AxumRequestContext::get(&core_request).is_none());
}

#[tokio::test]
async fn json_content_type_buffers_body() {
let json_payload = r#"{"name":"test"}"#;
let request = Request::builder()
.method(Method::POST)
.uri("/api/test")
.header("content-type", "application/json")
.body(AxumBody::from(json_payload))
.expect("request");

let core_request = into_core_request(request)
.await
.expect("request conversion");
assert_eq!(core_request.method(), &Method::POST);

match core_request.body() {
Body::Once(bytes) => {
assert_eq!(bytes.as_ref(), json_payload.as_bytes());
}
Body::Stream(_) => panic!("JSON body should be buffered, not streaming"),
}
}

#[tokio::test]
async fn non_json_content_type_streams_body() {
let request = Request::builder()
.method(Method::POST)
.uri("/upload")
.header("content-type", "application/octet-stream")
.body(AxumBody::from("binary data"))
.expect("request");

let core_request = into_core_request(request)
.await
.expect("request conversion");

assert!(matches!(core_request.body(), Body::Stream(_)));
}

#[test]
fn test_is_json_content_type() {
assert!(is_json_content_type(&HeaderValue::from_static(
"application/json"
)));
assert!(is_json_content_type(&HeaderValue::from_static(
"application/json; charset=utf-8"
)));
assert!(is_json_content_type(&HeaderValue::from_static(
"application/vnd.api+json"
)));
assert!(is_json_content_type(&HeaderValue::from_static(
"APPLICATION/VND.CUSTOM+JSON; CHARSET=UTF-8"
)));

assert!(!is_json_content_type(&HeaderValue::from_static(
"text/json"
)));
assert!(!is_json_content_type(&HeaderValue::from_static(
"application/json+xml"
)));
}
}
12 changes: 11 additions & 1 deletion crates/anyedge-adapter-axum/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use axum::body::Body;
use axum::body::Body as AxumBody;
use axum::http::{Request, Response};
use http::StatusCode;
use tokio::{runtime::Handle, task};
use tower::Service;

Expand Down Expand Up @@ -37,7 +39,15 @@ impl Service<Request<AxumBody>> for AnyEdgeAxumService {
fn call(&mut self, request: Request<AxumBody>) -> Self::Future {
let router = self.router.clone();
Box::pin(async move {
let core_request = into_core_request(request);
let core_request = match into_core_request(request).await {
Ok(req) => req,
Err(e) => {
let mut err_response = Response::new(Body::from(e.to_string()));
*err_response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;

return Ok(err_response);
}
};
let core_response = task::block_in_place(move || {
Handle::current().block_on(router.oneshot(core_request))
});
Expand Down