From 31532493cbb6c657496fc70483242674c31f9cc5 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Tue, 8 Apr 2025 15:16:47 +0200 Subject: [PATCH] Improved error handling --- Cargo.lock | 1 + Cargo.toml | 1 + src/auth.rs | 101 ++++++++++++++++++++++++---------------- src/bin/generate_key.rs | 7 +-- src/helper.rs | 14 ++++++ src/lib.rs | 1 + src/main.rs | 21 +++++---- src/ssh.rs | 12 ++--- src/tunnel.rs | 77 +++++++++++++++++++----------- 9 files changed, 146 insertions(+), 89 deletions(-) create mode 100644 src/helper.rs diff --git a/Cargo.lock b/Cargo.lock index 004de8d..fd58291 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2909,6 +2909,7 @@ dependencies = [ "rand 0.8.5", "reqwest", "russh", + "thiserror 2.0.12", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index b52e631..a0c2162 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ indexmap = "2.9.0" rand = "0.8.5" reqwest = { version = "0.12.15", features = ["rustls-tls"] } russh = "0.51.1" +thiserror = "2.0.12" tokio = { version = "1.44.1", features = ["full"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } diff --git a/src/auth.rs b/src/auth.rs index be77262..840d8f8 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,22 +1,50 @@ -use bytes::Bytes; -use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; use hyper::{ - HeaderMap, Response, - header::{self, HeaderValue}, + HeaderMap, StatusCode, + header::{self, HeaderName, HeaderValue, ToStrError}, }; use reqwest::redirect::Policy; -use tracing::debug; - -pub enum AuthStatus { - Authenticated(String), - Unauthenticated(Response>), -} +use tracing::{debug, error}; #[derive(Debug, Clone)] pub struct ForwardAuth { address: String, } +#[derive(Debug)] +pub struct User { + username: String, +} + +impl User { + pub fn is(&self, username: impl AsRef) -> bool { + self.username.eq(username.as_ref()) + } +} + +#[derive(Debug)] +pub enum AuthStatus { + // Contains the value of the location header that will redirect the user to the login page + Unauthenticated(HeaderValue), + Authenticated(User), + Unauthorized, +} + +const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user"); + +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("Http error: {0}")] + Http(#[from] hyper::http::Error), + #[error("Header '{0}' is missing from auth endpoint response")] + MissingHeader(HeaderName), + #[error("Header '{0}' received from auth endpoint is invalid: {1}")] + InvalidHeader(HeaderName, ToStrError), + #[error("Unexpected response from auth endpoint: {0:?}")] + UnexpectedResponse(reqwest::Response), +} + impl ForwardAuth { pub fn new(endpoint: impl Into) -> Self { Self { @@ -24,11 +52,13 @@ impl ForwardAuth { } } - pub async fn check_auth(&self, headers: &HeaderMap) -> AuthStatus { + pub async fn check_auth( + &self, + headers: &HeaderMap, + ) -> Result { let client = reqwest::ClientBuilder::new() .redirect(Policy::none()) - .build() - .unwrap(); + .build()?; let headers = headers .clone() @@ -45,42 +75,33 @@ impl ForwardAuth { }) .collect(); - debug!("{headers:#?}"); - - let resp = client - .get(&self.address) - .headers(headers) - .send() - .await - .unwrap(); + let resp = client.get(&self.address).headers(headers).send().await?; let status_code = resp.status(); - if !status_code.is_success() { - debug!("{:#?}", resp.headers()); - let location = resp.headers().get(header::LOCATION).unwrap().clone(); - let body = resp.bytes().await.unwrap(); - let resp = Response::builder() - .status(status_code) - .header(header::LOCATION, location) - .body(Full::new(body)) - .unwrap() - .map(|b| b.map_err(|never| match never {}).boxed()); + if status_code == StatusCode::FOUND { + let location = resp + .headers() + .get(header::LOCATION) + .cloned() + .ok_or(AuthError::MissingHeader(header::LOCATION))?; - return AuthStatus::Unauthenticated(resp); + return Ok(AuthStatus::Unauthenticated(location)); + } else if status_code == StatusCode::FORBIDDEN { + return Ok(AuthStatus::Unauthorized); + } else if !status_code.is_success() { + return Err(AuthError::UnexpectedResponse(resp)); } - debug!("{:#?}", resp.headers()); - let user = resp + let username = resp .headers() - .get("remote-user") - .unwrap() + .get(REMOTE_USER) + .ok_or(AuthError::MissingHeader(REMOTE_USER))? .to_str() - .unwrap() + .map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))? .to_owned(); - debug!("{}", resp.text().await.unwrap()); - debug!("Logged in as user: {user}"); + debug!("Connected user is: {username}"); - AuthStatus::Authenticated(user) + Ok(AuthStatus::Authenticated(User { username })) } } diff --git a/src/bin/generate_key.rs b/src/bin/generate_key.rs index 30913d1..f12c3d9 100644 --- a/src/bin/generate_key.rs +++ b/src/bin/generate_key.rs @@ -18,12 +18,9 @@ fn main() -> color_eyre::Result<()> { color_eyre::install()?; - let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519) - .expect("algorithm should be supported"); + let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?; - let key = key - .to_openssh(LineEnding::LF) - .expect("encodig should not fail"); + let key = key.to_openssh(LineEnding::LF)?; args.output .write(key.as_bytes()) diff --git a/src/helper.rs b/src/helper.rs new file mode 100644 index 0000000..f470d12 --- /dev/null +++ b/src/helper.rs @@ -0,0 +1,14 @@ +use bytes::Bytes; +use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; +use hyper::{Response, StatusCode}; + +pub fn response( + status_code: StatusCode, + body: impl Into, +) -> Response> { + Response::builder() + .status(status_code) + .body(Full::new(Bytes::from(body.into()))) + .expect("all configuration should be valid") + .map(|b| b.map_err(|never| match never {}).boxed()) +} diff --git a/src/lib.rs b/src/lib.rs index 8cf8fad..76eb4ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,5 +2,6 @@ #![feature(let_chains)] pub mod animals; pub mod auth; +pub mod helper; pub mod ssh; pub mod tunnel; diff --git a/src/main.rs b/src/main.rs index 1401340..6e1f248 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,29 +1,30 @@ use std::{net::SocketAddr, path::Path}; +use color_eyre::eyre::Context; use dotenvy::dotenv; use hyper::server::conn::http1::{self}; use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use tokio::net::TcpListener; -use tracing::{info, warn}; +use tracing::{error, info}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use tunnel_rs::{ssh::Server, tunnel::Tunnels}; #[tokio::main] -async fn main() { +async fn main() -> color_eyre::Result<()> { + color_eyre::install()?; dotenv().ok(); - let env_filter = EnvFilter::try_from_default_env() - .or_else(|_| EnvFilter::try_new("info")) - .expect("Fallback should be valid"); + let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?; let logger = tracing_subscriber::fmt::layer().compact(); Registry::default().with(logger).with(env_filter).init(); let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") { - russh::keys::PrivateKey::read_openssh_file(Path::new(&path)).unwrap() + russh::keys::PrivateKey::read_openssh_file(Path::new(&path)) + .wrap_err_with(|| format!("failed to read ssh key: {path}"))? } else { - russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap() + russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)? }; let port = 3000; @@ -38,12 +39,12 @@ async fn main() { info!("SSH is available on {addr}"); let addr = SocketAddr::from(([0, 0, 0, 0], port)); - let listener = TcpListener::bind(addr).await.unwrap(); + let listener = TcpListener::bind(addr).await?; info!("HTTP is available on {addr}"); // TODO: Graceful shutdown loop { - let (stream, _) = listener.accept().await.unwrap(); + let (stream, _) = listener.accept().await?; let io = TokioIo::new(stream); let tunnels = tunnels.clone(); @@ -55,7 +56,7 @@ async fn main() { .with_upgrades() .await { - warn!("Failed to serve connection: {err:?}"); + error!("Failed to serve connection: {err:?}"); } }); } diff --git a/src/ssh.rs b/src/ssh.rs index bdca984..eeadcdc 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -143,13 +143,11 @@ impl russh::server::Handler for Handler { } Err(err) => { trace!("Sending error/help message and disconnecting"); - session - .disconnect( - russh::Disconnect::ByApplication, - &format!("\n\r{err}"), - "EN", - ) - .unwrap(); + session.disconnect( + russh::Disconnect::ByApplication, + &format!("\n\r{err}"), + "EN", + )?; session.channel_failure(channel) } diff --git a/src/tunnel.rs b/src/tunnel.rs index 358fdbb..f19b7c7 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,13 +1,16 @@ use bytes::Bytes; -use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; +use http_body_util::{BodyExt, Empty, combinators::BoxBody}; use hyper::{ - Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST, + Request, Response, StatusCode, + body::Incoming, + client::conn::http1::Builder, + header::{self, HOST}, service::Service, }; use hyper_util::rt::TokioIo; use indexmap::IndexMap; use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc}; -use tracing::{debug, trace, warn}; +use tracing::{debug, error, trace, warn}; use russh::{ Channel, @@ -17,10 +20,8 @@ use tokio::sync::RwLock; use crate::{ animals::get_animal_name, - auth::{ - AuthStatus::{Authenticated, Unauthenticated}, - ForwardAuth, - }, + auth::{AuthStatus, ForwardAuth}, + helper::response, }; #[derive(Debug, Clone)] @@ -120,18 +121,7 @@ impl Service> for Tunnels { type Future = Pin> + Send>>; fn call(&self, req: Request) -> Self::Future { - fn response( - status_code: StatusCode, - body: impl Into, - ) -> Response> { - Response::builder() - .status(status_code) - .body(Full::new(Bytes::from(body.into()))) - .unwrap() - .map(|b| b.map_err(|never| match never {}).boxed()) - } - - trace!(?req); + trace!("{:#?}", req); let Some(authority) = req .uri() @@ -141,15 +131,18 @@ impl Service> for Tunnels { .or_else(|| { req.headers() .get(HOST) - .map(|h| h.to_str().unwrap().to_owned()) + .and_then(|h| h.to_str().ok().map(|s| s.to_owned())) }) else { - let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header"); + let resp = response( + StatusCode::BAD_REQUEST, + "Missing or invalid authority or host header", + ); return Box::pin(async { Ok(resp) }); }; - debug!(tunnel = authority, "Request"); + debug!(tunnel = authority, "Tunnel request"); let s = self.clone(); Box::pin(async move { @@ -163,13 +156,43 @@ impl Service> for Tunnels { if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() { let user = match s.forward_auth.check_auth(req.headers()).await { - Authenticated(user) => user, - Unauthenticated(response) => return Ok(response), + Ok(AuthStatus::Authenticated(user)) => user, + Ok(AuthStatus::Unauthenticated(location)) => { + let resp = Response::builder() + .status(StatusCode::FOUND) + .header(header::LOCATION, location) + .body( + Empty::new() + // NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error + .map_err(|never| match never {}) + .boxed(), + ) + .expect("configuration should be valid"); + + return Ok(resp); + } + Ok(AuthStatus::Unauthorized) => { + let resp = response( + StatusCode::FORBIDDEN, + "You do not have permission to access this tunnel", + ); + + return Ok(resp); + } + Err(err) => { + error!("Unexpected error during authentication: {err}"); + let resp = response( + StatusCode::FORBIDDEN, + "Unexpected error during authentication", + ); + + return Ok(resp); + } }; - trace!("Tunnel owned by {owner} is getting accessed by {user}"); + trace!("Tunnel owned by {owner} is getting accessed by {user:?}"); - if !user.eq(owner) { + if !user.is(owner) { let resp = response( StatusCode::FORBIDDEN, "You do not have permission to access this tunnel", @@ -202,7 +225,7 @@ impl Service> for Tunnels { } }); - let resp = sender.send_request(req).await.unwrap(); + let resp = sender.send_request(req).await?; Ok(resp.map(|b| b.boxed())) }) }