diff --git a/dropshot/examples/request-headers.rs b/dropshot/examples/request-headers.rs index 969511c07..0e3dc4869 100644 --- a/dropshot/examples/request-headers.rs +++ b/dropshot/examples/request-headers.rs @@ -48,9 +48,8 @@ async fn main() -> Result<(), String> { async fn example_api_get_header_generic( rqctx: Arc>, ) -> Result, HttpError> { - let request = rqctx.request.lock().await; // Note that clients can provide multiple values for a header. See // http::HeaderMap for ways to get all of them. - let header = request.headers().get("demo-header"); + let header = rqctx.headers.get("demo-header"); Ok(HttpResponseOk(format!("value for header: {:?}", header))) } diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index c05abc3ab..da0abf004 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -54,9 +54,12 @@ use crate::websocket::WEBSOCKET_PARAM_SENTINEL; use async_trait::async_trait; use bytes::Bytes; -use futures::lock::Mutex; use http::HeaderMap; +use http::HeaderValue; +use http::Method; use http::StatusCode; +use http::Uri; +use http::Version; use hyper::Body; use hyper::Request; use hyper::Response; @@ -84,21 +87,18 @@ pub type HttpHandlerResult = Result, HttpError>; /** * Handle for various interfaces useful during request processing. */ -/* - * TODO-cleanup What's the right way to package up "request"? The only time we - * need it to be mutable is when we're reading the body (e.g., as part of the - * JSON extractor). In order to support that, we wrap it in something that - * supports interior mutability. It also needs to be thread-safe, since we're - * using async/await. That brings us to Arc>, but it seems like - * overkill since it will only really be used by one thread at a time (at all, - * let alone mutably) and there will never be contention on the Mutex. - */ #[derive(Debug)] pub struct RequestContext { /** shared server state */ pub server: Arc>, - /** HTTP request details */ - pub request: Arc>>, + /// The request's method + pub method: Method, + /// The request's URI + pub uri: Uri, + /// The request's version + pub version: Version, + /// The request's headers + pub headers: HeaderMap, /** HTTP request routing variables */ pub path_variables: VariableSet, /** expected request body mime type */ @@ -188,11 +188,19 @@ impl RequestContextArgument pub trait Extractor: Send + Sync + Sized { /** * Construct an instance of this type from a `RequestContext`. + * + * `request` is `Some()` if and only if `requires_body` returns true. + * For a particular request, if more than one extractor requires the body, + * the request handler will panic. */ async fn from_request( rqctx: Arc>, + request: Option>, ) -> Result; + /// Returns true if this extractor requires the request body. + fn requires_body() -> bool; + fn metadata( body_content_type: ApiEndpointBodyContentType, ) -> ExtractorMetadata; @@ -216,10 +224,26 @@ macro_rules! impl_extractor_for_tuple { #[async_trait] impl< $($T: Extractor + 'static,)* > Extractor for ($($T,)*) { - async fn from_request(_rqctx: Arc>) + // unused_mut and unused_variables are for the zero-element case. + #[allow(unused_mut)] + #[allow(unused_variables)] + async fn from_request( + _rqctx: Arc>, + mut body: Option>, + ) -> Result<( $($T,)* ), HttpError> { - futures::try_join!($($T::from_request(Arc::clone(&_rqctx)),)*) + let requires_body = [$($T::requires_body(),)*]; + if requires_body.iter().filter(|x| **x).count() > 1 { + panic!("multiple extractors use body"); + } + futures::try_join!( + $($T::from_request( + Arc::clone(&_rqctx), + if $T::requires_body() { + Some(body.take().expect("extractor uses body, but it is unavailable")) + } else { None }, + ),)*) } fn metadata(_body_content_type: ApiEndpointBodyContentType) -> ExtractorMetadata { @@ -240,6 +264,11 @@ macro_rules! impl_extractor_for_tuple { )* ExtractorMetadata { extension_mode, parameters } } + + fn requires_body() -> bool { + let uses_body = [$($T::requires_body(),)*]; + uses_body.iter().any(|x| *x) + } } }} @@ -417,6 +446,7 @@ pub trait RouteHandler: Debug + Send + Sync { async fn handle_request( &self, rqctx: RequestContext, + request: Request, ) -> HttpHandlerResult; } @@ -483,6 +513,7 @@ where async fn handle_request( &self, rqctx_raw: RequestContext, + request: Request, ) -> HttpHandlerResult { /* * This is where the magic happens: in the code below, `funcparams` has @@ -503,7 +534,8 @@ where * resolved statically. */ let rqctx = Arc::new(rqctx_raw); - let funcparams = Extractor::from_request(Arc::clone(&rqctx)).await?; + let funcparams = + Extractor::from_request(Arc::clone(&rqctx), Some(request)).await?; let future = self.handler.handle_request(rqctx, funcparams); future.await } @@ -580,12 +612,12 @@ impl Query { * it as an instance of `QueryType`. */ fn http_request_load_query( - request: &Request, + uri: &Uri, ) -> Result, HttpError> where QueryType: DeserializeOwned + JsonSchema + Send + Sync, { - let raw_query_string = request.uri().query().unwrap_or(""); + let raw_query_string = uri.query().unwrap_or(""); /* * TODO-correctness: are query strings defined to be urlencoded in this way? */ @@ -613,9 +645,9 @@ where { async fn from_request( rqctx: Arc>, + _request: Option>, ) -> Result, HttpError> { - let request = rqctx.request.lock().await; - http_request_load_query(&request) + http_request_load_query(&rqctx.uri) } fn metadata( @@ -623,6 +655,10 @@ where ) -> ExtractorMetadata { get_metadata::(&ApiEndpointParameterLocation::Query) } + + fn requires_body() -> bool { + false + } } /* @@ -661,6 +697,7 @@ where { async fn from_request( rqctx: Arc>, + _request: Option>, ) -> Result, HttpError> { let params: PathType = http_extract_path_params(&rqctx.path_variables)?; Ok(Path { inner: params }) @@ -671,6 +708,10 @@ where ) -> ExtractorMetadata { get_metadata::(&ApiEndpointParameterLocation::Path) } + + fn requires_body() -> bool { + false + } } /** @@ -966,12 +1007,12 @@ impl */ async fn http_request_load_body( rqctx: Arc>, + mut request: Request, ) -> Result, HttpError> where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { let server = &rqctx.server; - let mut request = rqctx.request.lock().await; let body = http_read_body( request.body_mut(), server.config.request_body_max_bytes, @@ -1044,8 +1085,9 @@ where { async fn from_request( rqctx: Arc>, + request: Option>, ) -> Result, HttpError> { - http_request_load_body(rqctx).await + http_request_load_body(rqctx, request.unwrap()).await } fn metadata(content_type: ApiEndpointBodyContentType) -> ExtractorMetadata { @@ -1063,6 +1105,10 @@ where parameters: vec![body], } } + + fn requires_body() -> bool { + true + } } /* @@ -1107,9 +1153,10 @@ impl UntypedBody { impl Extractor for UntypedBody { async fn from_request( rqctx: Arc>, + request: Option>, ) -> Result { let server = &rqctx.server; - let mut request = rqctx.request.lock().await; + let mut request = request.unwrap(); let body_bytes = http_read_body( request.body_mut(), server.config.request_body_max_bytes, @@ -1141,6 +1188,10 @@ impl Extractor for UntypedBody { extension_mode: ExtensionMode::None, } } + + fn requires_body() -> bool { + true + } } /* diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 506a51a52..bb06c8916 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -794,13 +794,17 @@ async fn http_request_handle( server.router.lookup_route(&method, uri.path().into())?; let rqctx = RequestContext { server: Arc::clone(&server), - request: Arc::new(Mutex::new(request)), + method: request.method().clone(), + uri: request.uri().clone(), + version: request.version(), + headers: request.headers().clone(), path_variables: lookup_result.variables, body_content_type: lookup_result.body_content_type, request_id: request_id.to_string(), log: request_log, }; - let mut response = lookup_result.handler.handle_request(rqctx).await?; + let mut response = + lookup_result.handler.handle_request(rqctx, request).await?; response.headers_mut().insert( HEADER_REQUEST_ID, http::header::HeaderValue::from_str(&request_id).unwrap(), diff --git a/dropshot/src/websocket.rs b/dropshot/src/websocket.rs index 1055660b9..af8d9589f 100644 --- a/dropshot/src/websocket.rs +++ b/dropshot/src/websocket.rs @@ -14,9 +14,9 @@ use crate::{ RequestContext, ServerContext, }; use async_trait::async_trait; -use http::header; use http::Response; use http::StatusCode; +use http::{header, Request}; use hyper::upgrade::OnUpgrade; use hyper::Body; use schemars::JsonSchema; @@ -97,8 +97,9 @@ fn derive_accept_key(request_key: &[u8]) -> String { impl Extractor for WebsocketUpgrade { async fn from_request( rqctx: Arc>, + request: Option>, ) -> Result { - let request = &mut *rqctx.request.lock().await; + let request = request.unwrap(); if !request .headers() @@ -181,6 +182,10 @@ impl Extractor for WebsocketUpgrade { extension_mode: ExtensionMode::Websocket, } } + + fn requires_body() -> bool { + true + } } impl WebsocketUpgrade { @@ -310,7 +315,6 @@ mod tests { use crate::router::HttpRouter; use crate::server::{DropshotState, ServerConfig}; use crate::{Extractor, HttpError, RequestContext, WebsocketUpgrade}; - use futures::lock::Mutex; use http::Request; use hyper::Body; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; @@ -320,39 +324,41 @@ mod tests { async fn ws_upg_from_mock_rqctx() -> Result { let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!()); - let fut = WebsocketUpgrade::from_request(Arc::new(RequestContext { - server: Arc::new(DropshotState { - private: (), - config: ServerConfig { - request_body_max_bytes: 0, - page_max_nitems: NonZeroU32::new(1).unwrap(), - page_default_nitems: NonZeroU32::new(1).unwrap(), - }, - router: HttpRouter::new(), + let request = Request::builder() + .header(http::header::CONNECTION, "Upgrade") + .header(http::header::UPGRADE, "websocket") + .header(http::header::SEC_WEBSOCKET_VERSION, "13") + .header(http::header::SEC_WEBSOCKET_KEY, "aGFjayB0aGUgcGxhbmV0IQ==") + .body(Body::empty()) + .unwrap(); + let fut = WebsocketUpgrade::from_request( + Arc::new(RequestContext { + server: Arc::new(DropshotState { + private: (), + config: ServerConfig { + request_body_max_bytes: 0, + page_max_nitems: NonZeroU32::new(1).unwrap(), + page_default_nitems: NonZeroU32::new(1).unwrap(), + }, + router: HttpRouter::new(), + log: log.clone(), + local_addr: SocketAddr::new( + IpAddr::V6(Ipv6Addr::LOCALHOST), + 8080, + ), + tls_acceptor: None, + }), + method: request.method().clone(), + uri: request.uri().clone(), + version: request.version(), + headers: request.headers().clone(), + path_variables: Default::default(), + body_content_type: Default::default(), + request_id: "".to_string(), log: log.clone(), - local_addr: SocketAddr::new( - IpAddr::V6(Ipv6Addr::LOCALHOST), - 8080, - ), - tls_acceptor: None, }), - request: Arc::new(Mutex::new( - Request::builder() - .header(http::header::CONNECTION, "Upgrade") - .header(http::header::UPGRADE, "websocket") - .header(http::header::SEC_WEBSOCKET_VERSION, "13") - .header( - http::header::SEC_WEBSOCKET_KEY, - "aGFjayB0aGUgcGxhbmV0IQ==", - ) - .body(Body::empty()) - .unwrap(), - )), - path_variables: Default::default(), - body_content_type: Default::default(), - request_id: "".to_string(), - log: log.clone(), - })); + Some(request), + ); tokio::time::timeout(Duration::from_secs(1), fut) .await .expect("Deadlocked in WebsocketUpgrade constructor")