diff --git a/src/auth.rs b/src/auth.rs index ab45dc1..e4d7d63 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -53,7 +53,7 @@ impl ForwardAuth { } } - pub async fn check_auth( + pub async fn check( &self, methods: &Method, headers: &HeaderMap, diff --git a/src/lib.rs b/src/lib.rs index 81599ba..149c98d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,9 +12,11 @@ mod stats; mod tui; mod tunnel; mod units; +mod web; mod wrapper; pub use ldap::Ldap; pub use server::Server; pub use tunnel::Registry; pub use tunnel::Tunnel; +pub use web::Service; diff --git a/src/main.rs b/src/main.rs index 65ca9b0..d5759cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use rand::rngs::OsRng; use tokio::net::TcpListener; use tracing::{error, info, warn}; use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; -use tunnel_rs::{Ldap, Registry, Server, auth::ForwardAuth}; +use tunnel_rs::{Ldap, Registry, Server, Service, auth::ForwardAuth}; #[tokio::main] async fn main() -> color_eyre::Result<()> { @@ -43,14 +43,14 @@ async fn main() -> color_eyre::Result<()> { let authz_address = std::env::var("AUTHZ_ENDPOINT").wrap_err("AUTHZ_ENDPOINT is not set")?; let ldap = Ldap::start_from_env().await?; - - let auth = ForwardAuth::new(authz_address); - let tunnels = Registry::new(domain, auth); - let mut ssh = Server::new(ldap, tunnels.clone()); + let registry = Registry::new(domain); + let mut ssh = Server::new(ldap, registry.clone()); let addr = SocketAddr::from(([0, 0, 0, 0], ssh_port)); tokio::spawn(async move { ssh.run(key, addr).await }); info!("SSH is available on {addr}"); + let auth = ForwardAuth::new(authz_address); + let service = Service::new(registry, auth); let addr = SocketAddr::from(([0, 0, 0, 0], http_port)); let listener = TcpListener::bind(addr).await?; info!("HTTP is available on {addr}"); @@ -60,12 +60,12 @@ async fn main() -> color_eyre::Result<()> { let (stream, _) = listener.accept().await?; let io = TokioIo::new(stream); - let tunnels = tunnels.clone(); + let service = service.clone(); tokio::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(io, tunnels) + .serve_connection(io, service) .with_upgrades() .await { diff --git a/src/tunnel.rs b/src/tunnel/mod.rs similarity index 90% rename from src/tunnel.rs rename to src/tunnel/mod.rs index f782e80..abdd835 100644 --- a/src/tunnel.rs +++ b/src/tunnel/mod.rs @@ -1,14 +1,15 @@ +mod registry; +mod tui; + use registry::RegistryEntry; use std::sync::Arc; use tracing::trace; use russh::server::Handle; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, RwLockReadGuard}; use crate::{stats::Stats, wrapper::Wrapper}; -mod registry; -pub mod tui; pub use registry::Registry; #[derive(Debug, Clone)] @@ -43,6 +44,14 @@ impl TunnelInner { Ok(Wrapper::new(channel.into_stream(), self.stats.clone())) } + + pub async fn is_public(&self) -> bool { + matches!(*self.access.read().await, TunnelAccess::Public) + } + + pub async fn get_access(&self) -> RwLockReadGuard<'_, TunnelAccess> { + self.access.read().await + } } #[derive(Debug)] @@ -82,10 +91,6 @@ impl Tunnel { *self.inner.access.write().await = access; } - pub async fn is_public(&self) -> bool { - matches!(*self.inner.access.read().await, TunnelAccess::Public) - } - pub fn get_address(&self) -> Option<&String> { self.registry_entry.get_address() } diff --git a/src/tunnel/registry.rs b/src/tunnel/registry.rs index e62b3fa..f7f8f01 100644 --- a/src/tunnel/registry.rs +++ b/src/tunnel/registry.rs @@ -1,29 +1,12 @@ use std::{ collections::{HashMap, hash_map::Entry}, - ops::Deref, - pin::Pin, sync::Arc, }; -use bytes::Bytes; -use http_body_util::{BodyExt as _, Empty, combinators::BoxBody}; -use hyper::{ - Request, Response, StatusCode, - body::Incoming, - client::conn::http1::Builder, - header::{self, HOST}, - service::Service, -}; use tokio::sync::RwLock; -use tracing::{debug, error, trace, warn}; +use tracing::trace; -use crate::{ - Tunnel, - animals::get_animal_name, - auth::{AuthStatus, ForwardAuth}, - helper::response, - tunnel::TunnelAccess, -}; +use crate::{Tunnel, animals::get_animal_name}; use super::TunnelInner; @@ -73,15 +56,13 @@ impl Drop for RegistryEntry { pub struct Registry { tunnels: Arc>>, domain: String, - auth: ForwardAuth, } impl Registry { - pub fn new(domain: impl Into, auth: ForwardAuth) -> Self { + pub fn new(domain: impl Into) -> Self { Self { tunnels: Arc::new(RwLock::new(HashMap::new())), domain: domain.into(), - auth, } } @@ -142,120 +123,8 @@ impl Registry { tunnel.registry_entry.name = name.into(); self.register(tunnel).await; } -} -impl Service> for Registry { - type Response = Response>; - type Error = hyper::Error; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - trace!("{:#?}", req); - - let Some(authority) = req - .uri() - .authority() - .as_ref() - .map(|a| a.to_string()) - .or_else(|| { - req.headers() - .get(HOST) - .and_then(|h| h.to_str().ok().map(|s| s.to_owned())) - }) - else { - let resp = response( - StatusCode::BAD_REQUEST, - "Missing or invalid authority or host header", - ); - - return Box::pin(async { Ok(resp) }); - }; - - debug!(authority, "Tunnel request"); - - let s = self.clone(); - Box::pin(async move { - let Some(entry) = s.tunnels.read().await.get(&authority).cloned() else { - debug!(tunnel = authority, "Unknown tunnel"); - let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); - - return Ok(resp); - }; - - if !matches!(entry.access.read().await.deref(), TunnelAccess::Public) { - let user = match s.auth.check_auth(req.method(), req.headers()).await { - Ok(AuthStatus::Authenticated(user)) => user, - Ok(AuthStatus::Unauthenticated(location)) => { - let resp = Response::builder() - .status(StatusCode::FOUND) - .header(header::LOCATION, location) - .body( - Empty::new() - // NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error - .map_err(|never| match never {}) - .boxed(), - ) - .expect("configuration should be valid"); - - return Ok(resp); - } - Ok(AuthStatus::Unauthorized) => { - let resp = response( - StatusCode::FORBIDDEN, - "You do not have permission to access this tunnel", - ); - - return Ok(resp); - } - Err(err) => { - error!("Unexpected error during authentication: {err}"); - let resp = response( - StatusCode::FORBIDDEN, - "Unexpected error during authentication", - ); - - return Ok(resp); - } - }; - - trace!("Tunnel is getting accessed by {user:?}"); - - if let TunnelAccess::Private(owner) = entry.access.read().await.deref() { - if !user.is(owner) { - let resp = response( - StatusCode::FORBIDDEN, - "You do not have permission to access this tunnel", - ); - - return Ok(resp); - } - } - } - - let io = match entry.open().await { - Ok(io) => io, - Err(err) => { - warn!(tunnel = authority, "Failed to open tunnel: {err}"); - let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); - - return Ok(resp); - } - }; - - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - - tokio::spawn(async move { - if let Err(err) = conn.await { - warn!(runnel = authority, "Connection failed: {err}"); - } - }); - - let resp = sender.send_request(req).await?; - Ok(resp.map(|b| b.boxed())) - }) + pub async fn get(&self, address: &str) -> Option { + self.tunnels.read().await.get(address).cloned() } } diff --git a/src/web.rs b/src/web.rs new file mode 100644 index 0000000..2cbed38 --- /dev/null +++ b/src/web.rs @@ -0,0 +1,147 @@ +use crate::Registry; +use std::{ops::Deref, pin::Pin}; + +use bytes::Bytes; +use http_body_util::{BodyExt as _, Empty, combinators::BoxBody}; +use hyper::{ + Request, Response, StatusCode, + body::Incoming, + client::conn::http1::Builder, + header::{self, HOST}, +}; +use tracing::{debug, error, trace, warn}; + +use crate::{ + auth::{AuthStatus, ForwardAuth}, + helper::response, + tunnel::TunnelAccess, +}; + +#[derive(Debug, Clone)] +pub struct Service { + registry: Registry, + auth: ForwardAuth, +} + +impl Service { + pub fn new(registry: Registry, auth: ForwardAuth) -> Self { + Self { registry, auth } + } +} + +impl hyper::service::Service> for Service { + type Response = Response>; + type Error = hyper::Error; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + trace!("{:#?}", req); + + let Some(authority) = req + .uri() + .authority() + .as_ref() + .map(|a| a.to_string()) + .or_else(|| { + req.headers() + .get(HOST) + .and_then(|h| h.to_str().ok().map(|s| s.to_owned())) + }) + else { + let resp = response( + StatusCode::BAD_REQUEST, + "Missing or invalid authority or host header", + ); + + return Box::pin(async { Ok(resp) }); + }; + + debug!(authority, "Tunnel request"); + + let registry = self.registry.clone(); + let auth = self.auth.clone(); + Box::pin(async move { + let Some(entry) = registry.get(&authority).await else { + debug!(tunnel = authority, "Unknown tunnel"); + let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); + + return Ok(resp); + }; + + if !entry.is_public().await { + let user = match auth.check(req.method(), req.headers()).await { + Ok(AuthStatus::Authenticated(user)) => user, + Ok(AuthStatus::Unauthenticated(location)) => { + let resp = Response::builder() + .status(StatusCode::FOUND) + .header(header::LOCATION, location) + .body( + Empty::new() + // NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error + .map_err(|never| match never {}) + .boxed(), + ) + .expect("configuration should be valid"); + + return Ok(resp); + } + Ok(AuthStatus::Unauthorized) => { + let resp = response( + StatusCode::FORBIDDEN, + "You do not have permission to access this tunnel", + ); + + return Ok(resp); + } + Err(err) => { + error!("Unexpected error during authentication: {err}"); + let resp = response( + StatusCode::FORBIDDEN, + "Unexpected error during authentication", + ); + + return Ok(resp); + } + }; + + trace!("Tunnel is getting accessed by {user:?}"); + + if let TunnelAccess::Private(owner) = entry.get_access().await.deref() { + if !user.is(owner) { + let resp = response( + StatusCode::FORBIDDEN, + "You do not have permission to access this tunnel", + ); + + return Ok(resp); + } + } + } + + let io = match entry.open().await { + Ok(io) => io, + Err(err) => { + warn!(tunnel = authority, "Failed to open tunnel: {err}"); + let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); + + return Ok(resp); + } + }; + + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!(runnel = authority, "Connection failed: {err}"); + } + }); + + let resp = sender.send_request(req).await?; + Ok(resp.map(|b| b.boxed())) + }) + } +}