Implemented more graceful shutdown
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 7m27s

This commit is contained in:
Dreaded_X 2025-04-20 00:11:40 +02:00
parent 49fd6d8a3a
commit 7851d6bb12
Signed by: Dreaded_X
GPG Key ID: 5A0CBFE3C3377FAA
9 changed files with 184 additions and 52 deletions

13
Cargo.lock generated
View File

@ -1107,6 +1107,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.2" version = "0.15.2"
@ -1473,7 +1479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown 0.15.2",
] ]
[[package]] [[package]]
@ -1682,7 +1688,7 @@ version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [ dependencies = [
"hashbrown", "hashbrown 0.15.2",
] ]
[[package]] [[package]]
@ -2875,6 +2881,7 @@ dependencies = [
"russh", "russh",
"thiserror 2.0.12", "thiserror 2.0.12",
"tokio", "tokio",
"tokio-util",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"unicode-width 0.2.0", "unicode-width 0.2.0",
@ -3205,6 +3212,8 @@ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]

View File

@ -24,6 +24,7 @@ reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1" russh = "0.51.1"
thiserror = "2.0.12" thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] } tokio = { version = "1.44.2", features = ["full"] }
tokio-util = { version = "0.7.14", features = ["rt"] }
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
unicode-width = "0.2.0" unicode-width = "0.2.0"

View File

@ -1,5 +1,9 @@
use ldap3::{LdapConnAsync, SearchEntry}; use ldap3::{LdapConnAsync, SearchEntry};
use russh::keys::PublicKey; use russh::keys::PublicKey;
use tokio::select;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Ldap { pub struct Ldap {
@ -20,7 +24,9 @@ pub enum LdapError {
} }
impl Ldap { impl Ldap {
pub async fn start_from_env() -> Result<Ldap, LdapError> { pub async fn start_from_env(
token: CancellationToken,
) -> Result<(Ldap, JoinHandle<()>), LdapError> {
let address = std::env::var("LDAP_ADDRESS") let address = std::env::var("LDAP_ADDRESS")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?; .map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?;
let base = std::env::var("LDAP_BASE") let base = std::env::var("LDAP_BASE")
@ -41,11 +47,24 @@ impl Ldap {
)?; )?;
let (conn, mut ldap) = LdapConnAsync::new(&address).await?; let (conn, mut ldap) = LdapConnAsync::new(&address).await?;
ldap3::drive!(conn); let handle = tokio::spawn(async move {
select! {
res = conn.drive() => {
if let Err(err) = res {
warn!("LDAP connection error: {}", err);
} else {
debug!("LDAP drive has stopped, this should not happen?");
}
}
_ = token.cancelled() => {
debug!("Graceful shutdown");
}
}
});
ldap.simple_bind(&bind_dn, &password).await?.success()?; ldap.simple_bind(&bind_dn, &password).await?.success()?;
Ok(Self { base, ldap }) Ok((Self { base, ldap }, handle))
} }
pub async fn get_ssh_keys( pub async fn get_ssh_keys(

View File

@ -1,5 +1,6 @@
#![feature(let_chains)] #![feature(let_chains)]
#![feature(iter_intersperse)] #![feature(iter_intersperse)]
#![feature(future_join)]
mod helper; mod helper;
mod io; mod io;
pub mod ldap; pub mod ldap;

View File

@ -1,10 +1,11 @@
#![feature(future_join)]
use std::future::join;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::Path;
use std::time::Duration;
use color_eyre::eyre::Context; use color_eyre::eyre::Context;
use dotenvy::dotenv; use dotenvy::dotenv;
use hyper::server::conn::http1::{self};
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use siranga::VERSION; use siranga::VERSION;
use siranga::ldap::Ldap; use siranga::ldap::Ldap;
@ -12,11 +13,22 @@ use siranga::ssh::Server;
use siranga::tunnel::Registry; use siranga::tunnel::Registry;
use siranga::web::{ForwardAuth, Service}; use siranga::web::{ForwardAuth, Service};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
async fn shutdown_task(token: CancellationToken) {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for ctrl-c event");
info!("Starting graceful shutdown");
token.cancel();
tokio::time::sleep(Duration::from_secs(5)).await;
}
#[tokio::main] #[tokio::main]
async fn main() -> color_eyre::Result<()> { async fn main() -> color_eyre::Result<()> {
color_eyre::install()?; color_eyre::install()?;
@ -59,35 +71,32 @@ async fn main() -> color_eyre::Result<()> {
std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}")); std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}"));
let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?; let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?;
let ldap = Ldap::start_from_env().await?;
let registry = Registry::new(domain); let registry = Registry::new(domain);
let mut ssh = Server::new(ldap, registry.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); let token = CancellationToken::new();
tokio::spawn(async move { ssh.run(key, addr).await });
info!("SSH is available on {addr}"); let (ldap, ldap_handle) = Ldap::start_from_env(token.clone()).await?;
let ssh = Server::new(ldap, registry.clone(), token.clone());
let ssh_addr = SocketAddr::from(([0, 0, 0, 0], ssh_port));
let ssh_task = ssh.run(key, ssh_addr);
info!("SSH is available on {ssh_addr}");
let auth = ForwardAuth::new(authz_address); let auth = ForwardAuth::new(authz_address);
let service = Service::new(registry, auth); let service = Service::new(registry, auth);
let addr = SocketAddr::from(([0, 0, 0, 0], http_port)); let http_addr = SocketAddr::from(([0, 0, 0, 0], http_port));
let listener = TcpListener::bind(addr).await?; let http_listener = TcpListener::bind(http_addr).await?;
info!("HTTP is available on {addr}"); let http_task = service.serve(http_listener, token.clone());
info!("HTTP is available on {http_addr}");
// TODO: Graceful shutdown select! {
loop { _ = join!(ldap_handle, ssh_task, http_task) => {
let (stream, _) = listener.accept().await?; info!("Shutdown gracefully");
let io = TokioIo::new(stream); }
_ = shutdown_task(token.clone()) => {
error!("Failed to shut down gracefully");
}
};
let service = service.clone(); Ok(())
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, service)
.with_upgrades()
.await
{
error!("Failed to serve connection: {err:?}");
}
});
}
} }

View File

@ -8,8 +8,10 @@ use ratatui::{Terminal, TerminalOptions, Viewport};
use russh::ChannelId; use russh::ChannelId;
use russh::keys::ssh_key::PublicKey; use russh::keys::ssh_key::PublicKey;
use russh::server::{Auth, Msg, Session}; use russh::server::{Auth, Msg, Session};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use super::renderer::Renderer;
use crate::VERSION; use crate::VERSION;
use crate::io::{Input, TerminalHandle}; use crate::io::{Input, TerminalHandle};
use crate::ldap::{Ldap, LdapError}; use crate::ldap::{Ldap, LdapError};
@ -62,7 +64,7 @@ pub struct Handler {
} }
impl Handler { impl Handler {
pub fn new(ldap: Ldap, registry: Registry) -> Self { pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Self { Self {
ldap, ldap,
registry, registry,
@ -70,7 +72,7 @@ impl Handler {
user: None, user: None,
pty_channel: None, pty_channel: None,
renderer: Default::default(), renderer: Renderer::new(token),
selected: None, selected: None,
rename_input: None, rename_input: None,
} }

View File

@ -11,7 +11,9 @@ use russh::MethodKind;
use russh::keys::PrivateKey; use russh::keys::PrivateKey;
use russh::server::Server as _; use russh::server::Server as _;
use tokio::net::ToSocketAddrs; use tokio::net::ToSocketAddrs;
use tracing::{debug, warn}; use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::ldap::Ldap; use crate::ldap::Ldap;
use crate::tunnel::Registry; use crate::tunnel::Registry;
@ -19,18 +21,30 @@ use crate::tunnel::Registry;
pub struct Server { pub struct Server {
ldap: Ldap, ldap: Ldap,
registry: Registry, registry: Registry,
token: CancellationToken,
}
async fn graceful_shutdown(token: CancellationToken) {
token.cancelled().await;
let duration = 1;
// All pty sessions will close once the token is cancelled, but to properly allow the sessions
// to close the ssh server still needs to be driven, so we let it run a little bit longer.
// TODO: Figure out a way to wait for all connections to be closed, would require also closing
// non-pty sessions somehow
debug!("Waiting for {duration}s before stopping");
tokio::time::sleep(Duration::from_secs(duration)).await;
} }
impl Server { impl Server {
pub fn new(ldap: Ldap, registry: Registry) -> Self { pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Server { ldap, registry } Server {
ldap,
registry,
token,
}
} }
pub fn run( pub async fn run(mut self, key: PrivateKey, addr: impl ToSocketAddrs + Send + std::fmt::Debug) {
&mut self,
key: PrivateKey,
addr: impl ToSocketAddrs + Send + std::fmt::Debug,
) -> impl Future<Output = Result<(), std::io::Error>> + Send {
let config = russh::server::Config { let config = russh::server::Config {
inactivity_timeout: Some(Duration::from_secs(3600)), inactivity_timeout: Some(Duration::from_secs(3600)),
auth_rejection_time: Duration::from_secs(1), auth_rejection_time: Duration::from_secs(1),
@ -47,7 +61,17 @@ impl Server {
debug!(?addr, "Running ssh"); debug!(?addr, "Running ssh");
async move { self.run_on_address(config, addr).await } let token = self.token.clone();
select! {
res = self.run_on_address(config, addr) => {
if let Err(err) = res {
error!("SSH Server error: {err}");
}
}
_ = graceful_shutdown(token) => {
debug!("Graceful shutdown");
}
}
} }
} }
@ -55,7 +79,7 @@ impl russh::server::Server for Server {
type Handler = Handler; type Handler = Handler;
fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler { fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
Handler::new(self.ldap.clone(), self.registry.clone()) Handler::new(self.ldap.clone(), self.registry.clone(), self.token.clone())
} }
fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) { fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) {

View File

@ -14,7 +14,8 @@ use ratatui::widgets::{
use ratatui::{Frame, Terminal}; use ratatui::{Frame, Terminal};
use tokio::select; use tokio::select;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tracing::error; use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use unicode_width::UnicodeWidthStr; use unicode_width::UnicodeWidthStr;
use crate::VERSION; use crate::VERSION;
@ -36,6 +37,8 @@ struct RendererInner {
rows: Vec<TunnelRow>, rows: Vec<TunnelRow>,
input: Option<String>, input: Option<String>,
rx: UnboundedReceiver<Message>, rx: UnboundedReceiver<Message>,
token: CancellationToken,
} }
fn compute_widths(rows: &Vec<Vec<Span<'static>>>) -> Vec<u16> { fn compute_widths(rows: &Vec<Vec<Span<'static>>>) -> Vec<u16> {
@ -75,12 +78,13 @@ fn compute_column_skip(
} }
impl RendererInner { impl RendererInner {
fn new(rx: UnboundedReceiver<Message>) -> Self { fn new(rx: UnboundedReceiver<Message>, token: CancellationToken) -> Self {
Self { Self {
state: Default::default(), state: Default::default(),
rows: Default::default(), rows: Default::default(),
input: None, input: None,
rx, rx,
token,
} }
} }
@ -303,6 +307,10 @@ impl RendererInner {
self.render(frame); self.render(frame);
})?; })?;
} }
_ = self.token.cancelled() => {
debug!("Graceful shutdown");
break;
}
} }
} }
@ -310,16 +318,24 @@ impl RendererInner {
} }
} }
#[derive(Debug, Default, Clone)] #[derive(Debug, Clone)]
pub struct Renderer { pub struct Renderer {
tx: Option<UnboundedSender<Message>>, tx: Option<UnboundedSender<Message>>,
token: CancellationToken,
} }
impl Renderer { impl Renderer {
pub fn new(token: CancellationToken) -> Self {
Self {
tx: Default::default(),
token,
}
}
pub fn start(&mut self, terminal: Terminal<CrosstermBackend<TerminalHandle>>) { pub fn start(&mut self, terminal: Terminal<CrosstermBackend<TerminalHandle>>) {
let (tx, rx) = unbounded_channel(); let (tx, rx) = unbounded_channel();
let mut inner = RendererInner::new(rx); let mut inner = RendererInner::new(rx, self.token.clone());
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = inner.start(terminal).await { if let Err(err) = inner.start(terminal).await {

View File

@ -1,6 +1,7 @@
mod auth; mod auth;
mod response; mod response;
use std::future::join;
use std::ops::Deref; use std::ops::Deref;
use std::pin::Pin; use std::pin::Pin;
@ -12,8 +13,14 @@ use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::client::conn::http1::Builder; use hyper::client::conn::http1::Builder;
use hyper::header::{self, HOST}; use hyper::header::{self, HOST};
use hyper::server::conn::http1;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use hyper_util::server::graceful::GracefulShutdown;
use response::response; use response::response;
use tokio::net::TcpListener;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
use crate::tunnel::{Registry, TunnelAccess}; use crate::tunnel::{Registry, TunnelAccess};
@ -28,6 +35,49 @@ impl Service {
pub fn new(registry: Registry, auth: ForwardAuth) -> Self { pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
Self { registry, auth } Self { registry, auth }
} }
pub async fn handle_connection(
&self,
listener: &TcpListener,
graceful_shutdown: &GracefulShutdown,
) -> std::io::Result<()> {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let connection = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, self.clone());
let connection = graceful_shutdown.watch(connection);
tokio::spawn(async move {
if let Err(err) = connection.await {
error!("Failed to serve connection: {err:?}");
}
});
Ok(())
}
pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
let graceful_shutdown = GracefulShutdown::new();
loop {
select! {
res = self.handle_connection(&listener, &graceful_shutdown) => {
if let Err(err) = res {
error!("Failed to accept connection: {err}")
}
}
_ = token.cancelled() => {
debug!("Graceful shutdown");
break;
}
}
}
graceful_shutdown.shutdown().await;
}
} }
impl hyper::service::Service<Request<Incoming>> for Service { impl hyper::service::Service<Request<Incoming>> for Service {
@ -135,14 +185,15 @@ impl hyper::service::Service<Request<Incoming>> for Service {
.handshake(io) .handshake(io)
.await?; .await?;
tokio::spawn(async move { let conn = async {
if let Err(err) = conn.await { if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}"); warn!(runnel = authority, "Connection failed: {err}");
} }
}); };
let resp = sender.send_request(req).await?; let (resp, _) = join!(sender.send_request(req), conn).await;
Ok(resp.map(|b| b.boxed()))
Ok(resp?.map(|b| b.boxed()))
}) })
} }
} }