Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions client/src/h1/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@ use std::{error, io};

use xitca_http::h1::proto::error::ProtoError;

#[derive(Debug)]
pub enum UnexpectedStateError {
RemainingData,
ConnectionClosed,
}

#[derive(Debug)]
pub enum Error {
Std(Box<dyn error::Error + Send + Sync>),
Io(io::Error),
Proto(ProtoError),
UnexpectedState(UnexpectedStateError),
}

impl From<Box<dyn error::Error + Send + Sync>> for Error {
Expand All @@ -26,3 +33,9 @@ impl From<ProtoError> for Error {
Self::Proto(e)
}
}

impl From<UnexpectedStateError> for Error {
fn from(e: UnexpectedStateError) -> Self {
Self::UnexpectedState(e)
}
}
2 changes: 1 addition & 1 deletion client/src/h1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ mod error;
pub(crate) mod body;
pub(crate) mod proto;

pub use self::error::Error;
pub use self::error::{Error, UnexpectedStateError};
17 changes: 16 additions & 1 deletion client/src/h1/proto/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
body::BodyError,
bytes::{Bytes, BytesMut},
date::DateTimeHandle,
h1::Error,
h1::{error::UnexpectedStateError, Error},
http::{
Method, Request, Response, StatusCode,
header::{EXPECT, HOST, HeaderValue},
Expand All @@ -29,6 +29,21 @@ where
B: Stream<Item = Result<Bytes, E>> + Unpin,
BodyError: From<E>,
{
// try to read if there is any remaining data or if the connection is closed
match stream.read(&mut [0; 1]) {
Ok(n) => {
if n > 0 {
return Err(Error::from(UnexpectedStateError::RemainingData));
}

return Err(Error::from(UnexpectedStateError::ConnectionClosed));
}
// if the stream is not ready to read, it's in correct state.
Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
// other errors are considered as not in correct state, we should close the connection here
Err(io) => return Err(Error::from(io)),
}

let mut buf = BytesMut::new();

if !req.headers().contains_key(HOST) {
Expand Down
2 changes: 2 additions & 0 deletions client/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! middleware offer extended functionality to http client.

mod redirect;
mod retry_closed_connection;

mod async_fn;
#[cfg(feature = "compress")]
Expand All @@ -11,3 +12,4 @@ pub use decompress::Decompress;

pub(crate) use async_fn::AsyncFn;
pub use redirect::FollowRedirect;
pub use retry_closed_connection::RetryClosedConnection;
62 changes: 62 additions & 0 deletions client/src/middleware/retry_closed_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use crate::{
error::Error,
response::Response,
service::{Service, ServiceRequest},
};

/// middleware for retrying closed connection
pub struct RetryClosedConnection<S, const MAX_COUNT: usize = 3> {
service: S,
}

impl<S> RetryClosedConnection<S> {
/// construct retry closed connection middleware for client.
///
/// # Examples:
/// ```rust
/// # use xitca_client::{ClientBuilder, middleware::RetryClosedConnection};
/// let builder = ClientBuilder::new()
/// .middleware(RetryClosedConnection::new);
/// ```
pub fn new(service: S) -> Self {
Self { service }
}
}

impl<S, const MAX: usize> RetryClosedConnection<S, MAX> {
/// set max retry count for request. when max value is reached the request will return the most recent errror.
pub fn max<const MAX2: usize>(self) -> RetryClosedConnection<S, MAX2> {
RetryClosedConnection { service: self.service }
}
}

impl<'r, 'c, S, const MAX: usize> Service<ServiceRequest<'r, 'c>> for RetryClosedConnection<S, MAX>
where
S: for<'r2, 'c2> Service<ServiceRequest<'r2, 'c2>, Response = Response, Error = Error> + Send + Sync,
{
type Response = Response;
type Error = Error;

async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
let ServiceRequest { req, client, timeout } = req;
let mut count = 0;

loop {
let res = self.service.call(ServiceRequest { req, client, timeout }).await;

if count == MAX {
return res;
}

match res {
#[cfg(feature = "http1")]
Err(Error::H1(crate::h1::Error::UnexpectedState(
crate::h1::UnexpectedStateError::ConnectionClosed,
))) => (),
res => return res,
}

count += 1;
}
}
}
46 changes: 41 additions & 5 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::{
error, fmt, fs,
future::Future,
io,
net::SocketAddr,
net::TcpListener,
net::{SocketAddr, TcpListener, ToSocketAddrs},
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -36,9 +35,18 @@ where
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
{
let lst = TcpListener::bind("127.0.0.1:0")?;
test_server_with_addr(service, "127.0.0.1:0")
}

let addr = lst.local_addr()?;
pub fn test_server_with_addr<T, Req, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
A: ToSocketAddrs,
{
let lst = TcpListener::bind(addr)?;
let local_addr = lst.local_addr()?;

let handle = Builder::new()
.worker_threads(1)
Expand All @@ -47,7 +55,35 @@ where
.listen::<_, _, _, Req>("test_server", lst, service)
.build();

Ok(TestServerHandle { addr, handle })
Ok(TestServerHandle {
addr: local_addr,
handle,
})
}

/// A specialized http/1 server on top of [test_server]
pub fn test_h1_server_with_addr<T, B, E, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Request<RequestExt<h1::RequestBody>>, Response = HResponse<B>> + 'static,
<T::Response as Service<Request<RequestExt<h1::RequestBody>>>>::Error: fmt::Debug,
T::Error: error::Error + 'static,
B: Stream<Item = Result<Bytes, E>> + 'static,
E: fmt::Debug + 'static,
A: ToSocketAddrs,
{
#[cfg(not(feature = "io-uring"))]
{
test_server_with_addr::<_, (TcpStream, SocketAddr), A>(service.enclosed(HttpServiceBuilder::h1()), addr)
}

#[cfg(feature = "io-uring")]
{
test_server_with_addr::<_, (xitca_io::net::io_uring::TcpStream, SocketAddr), A>(
service.enclosed(HttpServiceBuilder::h1().io_uring()),
addr,
)
}
}

/// A specialized http/1 server on top of [test_server]
Expand Down
39 changes: 37 additions & 2 deletions test/tests/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
time::Duration,
};

use xitca_client::Client;
use xitca_client::{middleware::RetryClosedConnection, Client};
use xitca_http::{
body::{BoxBody, ResponseBody},
bytes::{Bytes, BytesMut},
Expand All @@ -16,7 +16,7 @@ use xitca_http::{
},
};
use xitca_service::fn_service;
use xitca_test::{test_h1_server, Error};
use xitca_test::{test_h1_server, test_h1_server_with_addr, Error};

#[tokio::test]
async fn h1_get() -> Result<(), Error> {
Expand Down Expand Up @@ -270,6 +270,41 @@ async fn h1_keepalive() -> Result<(), Error> {
Ok(())
}

#[tokio::test]
async fn h1_get_connection_closed_by_server() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
let ip_port = handle.ip_port_string();

let server_url = format!("http://{}/", ip_port);

let c = Client::builder()
.middleware(RetryClosedConnection::new)
.set_pool_capacity(1)
.finish();

let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(false);
handle.await?;

let mut handle = test_h1_server_with_addr(fn_service(crate::handle), ip_port)?;
let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;

assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(false);
handle.await?;

Ok(())
}

async fn handle(req: Request<RequestExt<h1::RequestBody>>) -> Result<Response<ResponseBody>, Error> {
match (req.method(), req.uri().path()) {
(&Method::GET, "/") | (&Method::HEAD, "/") => Ok(Response::new(Bytes::from("GET Response").into())),
Expand Down
Loading