From 3ada40d4ae2cf05458ca9baa3c2cd948b7e481cc Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Wed, 16 Apr 2025 01:53:24 +0200 Subject: [PATCH] Massive refactor --- src/handler.rs | 40 ++--- src/lib.rs | 5 +- src/main.rs | 12 +- src/server.rs | 14 +- src/tunnel.rs | 325 +++++++++-------------------------------- src/tunnel/registry.rs | 261 +++++++++++++++++++++++++++++++++ src/tunnel/tui.rs | 12 +- 7 files changed, 363 insertions(+), 306 deletions(-) create mode 100644 src/tunnel/registry.rs diff --git a/src/handler.rs b/src/handler.rs index 03b8d2f..6a34f0d 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -15,7 +15,7 @@ use crate::{ io::TerminalHandle, ldap::LdapError, tui::Renderer, - tunnel::{Tunnel, TunnelAccess, Tunnels}, + tunnel::{Registry, Tunnel, TunnelAccess}, }; #[derive(Debug, thiserror::Error)] @@ -31,7 +31,7 @@ pub enum HandlerError { pub struct Handler { ldap: Ldap, - all_tunnels: Tunnels, + registry: Registry, tunnels: Vec, user: Option, @@ -45,10 +45,10 @@ pub struct Handler { } impl Handler { - pub fn new(ldap: Ldap, all_tunnels: Tunnels) -> Self { + pub fn new(ldap: Ldap, registry: Registry) -> Self { Self { ldap, - all_tunnels, + registry, tunnels: Default::default(), user: None, pty_channel: None, @@ -136,7 +136,7 @@ impl Handler { && let Some(tunnel) = self.tunnels.get_mut(selected) && let Some(buffer) = self.rename_buffer.take() { - *tunnel = self.all_tunnels.rename_tunnel(tunnel.clone(), buffer).await; + tunnel.set_name(buffer).await; } else { warn!("Trying to rename invalid tunnel"); } @@ -177,7 +177,7 @@ impl Handler { return Ok(false); }; - *tunnel = self.all_tunnels.retry_tunnel(tunnel.clone()).await; + tunnel.retry().await; } Input::Char('r') => { if self.selected.is_some() { @@ -195,8 +195,7 @@ impl Handler { return Ok(false); } - let tunnel = self.tunnels.remove(selected); - self.all_tunnels.remove_tunnel(tunnel).await; + self.tunnels.remove(selected); if self.tunnels.is_empty() { self.selected = None; @@ -359,10 +358,14 @@ impl russh::server::Handler for Handler { return Err(russh::Error::Inconsistent.into()); }; - let tunnel = self - .all_tunnels - .create_tunnel(session.handle(), address, *port, user) - .await; + let tunnel = Tunnel::create( + &mut self.registry, + session.handle(), + address, + *port, + TunnelAccess::Private(user), + ) + .await; self.tunnels.push(tunnel); @@ -421,16 +424,3 @@ impl russh::server::Handler for Handler { Ok(()) } } - -impl Drop for Handler { - fn drop(&mut self) { - let tunnels = self.tunnels.clone(); - let mut all_tunnels = self.all_tunnels.clone(); - - tokio::spawn(async move { - for tunnel in tunnels { - all_tunnels.remove_tunnel(tunnel).await; - } - }); - } -} diff --git a/src/lib.rs b/src/lib.rs index 13073e5..81599ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![feature(let_chains)] mod animals; -mod auth; +pub mod auth; mod cli; mod handler; mod helper; @@ -16,4 +16,5 @@ mod wrapper; pub use ldap::Ldap; pub use server::Server; -pub use tunnel::{Tunnel, Tunnels}; +pub use tunnel::Registry; +pub use tunnel::Tunnel; diff --git a/src/main.rs b/src/main.rs index 9ef2142..65ca9b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,8 +7,8 @@ use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use tokio::net::TcpListener; use tracing::{error, info, warn}; -use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; -use tunnel_rs::{Ldap, Server, Tunnels}; +use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; +use tunnel_rs::{Ldap, Registry, Server, auth::ForwardAuth}; #[tokio::main] async fn main() -> color_eyre::Result<()> { @@ -18,7 +18,10 @@ async fn main() -> color_eyre::Result<()> { let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?; let logger = tracing_subscriber::fmt::layer().compact(); - Registry::default().with(logger).with(env_filter).init(); + tracing_subscriber::Registry::default() + .with(logger) + .with(env_filter) + .init(); let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") { russh::keys::PrivateKey::read_openssh_file(Path::new(&path)) @@ -41,7 +44,8 @@ async fn main() -> color_eyre::Result<()> { let ldap = Ldap::start_from_env().await?; - let tunnels = Tunnels::new(domain, authz_address); + let auth = ForwardAuth::new(authz_address); + let tunnels = Registry::new(domain, auth); let mut ssh = Server::new(ldap, tunnels.clone()); let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); tokio::spawn(async move { ssh.run(key, addr).await }); diff --git a/src/server.rs b/src/server.rs index eee3c8f..c5deaca 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,20 +4,16 @@ use russh::{MethodKind, keys::PrivateKey, server::Server as _}; use tokio::net::ToSocketAddrs; use tracing::{debug, warn}; -use crate::{Ldap, handler::Handler, tunnel::Tunnels}; +use crate::{Ldap, handler::Handler, tunnel::Registry}; pub struct Server { ldap: Ldap, - tunnels: Tunnels, + registry: Registry, } impl Server { - pub fn new(ldap: Ldap, tunnels: Tunnels) -> Self { - Server { ldap, tunnels } - } - - pub fn tunnels(&self) -> Tunnels { - self.tunnels.clone() + pub fn new(ldap: Ldap, registry: Registry) -> Self { + Server { ldap, registry } } pub fn run( @@ -49,7 +45,7 @@ impl russh::server::Server for Server { type Handler = Handler; fn new_client(&mut self, _peer_addr: Option) -> Self::Handler { - Handler::new(self.ldap.clone(), self.tunnels.clone()) + Handler::new(self.ldap.clone(), self.registry.clone()) } fn handle_session_error(&mut self, error: ::Error) { diff --git a/src/tunnel.rs b/src/tunnel.rs index 5113478..f782e80 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,32 +1,15 @@ -use bytes::Bytes; -use http_body_util::{BodyExt, Empty, combinators::BoxBody}; -use hyper::{ - Request, Response, StatusCode, - body::Incoming, - client::conn::http1::Builder, - header::{self, HOST}, - service::Service, -}; -use std::{ - collections::{HashMap, hash_map::Entry}, - ops::Deref, - pin::Pin, - sync::Arc, -}; -use tracing::{debug, error, trace, warn}; +use registry::RegistryEntry; +use std::sync::Arc; +use tracing::trace; use russh::server::Handle; use tokio::sync::RwLock; -use crate::{ - animals::get_animal_name, - auth::{AuthStatus, ForwardAuth}, - helper::response, - stats::Stats, - wrapper::Wrapper, -}; +use crate::{stats::Stats, wrapper::Wrapper}; +mod registry; pub mod tui; +pub use registry::Registry; #[derive(Debug, Clone)] pub enum TunnelAccess { @@ -36,262 +19,84 @@ pub enum TunnelAccess { } #[derive(Debug, Clone)] -pub struct Tunnel { +pub struct TunnelInner { handle: Handle, - name: String, - address: String, - domain: Option, + internal_address: String, port: u32, access: Arc>, stats: Arc, } -impl Tunnel { - pub async fn open_tunnel(&self) -> Result { - trace!(tunnel = self.name, "Opening tunnel"); +impl TunnelInner { + pub async fn open(&self) -> Result { + trace!("Opening tunnel"); self.stats.add_connection(); let channel = self .handle - .channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port) + .channel_open_forwarded_tcpip( + &self.internal_address, + self.port, + &self.internal_address, + self.port, + ) .await?; Ok(Wrapper::new(channel.into_stream(), self.stats.clone())) } +} + +#[derive(Debug)] +pub struct Tunnel { + inner: TunnelInner, + + registry: Registry, + registry_entry: RegistryEntry, +} + +impl Tunnel { + pub async fn create( + registry: &mut Registry, + handle: Handle, + internal_address: impl Into, + port: u32, + access: TunnelAccess, + ) -> Self { + let mut tunnel = Self { + inner: TunnelInner { + handle, + internal_address: internal_address.into(), + port, + access: Arc::new(RwLock::new(access)), + stats: Default::default(), + }, + registry: registry.clone(), + registry_entry: RegistryEntry::new(registry.clone()), + }; + + registry.register(&mut tunnel).await; + + tunnel + } pub async fn set_access(&self, access: TunnelAccess) { - *self.access.write().await = access; + *self.inner.access.write().await = access; } pub async fn is_public(&self) -> bool { - matches!(*self.access.read().await, TunnelAccess::Public) + matches!(*self.inner.access.read().await, TunnelAccess::Public) } - pub fn get_address(&self) -> Option { - self.domain - .clone() - .map(|domain| format!("{}.{domain}", self.name)) - } -} - -#[derive(Debug, Clone)] -pub struct Tunnels { - tunnels: Arc>>, - domain: String, - forward_auth: ForwardAuth, -} - -impl Tunnels { - pub fn new(domain: impl Into, endpoint: impl Into) -> Self { - Self { - tunnels: Arc::new(RwLock::new(HashMap::new())), - domain: domain.into(), - forward_auth: ForwardAuth::new(endpoint), - } - } - - async fn generate_tunnel_name(&mut self, mut tunnel: Tunnel) -> Tunnel { - // NOTE: It is technically possible to become stuck in this loop. - // However, that really only becomes a concern if a (very) high - // number of tunnels is open at the same time. - tunnel.domain = Some(self.domain.clone()); - loop { - tunnel.name = get_animal_name().into(); - if !self - .tunnels - .read() - .await - .contains_key(&tunnel.get_address().expect("domain is set")) - { - break; - } - trace!(tunnel = tunnel.name, "Already in use, picking new name"); - } - - tunnel - } - - pub async fn create_tunnel( - &mut self, - handle: Handle, - name: impl Into, - port: u32, - user: impl Into, - ) -> Tunnel { - let address = name.into(); - let mut tunnel = Tunnel { - handle, - name: address.clone(), - address, - domain: Some(self.domain.clone()), - port, - access: Arc::new(RwLock::new(TunnelAccess::Private(user.into()))), - stats: Default::default(), - }; - - if tunnel.name == "localhost" { - tunnel = self.generate_tunnel_name(tunnel).await; - }; - - self.add_tunnel(tunnel).await - } - - async fn add_tunnel(&mut self, mut tunnel: Tunnel) -> Tunnel { - let address = tunnel.get_address().expect("domain is set"); - if let Entry::Vacant(e) = self.tunnels.write().await.entry(address) { - trace!(tunnel = tunnel.name, "Adding tunnel"); - e.insert(tunnel.clone()); - } else { - trace!("Address already in use"); - tunnel.domain = None - } - - tunnel - } - - pub async fn remove_tunnel(&mut self, mut tunnel: Tunnel) -> Tunnel { - let mut all_tunnels = self.tunnels.write().await; - if let Some(address) = tunnel.get_address() { - trace!(tunnel.name, "Removing tunnel"); - all_tunnels.remove(&address); - } - tunnel.domain = None; - tunnel - } - - pub async fn retry_tunnel(&mut self, tunnel: Tunnel) -> Tunnel { - let mut tunnel = self.remove_tunnel(tunnel).await; - tunnel.domain = Some(self.domain.clone()); - - self.add_tunnel(tunnel).await - } - - pub async fn rename_tunnel(&mut self, tunnel: Tunnel, name: impl Into) -> Tunnel { - let mut tunnel = self.remove_tunnel(tunnel).await; - let name: String = name.into(); - if name.is_empty() { - tunnel = self.generate_tunnel_name(tunnel).await; - } else { - tunnel.domain = Some(self.domain.clone()); - tunnel.name = name; - } - - self.add_tunnel(tunnel).await - } -} - -impl Service> for Tunnels { - type Response = Response>; - type Error = hyper::Error; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - trace!("{:#?}", req); - - let Some(authority) = req - .uri() - .authority() - .as_ref() - .map(|a| a.to_string()) - .or_else(|| { - req.headers() - .get(HOST) - .and_then(|h| h.to_str().ok().map(|s| s.to_owned())) - }) - else { - let resp = response( - StatusCode::BAD_REQUEST, - "Missing or invalid authority or host header", - ); - - return Box::pin(async { Ok(resp) }); - }; - - debug!(tunnel = authority, "Tunnel request"); - - let s = self.clone(); - Box::pin(async move { - let tunnels = s.tunnels.read().await; - let Some(tunnel) = tunnels.get(&authority) else { - debug!(tunnel = authority, "Unknown tunnel"); - let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); - - return Ok(resp); - }; - - if !matches!(tunnel.access.read().await.deref(), TunnelAccess::Public) { - let user = match s.forward_auth.check_auth(req.method(), req.headers()).await { - Ok(AuthStatus::Authenticated(user)) => user, - Ok(AuthStatus::Unauthenticated(location)) => { - let resp = Response::builder() - .status(StatusCode::FOUND) - .header(header::LOCATION, location) - .body( - Empty::new() - // NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error - .map_err(|never| match never {}) - .boxed(), - ) - .expect("configuration should be valid"); - - return Ok(resp); - } - Ok(AuthStatus::Unauthorized) => { - let resp = response( - StatusCode::FORBIDDEN, - "You do not have permission to access this tunnel", - ); - - return Ok(resp); - } - Err(err) => { - error!("Unexpected error during authentication: {err}"); - let resp = response( - StatusCode::FORBIDDEN, - "Unexpected error during authentication", - ); - - return Ok(resp); - } - }; - - trace!("Tunnel is getting accessed by {user:?}"); - - if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() { - if !user.is(owner) { - let resp = response( - StatusCode::FORBIDDEN, - "You do not have permission to access this tunnel", - ); - - return Ok(resp); - } - } - } - - let io = match tunnel.open_tunnel().await { - Ok(io) => io, - Err(err) => { - warn!(tunnel = authority, "Failed to open tunnel: {err}"); - let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); - - return Ok(resp); - } - }; - - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - - tokio::spawn(async move { - if let Err(err) = conn.await { - warn!(runnel = authority, "Connection failed: {err}"); - } - }); - - let resp = sender.send_request(req).await?; - Ok(resp.map(|b| b.boxed())) - }) + pub fn get_address(&self) -> Option<&String> { + self.registry_entry.get_address() + } + + pub async fn set_name(&mut self, name: impl Into) { + let mut registry = self.registry.clone(); + registry.rename(self, name).await; + } + + pub async fn retry(&mut self) { + let mut registry = self.registry.clone(); + registry.register(self).await; } } diff --git a/src/tunnel/registry.rs b/src/tunnel/registry.rs new file mode 100644 index 0000000..e62b3fa --- /dev/null +++ b/src/tunnel/registry.rs @@ -0,0 +1,261 @@ +use std::{ + collections::{HashMap, hash_map::Entry}, + ops::Deref, + pin::Pin, + sync::Arc, +}; + +use bytes::Bytes; +use http_body_util::{BodyExt as _, Empty, combinators::BoxBody}; +use hyper::{ + Request, Response, StatusCode, + body::Incoming, + client::conn::http1::Builder, + header::{self, HOST}, + service::Service, +}; +use tokio::sync::RwLock; +use tracing::{debug, error, trace, warn}; + +use crate::{ + Tunnel, + animals::get_animal_name, + auth::{AuthStatus, ForwardAuth}, + helper::response, + tunnel::TunnelAccess, +}; + +use super::TunnelInner; + +#[derive(Debug)] +pub struct RegistryEntry { + registry: Registry, + name: String, + address: Option, +} + +impl RegistryEntry { + pub fn new(registry: Registry) -> Self { + Self { + registry, + name: Default::default(), + address: Default::default(), + } + } + + pub fn get_address(&self) -> Option<&String> { + self.address.as_ref() + } + + pub fn get_name(&self) -> &str { + &self.name + } +} + +impl Drop for RegistryEntry { + fn drop(&mut self) { + trace!( + name = self.name, + address = self.address, + "Dropping registry entry" + ); + + if let Some(address) = self.address.take() { + let registry = self.registry.clone(); + tokio::spawn(async move { + registry.tunnels.write().await.remove(&address); + }); + } + } +} + +#[derive(Debug, Clone)] +pub struct Registry { + tunnels: Arc>>, + domain: String, + auth: ForwardAuth, +} + +impl Registry { + pub fn new(domain: impl Into, auth: ForwardAuth) -> Self { + Self { + tunnels: Arc::new(RwLock::new(HashMap::new())), + domain: domain.into(), + auth, + } + } + + fn address(&self, name: impl AsRef) -> String { + format!("{}.{}", name.as_ref(), self.domain) + } + + async fn generate_tunnel_name(&self) -> String { + // NOTE: It is technically possible to become stuck in this loop. + // However, that really only becomes a concern if a (very) high + // number of tunnels is open at the same time. + loop { + let name = get_animal_name(); + if !self.tunnels.read().await.contains_key(&self.address(name)) { + break name.into(); + } + trace!(name, "Already in use, picking new name"); + } + } + + pub(super) async fn register(&mut self, tunnel: &mut Tunnel) { + if tunnel.registry_entry.name.is_empty() { + if tunnel.inner.internal_address == "localhost" { + tunnel.registry_entry.name = self.generate_tunnel_name().await; + } else { + tunnel.registry_entry.name = tunnel.inner.internal_address.clone(); + } + } + + trace!( + name = tunnel.registry_entry.name, + "Attempting to register tunnel" + ); + + if tunnel.registry_entry.address.is_some() { + trace!(name = tunnel.registry_entry.name, "Already registered"); + return; + } + + let address = self.address(&tunnel.registry_entry.name); + + if let Entry::Vacant(e) = self.tunnels.write().await.entry(address.clone()) { + tunnel.registry_entry.address = Some(address); + e.insert(tunnel.inner.clone()); + } else { + trace!(name = tunnel.registry_entry.name, "Address already in use"); + tunnel.registry_entry.address = None; + } + } + + pub(super) async fn rename(&mut self, tunnel: &mut Tunnel, name: impl Into) { + trace!(name = tunnel.registry_entry.name, "Renaming tunnel"); + + if let Some(address) = tunnel.registry_entry.address.take() { + self.tunnels.write().await.remove(&address); + } + + tunnel.registry_entry.name = name.into(); + self.register(tunnel).await; + } +} + +impl Service> for Registry { + type Response = Response>; + type Error = hyper::Error; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + trace!("{:#?}", req); + + let Some(authority) = req + .uri() + .authority() + .as_ref() + .map(|a| a.to_string()) + .or_else(|| { + req.headers() + .get(HOST) + .and_then(|h| h.to_str().ok().map(|s| s.to_owned())) + }) + else { + let resp = response( + StatusCode::BAD_REQUEST, + "Missing or invalid authority or host header", + ); + + return Box::pin(async { Ok(resp) }); + }; + + debug!(authority, "Tunnel request"); + + let s = self.clone(); + Box::pin(async move { + let Some(entry) = s.tunnels.read().await.get(&authority).cloned() else { + debug!(tunnel = authority, "Unknown tunnel"); + let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); + + return Ok(resp); + }; + + if !matches!(entry.access.read().await.deref(), TunnelAccess::Public) { + let user = match s.auth.check_auth(req.method(), req.headers()).await { + Ok(AuthStatus::Authenticated(user)) => user, + Ok(AuthStatus::Unauthenticated(location)) => { + let resp = Response::builder() + .status(StatusCode::FOUND) + .header(header::LOCATION, location) + .body( + Empty::new() + // NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error + .map_err(|never| match never {}) + .boxed(), + ) + .expect("configuration should be valid"); + + return Ok(resp); + } + Ok(AuthStatus::Unauthorized) => { + let resp = response( + StatusCode::FORBIDDEN, + "You do not have permission to access this tunnel", + ); + + return Ok(resp); + } + Err(err) => { + error!("Unexpected error during authentication: {err}"); + let resp = response( + StatusCode::FORBIDDEN, + "Unexpected error during authentication", + ); + + return Ok(resp); + } + }; + + trace!("Tunnel is getting accessed by {user:?}"); + + if let TunnelAccess::Private(owner) = entry.access.read().await.deref() { + if !user.is(owner) { + let resp = response( + StatusCode::FORBIDDEN, + "You do not have permission to access this tunnel", + ); + + return Ok(resp); + } + } + } + + let io = match entry.open().await { + Ok(io) => io, + Err(err) => { + warn!(tunnel = authority, "Failed to open tunnel: {err}"); + let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); + + return Ok(resp); + } + }; + + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!(runnel = authority, "Connection failed: {err}"); + } + }); + + let resp = sender.send_request(req).await?; + Ok(resp.map(|b| b.boxed())) + }) + } +} diff --git a/src/tunnel/tui.rs b/src/tunnel/tui.rs index 6c60806..a3771c3 100644 --- a/src/tunnel/tui.rs +++ b/src/tunnel/tui.rs @@ -18,7 +18,7 @@ pub fn header() -> Vec> { } pub async fn to_row(tunnel: &Tunnel) -> Vec> { - let access = match tunnel.access.read().await.deref() { + let access = match tunnel.inner.access.read().await.deref() { TunnelAccess::Private(owner) => owner.clone().yellow(), TunnelAccess::Protected => "PROTECTED".blue(), TunnelAccess::Public => "PUBLIC".green(), @@ -30,12 +30,12 @@ pub async fn to_row(tunnel: &Tunnel) -> Vec> { .unwrap_or("FAILED".red()); vec![ - tunnel.name.clone().into(), + tunnel.registry_entry.get_name().to_owned().into(), access, - tunnel.port.to_string().into(), + tunnel.inner.port.to_string().into(), address, - tunnel.stats.connections().to_string().into(), - tunnel.stats.rx().to_string().into(), - tunnel.stats.tx().to_string().into(), + tunnel.inner.stats.connections().to_string().into(), + tunnel.inner.stats.rx().to_string().into(), + tunnel.inner.stats.tx().to_string().into(), ] }