Added support for upgrade requests
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m9s
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m9s
This commit is contained in:
parent
ed7770f792
commit
8cafe2b3ca
137
src/web/mod.rs
137
src/web/mod.rs
|
@ -1,7 +1,6 @@
|
||||||
mod auth;
|
mod auth;
|
||||||
mod response;
|
mod response;
|
||||||
|
|
||||||
use std::future::join;
|
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
@ -11,16 +10,14 @@ use bytes::Bytes;
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::{BodyExt as _, Empty};
|
use http_body_util::{BodyExt as _, Empty};
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
use hyper::client::conn::http1::Builder;
|
use hyper::header::{self, HOST, UPGRADE};
|
||||||
use hyper::header::{self, HOST};
|
use hyper::{Request, Response, StatusCode, client, server};
|
||||||
use hyper::server::conn::http1;
|
|
||||||
use hyper::{Request, Response, StatusCode};
|
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use hyper_util::server::graceful::GracefulShutdown;
|
|
||||||
use response::response;
|
use response::response;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tokio_util::task::TaskTracker;
|
||||||
use tracing::{debug, error, trace, warn};
|
use tracing::{debug, error, trace, warn};
|
||||||
|
|
||||||
use crate::tunnel::{Registry, TunnelAccess};
|
use crate::tunnel::{Registry, TunnelAccess};
|
||||||
|
@ -29,29 +26,53 @@ use crate::tunnel::{Registry, TunnelAccess};
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
registry: Registry,
|
registry: Registry,
|
||||||
auth: ForwardAuth,
|
auth: ForwardAuth,
|
||||||
|
task_tracker: TaskTracker,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||||
|
Empty::<Bytes>::new()
|
||||||
|
.map_err(|never| match never {})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_request_parts<T>(req: Request<T>) -> (Request<T>, Request<BoxBody<Bytes, hyper::Error>>) {
|
||||||
|
let (parts, body) = req.into_parts();
|
||||||
|
let req = Request::from_parts(parts.clone(), body);
|
||||||
|
let forwarded_req = Request::from_parts(parts, empty());
|
||||||
|
|
||||||
|
(req, forwarded_req)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy_response_parts<T>(
|
||||||
|
resp: Response<T>,
|
||||||
|
) -> (Response<T>, Response<BoxBody<Bytes, hyper::Error>>) {
|
||||||
|
let (parts, body) = resp.into_parts();
|
||||||
|
let resp = Response::from_parts(parts.clone(), body);
|
||||||
|
let forwarded_resp = Response::from_parts(parts, empty());
|
||||||
|
|
||||||
|
(resp, forwarded_resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
|
pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
|
||||||
Self { registry, auth }
|
Self {
|
||||||
|
registry,
|
||||||
|
auth,
|
||||||
|
task_tracker: Default::default(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_connection(
|
pub async fn handle_connection(&self, listener: &TcpListener) -> std::io::Result<()> {
|
||||||
&self,
|
|
||||||
listener: &TcpListener,
|
|
||||||
graceful_shutdown: &GracefulShutdown,
|
|
||||||
) -> std::io::Result<()> {
|
|
||||||
let (stream, _) = listener.accept().await?;
|
let (stream, _) = listener.accept().await?;
|
||||||
|
|
||||||
let io = TokioIo::new(stream);
|
let io = TokioIo::new(stream);
|
||||||
let connection = http1::Builder::new()
|
let connection = server::conn::http1::Builder::new()
|
||||||
.preserve_header_case(true)
|
.preserve_header_case(true)
|
||||||
.title_case_headers(true)
|
.title_case_headers(true)
|
||||||
.serve_connection(io, self.clone());
|
.serve_connection(io, self.clone())
|
||||||
|
.with_upgrades();
|
||||||
|
|
||||||
let connection = graceful_shutdown.watch(connection);
|
self.task_tracker.spawn(async move {
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(err) = connection.await {
|
if let Err(err) = connection.await {
|
||||||
error!("Failed to serve connection: {err:?}");
|
error!("Failed to serve connection: {err:?}");
|
||||||
}
|
}
|
||||||
|
@ -61,22 +82,27 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
|
pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
|
||||||
let graceful_shutdown = GracefulShutdown::new();
|
|
||||||
loop {
|
loop {
|
||||||
select! {
|
select! {
|
||||||
res = self.handle_connection(&listener, &graceful_shutdown) => {
|
res = self.handle_connection(&listener) => {
|
||||||
if let Err(err) = res {
|
if let Err(err) = res {
|
||||||
error!("Failed to accept connection: {err}")
|
error!("Failed to accept connection: {err}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = token.cancelled() => {
|
_ = token.cancelled() => {
|
||||||
debug!("Graceful shutdown");
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graceful_shutdown.shutdown().await;
|
debug!(
|
||||||
|
"Waiting for {} connections to close",
|
||||||
|
self.task_tracker.len()
|
||||||
|
);
|
||||||
|
self.task_tracker.close();
|
||||||
|
self.task_tracker.wait().await;
|
||||||
|
|
||||||
|
debug!("Graceful shutdown");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,10 +135,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||||
|
|
||||||
debug!(authority, "Tunnel request");
|
debug!(authority, "Tunnel request");
|
||||||
|
|
||||||
let registry = self.registry.clone();
|
let s = self.clone();
|
||||||
let auth = self.auth.clone();
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let Some(entry) = registry.get(&authority).await else {
|
let Some(entry) = s.registry.get(&authority).await else {
|
||||||
debug!(tunnel = authority, "Unknown tunnel");
|
debug!(tunnel = authority, "Unknown tunnel");
|
||||||
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
|
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
|
||||||
|
|
||||||
|
@ -120,7 +145,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||||
};
|
};
|
||||||
|
|
||||||
if !entry.is_public().await {
|
if !entry.is_public().await {
|
||||||
let user = match auth.check(req.method(), req.headers()).await {
|
let user = match s.auth.check(req.method(), req.headers()).await {
|
||||||
Ok(AuthStatus::Authenticated(user)) => user,
|
Ok(AuthStatus::Authenticated(user)) => user,
|
||||||
Ok(AuthStatus::Unauthenticated(location)) => {
|
Ok(AuthStatus::Unauthenticated(location)) => {
|
||||||
let resp = Response::builder()
|
let resp = Response::builder()
|
||||||
|
@ -179,21 +204,73 @@ impl hyper::service::Service<Request<Incoming>> for Service {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let (mut sender, conn) = Builder::new()
|
let (mut sender, conn) = client::conn::http1::Builder::new()
|
||||||
.preserve_header_case(true)
|
.preserve_header_case(true)
|
||||||
.title_case_headers(true)
|
.title_case_headers(true)
|
||||||
.handshake(io)
|
.handshake(io)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let conn = async {
|
let conn = conn.with_upgrades();
|
||||||
|
s.task_tracker.spawn(async move {
|
||||||
if let Err(err) = conn.await {
|
if let Err(err) = conn.await {
|
||||||
warn!(runnel = authority, "Connection failed: {err}");
|
warn!(runnel = authority, "Connection failed: {err}");
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
|
|
||||||
let (resp, _) = join!(sender.send_request(req), conn).await;
|
let (mut req, forwarded_req) = copy_request_parts(req);
|
||||||
|
|
||||||
Ok(resp?.map(|b| b.boxed()))
|
let resp = sender.send_request(forwarded_req).await?;
|
||||||
|
|
||||||
|
if req.headers().contains_key(UPGRADE)
|
||||||
|
&& req.headers().get(UPGRADE) == resp.headers().get(UPGRADE)
|
||||||
|
{
|
||||||
|
let (mut resp, forwarded_resp) = copy_response_parts(resp);
|
||||||
|
|
||||||
|
debug!("UPGRADE established");
|
||||||
|
match hyper::upgrade::on(&mut resp).await {
|
||||||
|
Ok(upgraded_resp) => {
|
||||||
|
s.task_tracker.spawn(async move {
|
||||||
|
match hyper::upgrade::on(&mut req).await {
|
||||||
|
Ok(upgraded_req) => {
|
||||||
|
let mut upgraded_req = TokioIo::new(upgraded_req);
|
||||||
|
let mut upgraded_resp = TokioIo::new(upgraded_resp);
|
||||||
|
|
||||||
|
match tokio::io::copy_bidirectional(
|
||||||
|
&mut upgraded_req,
|
||||||
|
&mut upgraded_resp,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok((rx, tx)) => {
|
||||||
|
debug!(
|
||||||
|
"Received {rx} bytes and send {tx} bytes over upgraded tunnel"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
// Likely due to channel being closed
|
||||||
|
// TODO: Show warning if not channel closed, otherwise ignore
|
||||||
|
debug!("Upgraded connection error: {err:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!("Failed to upgrade: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return Ok(forwarded_resp.map(|b| b.boxed()));
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!("Failed to upgrade req: {err}");
|
||||||
|
return Ok(response(StatusCode::BAD_REQUEST, "Failed to upgrade"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trace!("{resp:#?}");
|
||||||
|
|
||||||
|
Ok(resp.map(|b| b.boxed()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user