Improved error handling

This commit is contained in:
Dreaded_X 2025-04-08 15:16:47 +02:00
parent 3937e06660
commit 31532493cb
Signed by: Dreaded_X
GPG Key ID: 5A0CBFE3C3377FAA
9 changed files with 146 additions and 89 deletions

1
Cargo.lock generated
View File

@ -2909,6 +2909,7 @@ dependencies = [
"rand 0.8.5",
"reqwest",
"russh",
"thiserror 2.0.12",
"tokio",
"tracing",
"tracing-subscriber",

View File

@ -17,6 +17,7 @@ indexmap = "2.9.0"
rand = "0.8.5"
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1"
thiserror = "2.0.12"
tokio = { version = "1.44.1", features = ["full"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }

View File

@ -1,22 +1,50 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use hyper::{
HeaderMap, Response,
header::{self, HeaderValue},
HeaderMap, StatusCode,
header::{self, HeaderName, HeaderValue, ToStrError},
};
use reqwest::redirect::Policy;
use tracing::debug;
pub enum AuthStatus {
Authenticated(String),
Unauthenticated(Response<BoxBody<Bytes, hyper::Error>>),
}
use tracing::{debug, error};
#[derive(Debug, Clone)]
pub struct ForwardAuth {
address: String,
}
#[derive(Debug)]
pub struct User {
username: String,
}
impl User {
pub fn is(&self, username: impl AsRef<str>) -> bool {
self.username.eq(username.as_ref())
}
}
#[derive(Debug)]
pub enum AuthStatus {
// Contains the value of the location header that will redirect the user to the login page
Unauthenticated(HeaderValue),
Authenticated(User),
Unauthorized,
}
const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user");
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Http error: {0}")]
Http(#[from] hyper::http::Error),
#[error("Header '{0}' is missing from auth endpoint response")]
MissingHeader(HeaderName),
#[error("Header '{0}' received from auth endpoint is invalid: {1}")]
InvalidHeader(HeaderName, ToStrError),
#[error("Unexpected response from auth endpoint: {0:?}")]
UnexpectedResponse(reqwest::Response),
}
impl ForwardAuth {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
@ -24,11 +52,13 @@ impl ForwardAuth {
}
}
pub async fn check_auth(&self, headers: &HeaderMap<HeaderValue>) -> AuthStatus {
pub async fn check_auth(
&self,
headers: &HeaderMap<HeaderValue>,
) -> Result<AuthStatus, AuthError> {
let client = reqwest::ClientBuilder::new()
.redirect(Policy::none())
.build()
.unwrap();
.build()?;
let headers = headers
.clone()
@ -45,42 +75,33 @@ impl ForwardAuth {
})
.collect();
debug!("{headers:#?}");
let resp = client
.get(&self.address)
.headers(headers)
.send()
.await
.unwrap();
let resp = client.get(&self.address).headers(headers).send().await?;
let status_code = resp.status();
if !status_code.is_success() {
debug!("{:#?}", resp.headers());
let location = resp.headers().get(header::LOCATION).unwrap().clone();
let body = resp.bytes().await.unwrap();
let resp = Response::builder()
.status(status_code)
.header(header::LOCATION, location)
.body(Full::new(body))
.unwrap()
.map(|b| b.map_err(|never| match never {}).boxed());
if status_code == StatusCode::FOUND {
let location = resp
.headers()
.get(header::LOCATION)
.cloned()
.ok_or(AuthError::MissingHeader(header::LOCATION))?;
return AuthStatus::Unauthenticated(resp);
return Ok(AuthStatus::Unauthenticated(location));
} else if status_code == StatusCode::FORBIDDEN {
return Ok(AuthStatus::Unauthorized);
} else if !status_code.is_success() {
return Err(AuthError::UnexpectedResponse(resp));
}
debug!("{:#?}", resp.headers());
let user = resp
let username = resp
.headers()
.get("remote-user")
.unwrap()
.get(REMOTE_USER)
.ok_or(AuthError::MissingHeader(REMOTE_USER))?
.to_str()
.unwrap()
.map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))?
.to_owned();
debug!("{}", resp.text().await.unwrap());
debug!("Logged in as user: {user}");
debug!("Connected user is: {username}");
AuthStatus::Authenticated(user)
Ok(AuthStatus::Authenticated(User { username }))
}
}

View File

@ -18,12 +18,9 @@ fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)
.expect("algorithm should be supported");
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?;
let key = key
.to_openssh(LineEnding::LF)
.expect("encodig should not fail");
let key = key.to_openssh(LineEnding::LF)?;
args.output
.write(key.as_bytes())

14
src/helper.rs Normal file
View File

@ -0,0 +1,14 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use hyper::{Response, StatusCode};
pub fn response(
status_code: StatusCode,
body: impl Into<String>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status_code)
.body(Full::new(Bytes::from(body.into())))
.expect("all configuration should be valid")
.map(|b| b.map_err(|never| match never {}).boxed())
}

View File

@ -2,5 +2,6 @@
#![feature(let_chains)]
pub mod animals;
pub mod auth;
pub mod helper;
pub mod ssh;
pub mod tunnel;

View File

@ -1,29 +1,30 @@
use std::{net::SocketAddr, path::Path};
use color_eyre::eyre::Context;
use dotenvy::dotenv;
use hyper::server::conn::http1::{self};
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use tokio::net::TcpListener;
use tracing::{info, warn};
use tracing::{error, info};
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
use tunnel_rs::{ssh::Server, tunnel::Tunnels};
#[tokio::main]
async fn main() {
async fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
dotenv().ok();
let env_filter = EnvFilter::try_from_default_env()
.or_else(|_| EnvFilter::try_new("info"))
.expect("Fallback should be valid");
let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
let logger = tracing_subscriber::fmt::layer().compact();
Registry::default().with(logger).with(env_filter).init();
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
russh::keys::PrivateKey::read_openssh_file(Path::new(&path)).unwrap()
russh::keys::PrivateKey::read_openssh_file(Path::new(&path))
.wrap_err_with(|| format!("failed to read ssh key: {path}"))?
} else {
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap()
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?
};
let port = 3000;
@ -38,12 +39,12 @@ async fn main() {
info!("SSH is available on {addr}");
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await.unwrap();
let listener = TcpListener::bind(addr).await?;
info!("HTTP is available on {addr}");
// TODO: Graceful shutdown
loop {
let (stream, _) = listener.accept().await.unwrap();
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let tunnels = tunnels.clone();
@ -55,7 +56,7 @@ async fn main() {
.with_upgrades()
.await
{
warn!("Failed to serve connection: {err:?}");
error!("Failed to serve connection: {err:?}");
}
});
}

View File

@ -143,13 +143,11 @@ impl russh::server::Handler for Handler {
}
Err(err) => {
trace!("Sending error/help message and disconnecting");
session
.disconnect(
russh::Disconnect::ByApplication,
&format!("\n\r{err}"),
"EN",
)
.unwrap();
session.disconnect(
russh::Disconnect::ByApplication,
&format!("\n\r{err}"),
"EN",
)?;
session.channel_failure(channel)
}

View File

@ -1,13 +1,16 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
use hyper::{
Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST,
Request, Response, StatusCode,
body::Incoming,
client::conn::http1::Builder,
header::{self, HOST},
service::Service,
};
use hyper_util::rt::TokioIo;
use indexmap::IndexMap;
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
use tracing::{debug, trace, warn};
use tracing::{debug, error, trace, warn};
use russh::{
Channel,
@ -17,10 +20,8 @@ use tokio::sync::RwLock;
use crate::{
animals::get_animal_name,
auth::{
AuthStatus::{Authenticated, Unauthenticated},
ForwardAuth,
},
auth::{AuthStatus, ForwardAuth},
helper::response,
};
#[derive(Debug, Clone)]
@ -120,18 +121,7 @@ impl Service<Request<Incoming>> for Tunnels {
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
fn response(
status_code: StatusCode,
body: impl Into<String>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status_code)
.body(Full::new(Bytes::from(body.into())))
.unwrap()
.map(|b| b.map_err(|never| match never {}).boxed())
}
trace!(?req);
trace!("{:#?}", req);
let Some(authority) = req
.uri()
@ -141,15 +131,18 @@ impl Service<Request<Incoming>> for Tunnels {
.or_else(|| {
req.headers()
.get(HOST)
.map(|h| h.to_str().unwrap().to_owned())
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
})
else {
let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header");
let resp = response(
StatusCode::BAD_REQUEST,
"Missing or invalid authority or host header",
);
return Box::pin(async { Ok(resp) });
};
debug!(tunnel = authority, "Request");
debug!(tunnel = authority, "Tunnel request");
let s = self.clone();
Box::pin(async move {
@ -163,13 +156,43 @@ impl Service<Request<Incoming>> for Tunnels {
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
let user = match s.forward_auth.check_auth(req.headers()).await {
Authenticated(user) => user,
Unauthenticated(response) => return Ok(response),
Ok(AuthStatus::Authenticated(user)) => user,
Ok(AuthStatus::Unauthenticated(location)) => {
let resp = Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, location)
.body(
Empty::new()
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
.map_err(|never| match never {})
.boxed(),
)
.expect("configuration should be valid");
return Ok(resp);
}
Ok(AuthStatus::Unauthorized) => {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
Err(err) => {
error!("Unexpected error during authentication: {err}");
let resp = response(
StatusCode::FORBIDDEN,
"Unexpected error during authentication",
);
return Ok(resp);
}
};
trace!("Tunnel owned by {owner} is getting accessed by {user}");
trace!("Tunnel owned by {owner} is getting accessed by {user:?}");
if !user.eq(owner) {
if !user.is(owner) {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
@ -202,7 +225,7 @@ impl Service<Request<Incoming>> for Tunnels {
}
});
let resp = sender.send_request(req).await.unwrap();
let resp = sender.send_request(req).await?;
Ok(resp.map(|b| b.boxed()))
})
}