Compare commits

...

4 Commits

Author SHA1 Message Date
8cafe2b3ca
Added support for upgrade requests
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m9s
2025-04-21 02:21:22 +02:00
ed7770f792
Fixed spelling of shutdown during forceful shutdown 2025-04-21 02:21:22 +02:00
dc1f75aee3
Close any remaining connections once the tui exits 2025-04-21 02:21:22 +02:00
0fe043acb5
Revert "Use store instead of fetch_add for atomics"
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m1s
This reverts commit d4bd0ef1ca.
2025-04-21 02:21:18 +02:00
4 changed files with 131 additions and 34 deletions

View File

@ -19,15 +19,15 @@ pub struct Stats {
impl Stats { impl Stats {
pub fn add_connection(&self) { pub fn add_connection(&self) {
self.connections.store(1, Ordering::Relaxed); self.connections.fetch_add(1, Ordering::Relaxed);
} }
pub fn add_rx_bytes(&self, n: usize) { pub fn add_rx_bytes(&self, n: usize) {
self.rx.store(n, Ordering::Relaxed); self.rx.fetch_add(n, Ordering::Relaxed);
} }
pub fn add_tx_bytes(&self, n: usize) { pub fn add_tx_bytes(&self, n: usize) {
self.tx.store(n, Ordering::Relaxed); self.tx.fetch_add(n, Ordering::Relaxed);
} }
pub fn connections(&self) -> usize { pub fn connections(&self) -> usize {

View File

@ -119,7 +119,7 @@ async fn main() -> color_eyre::Result<()> {
info!("Shutdown gracefully"); info!("Shutdown gracefully");
} }
_ = shutdown_task(token.clone()) => { _ = shutdown_task(token.clone()) => {
error!("Failed to shut down gracefully"); error!("Failed to shutdown gracefully");
} }
}; };

View File

@ -330,6 +330,26 @@ impl russh::server::Handler for Handler {
Ok(session.channel_success(channel)?) Ok(session.channel_success(channel)?)
} }
async fn channel_close(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(pty_channel) = self.pty_channel
&& pty_channel == channel
{
debug!("Pty channel closed");
session.disconnect(
russh::Disconnect::ByApplication,
"Remaining active connections have been closed",
"EN",
)?;
}
Ok(())
}
async fn tcpip_forward( async fn tcpip_forward(
&mut self, &mut self,
address: &str, address: &str,

View File

@ -1,7 +1,6 @@
mod auth; mod auth;
mod response; mod response;
use std::future::join;
use std::ops::Deref; use std::ops::Deref;
use std::pin::Pin; use std::pin::Pin;
@ -11,16 +10,14 @@ use bytes::Bytes;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt as _, Empty}; use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::client::conn::http1::Builder; use hyper::header::{self, HOST, UPGRADE};
use hyper::header::{self, HOST}; use hyper::{Request, Response, StatusCode, client, server};
use hyper::server::conn::http1;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use hyper_util::server::graceful::GracefulShutdown;
use response::response; use response::response;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::select; use tokio::select;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
use crate::tunnel::{Registry, TunnelAccess}; use crate::tunnel::{Registry, TunnelAccess};
@ -29,29 +26,53 @@ use crate::tunnel::{Registry, TunnelAccess};
pub struct Service { pub struct Service {
registry: Registry, registry: Registry,
auth: ForwardAuth, auth: ForwardAuth,
task_tracker: TaskTracker,
}
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
fn copy_request_parts<T>(req: Request<T>) -> (Request<T>, Request<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts.clone(), body);
let forwarded_req = Request::from_parts(parts, empty());
(req, forwarded_req)
}
fn copy_response_parts<T>(
resp: Response<T>,
) -> (Response<T>, Response<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = resp.into_parts();
let resp = Response::from_parts(parts.clone(), body);
let forwarded_resp = Response::from_parts(parts, empty());
(resp, forwarded_resp)
} }
impl Service { impl Service {
pub fn new(registry: Registry, auth: ForwardAuth) -> Self { pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
Self { registry, auth } Self {
registry,
auth,
task_tracker: Default::default(),
}
} }
pub async fn handle_connection( pub async fn handle_connection(&self, listener: &TcpListener) -> std::io::Result<()> {
&self,
listener: &TcpListener,
graceful_shutdown: &GracefulShutdown,
) -> std::io::Result<()> {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let connection = http1::Builder::new() let connection = server::conn::http1::Builder::new()
.preserve_header_case(true) .preserve_header_case(true)
.title_case_headers(true) .title_case_headers(true)
.serve_connection(io, self.clone()); .serve_connection(io, self.clone())
.with_upgrades();
let connection = graceful_shutdown.watch(connection); self.task_tracker.spawn(async move {
tokio::spawn(async move {
if let Err(err) = connection.await { if let Err(err) = connection.await {
error!("Failed to serve connection: {err:?}"); error!("Failed to serve connection: {err:?}");
} }
@ -61,22 +82,27 @@ impl Service {
} }
pub async fn serve(self, listener: TcpListener, token: CancellationToken) { pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
let graceful_shutdown = GracefulShutdown::new();
loop { loop {
select! { select! {
res = self.handle_connection(&listener, &graceful_shutdown) => { res = self.handle_connection(&listener) => {
if let Err(err) = res { if let Err(err) = res {
error!("Failed to accept connection: {err}") error!("Failed to accept connection: {err}")
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
debug!("Graceful shutdown");
break; break;
} }
} }
} }
graceful_shutdown.shutdown().await; debug!(
"Waiting for {} connections to close",
self.task_tracker.len()
);
self.task_tracker.close();
self.task_tracker.wait().await;
debug!("Graceful shutdown");
} }
} }
@ -109,10 +135,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
debug!(authority, "Tunnel request"); debug!(authority, "Tunnel request");
let registry = self.registry.clone(); let s = self.clone();
let auth = self.auth.clone();
Box::pin(async move { Box::pin(async move {
let Some(entry) = registry.get(&authority).await else { let Some(entry) = s.registry.get(&authority).await else {
debug!(tunnel = authority, "Unknown tunnel"); debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
@ -120,7 +145,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}; };
if !entry.is_public().await { if !entry.is_public().await {
let user = match auth.check(req.method(), req.headers()).await { let user = match s.auth.check(req.method(), req.headers()).await {
Ok(AuthStatus::Authenticated(user)) => user, Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => { Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder() let resp = Response::builder()
@ -179,21 +204,73 @@ impl hyper::service::Service<Request<Incoming>> for Service {
} }
}; };
let (mut sender, conn) = Builder::new() let (mut sender, conn) = client::conn::http1::Builder::new()
.preserve_header_case(true) .preserve_header_case(true)
.title_case_headers(true) .title_case_headers(true)
.handshake(io) .handshake(io)
.await?; .await?;
let conn = async { let conn = conn.with_upgrades();
s.task_tracker.spawn(async move {
if let Err(err) = conn.await { if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}"); warn!(runnel = authority, "Connection failed: {err}");
} }
}; });
let (resp, _) = join!(sender.send_request(req), conn).await; let (mut req, forwarded_req) = copy_request_parts(req);
Ok(resp?.map(|b| b.boxed())) let resp = sender.send_request(forwarded_req).await?;
if req.headers().contains_key(UPGRADE)
&& req.headers().get(UPGRADE) == resp.headers().get(UPGRADE)
{
let (mut resp, forwarded_resp) = copy_response_parts(resp);
debug!("UPGRADE established");
match hyper::upgrade::on(&mut resp).await {
Ok(upgraded_resp) => {
s.task_tracker.spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded_req) => {
let mut upgraded_req = TokioIo::new(upgraded_req);
let mut upgraded_resp = TokioIo::new(upgraded_resp);
match tokio::io::copy_bidirectional(
&mut upgraded_req,
&mut upgraded_resp,
)
.await
{
Ok((rx, tx)) => {
debug!(
"Received {rx} bytes and send {tx} bytes over upgraded tunnel"
);
}
Err(err) => {
// Likely due to channel being closed
// TODO: Show warning if not channel closed, otherwise ignore
debug!("Upgraded connection error: {err:?}");
}
}
}
Err(err) => {
error!("Failed to upgrade: {err}");
}
}
});
return Ok(forwarded_resp.map(|b| b.boxed()));
}
Err(err) => {
error!("Failed to upgrade req: {err}");
return Ok(response(StatusCode::BAD_REQUEST, "Failed to upgrade"));
}
}
}
trace!("{resp:#?}");
Ok(resp.map(|b| b.boxed()))
}) })
} }
} }