From a8e3fd9d2a608e1648100dd5602e3e9554460c33 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Tue, 15 Apr 2025 16:20:42 +0200 Subject: [PATCH] Added tunnel stats --- Cargo.lock | 1 + Cargo.toml | 1 + src/lib.rs | 2 + src/tunnel.rs | 50 +++++++++++++++++------ src/tunnel/tui.rs | 6 +++ src/units.rs | 71 ++++++++++++++++++++++++++++++++ src/wrapper.rs | 101 ++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 220 insertions(+), 12 deletions(-) create mode 100644 src/units.rs create mode 100644 src/wrapper.rs diff --git a/Cargo.lock b/Cargo.lock index 50e7612..bd8ec10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3294,6 +3294,7 @@ dependencies = [ "hyper", "hyper-util", "ldap3", + "pin-project-lite", "rand 0.8.5", "ratatui", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 1e4a9e0..b26b138 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ http-body-util = { version = "0.1.3", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] } ldap3 = "0.11.5" +pin-project-lite = "0.2.16" rand = "0.8.5" ratatui = { version = "0.29.0", features = ["unstable-backend-writer"] } reqwest = { version = "0.12.15", features = ["rustls-tls"] } diff --git a/src/lib.rs b/src/lib.rs index 4e05d1a..fbd71f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ mod ldap; mod server; mod tui; mod tunnel; +mod units; +mod wrapper; pub use ldap::Ldap; pub use server::Server; diff --git a/src/tunnel.rs b/src/tunnel.rs index 6f9a55b..382cec1 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -7,25 +7,26 @@ use hyper::{ header::{self, HOST}, service::Service, }; -use hyper_util::rt::TokioIo; use std::{ collections::{HashMap, hash_map::Entry}, ops::Deref, pin::Pin, - sync::Arc, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, }; use tracing::{debug, error, trace, warn}; -use russh::{ - Channel, - server::{Handle, Msg}, -}; +use russh::server::Handle; use tokio::sync::RwLock; use crate::{ animals::get_animal_name, auth::{AuthStatus, ForwardAuth}, helper::response, + units::Unit, + wrapper::Wrapper, }; pub mod tui; @@ -45,14 +46,25 @@ pub struct Tunnel { domain: Option, port: u32, access: Arc>, + connection_count: Arc, + bytes_rx: Arc, + bytes_tx: Arc, } impl Tunnel { - pub async fn open_tunnel(&self) -> Result, russh::Error> { + pub async fn open_tunnel(&self) -> Result { trace!(tunnel = self.name, "Opening tunnel"); - self.handle + self.connection_count.fetch_add(1, Ordering::Relaxed); + let channel = self + .handle .channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port) - .await + .await?; + + Ok(Wrapper::new( + channel.into_stream(), + self.bytes_rx.clone(), + self.bytes_tx.clone(), + )) } pub async fn set_access(&self, access: TunnelAccess) { @@ -68,6 +80,18 @@ impl Tunnel { .clone() .map(|domain| format!("{}.{domain}", self.name)) } + + pub fn get_connections(&self) -> usize { + self.connection_count.load(Ordering::Relaxed) + } + + pub fn get_rx_string(&self) -> String { + Unit::new(self.bytes_rx.load(Ordering::Relaxed), "B").to_string() + } + + pub fn get_tx_string(&self) -> String { + Unit::new(self.bytes_tx.load(Ordering::Relaxed), "B").to_string() + } } #[derive(Debug, Clone)] @@ -122,6 +146,9 @@ impl Tunnels { domain: Some(self.domain.clone()), port, access: Arc::new(RwLock::new(TunnelAccess::Private(user.into()))), + connection_count: Default::default(), + bytes_rx: Default::default(), + bytes_tx: Default::default(), }; if tunnel.name == "localhost" { @@ -264,8 +291,8 @@ impl Service> for Tunnels { } } - let channel = match tunnel.open_tunnel().await { - Ok(channel) => channel, + let io = match tunnel.open_tunnel().await { + Ok(io) => io, Err(err) => { warn!(tunnel = authority, "Failed to open tunnel: {err}"); let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); @@ -273,7 +300,6 @@ impl Service> for Tunnels { return Ok(resp); } }; - let io = TokioIo::new(channel.into_stream()); let (mut sender, conn) = Builder::new() .preserve_header_case(true) diff --git a/src/tunnel/tui.rs b/src/tunnel/tui.rs index 295f051..d086a91 100644 --- a/src/tunnel/tui.rs +++ b/src/tunnel/tui.rs @@ -11,6 +11,9 @@ pub fn header() -> Vec> { "Access".into(), "Port".into(), "Address".into(), + "Conn".into(), + "Rx".into(), + "Tx".into(), ] } @@ -31,5 +34,8 @@ pub async fn to_row(tunnel: &Tunnel) -> Vec> { access, tunnel.port.to_string().into(), address, + tunnel.get_connections().to_string().into(), + tunnel.get_rx_string().into(), + tunnel.get_tx_string().into(), ] } diff --git a/src/units.rs b/src/units.rs new file mode 100644 index 0000000..5e7bb6b --- /dev/null +++ b/src/units.rs @@ -0,0 +1,71 @@ +use std::fmt; + +pub struct Unit { + value: usize, + prefix: UnitPrefix, + unit: String, +} + +impl Unit { + pub fn new(mut value: usize, unit: impl Into) -> Self { + let mut prefix = UnitPrefix::None; + + while value > 10000 { + value /= 1000; + prefix = prefix.next(); + } + + Self { + value, + prefix, + unit: unit.into(), + } + } +} + +impl fmt::Display for Unit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {}{}", self.value, self.prefix, self.unit) + } +} + +enum UnitPrefix { + None, + Kilo, + Mega, + Giga, + Tera, + Peta, + Exa, + Impossible, +} + +impl UnitPrefix { + fn next(self) -> Self { + match self { + UnitPrefix::None => UnitPrefix::Kilo, + UnitPrefix::Kilo => UnitPrefix::Mega, + UnitPrefix::Mega => UnitPrefix::Giga, + UnitPrefix::Giga => UnitPrefix::Tera, + UnitPrefix::Tera => UnitPrefix::Peta, + UnitPrefix::Peta => UnitPrefix::Exa, + UnitPrefix::Exa | UnitPrefix::Impossible => UnitPrefix::Impossible, + } + } +} + +impl fmt::Display for UnitPrefix { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let prefix = match self { + UnitPrefix::None => "", + UnitPrefix::Kilo => "k", + UnitPrefix::Mega => "M", + UnitPrefix::Giga => "G", + UnitPrefix::Tera => "T", + UnitPrefix::Peta => "P", + UnitPrefix::Exa => "E", + UnitPrefix::Impossible => "x", + }; + f.write_str(prefix) + } +} diff --git a/src/wrapper.rs b/src/wrapper.rs new file mode 100644 index 0000000..219772d --- /dev/null +++ b/src/wrapper.rs @@ -0,0 +1,101 @@ +use std::{ + pin::Pin, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + task::{Context, Poll}, +}; + +use pin_project_lite::pin_project; +use russh::{ChannelStream, server::Msg}; + +pin_project! { + pub struct Wrapper { + #[pin] + inner: ChannelStream, + bytes_rx: Arc, + bytes_tx: Arc + } +} + +impl Wrapper { + pub fn new( + inner: ChannelStream, + bytes_rx: Arc, + bytes_tx: Arc, + ) -> Self { + Self { + inner, + bytes_rx, + bytes_tx, + } + } +} + +impl hyper::rt::Read for Wrapper { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let project = self.project(); + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(project.inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + project.bytes_tx.fetch_add(n, Ordering::Relaxed); + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for Wrapper { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let project = self.project(); + tokio::io::AsyncWrite::poll_write(project.inner, cx, buf).map(|res| { + res.inspect(|n| { + project.bytes_rx.fetch_add(*n, Ordering::Relaxed); + }) + }) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let project = self.project(); + tokio::io::AsyncWrite::poll_write_vectored(project.inner, cx, bufs).map(|res| { + res.inspect(|n| { + project.bytes_rx.fetch_add(*n, Ordering::Relaxed); + }) + }) + } +}