simple_request/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3
4use std::sync::Arc;
5
6use tokio::sync::Mutex;
7
8use tower_service::Service as TowerService;
9#[cfg(feature = "tls")]
10use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
11use hyper::{Uri, header::HeaderValue, body::Bytes, client::conn::http1::SendRequest};
12use hyper_util::{
13  rt::tokio::TokioExecutor,
14  client::legacy::{Client as HyperClient, connect::HttpConnector},
15};
16pub use hyper;
17
18mod request;
19pub use request::*;
20
21mod response;
22pub use response::*;
23
24#[derive(Debug)]
25pub enum Error {
26  InvalidUri,
27  MissingHost,
28  InconsistentHost,
29  ConnectionError(Box<dyn Send + Sync + std::error::Error>),
30  Hyper(hyper::Error),
31  HyperUtil(hyper_util::client::legacy::Error),
32}
33
34#[cfg(not(feature = "tls"))]
35type Connector = HttpConnector;
36#[cfg(feature = "tls")]
37type Connector = HttpsConnector<HttpConnector>;
38
39#[derive(Clone, Debug)]
40enum Connection {
41  ConnectionPool(HyperClient<Connector, Full<Bytes>>),
42  Connection {
43    connector: Connector,
44    host: Uri,
45    connection: Arc<Mutex<Option<SendRequest<Full<Bytes>>>>>,
46  },
47}
48
49#[derive(Clone, Debug)]
50pub struct Client {
51  connection: Connection,
52}
53
54impl Client {
55  fn connector() -> Connector {
56    let mut res = HttpConnector::new();
57    res.set_keepalive(Some(core::time::Duration::from_secs(60)));
58    res.set_nodelay(true);
59    res.set_reuse_address(true);
60    #[cfg(feature = "tls")]
61    res.enforce_http(false);
62    #[cfg(feature = "tls")]
63    let res = HttpsConnectorBuilder::new()
64      .with_native_roots()
65      .expect("couldn't fetch system's SSL roots")
66      .https_or_http()
67      .enable_http1()
68      .wrap_connector(res);
69    res
70  }
71
72  pub fn with_connection_pool() -> Client {
73    Client {
74      connection: Connection::ConnectionPool(
75        HyperClient::builder(TokioExecutor::new())
76          .pool_idle_timeout(core::time::Duration::from_secs(60))
77          .build(Self::connector()),
78      ),
79    }
80  }
81
82  pub fn without_connection_pool(host: &str) -> Result<Client, Error> {
83    Ok(Client {
84      connection: Connection::Connection {
85        connector: Self::connector(),
86        host: {
87          let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?;
88          if uri.host().is_none() {
89            Err(Error::MissingHost)?;
90          };
91          uri
92        },
93        connection: Arc::new(Mutex::new(None)),
94      },
95    })
96  }
97
98  pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response<'_>, Error> {
99    let request: Request = request.into();
100    let mut request = request.0;
101    if let Some(header_host) = request.headers().get(hyper::header::HOST) {
102      match &self.connection {
103        Connection::ConnectionPool(_) => {}
104        Connection::Connection { host, .. } => {
105          if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() {
106            Err(Error::InconsistentHost)?;
107          }
108        }
109      }
110    } else {
111      let host = match &self.connection {
112        Connection::ConnectionPool(_) => {
113          request.uri().host().ok_or(Error::MissingHost)?.to_string()
114        }
115        Connection::Connection { host, .. } => {
116          let host_str = host.host().unwrap();
117          if let Some(uri_host) = request.uri().host() {
118            if host_str != uri_host {
119              Err(Error::InconsistentHost)?;
120            }
121          }
122          host_str.to_string()
123        }
124      };
125      request
126        .headers_mut()
127        .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
128    }
129
130    let response = match &self.connection {
131      Connection::ConnectionPool(client) => {
132        client.request(request).await.map_err(Error::HyperUtil)?
133      }
134      Connection::Connection { connector, host, connection } => {
135        let mut connection_lock = connection.lock().await;
136
137        // If there's not a connection...
138        if connection_lock.is_none() {
139          let call_res = connector.clone().call(host.clone()).await;
140          #[cfg(not(feature = "tls"))]
141          let call_res = call_res.map_err(|e| Error::ConnectionError(format!("{e:?}").into()));
142          #[cfg(feature = "tls")]
143          let call_res = call_res.map_err(Error::ConnectionError);
144          let (requester, connection) =
145            hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?;
146          // This will die when we drop the requester, so we don't need to track an AbortHandle
147          // for it
148          tokio::spawn(connection);
149          *connection_lock = Some(requester);
150        }
151
152        let connection = connection_lock.as_mut().unwrap();
153        let mut err = connection.ready().await.err();
154        if err.is_none() {
155          // Send the request
156          let res = connection.send_request(request).await;
157          if let Ok(res) = res {
158            return Ok(Response(res, self));
159          }
160          err = res.err();
161        }
162        // Since this connection has been put into an error state, drop it
163        *connection_lock = None;
164        Err(Error::Hyper(err.unwrap()))?
165      }
166    };
167
168    Ok(Response(response, self))
169  }
170}