Added support for upgrade requests
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 5m31s

This commit is contained in:
Dreaded_X 2025-04-20 22:05:37 +02:00
parent c65b4d725d
commit 453718b936
Signed by: Dreaded_X
GPG Key ID: 5A0CBFE3C3377FAA

View File

@ -1,7 +1,6 @@
mod auth; mod auth;
mod response; mod response;
use std::future::join;
use std::ops::Deref; use std::ops::Deref;
use std::pin::Pin; use std::pin::Pin;
@ -11,16 +10,14 @@ use bytes::Bytes;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt as _, Empty}; use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::client::conn::http1::Builder; use hyper::header::{self, HOST, UPGRADE};
use hyper::header::{self, HOST}; use hyper::{Request, Response, StatusCode, client, server};
use hyper::server::conn::http1;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use hyper_util::server::graceful::GracefulShutdown;
use response::response; use response::response;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::select; use tokio::select;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
use crate::tunnel::{Registry, TunnelAccess}; use crate::tunnel::{Registry, TunnelAccess};
@ -29,29 +26,53 @@ use crate::tunnel::{Registry, TunnelAccess};
pub struct Service { pub struct Service {
registry: Registry, registry: Registry,
auth: ForwardAuth, auth: ForwardAuth,
task_tracker: TaskTracker,
}
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
fn copy_request_parts<T>(req: Request<T>) -> (Request<T>, Request<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts.clone(), body);
let forwarded_req = Request::from_parts(parts, empty());
(req, forwarded_req)
}
fn copy_response_parts<T>(
resp: Response<T>,
) -> (Response<T>, Response<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = resp.into_parts();
let resp = Response::from_parts(parts.clone(), body);
let forwarded_resp = Response::from_parts(parts, empty());
(resp, forwarded_resp)
} }
impl Service { impl Service {
pub fn new(registry: Registry, auth: ForwardAuth) -> Self { pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
Self { registry, auth } Self {
registry,
auth,
task_tracker: Default::default(),
}
} }
pub async fn handle_connection( pub async fn handle_connection(&self, listener: &TcpListener) -> std::io::Result<()> {
&self,
listener: &TcpListener,
graceful_shutdown: &GracefulShutdown,
) -> std::io::Result<()> {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let connection = http1::Builder::new() let connection = server::conn::http1::Builder::new()
.preserve_header_case(true) .preserve_header_case(true)
.title_case_headers(true) .title_case_headers(true)
.serve_connection(io, self.clone()); .serve_connection(io, self.clone())
.with_upgrades();
let connection = graceful_shutdown.watch(connection); self.task_tracker.spawn(async move {
tokio::spawn(async move {
if let Err(err) = connection.await { if let Err(err) = connection.await {
error!("Failed to serve connection: {err:?}"); error!("Failed to serve connection: {err:?}");
} }
@ -61,22 +82,27 @@ impl Service {
} }
pub async fn serve(self, listener: TcpListener, token: CancellationToken) { pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
let graceful_shutdown = GracefulShutdown::new();
loop { loop {
select! { select! {
res = self.handle_connection(&listener, &graceful_shutdown) => { res = self.handle_connection(&listener) => {
if let Err(err) = res { if let Err(err) = res {
error!("Failed to accept connection: {err}") error!("Failed to accept connection: {err}")
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
debug!("Graceful shutdown");
break; break;
} }
} }
} }
graceful_shutdown.shutdown().await; debug!(
"Waiting for {} connections to close",
self.task_tracker.len()
);
self.task_tracker.close();
self.task_tracker.wait().await;
debug!("Graceful shutdown");
} }
} }
@ -109,10 +135,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
debug!(authority, "Tunnel request"); debug!(authority, "Tunnel request");
let registry = self.registry.clone(); let s = self.clone();
let auth = self.auth.clone();
Box::pin(async move { Box::pin(async move {
let Some(entry) = registry.get(&authority).await else { let Some(entry) = s.registry.get(&authority).await else {
debug!(tunnel = authority, "Unknown tunnel"); debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
@ -120,7 +145,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}; };
if !entry.is_public().await { if !entry.is_public().await {
let user = match auth.check(req.method(), req.headers()).await { let user = match s.auth.check(req.method(), req.headers()).await {
Ok(AuthStatus::Authenticated(user)) => user, Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => { Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder() let resp = Response::builder()
@ -179,21 +204,73 @@ impl hyper::service::Service<Request<Incoming>> for Service {
} }
}; };
let (mut sender, conn) = Builder::new() let (mut sender, conn) = client::conn::http1::Builder::new()
.preserve_header_case(true) .preserve_header_case(true)
.title_case_headers(true) .title_case_headers(true)
.handshake(io) .handshake(io)
.await?; .await?;
let conn = async { let conn = conn.with_upgrades();
s.task_tracker.spawn(async move {
if let Err(err) = conn.await { if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}"); warn!(runnel = authority, "Connection failed: {err}");
} }
}; });
let (resp, _) = join!(sender.send_request(req), conn).await; let (mut req, forwarded_req) = copy_request_parts(req);
Ok(resp?.map(|b| b.boxed())) let resp = sender.send_request(forwarded_req).await?;
if req.headers().contains_key(UPGRADE)
&& req.headers().get(UPGRADE) == resp.headers().get(UPGRADE)
{
let (mut resp, forwarded_resp) = copy_response_parts(resp);
debug!("UPGRADE established");
match hyper::upgrade::on(&mut resp).await {
Ok(upgraded_resp) => {
s.task_tracker.spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded_req) => {
let mut upgraded_req = TokioIo::new(upgraded_req);
let mut upgraded_resp = TokioIo::new(upgraded_resp);
match tokio::io::copy_bidirectional(
&mut upgraded_req,
&mut upgraded_resp,
)
.await
{
Ok((rx, tx)) => {
debug!(
"Received {rx} bytes and send {tx} bytes over upgraded tunnel"
);
}
Err(err) => {
// Likely due to channel being closed
// TODO: Show warning if not channel closed, otherwise ignore
debug!("Upgraded connection error: {err:?}");
}
}
}
Err(err) => {
error!("Failed to upgrade: {err}");
}
}
});
return Ok(forwarded_resp.map(|b| b.boxed()));
}
Err(err) => {
error!("Failed to upgrade req: {err}");
return Ok(response(StatusCode::BAD_REQUEST, "Failed to upgrade"));
}
}
}
trace!("{resp:#?}");
Ok(resp.map(|b| b.boxed()))
}) })
} }
} }