diff --git a/README.md b/README.md index 0fd3729..2b46a7a 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,8 @@ use gostd::vendor - [x] 完成bytes模块。(version >=0.2.8) - [x] 完成mime::multipart模块。(version >=0.3.1) - [x] 修复windos10平台编译失败的bug。(version>=0.3.18) - +- [x] 对net/http模块,自定义错误处理上的优化,引入bytes提高性能,api会有参数类型的变化(version>=0.4.1) +- [ ] 对net/http模块,增加异步编程的支持,独立一个模块支持异步编程,原来模块保持不变。 # 独立发布包 独立发布gostd_time,代码等价于 use gostd::time 。 diff --git a/gostd/Cargo.toml b/gostd/Cargo.toml index 0877332..a34cdb0 100644 --- a/gostd/Cargo.toml +++ b/gostd/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gostd" license = "MIT" -version = "0.3.21" +version = "0.4.1" edition = "2018" authors = ["wandercn"] description = "gostd is the go standard library implementation in rust. gostd 是Go标准库的rust实现" @@ -26,3 +26,6 @@ gostd_time = { version = "^1.0", optional = false, path = "../time" } rand = "0.8.5" log = "0.4" lazy_static = "1.4.0" +bytes = "1" +thiserror ="2.0" +anyhow= "1.0" diff --git a/gostd/src/net/http/mod.rs b/gostd/src/net/http/mod.rs index 14a7906..8bf17eb 100644 --- a/gostd/src/net/http/mod.rs +++ b/gostd/src/net/http/mod.rs @@ -202,6 +202,8 @@ use crate::net::url; use crate::strings; use crate::time; use std::collections::HashMap; +use std::collections::HashSet; +use std::convert::TryInto; use std::io::Error; /// Get issues a GET to the specified URL. If the response is one of the following redirect codes, Get follows the redirect,up to a maximum of 10 redirects: /// ```text @@ -223,21 +225,21 @@ use std::io::Error; /// ``` /// use gostd::net::http; /// -/// fn main() -> Result<(), std::io::Error> { +/// fn main() -> Result<(), Box> { /// let url = "https://petstore.swagger.io/v2/pet/findByStatus?status=available"; /// let response = http::Get(url)?;/// /// println!( /// "{}", -/// String::from_utf8(response.Body.expect("return body error")).unwrap() +/// String::from_utf8(response.Body.expect("return body error").to_vec()).unwrap() /// ); /// Ok(()) /// } /// ``` -pub fn Get(url: &str) -> HttpResult { +pub fn Get(url: &str) -> HttpResult { Client::New().Get(url) } -pub fn Head(url: &str) -> HttpResult { +pub fn Head(url: &str) -> HttpResult { Client::New().Head(url) } @@ -252,39 +254,39 @@ pub fn Head(url: &str) -> HttpResult { /// /// ``` /// use gostd::net::http; -/// fn main() -> Result<(), std::io::Error> { +/// fn main() -> Result<(), Box> { /// let url = "https://petstore.swagger.io/v2/pet"; /// let postbody = r#"{"id":0,"category":{"id":0,"name":"string"},"name":"doggie","photoUrls":["string"],"tags":[{"id":0,"name":"string"}],"status":"available"}"# /// .as_bytes() /// .to_vec(); -/// let response = http::Post(url, "application/json", Some(postbody))?; +/// let response = http::Post(url, "application/json", Some(postbody.into()))?; /// /// println!( /// "{}", -/// String::from_utf8(response.Body.expect("return body error")).unwrap() +/// String::from_utf8(response.Body.expect("return body error").to_vec()).unwrap() /// ); /// /// Ok(()) /// } /// /// ``` -pub fn Post(url: &str, contentType: &str, body: Option>) -> HttpResult { +pub fn Post(url: &str, contentType: &str, body: Option) -> HttpResult { Client::New().Post(url, contentType, body) } -pub fn PostForm(url: &str, data: url::Values) -> HttpResult { +pub fn PostForm(url: &str, data: url::Values) -> HttpResult { Client::New().PostForm(url, data) } -pub fn Patch(url: &str, body: Option>) -> HttpResult { +pub fn Patch(url: &str, body: Option) -> HttpResult { Client::New().Patch(url, body) } -pub fn Put(url: &str, body: Option>) -> HttpResult { +pub fn Put(url: &str, body: Option) -> HttpResult { Client::New().Put(url, body) } -pub fn Delete(url: &str) -> HttpResult { +pub fn Delete(url: &str) -> HttpResult { Client::New().Delete(url) } @@ -294,8 +296,9 @@ pub struct Client { Jar: Box, Timeout: time::Duration, } +use anyhow::Result; -pub type HttpResult = Result; +pub type HttpResult = Result; impl Default for Client { fn default() -> Self { Self { @@ -310,46 +313,51 @@ impl Client { Self::default() } - pub fn Get(&mut self, url: &str) -> HttpResult { + pub fn Get(&mut self, url: &str) -> HttpResult { let mut req = Request::New(Method::Get, url, None)?; self.Do(&mut req) } - pub fn Post(&mut self, url: &str, contentType: &str, body: Option>) -> HttpResult { + pub fn Post( + &mut self, + url: &str, + contentType: &str, + body: Option, + ) -> HttpResult { let mut req = Request::New(Method::Post, url, body)?; req.Header.Set("Content-Type", contentType); self.Do(&mut req) } - pub fn PostForm(&mut self, url: &str, data: url::Values) -> HttpResult { + pub fn PostForm(&mut self, url: &str, data: url::Values) -> HttpResult { self.Post( url, "application/x-www-form-urlencoded", - Some(data.Encode().as_bytes().to_vec()), + Some(data.Encode().into_bytes().into()), ) } - pub fn Head(&mut self, url: &str) -> HttpResult { + pub fn Head(&mut self, url: &str) -> HttpResult { let mut req = Request::New(Method::Head, url, None)?; self.Do(&mut req) } - pub fn Patch(&mut self, url: &str, body: Option>) -> HttpResult { + pub fn Patch(&mut self, url: &str, body: Option) -> HttpResult { let mut req = Request::New(Method::Patch, url, body)?; self.Do(&mut req) } - pub fn Put(&mut self, url: &str, body: Option>) -> HttpResult { + pub fn Put(&mut self, url: &str, body: Option) -> HttpResult { let mut req = Request::New(Method::Put, url, body)?; self.Do(&mut req) } - pub fn Delete(&mut self, url: &str) -> HttpResult { + pub fn Delete(&mut self, url: &str) -> HttpResult { let mut req = Request::New(Method::Delete, url, None)?; self.Do(&mut req) } - pub fn Do(&mut self, req: &mut Request) -> HttpResult { + pub fn Do(&mut self, req: &mut Request) -> HttpResult { self.done(req) } @@ -357,12 +365,12 @@ impl Client { &mut self, req: &mut Request, deadline: time::Time, - ) -> Result<(Response, fn() -> bool), Error> { + ) -> HttpResult<(Response, fn() -> bool)> { let (resp, didTimeout) = send(req, self.transport(), deadline)?; Ok((resp, didTimeout)) } - fn done(&mut self, req: &mut Request) -> HttpResult { + fn done(&mut self, req: &mut Request) -> HttpResult { let deadline = self.deadline(); let (resp, didTimeout) = self.send(req, deadline)?; Ok(resp) @@ -384,7 +392,7 @@ fn send( ireq: &mut Request, mut rt: Box, deadline: time::Time, -) -> Result<(Response, fn() -> bool), Error> { +) -> HttpResult<(Response, fn() -> bool)> { let mut resp = Response::default(); fn didTimeout() -> bool { return false; @@ -425,7 +433,7 @@ fn redirectBehavior(reqMethod: &str, resp: &Response, ireq: &Request) -> (String } pub trait RoundTripper { - fn RoundTrip(&mut self, r: &Request) -> Result; + fn RoundTrip(&mut self, r: &Request) -> HttpResult; } fn refererForURL(lastReq: &url::URL, newReq: &url::URL) -> String { @@ -449,7 +457,7 @@ pub struct Request { ProtoMajor: int, ProtoMinor: int, pub Header: Header, - pub Body: Option>, + pub Body: Option, // GetBody func() (io.ReadCloser, error) ContentLength: int64, TransferEncoding: Vec, @@ -468,7 +476,7 @@ pub struct Request { } impl Request { - pub fn New(method: Method, url: &str, body: Option>) -> Result { + pub fn New(method: Method, url: &str, body: Option) -> Result { let mut u = url::Parse(url)?; u.Host = removeEmptyPort(u.Host.as_str()).to_string(); @@ -493,6 +501,7 @@ impl Request { }; if let Some(buf) = body { req.ContentLength = len!(buf) as i64; + req.Body = Some(buf); } if strings::HasPrefix(url, "https://") { @@ -548,7 +557,7 @@ pub struct Response { pub Header: Header, pub ContentLength: int64, pub TransferEncoding: Vec, - pub Body: Option>, + pub Body: Option, pub Close: bool, pub Uncompressed: bool, pub Trailer: Header, @@ -839,12 +848,12 @@ struct Transport { use std::net; use std::sync::mpsc; impl RoundTripper for Transport { - fn RoundTrip(&mut self, req: &Request) -> HttpResult { + fn RoundTrip(&mut self, req: &Request) -> HttpResult { self.roundTrip(req) } } impl Transport { - fn roundTrip(&mut self, req: &Request) -> HttpResult { + fn roundTrip(&mut self, req: &Request) -> HttpResult { let treq = &mut transportRequest { Req: req.clone(), extra: None, @@ -981,6 +990,7 @@ struct persistConn { reused: bool, } +use bytes::Bytes; use rustls::ClientConnection; use rustls::StreamOwned; use std::convert::TryFrom; @@ -990,8 +1000,9 @@ use std::net::Shutdown; use std::net::TcpStream; use std::rc::Rc; use std::sync::Arc; +use webpki_roots::TLS_SERVER_ROOTS; impl persistConn { - fn roundTrip(&mut self, req: &mut transportRequest, mut conn: TcpConn) -> HttpResult { + fn roundTrip(&mut self, req: &mut transportRequest, mut conn: TcpConn) -> HttpResult { self.numExpectedResponses += 1; let mut requestedGzip = false; if !self.t.DisableCompression @@ -1014,7 +1025,7 @@ impl persistConn { let r = req.Req.Write()?; if req.Req.isTLS { - let mut tlsConn = getTLSConn(req.Req.Host.as_str(), conn); + let mut tlsConn = getTLSConn(req.Req.Host.as_str(), conn)?; tlsConn.write(r.as_slice())?; let mut reader = BufReader::new(tlsConn); let resp = ReadResponse(reader, &req.Req)?; @@ -1027,63 +1038,83 @@ impl persistConn { } } } - +use bytes::{Buf, BytesMut}; use rustls::pki_types::ServerName; +use rustls::{ClientConfig, RootCertStore}; use std::io::ErrorKind; +use thiserror::Error; +#[derive(Error, Debug)] +pub enum HTTPConnectError { + #[error("DNS name conversion failed: {0}")] + DnsNameConversion(#[from] rustls::pki_types::InvalidDnsNameError), + + #[error("Failed to connect to server: {0}")] + ConnectionFailure(String), + + #[error("TLS handshake failed: {0}")] + TlsHandshakeFailure(#[from] rustls::Error), + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), +} + +impl From for HTTPConnectError { + fn from(err: String) -> Self { + HTTPConnectError::ConnectionFailure(err) + } +} +impl From<&str> for HTTPConnectError { + fn from(err: &str) -> Self { + HTTPConnectError::ConnectionFailure(err.to_string()) + } +} -fn getTLSConn(dnsName: &str, socket: TcpConn) -> StreamOwned { - let mut clientRootCert = - rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); +fn get_tls_config() -> Arc { + let mut clientRootCert = RootCertStore::from_iter(TLS_SERVER_ROOTS.iter().cloned()); - let tlsconfig = rustls::ClientConfig::builder() - .with_root_certificates(clientRootCert) - .with_no_client_auth(); - let serverName = ServerName::try_from(dnsName).expect("url error").to_owned(); - let mut tlsClient = ClientConnection::new(Arc::new(tlsconfig), serverName).unwrap(); + Arc::new( + ClientConfig::builder() + .with_root_certificates(clientRootCert) + .with_no_client_auth(), + ) +} + +fn getTLSConn( + dnsName: &str, + socket: TcpConn, +) -> HttpResult> { + let tlsconfig = get_tls_config(); + let serverName = ServerName::try_from(dnsName.to_owned())?; + let mut tlsClient = ClientConnection::new(tlsconfig, serverName)?; let mut tlsConn = StreamOwned::new(tlsClient, socket); - tlsConn + Ok(tlsConn) } -pub fn ReadResponse(mut r: impl BufRead, req: &Request) -> HttpResult { +pub fn ReadResponse(mut r: impl BufRead, req: &Request) -> HttpResult { let mut resp = Response::default(); resp.Request = req.clone(); // parse status line。 let mut line = String::new(); r.read_line(&mut line)?; - let i = strings::IndexByte(line.as_str(), b' '); - if i == -1 { - return Err(Error::new(ErrorKind::Other, "malformed HTTP response")); - } - resp.Proto = line.get(..i as usize).unwrap().to_string(); - resp.Status = - strings::TrimLeft(&line.as_str()[i as usize + 1..len!(line) - 2], " ").to_string(); - let mut statusCode = resp.Status.as_str(); - let i = strings::IndexByte(resp.Status.as_str(), b' '); - if i != -1 { - statusCode = &resp.Status.as_str()[..i as usize]; - } - if len!(statusCode) != 3 { - return Err(Error::new(ErrorKind::Other, "malformed HTTP status code")); - } - // map_err 重新转换Err类型到io::Error - resp.StatusCode = statusCode - .parse::() - .map_err(|err| Error::new(ErrorKind::Other, err))?; - - if resp.StatusCode < 0 { - return Err(Error::new(ErrorKind::Other, "malformed HTTP status code")); - } - - let vers = ParseHTTPVersion(resp.Proto.as_str()); + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + return Err(Error::new(ErrorKind::Other, "malformed HTTP response").into()); + } + resp.Proto = parts[0].to_string(); + resp.Status = parts[1..].join(" "); + resp.StatusCode = parts[1] + .parse::() + .map_err(|_| Error::new(ErrorKind::Other, "malformed HTTP status code"))?; + let vers = ParseHTTPVersion(&resp.Proto); let ok = vers.2; if !ok { - return Err(Error::new(ErrorKind::Other, "malformed HTTP version")); + return Err(Error::new(ErrorKind::Other, "malformed HTTP version").into()); } resp.ProtoMajor = vers.0; resp.ProtoMinor = vers.1; // 1. 获取response的header部分,到第一个 '\r\b'独立行为header的结束。 - let mut headPart: Vec = vec![]; + let mut headPart = BytesMut::new(); let mut head_line = String::new(); while r.read_line(&mut head_line).is_ok() { if head_line.as_bytes() == b"\r\n" { @@ -1094,14 +1125,13 @@ pub fn ReadResponse(mut r: impl BufRead, req: &Request) -> HttpResult { } // parse headPart - resp.Header = Header::NewWithHashMap(parseHeader(headPart)); + resp.Header = Header::NewWithHashMap(parseHeader(&headPart)?); fixPragmaCacheControl(&mut resp.Header); - let mut bodyPart: Vec = vec![]; // set Body if resp.Header.Get("Transfer-Encoding").as_str() == "chunked" { // 2.chunked方式传输方式。获取body数据。 - resp.Body.replace(parseChunkedBody(r)?); + resp.Body = Some(parseChunkedBody(r)?); } else { // 3. 除chunked外的其他传输方式,都有Content-Length字段,根据长度获取body let ln: usize = resp @@ -1113,23 +1143,21 @@ pub fn ReadResponse(mut r: impl BufRead, req: &Request) -> HttpResult { let mut buf = vec![0; ln]; // 生成固定长度的数组,用于读取定长数据; r.read_exact(&mut buf)?; - bodyPart.extend_from_slice(&buf); - resp.Body = Some(bodyPart); + resp.Body = Some(BytesMut::from(&buf[..])); } - resp.ContentLength = len!(&resp.Body.as_ref().unwrap()) as i64; + resp.ContentLength = resp.Body.as_ref().map_or(0, |b| b.len() as i64); Ok(resp) } // chunk数据是以16位数据长度 7acc\r\n独立行开头+ [data] 下一行以\r\n结尾数据段形式,所以数据的结尾用0\r\n表示。 -fn parseChunkedBody(mut r: impl BufRead) -> Result, Error> { - let mut body: Vec = Vec::new(); +fn parseChunkedBody(mut r: impl BufRead) -> Result { + let mut body = BytesMut::new(); let mut size_buf = vec![]; while r.read_until(b'\n', &mut size_buf).is_ok() { // 校验开头行是\r\n结尾的chuank size行 - if size_buf.as_slice().ends_with(b"\r\n") { + if size_buf.ends_with(b"\r\n") { // 删除尾部的\r\n,只保留表示大小的字符串 - size_buf.pop(); // remove "\n" - size_buf.pop(); // remove "\r" + size_buf.truncate(size_buf.len() - 2); // Remove "\r\n" // 16进制chunk大小字符串 let size_str = std::str::from_utf8(&size_buf).map_err(|e| { @@ -1166,52 +1194,37 @@ pub type MIMEHeader = HashMap>; fn fixPragmaCacheControl(header: &mut Header) { if let Some(hp) = header.0.get("Pragma") { - if len!(hp) > 0 && &hp[0] == "no-cache" { - if header.0.get("Cache-Control").is_none() { - header.Set("Cache-Control", "no-cache"); - } + if len!(hp) > 0 && &hp[0] == "no-cache" && header.0.get("Cache-Control").is_none() { + header.Set("Cache-Control", "no-cache"); } } } -fn parseHeader(headPart: Vec) -> MIMEHeader { +fn parseHeader(headPart: &[u8]) -> Result { let mut m: MIMEHeader = HashMap::new(); - let lines = std::str::from_utf8(headPart.as_slice()).unwrap(); - - for kv in lines.split("\r\n").into_iter() { - let mut i = strings::IndexByte(kv, b':'); - if i < 0 { - continue; - } + let lines = std::str::from_utf8(headPart).map_err(|e| { + Error::new( + ErrorKind::InvalidData, + format!("invalid UTF-8 sequence: {}", e), + ) + })?; - let key = canonicalMIMEHeaderKey(kv.get(..i as usize).unwrap()); - if key == "".to_string() { - continue; - } - i += 1; - while (uint!(i) < len!(kv.as_bytes()) - && (kv.as_bytes()[i as usize] == b' ' || kv.as_bytes()[i as usize] == b'\t')) - { - i += 1; - } - let mut vv = Vec::::new(); - let value = strings::TrimFunc(string(&kv.as_bytes()[i as usize..]).trim(), |x| { - x == '\"' as u32 - }) - .to_string(); - - if let Some(mut v) = m.get(&key) { - vv = v.to_owned(); - vv.push(value); - m.insert(key, vv.to_owned()); - } else { - if len!(value) > 0 { - vv.push(value); - m.insert(key, vv.clone()); + for kv in lines.split("\r\n") { + if let Some((key, value)) = kv.split_once(':') { + let key = canonicalMIMEHeaderKey(key); + if key.is_empty() { + continue; } + + let value = value + .trim_start_matches(|c: char| c == ' ' || c == '\t') + .trim_matches('"') + .to_string(); + + m.entry(key).or_insert_with(Vec::new).push(value); } } - m + Ok(m) } fn startIndexOfBody(response: &Vec) -> Option { @@ -1230,142 +1243,56 @@ fn startIndexOfBody(response: &Vec) -> Option { } fn validHeaderFieldByte(b: byte) -> bool { - let isTokenTable: HashMap = [ - ('!', true), - ('#', true), - ('$', true), - ('%', true), - ('&', true), - ('\'', true), - ('*', true), - ('+', true), - ('.', true), - ('0', true), - ('1', true), - ('2', true), - ('3', true), - ('4', true), - ('5', true), - ('6', true), - ('7', true), - ('8', true), - ('9', true), - ('A', true), - ('B', true), - ('C', true), - ('D', true), - ('E', true), - ('F', true), - ('G', true), - ('H', true), - ('I', true), - ('J', true), - ('K', true), - ('L', true), - ('M', true), - ('N', true), - ('O', true), - ('P', true), - ('Q', true), - ('R', true), - ('S', true), - ('T', true), - ('U', true), - ('W', true), - ('V', true), - ('X', true), - ('Y', true), - ('Z', true), - ('^', true), - ('_', true), - ('`', true), - ('a', true), - ('b', true), - ('c', true), - ('d', true), - ('e', true), - ('f', true), - ('g', true), - ('h', true), - ('i', true), - ('j', true), - ('k', true), - ('l', true), - ('m', true), - ('n', true), - ('o', true), - ('p', true), - ('q', true), - ('r', true), - ('s', true), - ('t', true), - ('u', true), - ('v', true), - ('w', true), - ('x', true), - ('y', true), - ('z', true), - ('|', true), - ('~', true), + let isTokenTable: HashSet = [ + '!', '#', '$', '%', '&', '\'', '*', '+', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', + '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', + 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', + 'y', 'z', '|', '~', ] .iter() .cloned() .collect(); - return (int!(b) < int!(len!(isTokenTable))) && isTokenTable.get(&(b as char)).is_some(); + isTokenTable.contains(&(b as char)) } // Header KE规范化 content-length|CONTENT-LENGTH => Content-Length const toLower: byte = (b'a' - b'A'); -fn canonicalMIMEHeaderKey(a: &str) -> String { +fn canonicalMIMEHeaderKey(s: &str) -> String { let mut upper = true; - let mut new = String::new(); - for (i, &c) in a.as_bytes().iter().enumerate() { - let mut c1 = c; - if upper && b'a' <= c && c <= b'z' { - c1 -= toLower; - } else if !upper && b'A' <= c && c <= b'Z' { - c1 += toLower; - } - upper = (c1 == b'-'); - new.push(c1 as char); + let mut new = String::with_capacity(s.len()); + for &byte in s.as_bytes() { + let c = if upper && byte >= b'a' && byte <= b'z' { + byte - toLower + } else if !upper && byte >= b'A' && byte <= b'Z' { + byte + toLower + } else { + byte + }; + upper = c == b'-'; + new.push(c as char); } - new.clone() + new } pub fn ParseHTTPVersion(vers: &str) -> (int, int, bool) { - let Big = 1000000; - match vers { - "HTTP/1.1" => return (1, 1, true), - "HTTP/1.0" => return (1, 0, true), - _ => { - if !strings::HasPrefix(vers, "HTTP/") { - return (0, 0, false); - } + let big: int = 1_000_000; - let dot = strings::Index(vers, "."); + if !vers.starts_with("HTTP/") { + return (0, 0, false); + } - if dot < 0 { - return (0, 0, false); - } - let mut major = 0; - let mut minor = 0; + let version_part = &vers[5..]; + let parts: Vec<&str> = version_part.split('.').collect(); - if let Ok(mj) = vers.get(5..dot as usize).unwrap().parse::() { - major = mj; - if major < 0 || major > Big { - return (0, 0, false); - } - } else { - return (0, 0, false); - } + if parts.len() != 2 { + return (0, 0, false); + } - if let Ok(mi) = vers.get(dot as usize + 1..).unwrap().parse::() { - minor = mi; - return (0, 0, false); - } else { - return (0, 0, false); - } - return (major, minor, true); + match (parts[0].parse::(), parts[1].parse::()) { + (Ok(major), Ok(minor)) if major >= 0 && major <= big && minor >= 0 && minor <= big => { + (major, minor, true) } + _ => (0, 0, false), } }