Improved error handling

This commit is contained in:
Dreaded_X 2025-04-08 15:16:47 +02:00
parent 3937e06660
commit 31532493cb
Signed by: Dreaded_X
GPG Key ID: 5A0CBFE3C3377FAA
9 changed files with 146 additions and 89 deletions

1
Cargo.lock generated
View File

@ -2909,6 +2909,7 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
"reqwest", "reqwest",
"russh", "russh",
"thiserror 2.0.12",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",

View File

@ -17,6 +17,7 @@ indexmap = "2.9.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.15", features = ["rustls-tls"] } reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1" russh = "0.51.1"
thiserror = "2.0.12"
tokio = { version = "1.44.1", features = ["full"] } tokio = { version = "1.44.1", features = ["full"] }
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }

View File

@ -1,22 +1,50 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use hyper::{ use hyper::{
HeaderMap, Response, HeaderMap, StatusCode,
header::{self, HeaderValue}, header::{self, HeaderName, HeaderValue, ToStrError},
}; };
use reqwest::redirect::Policy; use reqwest::redirect::Policy;
use tracing::debug; use tracing::{debug, error};
pub enum AuthStatus {
Authenticated(String),
Unauthenticated(Response<BoxBody<Bytes, hyper::Error>>),
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ForwardAuth { pub struct ForwardAuth {
address: String, address: String,
} }
#[derive(Debug)]
pub struct User {
username: String,
}
impl User {
pub fn is(&self, username: impl AsRef<str>) -> bool {
self.username.eq(username.as_ref())
}
}
#[derive(Debug)]
pub enum AuthStatus {
// Contains the value of the location header that will redirect the user to the login page
Unauthenticated(HeaderValue),
Authenticated(User),
Unauthorized,
}
const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user");
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Http error: {0}")]
Http(#[from] hyper::http::Error),
#[error("Header '{0}' is missing from auth endpoint response")]
MissingHeader(HeaderName),
#[error("Header '{0}' received from auth endpoint is invalid: {1}")]
InvalidHeader(HeaderName, ToStrError),
#[error("Unexpected response from auth endpoint: {0:?}")]
UnexpectedResponse(reqwest::Response),
}
impl ForwardAuth { impl ForwardAuth {
pub fn new(endpoint: impl Into<String>) -> Self { pub fn new(endpoint: impl Into<String>) -> Self {
Self { Self {
@ -24,11 +52,13 @@ impl ForwardAuth {
} }
} }
pub async fn check_auth(&self, headers: &HeaderMap<HeaderValue>) -> AuthStatus { pub async fn check_auth(
&self,
headers: &HeaderMap<HeaderValue>,
) -> Result<AuthStatus, AuthError> {
let client = reqwest::ClientBuilder::new() let client = reqwest::ClientBuilder::new()
.redirect(Policy::none()) .redirect(Policy::none())
.build() .build()?;
.unwrap();
let headers = headers let headers = headers
.clone() .clone()
@ -45,42 +75,33 @@ impl ForwardAuth {
}) })
.collect(); .collect();
debug!("{headers:#?}"); let resp = client.get(&self.address).headers(headers).send().await?;
let resp = client
.get(&self.address)
.headers(headers)
.send()
.await
.unwrap();
let status_code = resp.status(); let status_code = resp.status();
if !status_code.is_success() { if status_code == StatusCode::FOUND {
debug!("{:#?}", resp.headers()); let location = resp
let location = resp.headers().get(header::LOCATION).unwrap().clone();
let body = resp.bytes().await.unwrap();
let resp = Response::builder()
.status(status_code)
.header(header::LOCATION, location)
.body(Full::new(body))
.unwrap()
.map(|b| b.map_err(|never| match never {}).boxed());
return AuthStatus::Unauthenticated(resp);
}
debug!("{:#?}", resp.headers());
let user = resp
.headers() .headers()
.get("remote-user") .get(header::LOCATION)
.unwrap() .cloned()
.ok_or(AuthError::MissingHeader(header::LOCATION))?;
return Ok(AuthStatus::Unauthenticated(location));
} else if status_code == StatusCode::FORBIDDEN {
return Ok(AuthStatus::Unauthorized);
} else if !status_code.is_success() {
return Err(AuthError::UnexpectedResponse(resp));
}
let username = resp
.headers()
.get(REMOTE_USER)
.ok_or(AuthError::MissingHeader(REMOTE_USER))?
.to_str() .to_str()
.unwrap() .map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))?
.to_owned(); .to_owned();
debug!("{}", resp.text().await.unwrap());
debug!("Logged in as user: {user}"); debug!("Connected user is: {username}");
AuthStatus::Authenticated(user) Ok(AuthStatus::Authenticated(User { username }))
} }
} }

View File

@ -18,12 +18,9 @@ fn main() -> color_eyre::Result<()> {
color_eyre::install()?; color_eyre::install()?;
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519) let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?;
.expect("algorithm should be supported");
let key = key let key = key.to_openssh(LineEnding::LF)?;
.to_openssh(LineEnding::LF)
.expect("encodig should not fail");
args.output args.output
.write(key.as_bytes()) .write(key.as_bytes())

14
src/helper.rs Normal file
View File

@ -0,0 +1,14 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use hyper::{Response, StatusCode};
pub fn response(
status_code: StatusCode,
body: impl Into<String>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status_code)
.body(Full::new(Bytes::from(body.into())))
.expect("all configuration should be valid")
.map(|b| b.map_err(|never| match never {}).boxed())
}

View File

@ -2,5 +2,6 @@
#![feature(let_chains)] #![feature(let_chains)]
pub mod animals; pub mod animals;
pub mod auth; pub mod auth;
pub mod helper;
pub mod ssh; pub mod ssh;
pub mod tunnel; pub mod tunnel;

View File

@ -1,29 +1,30 @@
use std::{net::SocketAddr, path::Path}; use std::{net::SocketAddr, path::Path};
use color_eyre::eyre::Context;
use dotenvy::dotenv; use dotenvy::dotenv;
use hyper::server::conn::http1::{self}; use hyper::server::conn::http1::{self};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{info, warn}; use tracing::{error, info};
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
use tunnel_rs::{ssh::Server, tunnel::Tunnels}; use tunnel_rs::{ssh::Server, tunnel::Tunnels};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
dotenv().ok(); dotenv().ok();
let env_filter = EnvFilter::try_from_default_env() let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
.or_else(|_| EnvFilter::try_new("info"))
.expect("Fallback should be valid");
let logger = tracing_subscriber::fmt::layer().compact(); let logger = tracing_subscriber::fmt::layer().compact();
Registry::default().with(logger).with(env_filter).init(); Registry::default().with(logger).with(env_filter).init();
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") { let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
russh::keys::PrivateKey::read_openssh_file(Path::new(&path)).unwrap() russh::keys::PrivateKey::read_openssh_file(Path::new(&path))
.wrap_err_with(|| format!("failed to read ssh key: {path}"))?
} else { } else {
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap() russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?
}; };
let port = 3000; let port = 3000;
@ -38,12 +39,12 @@ async fn main() {
info!("SSH is available on {addr}"); info!("SSH is available on {addr}");
let addr = SocketAddr::from(([0, 0, 0, 0], port)); let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await.unwrap(); let listener = TcpListener::bind(addr).await?;
info!("HTTP is available on {addr}"); info!("HTTP is available on {addr}");
// TODO: Graceful shutdown // TODO: Graceful shutdown
loop { loop {
let (stream, _) = listener.accept().await.unwrap(); let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let tunnels = tunnels.clone(); let tunnels = tunnels.clone();
@ -55,7 +56,7 @@ async fn main() {
.with_upgrades() .with_upgrades()
.await .await
{ {
warn!("Failed to serve connection: {err:?}"); error!("Failed to serve connection: {err:?}");
} }
}); });
} }

View File

@ -143,13 +143,11 @@ impl russh::server::Handler for Handler {
} }
Err(err) => { Err(err) => {
trace!("Sending error/help message and disconnecting"); trace!("Sending error/help message and disconnecting");
session session.disconnect(
.disconnect(
russh::Disconnect::ByApplication, russh::Disconnect::ByApplication,
&format!("\n\r{err}"), &format!("\n\r{err}"),
"EN", "EN",
) )?;
.unwrap();
session.channel_failure(channel) session.channel_failure(channel)
} }

View File

@ -1,13 +1,16 @@
use bytes::Bytes; use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; use http_body_util::{BodyExt, Empty, combinators::BoxBody};
use hyper::{ use hyper::{
Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST, Request, Response, StatusCode,
body::Incoming,
client::conn::http1::Builder,
header::{self, HOST},
service::Service, service::Service,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use indexmap::IndexMap; use indexmap::IndexMap;
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc}; use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
use tracing::{debug, trace, warn}; use tracing::{debug, error, trace, warn};
use russh::{ use russh::{
Channel, Channel,
@ -17,10 +20,8 @@ use tokio::sync::RwLock;
use crate::{ use crate::{
animals::get_animal_name, animals::get_animal_name,
auth::{ auth::{AuthStatus, ForwardAuth},
AuthStatus::{Authenticated, Unauthenticated}, helper::response,
ForwardAuth,
},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -120,18 +121,7 @@ impl Service<Request<Incoming>> for Tunnels {
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future { fn call(&self, req: Request<Incoming>) -> Self::Future {
fn response( trace!("{:#?}", req);
status_code: StatusCode,
body: impl Into<String>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status_code)
.body(Full::new(Bytes::from(body.into())))
.unwrap()
.map(|b| b.map_err(|never| match never {}).boxed())
}
trace!(?req);
let Some(authority) = req let Some(authority) = req
.uri() .uri()
@ -141,15 +131,18 @@ impl Service<Request<Incoming>> for Tunnels {
.or_else(|| { .or_else(|| {
req.headers() req.headers()
.get(HOST) .get(HOST)
.map(|h| h.to_str().unwrap().to_owned()) .and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
}) })
else { else {
let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header"); let resp = response(
StatusCode::BAD_REQUEST,
"Missing or invalid authority or host header",
);
return Box::pin(async { Ok(resp) }); return Box::pin(async { Ok(resp) });
}; };
debug!(tunnel = authority, "Request"); debug!(tunnel = authority, "Tunnel request");
let s = self.clone(); let s = self.clone();
Box::pin(async move { Box::pin(async move {
@ -163,13 +156,43 @@ impl Service<Request<Incoming>> for Tunnels {
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() { if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
let user = match s.forward_auth.check_auth(req.headers()).await { let user = match s.forward_auth.check_auth(req.headers()).await {
Authenticated(user) => user, Ok(AuthStatus::Authenticated(user)) => user,
Unauthenticated(response) => return Ok(response), Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, location)
.body(
Empty::new()
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
.map_err(|never| match never {})
.boxed(),
)
.expect("configuration should be valid");
return Ok(resp);
}
Ok(AuthStatus::Unauthorized) => {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
Err(err) => {
error!("Unexpected error during authentication: {err}");
let resp = response(
StatusCode::FORBIDDEN,
"Unexpected error during authentication",
);
return Ok(resp);
}
}; };
trace!("Tunnel owned by {owner} is getting accessed by {user}"); trace!("Tunnel owned by {owner} is getting accessed by {user:?}");
if !user.eq(owner) { if !user.is(owner) {
let resp = response( let resp = response(
StatusCode::FORBIDDEN, StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel", "You do not have permission to access this tunnel",
@ -202,7 +225,7 @@ impl Service<Request<Incoming>> for Tunnels {
} }
}); });
let resp = sender.send_request(req).await.unwrap(); let resp = sender.send_request(req).await?;
Ok(resp.map(|b| b.boxed())) Ok(resp.map(|b| b.boxed()))
}) })
} }