Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dropshot/examples/request-headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ async fn main() -> Result<(), String> {
async fn example_api_get_header_generic(
rqctx: Arc<RequestContext<()>>,
) -> Result<HttpResponseOk<String>, HttpError> {
let request = rqctx.request.lock().await;
// Note that clients can provide multiple values for a header. See
// http::HeaderMap for ways to get all of them.
let header = request.headers().get("demo-header");
let header = rqctx.headers.get("demo-header");
Ok(HttpResponseOk(format!("value for header: {:?}", header)))
}
95 changes: 73 additions & 22 deletions dropshot/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ use crate::websocket::WEBSOCKET_PARAM_SENTINEL;

use async_trait::async_trait;
use bytes::Bytes;
use futures::lock::Mutex;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use http::Uri;
use http::Version;
use hyper::Body;
use hyper::Request;
use hyper::Response;
Expand Down Expand Up @@ -84,21 +87,18 @@ pub type HttpHandlerResult = Result<Response<Body>, HttpError>;
/**
* Handle for various interfaces useful during request processing.
*/
/*
* TODO-cleanup What's the right way to package up "request"? The only time we
* need it to be mutable is when we're reading the body (e.g., as part of the
* JSON extractor). In order to support that, we wrap it in something that
* supports interior mutability. It also needs to be thread-safe, since we're
* using async/await. That brings us to Arc<Mutex<...>>, but it seems like
* overkill since it will only really be used by one thread at a time (at all,
* let alone mutably) and there will never be contention on the Mutex.
*/
#[derive(Debug)]
pub struct RequestContext<Context: ServerContext> {
/** shared server state */
pub server: Arc<DropshotState<Context>>,
/** HTTP request details */
pub request: Arc<Mutex<Request<Body>>>,
/// The request's method
pub method: Method,
/// The request's URI
pub uri: Uri,
/// The request's version
pub version: Version,
/// The request's headers
pub headers: HeaderMap<HeaderValue>,
/** HTTP request routing variables */
pub path_variables: VariableSet,
/** expected request body mime type */
Expand Down Expand Up @@ -188,11 +188,19 @@ impl<T: 'static + ServerContext> RequestContextArgument
pub trait Extractor: Send + Sync + Sized {
/**
* Construct an instance of this type from a `RequestContext`.
*
* `request` is `Some()` if and only if `requires_body` returns true.
* For a particular request, if more than one extractor requires the body,
* the request handler will panic.
*/
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
request: Option<Request<Body>>,
) -> Result<Self, HttpError>;

/// Returns true if this extractor requires the request body.
fn requires_body() -> bool;

fn metadata(
body_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata;
Expand All @@ -216,10 +224,26 @@ macro_rules! impl_extractor_for_tuple {
#[async_trait]
impl< $($T: Extractor + 'static,)* > Extractor for ($($T,)*)
{
async fn from_request<Context: ServerContext>(_rqctx: Arc<RequestContext<Context>>)
// unused_mut and unused_variables are for the zero-element case.
#[allow(unused_mut)]
#[allow(unused_variables)]
async fn from_request<Context: ServerContext>(
_rqctx: Arc<RequestContext<Context>>,
mut body: Option<Request<Body>>,
)
-> Result<( $($T,)* ), HttpError>
{
futures::try_join!($($T::from_request(Arc::clone(&_rqctx)),)*)
let requires_body = [$($T::requires_body(),)*];
if requires_body.iter().filter(|x| **x).count() > 1 {
panic!("multiple extractors use body");
}
futures::try_join!(
$($T::from_request(
Arc::clone(&_rqctx),
if $T::requires_body() {
Some(body.take().expect("extractor uses body, but it is unavailable"))
} else { None },
),)*)
}

fn metadata(_body_content_type: ApiEndpointBodyContentType) -> ExtractorMetadata {
Expand All @@ -240,6 +264,11 @@ macro_rules! impl_extractor_for_tuple {
)*
ExtractorMetadata { extension_mode, parameters }
}

fn requires_body() -> bool {
let uses_body = [$($T::requires_body(),)*];
uses_body.iter().any(|x| *x)
}
}
}}

Expand Down Expand Up @@ -417,6 +446,7 @@ pub trait RouteHandler<Context: ServerContext>: Debug + Send + Sync {
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
request: Request<Body>,
) -> HttpHandlerResult;
}

Expand Down Expand Up @@ -483,6 +513,7 @@ where
async fn handle_request(
&self,
rqctx_raw: RequestContext<Context>,
request: Request<Body>,
) -> HttpHandlerResult {
/*
* This is where the magic happens: in the code below, `funcparams` has
Expand All @@ -503,7 +534,8 @@ where
* resolved statically.
*/
let rqctx = Arc::new(rqctx_raw);
let funcparams = Extractor::from_request(Arc::clone(&rqctx)).await?;
let funcparams =
Extractor::from_request(Arc::clone(&rqctx), Some(request)).await?;
let future = self.handler.handle_request(rqctx, funcparams);
future.await
}
Expand Down Expand Up @@ -580,12 +612,12 @@ impl<QueryType: DeserializeOwned + JsonSchema + Send + Sync> Query<QueryType> {
* it as an instance of `QueryType`.
*/
fn http_request_load_query<QueryType>(
request: &Request<Body>,
uri: &Uri,
) -> Result<Query<QueryType>, HttpError>
where
QueryType: DeserializeOwned + JsonSchema + Send + Sync,
{
let raw_query_string = request.uri().query().unwrap_or("");
let raw_query_string = uri.query().unwrap_or("");
/*
* TODO-correctness: are query strings defined to be urlencoded in this way?
*/
Expand Down Expand Up @@ -613,16 +645,20 @@ where
{
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
_request: Option<Request<Body>>,
) -> Result<Query<QueryType>, HttpError> {
let request = rqctx.request.lock().await;
http_request_load_query(&request)
http_request_load_query(&rqctx.uri)
}

fn metadata(
_body_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
get_metadata::<QueryType>(&ApiEndpointParameterLocation::Query)
}

fn requires_body() -> bool {
false
}
}

/*
Expand Down Expand Up @@ -661,6 +697,7 @@ where
{
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
_request: Option<Request<Body>>,
) -> Result<Path<PathType>, HttpError> {
let params: PathType = http_extract_path_params(&rqctx.path_variables)?;
Ok(Path { inner: params })
Expand All @@ -671,6 +708,10 @@ where
) -> ExtractorMetadata {
get_metadata::<PathType>(&ApiEndpointParameterLocation::Path)
}

fn requires_body() -> bool {
false
}
}

/**
Expand Down Expand Up @@ -966,12 +1007,12 @@ impl<BodyType: JsonSchema + DeserializeOwned + Send + Sync>
*/
async fn http_request_load_body<Context: ServerContext, BodyType>(
rqctx: Arc<RequestContext<Context>>,
mut request: Request<Body>,
) -> Result<TypedBody<BodyType>, HttpError>
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync,
{
let server = &rqctx.server;
let mut request = rqctx.request.lock().await;
let body = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
Expand Down Expand Up @@ -1044,8 +1085,9 @@ where
{
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
request: Option<Request<Body>>,
) -> Result<TypedBody<BodyType>, HttpError> {
http_request_load_body(rqctx).await
http_request_load_body(rqctx, request.unwrap()).await
}

fn metadata(content_type: ApiEndpointBodyContentType) -> ExtractorMetadata {
Expand All @@ -1063,6 +1105,10 @@ where
parameters: vec![body],
}
}

fn requires_body() -> bool {
true
}
}

/*
Expand Down Expand Up @@ -1107,9 +1153,10 @@ impl UntypedBody {
impl Extractor for UntypedBody {
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
request: Option<Request<Body>>,
) -> Result<UntypedBody, HttpError> {
let server = &rqctx.server;
let mut request = rqctx.request.lock().await;
let mut request = request.unwrap();
let body_bytes = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
Expand Down Expand Up @@ -1141,6 +1188,10 @@ impl Extractor for UntypedBody {
extension_mode: ExtensionMode::None,
}
}

fn requires_body() -> bool {
true
}
}

/*
Expand Down
8 changes: 6 additions & 2 deletions dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,13 +794,17 @@ async fn http_request_handle<C: ServerContext>(
server.router.lookup_route(&method, uri.path().into())?;
let rqctx = RequestContext {
server: Arc::clone(&server),
request: Arc::new(Mutex::new(request)),
method: request.method().clone(),
uri: request.uri().clone(),
version: request.version(),
headers: request.headers().clone(),
path_variables: lookup_result.variables,
body_content_type: lookup_result.body_content_type,
request_id: request_id.to_string(),
log: request_log,
};
let mut response = lookup_result.handler.handle_request(rqctx).await?;
let mut response =
lookup_result.handler.handle_request(rqctx, request).await?;
response.headers_mut().insert(
HEADER_REQUEST_ID,
http::header::HeaderValue::from_str(&request_id).unwrap(),
Expand Down
74 changes: 40 additions & 34 deletions dropshot/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use crate::{
RequestContext, ServerContext,
};
use async_trait::async_trait;
use http::header;
use http::Response;
use http::StatusCode;
use http::{header, Request};
use hyper::upgrade::OnUpgrade;
use hyper::Body;
use schemars::JsonSchema;
Expand Down Expand Up @@ -97,8 +97,9 @@ fn derive_accept_key(request_key: &[u8]) -> String {
impl Extractor for WebsocketUpgrade {
async fn from_request<Context: ServerContext>(
rqctx: Arc<RequestContext<Context>>,
request: Option<Request<Body>>,
) -> Result<Self, HttpError> {
let request = &mut *rqctx.request.lock().await;
let request = request.unwrap();

if !request
.headers()
Expand Down Expand Up @@ -181,6 +182,10 @@ impl Extractor for WebsocketUpgrade {
extension_mode: ExtensionMode::Websocket,
}
}

fn requires_body() -> bool {
true
}
}

impl WebsocketUpgrade {
Expand Down Expand Up @@ -310,7 +315,6 @@ mod tests {
use crate::router::HttpRouter;
use crate::server::{DropshotState, ServerConfig};
use crate::{Extractor, HttpError, RequestContext, WebsocketUpgrade};
use futures::lock::Mutex;
use http::Request;
use hyper::Body;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
Expand All @@ -320,39 +324,41 @@ mod tests {

async fn ws_upg_from_mock_rqctx() -> Result<WebsocketUpgrade, HttpError> {
let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!());
let fut = WebsocketUpgrade::from_request(Arc::new(RequestContext {
server: Arc::new(DropshotState {
private: (),
config: ServerConfig {
request_body_max_bytes: 0,
page_max_nitems: NonZeroU32::new(1).unwrap(),
page_default_nitems: NonZeroU32::new(1).unwrap(),
},
router: HttpRouter::new(),
let request = Request::builder()
.header(http::header::CONNECTION, "Upgrade")
.header(http::header::UPGRADE, "websocket")
.header(http::header::SEC_WEBSOCKET_VERSION, "13")
.header(http::header::SEC_WEBSOCKET_KEY, "aGFjayB0aGUgcGxhbmV0IQ==")
.body(Body::empty())
.unwrap();
let fut = WebsocketUpgrade::from_request(
Arc::new(RequestContext {
server: Arc::new(DropshotState {
private: (),
config: ServerConfig {
request_body_max_bytes: 0,
page_max_nitems: NonZeroU32::new(1).unwrap(),
page_default_nitems: NonZeroU32::new(1).unwrap(),
},
router: HttpRouter::new(),
log: log.clone(),
local_addr: SocketAddr::new(
IpAddr::V6(Ipv6Addr::LOCALHOST),
8080,
),
tls_acceptor: None,
}),
method: request.method().clone(),
uri: request.uri().clone(),
version: request.version(),
headers: request.headers().clone(),
path_variables: Default::default(),
body_content_type: Default::default(),
request_id: "".to_string(),
log: log.clone(),
local_addr: SocketAddr::new(
IpAddr::V6(Ipv6Addr::LOCALHOST),
8080,
),
tls_acceptor: None,
}),
request: Arc::new(Mutex::new(
Request::builder()
.header(http::header::CONNECTION, "Upgrade")
.header(http::header::UPGRADE, "websocket")
.header(http::header::SEC_WEBSOCKET_VERSION, "13")
.header(
http::header::SEC_WEBSOCKET_KEY,
"aGFjayB0aGUgcGxhbmV0IQ==",
)
.body(Body::empty())
.unwrap(),
)),
path_variables: Default::default(),
body_content_type: Default::default(),
request_id: "".to_string(),
log: log.clone(),
}));
Some(request),
);
tokio::time::timeout(Duration::from_secs(1), fut)
.await
.expect("Deadlocked in WebsocketUpgrade constructor")
Expand Down