Implemented more graceful shutdown
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 7m27s

This commit is contained in:
Dreaded_X 2025-04-20 00:11:40 +02:00
parent 49fd6d8a3a
commit 7851d6bb12
Signed by: Dreaded_X
GPG Key ID: 5A0CBFE3C3377FAA
9 changed files with 184 additions and 52 deletions

13
Cargo.lock generated
View File

@ -1107,6 +1107,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.2"
@ -1473,7 +1479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.15.2",
]
[[package]]
@ -1682,7 +1688,7 @@ version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
"hashbrown",
"hashbrown 0.15.2",
]
[[package]]
@ -2875,6 +2881,7 @@ dependencies = [
"russh",
"thiserror 2.0.12",
"tokio",
"tokio-util",
"tracing",
"tracing-subscriber",
"unicode-width 0.2.0",
@ -3205,6 +3212,8 @@ dependencies = [
"bytes",
"futures-core",
"futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite",
"tokio",
]

View File

@ -24,6 +24,7 @@ reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1"
thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] }
tokio-util = { version = "0.7.14", features = ["rt"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
unicode-width = "0.2.0"

View File

@ -1,5 +1,9 @@
use ldap3::{LdapConnAsync, SearchEntry};
use russh::keys::PublicKey;
use tokio::select;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct Ldap {
@ -20,7 +24,9 @@ pub enum LdapError {
}
impl Ldap {
pub async fn start_from_env() -> Result<Ldap, LdapError> {
pub async fn start_from_env(
token: CancellationToken,
) -> Result<(Ldap, JoinHandle<()>), LdapError> {
let address = std::env::var("LDAP_ADDRESS")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?;
let base = std::env::var("LDAP_BASE")
@ -41,11 +47,24 @@ impl Ldap {
)?;
let (conn, mut ldap) = LdapConnAsync::new(&address).await?;
ldap3::drive!(conn);
let handle = tokio::spawn(async move {
select! {
res = conn.drive() => {
if let Err(err) = res {
warn!("LDAP connection error: {}", err);
} else {
debug!("LDAP drive has stopped, this should not happen?");
}
}
_ = token.cancelled() => {
debug!("Graceful shutdown");
}
}
});
ldap.simple_bind(&bind_dn, &password).await?.success()?;
Ok(Self { base, ldap })
Ok((Self { base, ldap }, handle))
}
pub async fn get_ssh_keys(

View File

@ -1,5 +1,6 @@
#![feature(let_chains)]
#![feature(iter_intersperse)]
#![feature(future_join)]
mod helper;
mod io;
pub mod ldap;

View File

@ -1,10 +1,11 @@
#![feature(future_join)]
use std::future::join;
use std::net::SocketAddr;
use std::path::Path;
use std::time::Duration;
use color_eyre::eyre::Context;
use dotenvy::dotenv;
use hyper::server::conn::http1::{self};
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use siranga::VERSION;
use siranga::ldap::Ldap;
@ -12,11 +13,22 @@ use siranga::ssh::Server;
use siranga::tunnel::Registry;
use siranga::web::{ForwardAuth, Service};
use tokio::net::TcpListener;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
async fn shutdown_task(token: CancellationToken) {
tokio::signal::ctrl_c()
.await
.expect("Failed to listen for ctrl-c event");
info!("Starting graceful shutdown");
token.cancel();
tokio::time::sleep(Duration::from_secs(5)).await;
}
#[tokio::main]
async fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
@ -59,35 +71,32 @@ async fn main() -> color_eyre::Result<()> {
std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}"));
let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?;
let ldap = Ldap::start_from_env().await?;
let registry = Registry::new(domain);
let mut ssh = Server::new(ldap, registry.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port));
tokio::spawn(async move { ssh.run(key, addr).await });
info!("SSH is available on {addr}");
let token = CancellationToken::new();
let (ldap, ldap_handle) = Ldap::start_from_env(token.clone()).await?;
let ssh = Server::new(ldap, registry.clone(), token.clone());
let ssh_addr = SocketAddr::from(([0, 0, 0, 0], ssh_port));
let ssh_task = ssh.run(key, ssh_addr);
info!("SSH is available on {ssh_addr}");
let auth = ForwardAuth::new(authz_address);
let service = Service::new(registry, auth);
let addr = SocketAddr::from(([0, 0, 0, 0], http_port));
let listener = TcpListener::bind(addr).await?;
info!("HTTP is available on {addr}");
let http_addr = SocketAddr::from(([0, 0, 0, 0], http_port));
let http_listener = TcpListener::bind(http_addr).await?;
let http_task = service.serve(http_listener, token.clone());
info!("HTTP is available on {http_addr}");
// TODO: Graceful shutdown
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
select! {
_ = join!(ldap_handle, ssh_task, http_task) => {
info!("Shutdown gracefully");
}
_ = shutdown_task(token.clone()) => {
error!("Failed to shut down gracefully");
}
};
let service = service.clone();
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, service)
.with_upgrades()
.await
{
error!("Failed to serve connection: {err:?}");
}
});
}
Ok(())
}

View File

@ -8,8 +8,10 @@ use ratatui::{Terminal, TerminalOptions, Viewport};
use russh::ChannelId;
use russh::keys::ssh_key::PublicKey;
use russh::server::{Auth, Msg, Session};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
use super::renderer::Renderer;
use crate::VERSION;
use crate::io::{Input, TerminalHandle};
use crate::ldap::{Ldap, LdapError};
@ -62,7 +64,7 @@ pub struct Handler {
}
impl Handler {
pub fn new(ldap: Ldap, registry: Registry) -> Self {
pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Self {
ldap,
registry,
@ -70,7 +72,7 @@ impl Handler {
user: None,
pty_channel: None,
renderer: Default::default(),
renderer: Renderer::new(token),
selected: None,
rename_input: None,
}

View File

@ -11,7 +11,9 @@ use russh::MethodKind;
use russh::keys::PrivateKey;
use russh::server::Server as _;
use tokio::net::ToSocketAddrs;
use tracing::{debug, warn};
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::ldap::Ldap;
use crate::tunnel::Registry;
@ -19,18 +21,30 @@ use crate::tunnel::Registry;
pub struct Server {
ldap: Ldap,
registry: Registry,
token: CancellationToken,
}
async fn graceful_shutdown(token: CancellationToken) {
token.cancelled().await;
let duration = 1;
// All pty sessions will close once the token is cancelled, but to properly allow the sessions
// to close the ssh server still needs to be driven, so we let it run a little bit longer.
// TODO: Figure out a way to wait for all connections to be closed, would require also closing
// non-pty sessions somehow
debug!("Waiting for {duration}s before stopping");
tokio::time::sleep(Duration::from_secs(duration)).await;
}
impl Server {
pub fn new(ldap: Ldap, registry: Registry) -> Self {
Server { ldap, registry }
pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Server {
ldap,
registry,
token,
}
}
pub fn run(
&mut self,
key: PrivateKey,
addr: impl ToSocketAddrs + Send + std::fmt::Debug,
) -> impl Future<Output = Result<(), std::io::Error>> + Send {
pub async fn run(mut self, key: PrivateKey, addr: impl ToSocketAddrs + Send + std::fmt::Debug) {
let config = russh::server::Config {
inactivity_timeout: Some(Duration::from_secs(3600)),
auth_rejection_time: Duration::from_secs(1),
@ -47,7 +61,17 @@ impl Server {
debug!(?addr, "Running ssh");
async move { self.run_on_address(config, addr).await }
let token = self.token.clone();
select! {
res = self.run_on_address(config, addr) => {
if let Err(err) = res {
error!("SSH Server error: {err}");
}
}
_ = graceful_shutdown(token) => {
debug!("Graceful shutdown");
}
}
}
}
@ -55,7 +79,7 @@ impl russh::server::Server for Server {
type Handler = Handler;
fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
Handler::new(self.ldap.clone(), self.registry.clone())
Handler::new(self.ldap.clone(), self.registry.clone(), self.token.clone())
}
fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) {

View File

@ -14,7 +14,8 @@ use ratatui::widgets::{
use ratatui::{Frame, Terminal};
use tokio::select;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tracing::error;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use unicode_width::UnicodeWidthStr;
use crate::VERSION;
@ -36,6 +37,8 @@ struct RendererInner {
rows: Vec<TunnelRow>,
input: Option<String>,
rx: UnboundedReceiver<Message>,
token: CancellationToken,
}
fn compute_widths(rows: &Vec<Vec<Span<'static>>>) -> Vec<u16> {
@ -75,12 +78,13 @@ fn compute_column_skip(
}
impl RendererInner {
fn new(rx: UnboundedReceiver<Message>) -> Self {
fn new(rx: UnboundedReceiver<Message>, token: CancellationToken) -> Self {
Self {
state: Default::default(),
rows: Default::default(),
input: None,
rx,
token,
}
}
@ -303,6 +307,10 @@ impl RendererInner {
self.render(frame);
})?;
}
_ = self.token.cancelled() => {
debug!("Graceful shutdown");
break;
}
}
}
@ -310,16 +318,24 @@ impl RendererInner {
}
}
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub struct Renderer {
tx: Option<UnboundedSender<Message>>,
token: CancellationToken,
}
impl Renderer {
pub fn new(token: CancellationToken) -> Self {
Self {
tx: Default::default(),
token,
}
}
pub fn start(&mut self, terminal: Terminal<CrosstermBackend<TerminalHandle>>) {
let (tx, rx) = unbounded_channel();
let mut inner = RendererInner::new(rx);
let mut inner = RendererInner::new(rx, self.token.clone());
tokio::spawn(async move {
if let Err(err) = inner.start(terminal).await {

View File

@ -1,6 +1,7 @@
mod auth;
mod response;
use std::future::join;
use std::ops::Deref;
use std::pin::Pin;
@ -12,8 +13,14 @@ use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming;
use hyper::client::conn::http1::Builder;
use hyper::header::{self, HOST};
use hyper::server::conn::http1;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use hyper_util::server::graceful::GracefulShutdown;
use response::response;
use tokio::net::TcpListener;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, trace, warn};
use crate::tunnel::{Registry, TunnelAccess};
@ -28,6 +35,49 @@ impl Service {
pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
Self { registry, auth }
}
pub async fn handle_connection(
&self,
listener: &TcpListener,
graceful_shutdown: &GracefulShutdown,
) -> std::io::Result<()> {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let connection = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, self.clone());
let connection = graceful_shutdown.watch(connection);
tokio::spawn(async move {
if let Err(err) = connection.await {
error!("Failed to serve connection: {err:?}");
}
});
Ok(())
}
pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
let graceful_shutdown = GracefulShutdown::new();
loop {
select! {
res = self.handle_connection(&listener, &graceful_shutdown) => {
if let Err(err) = res {
error!("Failed to accept connection: {err}")
}
}
_ = token.cancelled() => {
debug!("Graceful shutdown");
break;
}
}
}
graceful_shutdown.shutdown().await;
}
}
impl hyper::service::Service<Request<Incoming>> for Service {
@ -135,14 +185,15 @@ impl hyper::service::Service<Request<Incoming>> for Service {
.handshake(io)
.await?;
tokio::spawn(async move {
let conn = async {
if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}");
}
});
};
let resp = sender.send_request(req).await?;
Ok(resp.map(|b| b.boxed()))
let (resp, _) = join!(sender.send_request(req), conn).await;
Ok(resp?.map(|b| b.boxed()))
})
}
}