14 Commits

Author SHA1 Message Date
5a7652f3a4 Make ldap search filter configurable
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m8s
kustomization/siranga/3850ce12 reconciliation succeeded
2025-04-22 00:42:56 +02:00
95ad229077 Get bind_dn from secret 2025-04-22 00:42:33 +02:00
e9673211c1 Added liveness probe
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 5m57s
2025-04-21 03:22:34 +02:00
f0bf60c78a Use named container ports 2025-04-21 02:36:59 +02:00
8cafe2b3ca Added support for upgrade requests
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m9s
2025-04-21 02:21:22 +02:00
ed7770f792 Fixed spelling of shutdown during forceful shutdown 2025-04-21 02:21:22 +02:00
dc1f75aee3 Close any remaining connections once the tui exits 2025-04-21 02:21:22 +02:00
0fe043acb5 Revert "Use store instead of fetch_add for atomics"
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m1s
This reverts commit d4bd0ef1ca.
2025-04-21 02:21:18 +02:00
878df8da40 Start graceful shutdown on SIGTERM
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 5m32s
2025-04-20 00:58:18 +02:00
27f6119905 Second ctrl-c forces application to stop directly
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 5m21s
2025-04-20 00:26:23 +02:00
c7b0cfc888 Gracefully shutdown if LDAP connection is lost 2025-04-20 00:24:32 +02:00
7851d6bb12 Implemented more graceful shutdown
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 7m27s
2025-04-20 00:14:24 +02:00
49fd6d8a3a Added suggestion to enable quiet mode for ssh client 2025-04-20 00:14:23 +02:00
ca742fe332 Updated authelia acl helper 2025-04-20 00:14:20 +02:00
12 changed files with 497 additions and 71 deletions

120
Cargo.lock generated
View File

@@ -173,6 +173,60 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "axum"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de45108900e1f9b9242f7f2e254aa3e2c029c921c258fe9e6b4217eeebd54288"
dependencies = [
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.71" version = "0.3.71"
@@ -1107,6 +1161,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.2" version = "0.15.2"
@@ -1473,7 +1533,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown 0.15.2",
] ]
[[package]] [[package]]
@@ -1624,6 +1684,16 @@ dependencies = [
"url", "url",
] ]
[[package]]
name = "leon"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42a865ffec5587961f5afc6d365bccb304f4feaa1928f4fe94c91c9d210d7310"
dependencies = [
"miette",
"thiserror 2.0.12",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.172" version = "0.2.172"
@@ -1682,7 +1752,7 @@ version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [ dependencies = [
"hashbrown", "hashbrown 0.15.2",
] ]
[[package]] [[package]]
@@ -1694,6 +1764,12 @@ dependencies = [
"regex-automata 0.1.10", "regex-automata 0.1.10",
] ]
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]] [[package]]
name = "md5" name = "md5"
version = "0.7.0" version = "0.7.0"
@@ -1706,6 +1782,29 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "miette"
version = "7.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a955165f87b37fd1862df2a59547ac542c77ef6d17c666f619d1ad22dd89484"
dependencies = [
"cfg-if",
"miette-derive",
"thiserror 1.0.69",
"unicode-width 0.1.14",
]
[[package]]
name = "miette-derive"
version = "7.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf45bf44ab49be92fd1227a3be6fc6f617f1a337c06af54981048574d8783147"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.17" version = "0.3.17"
@@ -2763,6 +2862,16 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
@@ -2856,6 +2965,7 @@ dependencies = [
name = "siranga" name = "siranga"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"axum",
"bytes", "bytes",
"clap", "clap",
"clio", "clio",
@@ -2868,6 +2978,7 @@ dependencies = [
"hyper", "hyper",
"hyper-util", "hyper-util",
"ldap3", "ldap3",
"leon",
"pin-project-lite", "pin-project-lite",
"rand 0.8.5", "rand 0.8.5",
"ratatui", "ratatui",
@@ -2875,6 +2986,7 @@ dependencies = [
"russh", "russh",
"thiserror 2.0.12", "thiserror 2.0.12",
"tokio", "tokio",
"tokio-util",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"unicode-width 0.2.0", "unicode-width 0.2.0",
@@ -3205,6 +3317,8 @@ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"futures-util",
"hashbrown 0.14.5",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]
@@ -3222,6 +3336,7 @@ dependencies = [
"tokio", "tokio",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
@@ -3242,6 +3357,7 @@ version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [ dependencies = [
"log",
"pin-project-lite", "pin-project-lite",
"tracing-attributes", "tracing-attributes",
"tracing-core", "tracing-core",

View File

@@ -5,6 +5,7 @@ default-run = "siranga"
license = "AGPL-3.0-only" license = "AGPL-3.0-only"
[dependencies] [dependencies]
axum = "0.8.3"
bytes = "1.10.1" bytes = "1.10.1"
clap = { version = "4.5.35", features = ["derive"] } clap = { version = "4.5.35", features = ["derive"] }
clio = { version = "0.3.5", features = ["clap-parse"] } clio = { version = "0.3.5", features = ["clap-parse"] }
@@ -17,6 +18,7 @@ http-body-util = { version = "0.1.3", features = ["full"] }
hyper = { version = "1.6.0", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] }
hyper-util = { version = "0.1.11", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] }
ldap3 = "0.11.5" ldap3 = "0.11.5"
leon = "3.0.2"
pin-project-lite = "0.2.16" pin-project-lite = "0.2.16"
rand = "0.8.5" rand = "0.8.5"
ratatui = { version = "0.29.0", features = ["unstable-backend-writer"] } ratatui = { version = "0.29.0", features = ["unstable-backend-writer"] }
@@ -24,6 +26,7 @@ reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1" russh = "0.51.1"
thiserror = "2.0.12" thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] } tokio = { version = "1.44.2", features = ["full"] }
tokio-util = { version = "0.7.14", features = ["rt"] }
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
unicode-width = "0.2.0" unicode-width = "0.2.0"

View File

@@ -31,8 +31,12 @@ spec:
cpu: 50m cpu: 50m
memory: 100Mi memory: 100Mi
ports: ports:
- containerPort: 3000 - name: ssh
- containerPort: 2222 containerPort: 2222
- name: http
containerPort: 3000
- name: metrics
containerPort: 4000
volumeMounts: volumeMounts:
- name: credentials - name: credentials
readOnly: true readOnly: true
@@ -51,12 +55,23 @@ spec:
value: ldap://lldap.lldap.svc.cluster.local:3890 value: ldap://lldap.lldap.svc.cluster.local:3890
- name: LDAP_BASE - name: LDAP_BASE
value: ou=people,dc=huizinga,dc=dev value: ou=people,dc=huizinga,dc=dev
- name: LDAP_SEARCH_FILTER
value: (uid={username})
- name: LDAP_BIND_DN - name: LDAP_BIND_DN
value: uid=siranga.siranga,ou=people,dc=huizinga,dc=dev valueFrom:
secretKeyRef:
name: siranga-lldap-credentials
key: bind_dn
- name: LDAP_PASSWORD_FILE - name: LDAP_PASSWORD_FILE
value: /secrets/credentials/password value: /secrets/credentials/password
- name: PRIVATE_KEY_FILE - name: PRIVATE_KEY_FILE
value: /secrets/key/private.pem value: /secrets/key/private.pem
livenessProbe:
httpGet:
path: /health
port: metrics
initialDelaySeconds: 3
periodSeconds: 3
volumes: volumes:
- name: credentials - name: credentials
secret: secret:

View File

@@ -6,6 +6,7 @@ spec:
ports: ports:
- name: http - name: http
port: 3000 port: 3000
targetPort: http
selector: selector:
app: siranga app: siranga
--- ---
@@ -20,6 +21,6 @@ spec:
ports: ports:
- name: ssh - name: ssh
port: 22 port: 22
targetPort: 2222 targetPort: ssh
selector: selector:
app: siranga app: siranga

View File

@@ -19,15 +19,15 @@ pub struct Stats {
impl Stats { impl Stats {
pub fn add_connection(&self) { pub fn add_connection(&self) {
self.connections.store(1, Ordering::Relaxed); self.connections.fetch_add(1, Ordering::Relaxed);
} }
pub fn add_rx_bytes(&self, n: usize) { pub fn add_rx_bytes(&self, n: usize) {
self.rx.store(n, Ordering::Relaxed); self.rx.fetch_add(n, Ordering::Relaxed);
} }
pub fn add_tx_bytes(&self, n: usize) { pub fn add_tx_bytes(&self, n: usize) {
self.tx.store(n, Ordering::Relaxed); self.tx.fetch_add(n, Ordering::Relaxed);
} }
pub fn connections(&self) -> usize { pub fn connections(&self) -> usize {

View File

@@ -1,10 +1,16 @@
use ldap3::{LdapConnAsync, SearchEntry}; use ldap3::{LdapConnAsync, SearchEntry};
use leon::{Template, vals};
use russh::keys::PublicKey; use russh::keys::PublicKey;
use tokio::select;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Ldap { pub struct Ldap {
base: String, base: String,
ldap: ldap3::Ldap, ldap: ldap3::Ldap,
search_filter: String,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -17,16 +23,24 @@ pub enum LdapError {
MissingEnvironmentVariable(&'static str), MissingEnvironmentVariable(&'static str),
#[error("Could not read password file: {0}")] #[error("Could not read password file: {0}")]
CouldNotReadPasswordFile(#[from] std::io::Error), CouldNotReadPasswordFile(#[from] std::io::Error),
#[error("Failed to parse search filter: {0}")]
FailedToParseSearchFilter(#[from] leon::ParseError),
#[error("Failed to render search filter: {0}")]
FailedToRenderSearchFilter(#[from] leon::RenderError),
} }
impl Ldap { impl Ldap {
pub async fn start_from_env() -> Result<Ldap, LdapError> { pub async fn start_from_env(
token: CancellationToken,
) -> Result<(Ldap, JoinHandle<()>), LdapError> {
let address = std::env::var("LDAP_ADDRESS") let address = std::env::var("LDAP_ADDRESS")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?; .map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_ADDRESS"))?;
let base = std::env::var("LDAP_BASE") let base = std::env::var("LDAP_BASE")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_BASE"))?; .map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_BASE"))?;
let bind_dn = std::env::var("LDAP_BIND_DN") let bind_dn = std::env::var("LDAP_BIND_DN")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_BIND_DN"))?; .map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_BIND_DN"))?;
let search_filter = std::env::var("LDAP_SEARCH_FILTER")
.map_err(|_| LdapError::MissingEnvironmentVariable("LDAP_SEARCH_FILTER"))?;
let password = std::env::var("LDAP_PASSWORD_FILE").map_or_else( let password = std::env::var("LDAP_PASSWORD_FILE").map_or_else(
|_| { |_| {
@@ -41,24 +55,57 @@ impl Ldap {
)?; )?;
let (conn, mut ldap) = LdapConnAsync::new(&address).await?; let (conn, mut ldap) = LdapConnAsync::new(&address).await?;
ldap3::drive!(conn); let handle = tokio::spawn(async move {
select! {
res = conn.drive() => {
if let Err(err) = res {
error!("LDAP connection error: {}", err);
} else {
error!("LDAP connection lost");
token.cancel();
}
}
_ = token.cancelled() => {
debug!("Graceful shutdown");
}
}
});
ldap.simple_bind(&bind_dn, &password).await?.success()?; ldap.simple_bind(&bind_dn, &password).await?.success()?;
Ok(Self { base, ldap }) Ok((
Self {
base,
ldap,
search_filter,
},
handle,
))
} }
pub async fn get_ssh_keys( pub async fn get_ssh_keys(
&mut self, &mut self,
user: impl AsRef<str>, user: impl AsRef<str>,
) -> Result<Vec<PublicKey>, LdapError> { ) -> Result<Vec<PublicKey>, LdapError> {
let search_filter = Template::parse(&self.search_filter)?;
let search_filter = search_filter.render(&&vals(|key| {
if key == "username" {
Some(user.as_ref().to_string().into())
} else {
None
}
}))?;
debug!("search_filter = {search_filter}");
Ok(self Ok(self
.ldap .ldap
.search( .search(
&self.base, &self.base,
ldap3::Scope::Subtree, ldap3::Scope::Subtree,
// TODO: Make this not hardcoded // TODO: Make this not hardcoded
&format!("(uid={})", user.as_ref()), &search_filter,
vec!["sshkeys"], vec!["sshkeys"],
) )
.await? .await?

View File

@@ -1,5 +1,6 @@
#![feature(let_chains)] #![feature(let_chains)]
#![feature(iter_intersperse)] #![feature(iter_intersperse)]
#![feature(future_join)]
mod helper; mod helper;
mod io; mod io;
pub mod ldap; pub mod ldap;

View File

@@ -1,10 +1,13 @@
#![feature(future_join)]
use std::future::join;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::Path;
use std::time::Duration;
use axum::routing::get;
use axum::{Json, Router};
use color_eyre::eyre::Context; use color_eyre::eyre::Context;
use dotenvy::dotenv; use dotenvy::dotenv;
use hyper::server::conn::http1::{self};
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use siranga::VERSION; use siranga::VERSION;
use siranga::ldap::Ldap; use siranga::ldap::Ldap;
@@ -12,11 +15,51 @@ use siranga::ssh::Server;
use siranga::tunnel::Registry; use siranga::tunnel::Registry;
use siranga::web::{ForwardAuth, Service}; use siranga::web::{ForwardAuth, Service};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{error, info, warn}; use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
#[cfg(unix)]
async fn sigterm() {
use tokio::signal::unix::SignalKind;
let mut sigterm =
tokio::signal::unix::signal(SignalKind::terminate()).expect("should be able to initialize");
sigterm.recv().await;
}
#[cfg(not(unix))]
async fn sigterm() {
std::future::pending::<()>().await;
}
async fn shutdown_task(token: CancellationToken) {
select! {
_ = tokio::signal::ctrl_c() => {
debug!("Received SIGINT");
}
_ = sigterm() => {
debug!("Received SIGTERM");
}
_ = token.cancelled() => {
debug!("Application called for graceful shutdown");
}
}
info!("Starting graceful shutdown");
token.cancel();
select! {
_ = tokio::time::sleep(Duration::from_secs(5)) => {}
_ = tokio::signal::ctrl_c() => {}
}
}
async fn axum_graceful_shutdown(token: CancellationToken) {
token.cancelled().await;
}
#[tokio::main] #[tokio::main]
async fn main() -> color_eyre::Result<()> { async fn main() -> color_eyre::Result<()> {
color_eyre::install()?; color_eyre::install()?;
@@ -48,46 +91,56 @@ async fn main() -> color_eyre::Result<()> {
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)? russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?
}; };
let http_port = std::env::var("HTTP_PORT")
.map(|port| port.parse().wrap_err_with(|| format!("HTTP_PORT={port}")))
.unwrap_or(Ok(3000))?;
let ssh_port = std::env::var("SSH_PORT") let ssh_port = std::env::var("SSH_PORT")
.map(|port| port.parse().wrap_err_with(|| format!("SSH_PORT={port}"))) .map(|port| port.parse().wrap_err_with(|| format!("SSH_PORT={port}")))
.unwrap_or(Ok(2222))?; .unwrap_or(Ok(2222))?;
let http_port = std::env::var("HTTP_PORT")
.map(|port| port.parse().wrap_err_with(|| format!("HTTP_PORT={port}")))
.unwrap_or(Ok(3000))?;
let metrics_port = std::env::var("METRICS_PORT")
.map(|port| {
port.parse()
.wrap_err_with(|| format!("METRICS_PORT={port}"))
})
.unwrap_or(Ok(4000))?;
let domain = let domain =
std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}")); std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{http_port}"));
let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?; let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?;
let ldap = Ldap::start_from_env().await?;
let registry = Registry::new(domain); let registry = Registry::new(domain);
let mut ssh = Server::new(ldap, registry.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); let token = CancellationToken::new();
tokio::spawn(async move { ssh.run(key, addr).await });
info!("SSH is available on {addr}"); let (ldap, ldap_handle) = Ldap::start_from_env(token.clone()).await?;
let ssh = Server::new(ldap, registry.clone(), token.clone());
let ssh_addr = SocketAddr::from(([0, 0, 0, 0], ssh_port));
let ssh_task = ssh.run(key, ssh_addr);
info!("SSH is available on {ssh_addr}");
let auth = ForwardAuth::new(authz_address); let auth = ForwardAuth::new(authz_address);
let service = Service::new(registry, auth); let service = Service::new(registry, auth);
let addr = SocketAddr::from(([0, 0, 0, 0], http_port)); let http_addr = SocketAddr::from(([0, 0, 0, 0], http_port));
let listener = TcpListener::bind(addr).await?; let http_listener = TcpListener::bind(http_addr).await?;
info!("HTTP is available on {addr}"); let http_task = service.serve(http_listener, token.clone());
info!("HTTP is available on {http_addr}");
// TODO: Graceful shutdown let metrics_app = Router::new().route("/health", get(async || Json("healthy")));
loop { let metrics_addr = SocketAddr::from(([0, 0, 0, 0], metrics_port));
let (stream, _) = listener.accept().await?; let metrics_listener = TcpListener::bind(metrics_addr).await?;
let io = TokioIo::new(stream); let metrics = axum::serve(metrics_listener, metrics_app)
.with_graceful_shutdown(axum_graceful_shutdown(token.clone()));
info!("Metrics are available on {http_addr}");
let service = service.clone(); select! {
tokio::spawn(async move { _ = join!(ldap_handle, ssh_task, http_task, metrics.into_future()) => {
if let Err(err) = http1::Builder::new() info!("Shutdown gracefully");
.preserve_header_case(true) }
.title_case_headers(true) _ = shutdown_task(token.clone()) => {
.serve_connection(io, service) error!("Failed to shutdown gracefully");
.with_upgrades() }
.await };
{
error!("Failed to serve connection: {err:?}"); Ok(())
}
});
}
} }

View File

@@ -8,8 +8,10 @@ use ratatui::{Terminal, TerminalOptions, Viewport};
use russh::ChannelId; use russh::ChannelId;
use russh::keys::ssh_key::PublicKey; use russh::keys::ssh_key::PublicKey;
use russh::server::{Auth, Msg, Session}; use russh::server::{Auth, Msg, Session};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use super::renderer::Renderer;
use crate::VERSION; use crate::VERSION;
use crate::io::{Input, TerminalHandle}; use crate::io::{Input, TerminalHandle};
use crate::ldap::{Ldap, LdapError}; use crate::ldap::{Ldap, LdapError};
@@ -62,7 +64,7 @@ pub struct Handler {
} }
impl Handler { impl Handler {
pub fn new(ldap: Ldap, registry: Registry) -> Self { pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Self { Self {
ldap, ldap,
registry, registry,
@@ -70,7 +72,7 @@ impl Handler {
user: None, user: None,
pty_channel: None, pty_channel: None,
renderer: Default::default(), renderer: Renderer::new(token),
selected: None, selected: None,
rename_input: None, rename_input: None,
} }
@@ -328,6 +330,26 @@ impl russh::server::Handler for Handler {
Ok(session.channel_success(channel)?) Ok(session.channel_success(channel)?)
} }
async fn channel_close(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(pty_channel) = self.pty_channel
&& pty_channel == channel
{
debug!("Pty channel closed");
session.disconnect(
russh::Disconnect::ByApplication,
"Remaining active connections have been closed",
"EN",
)?;
}
Ok(())
}
async fn tcpip_forward( async fn tcpip_forward(
&mut self, &mut self,
address: &str, address: &str,

View File

@@ -11,7 +11,9 @@ use russh::MethodKind;
use russh::keys::PrivateKey; use russh::keys::PrivateKey;
use russh::server::Server as _; use russh::server::Server as _;
use tokio::net::ToSocketAddrs; use tokio::net::ToSocketAddrs;
use tracing::{debug, warn}; use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::ldap::Ldap; use crate::ldap::Ldap;
use crate::tunnel::Registry; use crate::tunnel::Registry;
@@ -19,18 +21,30 @@ use crate::tunnel::Registry;
pub struct Server { pub struct Server {
ldap: Ldap, ldap: Ldap,
registry: Registry, registry: Registry,
token: CancellationToken,
}
async fn graceful_shutdown(token: CancellationToken) {
token.cancelled().await;
let duration = 1;
// All pty sessions will close once the token is cancelled, but to properly allow the sessions
// to close the ssh server still needs to be driven, so we let it run a little bit longer.
// TODO: Figure out a way to wait for all connections to be closed, would require also closing
// non-pty sessions somehow
debug!("Waiting for {duration}s before stopping");
tokio::time::sleep(Duration::from_secs(duration)).await;
} }
impl Server { impl Server {
pub fn new(ldap: Ldap, registry: Registry) -> Self { pub fn new(ldap: Ldap, registry: Registry, token: CancellationToken) -> Self {
Server { ldap, registry } Server {
ldap,
registry,
token,
}
} }
pub fn run( pub async fn run(mut self, key: PrivateKey, addr: impl ToSocketAddrs + Send + std::fmt::Debug) {
&mut self,
key: PrivateKey,
addr: impl ToSocketAddrs + Send + std::fmt::Debug,
) -> impl Future<Output = Result<(), std::io::Error>> + Send {
let config = russh::server::Config { let config = russh::server::Config {
inactivity_timeout: Some(Duration::from_secs(3600)), inactivity_timeout: Some(Duration::from_secs(3600)),
auth_rejection_time: Duration::from_secs(1), auth_rejection_time: Duration::from_secs(1),
@@ -47,7 +61,17 @@ impl Server {
debug!(?addr, "Running ssh"); debug!(?addr, "Running ssh");
async move { self.run_on_address(config, addr).await } let token = self.token.clone();
select! {
res = self.run_on_address(config, addr) => {
if let Err(err) = res {
error!("SSH Server error: {err}");
}
}
_ = graceful_shutdown(token) => {
debug!("Graceful shutdown");
}
}
} }
} }
@@ -55,7 +79,7 @@ impl russh::server::Server for Server {
type Handler = Handler; type Handler = Handler;
fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler { fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
Handler::new(self.ldap.clone(), self.registry.clone()) Handler::new(self.ldap.clone(), self.registry.clone(), self.token.clone())
} }
fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) { fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) {

View File

@@ -14,7 +14,8 @@ use ratatui::widgets::{
use ratatui::{Frame, Terminal}; use ratatui::{Frame, Terminal};
use tokio::select; use tokio::select;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tracing::error; use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use unicode_width::UnicodeWidthStr; use unicode_width::UnicodeWidthStr;
use crate::VERSION; use crate::VERSION;
@@ -36,6 +37,8 @@ struct RendererInner {
rows: Vec<TunnelRow>, rows: Vec<TunnelRow>,
input: Option<String>, input: Option<String>,
rx: UnboundedReceiver<Message>, rx: UnboundedReceiver<Message>,
token: CancellationToken,
} }
fn compute_widths(rows: &Vec<Vec<Span<'static>>>) -> Vec<u16> { fn compute_widths(rows: &Vec<Vec<Span<'static>>>) -> Vec<u16> {
@@ -75,12 +78,13 @@ fn compute_column_skip(
} }
impl RendererInner { impl RendererInner {
fn new(rx: UnboundedReceiver<Message>) -> Self { fn new(rx: UnboundedReceiver<Message>, token: CancellationToken) -> Self {
Self { Self {
state: Default::default(), state: Default::default(),
rows: Default::default(), rows: Default::default(),
input: None, input: None,
rx, rx,
token,
} }
} }
@@ -303,6 +307,10 @@ impl RendererInner {
self.render(frame); self.render(frame);
})?; })?;
} }
_ = self.token.cancelled() => {
debug!("Graceful shutdown");
break;
}
} }
} }
@@ -310,16 +318,24 @@ impl RendererInner {
} }
} }
#[derive(Debug, Default, Clone)] #[derive(Debug, Clone)]
pub struct Renderer { pub struct Renderer {
tx: Option<UnboundedSender<Message>>, tx: Option<UnboundedSender<Message>>,
token: CancellationToken,
} }
impl Renderer { impl Renderer {
pub fn new(token: CancellationToken) -> Self {
Self {
tx: Default::default(),
token,
}
}
pub fn start(&mut self, terminal: Terminal<CrosstermBackend<TerminalHandle>>) { pub fn start(&mut self, terminal: Terminal<CrosstermBackend<TerminalHandle>>) {
let (tx, rx) = unbounded_channel(); let (tx, rx) = unbounded_channel();
let mut inner = RendererInner::new(rx); let mut inner = RendererInner::new(rx, self.token.clone());
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = inner.start(terminal).await { if let Err(err) = inner.start(terminal).await {

View File

@@ -10,10 +10,14 @@ use bytes::Bytes;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt as _, Empty}; use http_body_util::{BodyExt as _, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::client::conn::http1::Builder; use hyper::header::{self, HOST, UPGRADE};
use hyper::header::{self, HOST}; use hyper::{Request, Response, StatusCode, client, server};
use hyper::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo;
use response::response; use response::response;
use tokio::net::TcpListener;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
use crate::tunnel::{Registry, TunnelAccess}; use crate::tunnel::{Registry, TunnelAccess};
@@ -22,11 +26,83 @@ use crate::tunnel::{Registry, TunnelAccess};
pub struct Service { pub struct Service {
registry: Registry, registry: Registry,
auth: ForwardAuth, auth: ForwardAuth,
task_tracker: TaskTracker,
}
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
fn copy_request_parts<T>(req: Request<T>) -> (Request<T>, Request<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts.clone(), body);
let forwarded_req = Request::from_parts(parts, empty());
(req, forwarded_req)
}
fn copy_response_parts<T>(
resp: Response<T>,
) -> (Response<T>, Response<BoxBody<Bytes, hyper::Error>>) {
let (parts, body) = resp.into_parts();
let resp = Response::from_parts(parts.clone(), body);
let forwarded_resp = Response::from_parts(parts, empty());
(resp, forwarded_resp)
} }
impl Service { impl Service {
pub fn new(registry: Registry, auth: ForwardAuth) -> Self { pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
Self { registry, auth } Self {
registry,
auth,
task_tracker: Default::default(),
}
}
pub async fn handle_connection(&self, listener: &TcpListener) -> std::io::Result<()> {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let connection = server::conn::http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, self.clone())
.with_upgrades();
self.task_tracker.spawn(async move {
if let Err(err) = connection.await {
error!("Failed to serve connection: {err:?}");
}
});
Ok(())
}
pub async fn serve(self, listener: TcpListener, token: CancellationToken) {
loop {
select! {
res = self.handle_connection(&listener) => {
if let Err(err) = res {
error!("Failed to accept connection: {err}")
}
}
_ = token.cancelled() => {
break;
}
}
}
debug!(
"Waiting for {} connections to close",
self.task_tracker.len()
);
self.task_tracker.close();
self.task_tracker.wait().await;
debug!("Graceful shutdown");
} }
} }
@@ -59,10 +135,9 @@ impl hyper::service::Service<Request<Incoming>> for Service {
debug!(authority, "Tunnel request"); debug!(authority, "Tunnel request");
let registry = self.registry.clone(); let s = self.clone();
let auth = self.auth.clone();
Box::pin(async move { Box::pin(async move {
let Some(entry) = registry.get(&authority).await else { let Some(entry) = s.registry.get(&authority).await else {
debug!(tunnel = authority, "Unknown tunnel"); debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
@@ -70,7 +145,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
}; };
if !entry.is_public().await { if !entry.is_public().await {
let user = match auth.check(req.method(), req.headers()).await { let user = match s.auth.check(req.method(), req.headers()).await {
Ok(AuthStatus::Authenticated(user)) => user, Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => { Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder() let resp = Response::builder()
@@ -129,19 +204,72 @@ impl hyper::service::Service<Request<Incoming>> for Service {
} }
}; };
let (mut sender, conn) = Builder::new() let (mut sender, conn) = client::conn::http1::Builder::new()
.preserve_header_case(true) .preserve_header_case(true)
.title_case_headers(true) .title_case_headers(true)
.handshake(io) .handshake(io)
.await?; .await?;
tokio::spawn(async move { let conn = conn.with_upgrades();
s.task_tracker.spawn(async move {
if let Err(err) = conn.await { if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}"); warn!(runnel = authority, "Connection failed: {err}");
} }
}); });
let resp = sender.send_request(req).await?; let (mut req, forwarded_req) = copy_request_parts(req);
let resp = sender.send_request(forwarded_req).await?;
if req.headers().contains_key(UPGRADE)
&& req.headers().get(UPGRADE) == resp.headers().get(UPGRADE)
{
let (mut resp, forwarded_resp) = copy_response_parts(resp);
debug!("UPGRADE established");
match hyper::upgrade::on(&mut resp).await {
Ok(upgraded_resp) => {
s.task_tracker.spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded_req) => {
let mut upgraded_req = TokioIo::new(upgraded_req);
let mut upgraded_resp = TokioIo::new(upgraded_resp);
match tokio::io::copy_bidirectional(
&mut upgraded_req,
&mut upgraded_resp,
)
.await
{
Ok((rx, tx)) => {
debug!(
"Received {rx} bytes and send {tx} bytes over upgraded tunnel"
);
}
Err(err) => {
// Likely due to channel being closed
// TODO: Show warning if not channel closed, otherwise ignore
debug!("Upgraded connection error: {err:?}");
}
}
}
Err(err) => {
error!("Failed to upgrade: {err}");
}
}
});
return Ok(forwarded_resp.map(|b| b.boxed()));
}
Err(err) => {
error!("Failed to upgrade req: {err}");
return Ok(response(StatusCode::BAD_REQUEST, "Failed to upgrade"));
}
}
}
trace!("{resp:#?}");
Ok(resp.map(|b| b.boxed())) Ok(resp.map(|b| b.boxed()))
}) })
} }