From 7851d6bb1256639c869a7a00a83120d2f3f20f79 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Sun, 20 Apr 2025 00:11:40 +0200 Subject: [PATCH] Implemented more graceful shutdown --- Cargo.lock | 13 ++++++++-- Cargo.toml | 1 + src/ldap.rs | 25 +++++++++++++++--- src/lib.rs | 1 + src/main.rs | 63 ++++++++++++++++++++++++++------------------- src/ssh/handler.rs | 6 +++-- src/ssh/mod.rs | 44 ++++++++++++++++++++++++------- src/ssh/renderer.rs | 24 ++++++++++++++--- src/web/mod.rs | 59 +++++++++++++++++++++++++++++++++++++++--- 9 files changed, 184 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 31b7d25..d08f087 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1107,6 +1107,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -1473,7 +1479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -1682,7 +1688,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -2875,6 +2881,7 @@ dependencies = [ "russh", "thiserror 2.0.12", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "unicode-width 0.2.0", @@ -3205,6 +3212,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 7af66c8..4345403 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ reqwest = { version = "0.12.15", features = ["rustls-tls"] } russh = "0.51.1" thiserror = "2.0.12" tokio = { version = "1.44.2", features = ["full"] } +tokio-util = { version = "0.7.14", features = ["rt"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } unicode-width = "0.2.0" diff --git a/src/ldap.rs b/src/ldap.rs index f854e12..b4599d3 100644 --- a/src/ldap.rs +++ b/src/ldap.rs @@ -1,5 +1,9 @@ use ldap3::{LdapConnAsync, SearchEntry}; use russh::keys::PublicKey; +use tokio::select; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use tracing::{debug, warn}; #[derive(Debug, Clone)] pub struct Ldap { @@ -20,7 +24,9 @@ pub enum LdapError { } impl Ldap { - pub async fn start_from_env() -> Result { + pub async fn start_from_env( + token: CancellationToken, + ) -> Result<(Ldap, JoinHandle<()>), LdapError> { let address = std::env::var("LDAP_ADDRESS") .map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?; let base = std::env::var("LDAP_BASE") @@ -41,11 +47,24 @@ impl Ldap { )?; let (conn, mut ldap) = LdapConnAsync::new(&address).await?; - ldap3::drive!(conn); + let handle = tokio::spawn(async move { + select! { + res = conn.drive() => { + if let Err(err) = res { + warn!("LDAP connection error: {}", err); + } else { + debug!("LDAP drive has stopped, this should not happen?"); + } + } + _ = token.cancelled() => { + debug!("Graceful shutdown"); + } + } + }); ldap.simple_bind(&bind_dn, &password).await?.success()?; - Ok(Self { base, ldap }) + Ok((Self { base, ldap }, handle)) } pub async fn get_ssh_keys( diff --git a/src/lib.rs b/src/lib.rs index f20a9a6..4727c70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature(let_chains)] #![feature(iter_intersperse)] +#![feature(future_join)] mod helper; mod io; pub mod ldap; diff --git a/src/main.rs b/src/main.rs index bd5bf16..fab6daa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ +#![feature(future_join)] +use std::future::join; use std::net::SocketAddr; use std::path::Path; +use std::time::Duration; use color_eyre::eyre::Context; use dotenvy::dotenv; -use hyper::server::conn::http1::{self}; -use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use siranga::VERSION; use siranga::ldap::Ldap; @@ -12,11 +13,22 @@ use siranga::ssh::Server; use siranga::tunnel::Registry; use siranga::web::{ForwardAuth, Service}; use tokio::net::TcpListener; +use tokio::select; +use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; use tracing_subscriber::EnvFilter; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; +async fn shutdown_task(token: CancellationToken) { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen for ctrl-c event"); + info!("Starting graceful shutdown"); + token.cancel(); + tokio::time::sleep(Duration::from_secs(5)).await; +} + #[tokio::main] async fn main() -> color_eyre::Result<()> { color_eyre::install()?; @@ -59,35 +71,32 @@ async fn main() -> color_eyre::Result<()> { std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}")); let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?; - let ldap = Ldap::start_from_env().await?; let registry = Registry::new(domain); - let mut ssh = Server::new(ldap, registry.clone()); - let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); - tokio::spawn(async move { ssh.run(key, addr).await }); - info!("SSH is available on {addr}"); + + let token = CancellationToken::new(); + + let (ldap, ldap_handle) = Ldap::start_from_env(token.clone()).await?; + + let ssh = Server::new(ldap, registry.clone(), token.clone()); + let ssh_addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); + let ssh_task = ssh.run(key, ssh_addr); + info!("SSH is available on {ssh_addr}"); let auth = ForwardAuth::new(authz_address); let service = Service::new(registry, auth); - let addr = SocketAddr::from(([0, 0, 0, 0], http_port)); - let listener = TcpListener::bind(addr).await?; - info!("HTTP is available on {addr}"); + let http_addr = SocketAddr::from(([0, 0, 0, 0], http_port)); + let http_listener = TcpListener::bind(http_addr).await?; + let http_task = service.serve(http_listener, token.clone()); + info!("HTTP is available on {http_addr}"); - // TODO: Graceful shutdown - loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); + select! { + _ = join!(ldap_handle, ssh_task, http_task) => { + info!("Shutdown gracefully"); + } + _ = shutdown_task(token.clone()) => { + error!("Failed to shut down gracefully"); + } + }; - let service = service.clone(); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .serve_connection(io, service) - .with_upgrades() - .await - { - error!("Failed to serve connection: {err:?}"); - } - }); - } + Ok(()) } diff --git a/src/ssh/handler.rs b/src/ssh/handler.rs index 1a36dca..0ed3515 100644 --- a/src/ssh/handler.rs +++ b/src/ssh/handler.rs @@ -8,8 +8,10 @@ use ratatui::{Terminal, TerminalOptions, Viewport}; use russh::ChannelId; use russh::keys::ssh_key::PublicKey; use russh::server::{Auth, Msg, Session}; +use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; +use super::renderer::Renderer; use crate::VERSION; use crate::io::{Input, TerminalHandle}; use crate::ldap::{Ldap, LdapError}; @@ -62,7 +64,7 @@ pub struct Handler { } impl Handler { - pub fn new(ldap: Ldap, registry: Registry) -> Self { + pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self { Self { ldap, registry, @@ -70,7 +72,7 @@ impl Handler { user: None, pty_channel: None, - renderer: Default::default(), + renderer: Renderer::new(token), selected: None, rename_input: None, } diff --git a/src/ssh/mod.rs b/src/ssh/mod.rs index 6b22bd7..01ab52a 100644 --- a/src/ssh/mod.rs +++ b/src/ssh/mod.rs @@ -11,7 +11,9 @@ use russh::MethodKind; use russh::keys::PrivateKey; use russh::server::Server as _; use tokio::net::ToSocketAddrs; -use tracing::{debug, warn}; +use tokio::select; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, warn}; use crate::ldap::Ldap; use crate::tunnel::Registry; @@ -19,18 +21,30 @@ use crate::tunnel::Registry; pub struct Server { ldap: Ldap, registry: Registry, + token: CancellationToken, +} + +async fn graceful_shutdown(token: CancellationToken) { + token.cancelled().await; + let duration = 1; + // All pty sessions will close once the token is cancelled, but to properly allow the sessions + // to close the ssh server still needs to be driven, so we let it run a little bit longer. + // TODO: Figure out a way to wait for all connections to be closed, would require also closing + // non-pty sessions somehow + debug!("Waiting for {duration}s before stopping"); + tokio::time::sleep(Duration::from_secs(duration)).await; } impl Server { - pub fn new(ldap: Ldap, registry: Registry) -> Self { - Server { ldap, registry } + pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self { + Server { + ldap, + registry, + token, + } } - pub fn run( - &mut self, - key: PrivateKey, - addr: impl ToSocketAddrs + Send + std::fmt::Debug, - ) -> impl Future> + Send { + pub async fn run(mut self, key: PrivateKey, addr: impl ToSocketAddrs + Send + std::fmt::Debug) { let config = russh::server::Config { inactivity_timeout: Some(Duration::from_secs(3600)), auth_rejection_time: Duration::from_secs(1), @@ -47,7 +61,17 @@ impl Server { debug!(?addr, "Running ssh"); - async move { self.run_on_address(config, addr).await } + let token = self.token.clone(); + select! { + res = self.run_on_address(config, addr) => { + if let Err(err) = res { + error!("SSH Server error: {err}"); + } + } + _ = graceful_shutdown(token) => { + debug!("Graceful shutdown"); + } + } } } @@ -55,7 +79,7 @@ impl russh::server::Server for Server { type Handler = Handler; fn new_client(&mut self, _peer_addr: Option) -> Self::Handler { - Handler::new(self.ldap.clone(), self.registry.clone()) + Handler::new(self.ldap.clone(), self.registry.clone(), self.token.clone()) } fn handle_session_error(&mut self, error: ::Error) { diff --git a/src/ssh/renderer.rs b/src/ssh/renderer.rs index ff97f24..52aec18 100644 --- a/src/ssh/renderer.rs +++ b/src/ssh/renderer.rs @@ -14,7 +14,8 @@ use ratatui::widgets::{ use ratatui::{Frame, Terminal}; use tokio::select; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::error; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error}; use unicode_width::UnicodeWidthStr; use crate::VERSION; @@ -36,6 +37,8 @@ struct RendererInner { rows: Vec, input: Option, rx: UnboundedReceiver, + + token: CancellationToken, } fn compute_widths(rows: &Vec>>) -> Vec { @@ -75,12 +78,13 @@ fn compute_column_skip( } impl RendererInner { - fn new(rx: UnboundedReceiver) -> Self { + fn new(rx: UnboundedReceiver, token: CancellationToken) -> Self { Self { state: Default::default(), rows: Default::default(), input: None, rx, + token, } } @@ -303,6 +307,10 @@ impl RendererInner { self.render(frame); })?; } + _ = self.token.cancelled() => { + debug!("Graceful shutdown"); + break; + } } } @@ -310,16 +318,24 @@ impl RendererInner { } } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub struct Renderer { tx: Option>, + token: CancellationToken, } impl Renderer { + pub fn new(token: CancellationToken) -> Self { + Self { + tx: Default::default(), + token, + } + } + pub fn start(&mut self, terminal: Terminal>) { let (tx, rx) = unbounded_channel(); - let mut inner = RendererInner::new(rx); + let mut inner = RendererInner::new(rx, self.token.clone()); tokio::spawn(async move { if let Err(err) = inner.start(terminal).await { diff --git a/src/web/mod.rs b/src/web/mod.rs index 58797a8..b96288b 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,6 +1,7 @@ mod auth; mod response; +use std::future::join; use std::ops::Deref; use std::pin::Pin; @@ -12,8 +13,14 @@ use http_body_util::{BodyExt as _, Empty}; use hyper::body::Incoming; use hyper::client::conn::http1::Builder; use hyper::header::{self, HOST}; +use hyper::server::conn::http1; use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use hyper_util::server::graceful::GracefulShutdown; use response::response; +use tokio::net::TcpListener; +use tokio::select; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, trace, warn}; use crate::tunnel::{Registry, TunnelAccess}; @@ -28,6 +35,49 @@ impl Service { pub fn new(registry: Registry, auth: ForwardAuth) -> Self { Self { registry, auth } } + + pub async fn handle_connection( + &self, + listener: &TcpListener, + graceful_shutdown: &GracefulShutdown, + ) -> std::io::Result<()> { + let (stream, _) = listener.accept().await?; + + let io = TokioIo::new(stream); + let connection = http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .serve_connection(io, self.clone()); + + let connection = graceful_shutdown.watch(connection); + + tokio::spawn(async move { + if let Err(err) = connection.await { + error!("Failed to serve connection: {err:?}"); + } + }); + + Ok(()) + } + + pub async fn serve(self, listener: TcpListener, token: CancellationToken) { + let graceful_shutdown = GracefulShutdown::new(); + loop { + select! { + res = self.handle_connection(&listener, &graceful_shutdown) => { + if let Err(err) = res { + error!("Failed to accept connection: {err}") + } + } + _ = token.cancelled() => { + debug!("Graceful shutdown"); + break; + } + } + } + + graceful_shutdown.shutdown().await; + } } impl hyper::service::Service> for Service { @@ -135,14 +185,15 @@ impl hyper::service::Service> for Service { .handshake(io) .await?; - tokio::spawn(async move { + let conn = async { if let Err(err) = conn.await { warn!(runnel = authority, "Connection failed: {err}"); } - }); + }; - let resp = sender.send_request(req).await?; - Ok(resp.map(|b| b.boxed())) + let (resp, _) = join!(sender.send_request(req), conn).await; + + Ok(resp?.map(|b| b.boxed())) }) } }