Massive refactor
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 6m12s

This commit is contained in:
Dreaded_X 2025-04-16 01:53:24 +02:00
parent f75726b93a
commit 3ada40d4ae
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
7 changed files with 363 additions and 306 deletions

View File

@ -15,7 +15,7 @@ use crate::{
io::TerminalHandle,
ldap::LdapError,
tui::Renderer,
tunnel::{Tunnel, TunnelAccess, Tunnels},
tunnel::{Registry, Tunnel, TunnelAccess},
};
#[derive(Debug, thiserror::Error)]
@ -31,7 +31,7 @@ pub enum HandlerError {
pub struct Handler {
ldap: Ldap,
all_tunnels: Tunnels,
registry: Registry,
tunnels: Vec<Tunnel>,
user: Option<String>,
@ -45,10 +45,10 @@ pub struct Handler {
}
impl Handler {
pub fn new(ldap: Ldap, all_tunnels: Tunnels) -> Self {
pub fn new(ldap: Ldap, registry: Registry) -> Self {
Self {
ldap,
all_tunnels,
registry,
tunnels: Default::default(),
user: None,
pty_channel: None,
@ -136,7 +136,7 @@ impl Handler {
&& let Some(tunnel) = self.tunnels.get_mut(selected)
&& let Some(buffer) = self.rename_buffer.take()
{
*tunnel = self.all_tunnels.rename_tunnel(tunnel.clone(), buffer).await;
tunnel.set_name(buffer).await;
} else {
warn!("Trying to rename invalid tunnel");
}
@ -177,7 +177,7 @@ impl Handler {
return Ok(false);
};
*tunnel = self.all_tunnels.retry_tunnel(tunnel.clone()).await;
tunnel.retry().await;
}
Input::Char('r') => {
if self.selected.is_some() {
@ -195,8 +195,7 @@ impl Handler {
return Ok(false);
}
let tunnel = self.tunnels.remove(selected);
self.all_tunnels.remove_tunnel(tunnel).await;
self.tunnels.remove(selected);
if self.tunnels.is_empty() {
self.selected = None;
@ -359,9 +358,13 @@ impl russh::server::Handler for Handler {
return Err(russh::Error::Inconsistent.into());
};
let tunnel = self
.all_tunnels
.create_tunnel(session.handle(), address, *port, user)
let tunnel = Tunnel::create(
&mut self.registry,
session.handle(),
address,
*port,
TunnelAccess::Private(user),
)
.await;
self.tunnels.push(tunnel);
@ -421,16 +424,3 @@ impl russh::server::Handler for Handler {
Ok(())
}
}
impl Drop for Handler {
fn drop(&mut self) {
let tunnels = self.tunnels.clone();
let mut all_tunnels = self.all_tunnels.clone();
tokio::spawn(async move {
for tunnel in tunnels {
all_tunnels.remove_tunnel(tunnel).await;
}
});
}
}

View File

@ -1,6 +1,6 @@
#![feature(let_chains)]
mod animals;
mod auth;
pub mod auth;
mod cli;
mod handler;
mod helper;
@ -16,4 +16,5 @@ mod wrapper;
pub use ldap::Ldap;
pub use server::Server;
pub use tunnel::{Tunnel, Tunnels};
pub use tunnel::Registry;
pub use tunnel::Tunnel;

View File

@ -7,8 +7,8 @@ use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use tokio::net::TcpListener;
use tracing::{error, info, warn};
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
use tunnel_rs::{Ldap, Server, Tunnels};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
use tunnel_rs::{Ldap, Registry, Server, auth::ForwardAuth};
#[tokio::main]
async fn main() -> color_eyre::Result<()> {
@ -18,7 +18,10 @@ async fn main() -> color_eyre::Result<()> {
let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
let logger = tracing_subscriber::fmt::layer().compact();
Registry::default().with(logger).with(env_filter).init();
tracing_subscriber::Registry::default()
.with(logger)
.with(env_filter)
.init();
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
russh::keys::PrivateKey::read_openssh_file(Path::new(&path))
@ -41,7 +44,8 @@ async fn main() -> color_eyre::Result<()> {
let ldap = Ldap::start_from_env().await?;
let tunnels = Tunnels::new(domain, authz_address);
let auth = ForwardAuth::new(authz_address);
let tunnels = Registry::new(domain, auth);
let mut ssh = Server::new(ldap, tunnels.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port));
tokio::spawn(async move { ssh.run(key, addr).await });

View File

@ -4,20 +4,16 @@ use russh::{MethodKind, keys::PrivateKey, server::Server as _};
use tokio::net::ToSocketAddrs;
use tracing::{debug, warn};
use crate::{Ldap, handler::Handler, tunnel::Tunnels};
use crate::{Ldap, handler::Handler, tunnel::Registry};
pub struct Server {
ldap: Ldap,
tunnels: Tunnels,
registry: Registry,
}
impl Server {
pub fn new(ldap: Ldap, tunnels: Tunnels) -> Self {
Server { ldap, tunnels }
}
pub fn tunnels(&self) -> Tunnels {
self.tunnels.clone()
pub fn new(ldap: Ldap, registry: Registry) -> Self {
Server { ldap, registry }
}
pub fn run(
@ -49,7 +45,7 @@ impl russh::server::Server for Server {
type Handler = Handler;
fn new_client(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
Handler::new(self.ldap.clone(), self.tunnels.clone())
Handler::new(self.ldap.clone(), self.registry.clone())
}
fn handle_session_error(&mut self, error: <Self::Handler as russh::server::Handler>::Error) {

View File

@ -1,32 +1,15 @@
use bytes::Bytes;
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
use hyper::{
Request, Response, StatusCode,
body::Incoming,
client::conn::http1::Builder,
header::{self, HOST},
service::Service,
};
use std::{
collections::{HashMap, hash_map::Entry},
ops::Deref,
pin::Pin,
sync::Arc,
};
use tracing::{debug, error, trace, warn};
use registry::RegistryEntry;
use std::sync::Arc;
use tracing::trace;
use russh::server::Handle;
use tokio::sync::RwLock;
use crate::{
animals::get_animal_name,
auth::{AuthStatus, ForwardAuth},
helper::response,
stats::Stats,
wrapper::Wrapper,
};
use crate::{stats::Stats, wrapper::Wrapper};
mod registry;
pub mod tui;
pub use registry::Registry;
#[derive(Debug, Clone)]
pub enum TunnelAccess {
@ -36,262 +19,84 @@ pub enum TunnelAccess {
}
#[derive(Debug, Clone)]
pub struct Tunnel {
pub struct TunnelInner {
handle: Handle,
name: String,
address: String,
domain: Option<String>,
internal_address: String,
port: u32,
access: Arc<RwLock<TunnelAccess>>,
stats: Arc<Stats>,
}
impl Tunnel {
pub async fn open_tunnel(&self) -> Result<Wrapper, russh::Error> {
trace!(tunnel = self.name, "Opening tunnel");
impl TunnelInner {
pub async fn open(&self) -> Result<Wrapper, russh::Error> {
trace!("Opening tunnel");
self.stats.add_connection();
let channel = self
.handle
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
.channel_open_forwarded_tcpip(
&self.internal_address,
self.port,
&self.internal_address,
self.port,
)
.await?;
Ok(Wrapper::new(channel.into_stream(), self.stats.clone()))
}
}
#[derive(Debug)]
pub struct Tunnel {
inner: TunnelInner,
registry: Registry,
registry_entry: RegistryEntry,
}
impl Tunnel {
pub async fn create(
registry: &mut Registry,
handle: Handle,
internal_address: impl Into<String>,
port: u32,
access: TunnelAccess,
) -> Self {
let mut tunnel = Self {
inner: TunnelInner {
handle,
internal_address: internal_address.into(),
port,
access: Arc::new(RwLock::new(access)),
stats: Default::default(),
},
registry: registry.clone(),
registry_entry: RegistryEntry::new(registry.clone()),
};
registry.register(&mut tunnel).await;
tunnel
}
pub async fn set_access(&self, access: TunnelAccess) {
*self.access.write().await = access;
*self.inner.access.write().await = access;
}
pub async fn is_public(&self) -> bool {
matches!(*self.access.read().await, TunnelAccess::Public)
matches!(*self.inner.access.read().await, TunnelAccess::Public)
}
pub fn get_address(&self) -> Option<String> {
self.domain
.clone()
.map(|domain| format!("{}.{domain}", self.name))
}
}
#[derive(Debug, Clone)]
pub struct Tunnels {
tunnels: Arc<RwLock<HashMap<String, Tunnel>>>,
domain: String,
forward_auth: ForwardAuth,
}
impl Tunnels {
pub fn new(domain: impl Into<String>, endpoint: impl Into<String>) -> Self {
Self {
tunnels: Arc::new(RwLock::new(HashMap::new())),
domain: domain.into(),
forward_auth: ForwardAuth::new(endpoint),
}
}
async fn generate_tunnel_name(&mut self, mut tunnel: Tunnel) -> Tunnel {
// NOTE: It is technically possible to become stuck in this loop.
// However, that really only becomes a concern if a (very) high
// number of tunnels is open at the same time.
tunnel.domain = Some(self.domain.clone());
loop {
tunnel.name = get_animal_name().into();
if !self
.tunnels
.read()
.await
.contains_key(&tunnel.get_address().expect("domain is set"))
{
break;
}
trace!(tunnel = tunnel.name, "Already in use, picking new name");
}
tunnel
}
pub async fn create_tunnel(
&mut self,
handle: Handle,
name: impl Into<String>,
port: u32,
user: impl Into<String>,
) -> Tunnel {
let address = name.into();
let mut tunnel = Tunnel {
handle,
name: address.clone(),
address,
domain: Some(self.domain.clone()),
port,
access: Arc::new(RwLock::new(TunnelAccess::Private(user.into()))),
stats: Default::default(),
};
if tunnel.name == "localhost" {
tunnel = self.generate_tunnel_name(tunnel).await;
};
self.add_tunnel(tunnel).await
}
async fn add_tunnel(&mut self, mut tunnel: Tunnel) -> Tunnel {
let address = tunnel.get_address().expect("domain is set");
if let Entry::Vacant(e) = self.tunnels.write().await.entry(address) {
trace!(tunnel = tunnel.name, "Adding tunnel");
e.insert(tunnel.clone());
} else {
trace!("Address already in use");
tunnel.domain = None
}
tunnel
}
pub async fn remove_tunnel(&mut self, mut tunnel: Tunnel) -> Tunnel {
let mut all_tunnels = self.tunnels.write().await;
if let Some(address) = tunnel.get_address() {
trace!(tunnel.name, "Removing tunnel");
all_tunnels.remove(&address);
}
tunnel.domain = None;
tunnel
}
pub async fn retry_tunnel(&mut self, tunnel: Tunnel) -> Tunnel {
let mut tunnel = self.remove_tunnel(tunnel).await;
tunnel.domain = Some(self.domain.clone());
self.add_tunnel(tunnel).await
}
pub async fn rename_tunnel(&mut self, tunnel: Tunnel, name: impl Into<String>) -> Tunnel {
let mut tunnel = self.remove_tunnel(tunnel).await;
let name: String = name.into();
if name.is_empty() {
tunnel = self.generate_tunnel_name(tunnel).await;
} else {
tunnel.domain = Some(self.domain.clone());
tunnel.name = name;
}
self.add_tunnel(tunnel).await
}
}
impl Service<Request<Incoming>> for Tunnels {
type Response = Response<BoxBody<Bytes, hyper::Error>>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
trace!("{:#?}", req);
let Some(authority) = req
.uri()
.authority()
.as_ref()
.map(|a| a.to_string())
.or_else(|| {
req.headers()
.get(HOST)
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
})
else {
let resp = response(
StatusCode::BAD_REQUEST,
"Missing or invalid authority or host header",
);
return Box::pin(async { Ok(resp) });
};
debug!(tunnel = authority, "Tunnel request");
let s = self.clone();
Box::pin(async move {
let tunnels = s.tunnels.read().await;
let Some(tunnel) = tunnels.get(&authority) else {
debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
return Ok(resp);
};
if !matches!(tunnel.access.read().await.deref(), TunnelAccess::Public) {
let user = match s.forward_auth.check_auth(req.method(), req.headers()).await {
Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, location)
.body(
Empty::new()
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
.map_err(|never| match never {})
.boxed(),
)
.expect("configuration should be valid");
return Ok(resp);
}
Ok(AuthStatus::Unauthorized) => {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
Err(err) => {
error!("Unexpected error during authentication: {err}");
let resp = response(
StatusCode::FORBIDDEN,
"Unexpected error during authentication",
);
return Ok(resp);
}
};
trace!("Tunnel is getting accessed by {user:?}");
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
if !user.is(owner) {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
}
}
let io = match tunnel.open_tunnel().await {
Ok(io) => io,
Err(err) => {
warn!(tunnel = authority, "Failed to open tunnel: {err}");
let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel");
return Ok(resp);
}
};
let (mut sender, conn) = Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(io)
.await?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}");
}
});
let resp = sender.send_request(req).await?;
Ok(resp.map(|b| b.boxed()))
})
pub fn get_address(&self) -> Option<&String> {
self.registry_entry.get_address()
}
pub async fn set_name(&mut self, name: impl Into<String>) {
let mut registry = self.registry.clone();
registry.rename(self, name).await;
}
pub async fn retry(&mut self) {
let mut registry = self.registry.clone();
registry.register(self).await;
}
}

261
src/tunnel/registry.rs Normal file
View File

@ -0,0 +1,261 @@
use std::{
collections::{HashMap, hash_map::Entry},
ops::Deref,
pin::Pin,
sync::Arc,
};
use bytes::Bytes;
use http_body_util::{BodyExt as _, Empty, combinators::BoxBody};
use hyper::{
Request, Response, StatusCode,
body::Incoming,
client::conn::http1::Builder,
header::{self, HOST},
service::Service,
};
use tokio::sync::RwLock;
use tracing::{debug, error, trace, warn};
use crate::{
Tunnel,
animals::get_animal_name,
auth::{AuthStatus, ForwardAuth},
helper::response,
tunnel::TunnelAccess,
};
use super::TunnelInner;
#[derive(Debug)]
pub struct RegistryEntry {
registry: Registry,
name: String,
address: Option<String>,
}
impl RegistryEntry {
pub fn new(registry: Registry) -> Self {
Self {
registry,
name: Default::default(),
address: Default::default(),
}
}
pub fn get_address(&self) -> Option<&String> {
self.address.as_ref()
}
pub fn get_name(&self) -> &str {
&self.name
}
}
impl Drop for RegistryEntry {
fn drop(&mut self) {
trace!(
name = self.name,
address = self.address,
"Dropping registry entry"
);
if let Some(address) = self.address.take() {
let registry = self.registry.clone();
tokio::spawn(async move {
registry.tunnels.write().await.remove(&address);
});
}
}
}
#[derive(Debug, Clone)]
pub struct Registry {
tunnels: Arc<RwLock<HashMap<String, TunnelInner>>>,
domain: String,
auth: ForwardAuth,
}
impl Registry {
pub fn new(domain: impl Into<String>, auth: ForwardAuth) -> Self {
Self {
tunnels: Arc::new(RwLock::new(HashMap::new())),
domain: domain.into(),
auth,
}
}
fn address(&self, name: impl AsRef<str>) -> String {
format!("{}.{}", name.as_ref(), self.domain)
}
async fn generate_tunnel_name(&self) -> String {
// NOTE: It is technically possible to become stuck in this loop.
// However, that really only becomes a concern if a (very) high
// number of tunnels is open at the same time.
loop {
let name = get_animal_name();
if !self.tunnels.read().await.contains_key(&self.address(name)) {
break name.into();
}
trace!(name, "Already in use, picking new name");
}
}
pub(super) async fn register(&mut self, tunnel: &mut Tunnel) {
if tunnel.registry_entry.name.is_empty() {
if tunnel.inner.internal_address == "localhost" {
tunnel.registry_entry.name = self.generate_tunnel_name().await;
} else {
tunnel.registry_entry.name = tunnel.inner.internal_address.clone();
}
}
trace!(
name = tunnel.registry_entry.name,
"Attempting to register tunnel"
);
if tunnel.registry_entry.address.is_some() {
trace!(name = tunnel.registry_entry.name, "Already registered");
return;
}
let address = self.address(&tunnel.registry_entry.name);
if let Entry::Vacant(e) = self.tunnels.write().await.entry(address.clone()) {
tunnel.registry_entry.address = Some(address);
e.insert(tunnel.inner.clone());
} else {
trace!(name = tunnel.registry_entry.name, "Address already in use");
tunnel.registry_entry.address = None;
}
}
pub(super) async fn rename(&mut self, tunnel: &mut Tunnel, name: impl Into<String>) {
trace!(name = tunnel.registry_entry.name, "Renaming tunnel");
if let Some(address) = tunnel.registry_entry.address.take() {
self.tunnels.write().await.remove(&address);
}
tunnel.registry_entry.name = name.into();
self.register(tunnel).await;
}
}
impl Service<Request<Incoming>> for Registry {
type Response = Response<BoxBody<Bytes, hyper::Error>>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
trace!("{:#?}", req);
let Some(authority) = req
.uri()
.authority()
.as_ref()
.map(|a| a.to_string())
.or_else(|| {
req.headers()
.get(HOST)
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
})
else {
let resp = response(
StatusCode::BAD_REQUEST,
"Missing or invalid authority or host header",
);
return Box::pin(async { Ok(resp) });
};
debug!(authority, "Tunnel request");
let s = self.clone();
Box::pin(async move {
let Some(entry) = s.tunnels.read().await.get(&authority).cloned() else {
debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
return Ok(resp);
};
if !matches!(entry.access.read().await.deref(), TunnelAccess::Public) {
let user = match s.auth.check_auth(req.method(), req.headers()).await {
Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, location)
.body(
Empty::new()
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
.map_err(|never| match never {})
.boxed(),
)
.expect("configuration should be valid");
return Ok(resp);
}
Ok(AuthStatus::Unauthorized) => {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
Err(err) => {
error!("Unexpected error during authentication: {err}");
let resp = response(
StatusCode::FORBIDDEN,
"Unexpected error during authentication",
);
return Ok(resp);
}
};
trace!("Tunnel is getting accessed by {user:?}");
if let TunnelAccess::Private(owner) = entry.access.read().await.deref() {
if !user.is(owner) {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
}
}
let io = match entry.open().await {
Ok(io) => io,
Err(err) => {
warn!(tunnel = authority, "Failed to open tunnel: {err}");
let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel");
return Ok(resp);
}
};
let (mut sender, conn) = Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(io)
.await?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(runnel = authority, "Connection failed: {err}");
}
});
let resp = sender.send_request(req).await?;
Ok(resp.map(|b| b.boxed()))
})
}
}

View File

@ -18,7 +18,7 @@ pub fn header() -> Vec<Span<'static>> {
}
pub async fn to_row(tunnel: &Tunnel) -> Vec<Span<'static>> {
let access = match tunnel.access.read().await.deref() {
let access = match tunnel.inner.access.read().await.deref() {
TunnelAccess::Private(owner) => owner.clone().yellow(),
TunnelAccess::Protected => "PROTECTED".blue(),
TunnelAccess::Public => "PUBLIC".green(),
@ -30,12 +30,12 @@ pub async fn to_row(tunnel: &Tunnel) -> Vec<Span<'static>> {
.unwrap_or("FAILED".red());
vec![
tunnel.name.clone().into(),
tunnel.registry_entry.get_name().to_owned().into(),
access,
tunnel.port.to_string().into(),
tunnel.inner.port.to_string().into(),
address,
tunnel.stats.connections().to_string().into(),
tunnel.stats.rx().to_string().into(),
tunnel.stats.tx().to_string().into(),
tunnel.inner.stats.connections().to_string().into(),
tunnel.inner.stats.rx().to_string().into(),
tunnel.inner.stats.tx().to_string().into(),
]
}