diff --git a/src/web/mod.rs b/src/web/mod.rs index b96288b..30e5bd6 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,7 +1,6 @@ mod auth; mod response; -use std::future::join; use std::ops::Deref; use std::pin::Pin; @@ -11,16 +10,14 @@ use bytes::Bytes; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt as _, Empty}; use hyper::body::Incoming; -use hyper::client::conn::http1::Builder; -use hyper::header::{self, HOST}; -use hyper::server::conn::http1; -use hyper::{Request, Response, StatusCode}; +use hyper::header::{self, HOST, UPGRADE}; +use hyper::{Request, Response, StatusCode, client, server}; use hyper_util::rt::TokioIo; -use hyper_util::server::graceful::GracefulShutdown; use response::response; use tokio::net::TcpListener; use tokio::select; use tokio_util::sync::CancellationToken; +use tokio_util::task::TaskTracker; use tracing::{debug, error, trace, warn}; use crate::tunnel::{Registry, TunnelAccess}; @@ -29,29 +26,53 @@ use crate::tunnel::{Registry, TunnelAccess}; pub struct Service { registry: Registry, auth: ForwardAuth, + task_tracker: TaskTracker, +} + +pub fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} + +fn copy_request_parts(req: Request) -> (Request, Request>) { + let (parts, body) = req.into_parts(); + let req = Request::from_parts(parts.clone(), body); + let forwarded_req = Request::from_parts(parts, empty()); + + (req, forwarded_req) +} + +fn copy_response_parts( + resp: Response, +) -> (Response, Response>) { + let (parts, body) = resp.into_parts(); + let resp = Response::from_parts(parts.clone(), body); + let forwarded_resp = Response::from_parts(parts, empty()); + + (resp, forwarded_resp) } impl Service { pub fn new(registry: Registry, auth: ForwardAuth) -> Self { - Self { registry, auth } + Self { + registry, + auth, + task_tracker: Default::default(), + } } - pub async fn handle_connection( - &self, - listener: &TcpListener, - graceful_shutdown: &GracefulShutdown, - ) -> std::io::Result<()> { + pub async fn handle_connection(&self, listener: &TcpListener) -> std::io::Result<()> { let (stream, _) = listener.accept().await?; let io = TokioIo::new(stream); - let connection = http1::Builder::new() + let connection = server::conn::http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(io, self.clone()); + .serve_connection(io, self.clone()) + .with_upgrades(); - let connection = graceful_shutdown.watch(connection); - - tokio::spawn(async move { + self.task_tracker.spawn(async move { if let Err(err) = connection.await { error!("Failed to serve connection: {err:?}"); } @@ -61,22 +82,27 @@ impl Service { } pub async fn serve(self, listener: TcpListener, token: CancellationToken) { - let graceful_shutdown = GracefulShutdown::new(); loop { select! { - res = self.handle_connection(&listener, &graceful_shutdown) => { + res = self.handle_connection(&listener) => { if let Err(err) = res { error!("Failed to accept connection: {err}") } } _ = token.cancelled() => { - debug!("Graceful shutdown"); break; } } } - graceful_shutdown.shutdown().await; + debug!( + "Waiting for {} connections to close", + self.task_tracker.len() + ); + self.task_tracker.close(); + self.task_tracker.wait().await; + + debug!("Graceful shutdown"); } } @@ -109,10 +135,9 @@ impl hyper::service::Service> for Service { debug!(authority, "Tunnel request"); - let registry = self.registry.clone(); - let auth = self.auth.clone(); + let s = self.clone(); Box::pin(async move { - let Some(entry) = registry.get(&authority).await else { + let Some(entry) = s.registry.get(&authority).await else { debug!(tunnel = authority, "Unknown tunnel"); let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); @@ -120,7 +145,7 @@ impl hyper::service::Service> for Service { }; if !entry.is_public().await { - let user = match auth.check(req.method(), req.headers()).await { + let user = match s.auth.check(req.method(), req.headers()).await { Ok(AuthStatus::Authenticated(user)) => user, Ok(AuthStatus::Unauthenticated(location)) => { let resp = Response::builder() @@ -179,21 +204,73 @@ impl hyper::service::Service> for Service { } }; - let (mut sender, conn) = Builder::new() + let (mut sender, conn) = client::conn::http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) .handshake(io) .await?; - let conn = async { + let conn = conn.with_upgrades(); + s.task_tracker.spawn(async move { if let Err(err) = conn.await { warn!(runnel = authority, "Connection failed: {err}"); } - }; + }); - let (resp, _) = join!(sender.send_request(req), conn).await; + let (mut req, forwarded_req) = copy_request_parts(req); - Ok(resp?.map(|b| b.boxed())) + let resp = sender.send_request(forwarded_req).await?; + + if req.headers().contains_key(UPGRADE) + && req.headers().get(UPGRADE) == resp.headers().get(UPGRADE) + { + let (mut resp, forwarded_resp) = copy_response_parts(resp); + + debug!("UPGRADE established"); + match hyper::upgrade::on(&mut resp).await { + Ok(upgraded_resp) => { + s.task_tracker.spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(upgraded_req) => { + let mut upgraded_req = TokioIo::new(upgraded_req); + let mut upgraded_resp = TokioIo::new(upgraded_resp); + + match tokio::io::copy_bidirectional( + &mut upgraded_req, + &mut upgraded_resp, + ) + .await + { + Ok((rx, tx)) => { + debug!( + "Received {rx} bytes and send {tx} bytes over upgraded tunnel" + ); + } + Err(err) => { + // Likely due to channel being closed + // TODO: Show warning if not channel closed, otherwise ignore + debug!("Upgraded connection error: {err:?}"); + } + } + } + Err(err) => { + error!("Failed to upgrade: {err}"); + } + } + }); + + return Ok(forwarded_resp.map(|b| b.boxed())); + } + Err(err) => { + error!("Failed to upgrade req: {err}"); + return Ok(response(StatusCode::BAD_REQUEST, "Failed to upgrade")); + } + } + } + + trace!("{resp:#?}"); + + Ok(resp.map(|b| b.boxed())) }) } }