Improved error handling
This commit is contained in:
parent
3937e06660
commit
31532493cb
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2909,6 +2909,7 @@ dependencies = [
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"russh",
|
"russh",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
|
|
@ -17,6 +17,7 @@ indexmap = "2.9.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
|
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
|
||||||
russh = "0.51.1"
|
russh = "0.51.1"
|
||||||
|
thiserror = "2.0.12"
|
||||||
tokio = { version = "1.44.1", features = ["full"] }
|
tokio = { version = "1.44.1", features = ["full"] }
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
|
||||||
|
|
101
src/auth.rs
101
src/auth.rs
|
@ -1,22 +1,50 @@
|
||||||
use bytes::Bytes;
|
|
||||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
|
||||||
use hyper::{
|
use hyper::{
|
||||||
HeaderMap, Response,
|
HeaderMap, StatusCode,
|
||||||
header::{self, HeaderValue},
|
header::{self, HeaderName, HeaderValue, ToStrError},
|
||||||
};
|
};
|
||||||
use reqwest::redirect::Policy;
|
use reqwest::redirect::Policy;
|
||||||
use tracing::debug;
|
use tracing::{debug, error};
|
||||||
|
|
||||||
pub enum AuthStatus {
|
|
||||||
Authenticated(String),
|
|
||||||
Unauthenticated(Response<BoxBody<Bytes, hyper::Error>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ForwardAuth {
|
pub struct ForwardAuth {
|
||||||
address: String,
|
address: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct User {
|
||||||
|
username: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl User {
|
||||||
|
pub fn is(&self, username: impl AsRef<str>) -> bool {
|
||||||
|
self.username.eq(username.as_ref())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum AuthStatus {
|
||||||
|
// Contains the value of the location header that will redirect the user to the login page
|
||||||
|
Unauthenticated(HeaderValue),
|
||||||
|
Authenticated(User),
|
||||||
|
Unauthorized,
|
||||||
|
}
|
||||||
|
|
||||||
|
const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user");
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum AuthError {
|
||||||
|
#[error("Reqwest error: {0}")]
|
||||||
|
Reqwest(#[from] reqwest::Error),
|
||||||
|
#[error("Http error: {0}")]
|
||||||
|
Http(#[from] hyper::http::Error),
|
||||||
|
#[error("Header '{0}' is missing from auth endpoint response")]
|
||||||
|
MissingHeader(HeaderName),
|
||||||
|
#[error("Header '{0}' received from auth endpoint is invalid: {1}")]
|
||||||
|
InvalidHeader(HeaderName, ToStrError),
|
||||||
|
#[error("Unexpected response from auth endpoint: {0:?}")]
|
||||||
|
UnexpectedResponse(reqwest::Response),
|
||||||
|
}
|
||||||
|
|
||||||
impl ForwardAuth {
|
impl ForwardAuth {
|
||||||
pub fn new(endpoint: impl Into<String>) -> Self {
|
pub fn new(endpoint: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -24,11 +52,13 @@ impl ForwardAuth {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn check_auth(&self, headers: &HeaderMap<HeaderValue>) -> AuthStatus {
|
pub async fn check_auth(
|
||||||
|
&self,
|
||||||
|
headers: &HeaderMap<HeaderValue>,
|
||||||
|
) -> Result<AuthStatus, AuthError> {
|
||||||
let client = reqwest::ClientBuilder::new()
|
let client = reqwest::ClientBuilder::new()
|
||||||
.redirect(Policy::none())
|
.redirect(Policy::none())
|
||||||
.build()
|
.build()?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let headers = headers
|
let headers = headers
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -45,42 +75,33 @@ impl ForwardAuth {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
debug!("{headers:#?}");
|
let resp = client.get(&self.address).headers(headers).send().await?;
|
||||||
|
|
||||||
let resp = client
|
|
||||||
.get(&self.address)
|
|
||||||
.headers(headers)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let status_code = resp.status();
|
let status_code = resp.status();
|
||||||
if !status_code.is_success() {
|
if status_code == StatusCode::FOUND {
|
||||||
debug!("{:#?}", resp.headers());
|
let location = resp
|
||||||
let location = resp.headers().get(header::LOCATION).unwrap().clone();
|
.headers()
|
||||||
let body = resp.bytes().await.unwrap();
|
.get(header::LOCATION)
|
||||||
let resp = Response::builder()
|
.cloned()
|
||||||
.status(status_code)
|
.ok_or(AuthError::MissingHeader(header::LOCATION))?;
|
||||||
.header(header::LOCATION, location)
|
|
||||||
.body(Full::new(body))
|
|
||||||
.unwrap()
|
|
||||||
.map(|b| b.map_err(|never| match never {}).boxed());
|
|
||||||
|
|
||||||
return AuthStatus::Unauthenticated(resp);
|
return Ok(AuthStatus::Unauthenticated(location));
|
||||||
|
} else if status_code == StatusCode::FORBIDDEN {
|
||||||
|
return Ok(AuthStatus::Unauthorized);
|
||||||
|
} else if !status_code.is_success() {
|
||||||
|
return Err(AuthError::UnexpectedResponse(resp));
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("{:#?}", resp.headers());
|
let username = resp
|
||||||
let user = resp
|
|
||||||
.headers()
|
.headers()
|
||||||
.get("remote-user")
|
.get(REMOTE_USER)
|
||||||
.unwrap()
|
.ok_or(AuthError::MissingHeader(REMOTE_USER))?
|
||||||
.to_str()
|
.to_str()
|
||||||
.unwrap()
|
.map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))?
|
||||||
.to_owned();
|
.to_owned();
|
||||||
debug!("{}", resp.text().await.unwrap());
|
|
||||||
|
|
||||||
debug!("Logged in as user: {user}");
|
debug!("Connected user is: {username}");
|
||||||
|
|
||||||
AuthStatus::Authenticated(user)
|
Ok(AuthStatus::Authenticated(User { username }))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,12 +18,9 @@ fn main() -> color_eyre::Result<()> {
|
||||||
|
|
||||||
color_eyre::install()?;
|
color_eyre::install()?;
|
||||||
|
|
||||||
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)
|
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?;
|
||||||
.expect("algorithm should be supported");
|
|
||||||
|
|
||||||
let key = key
|
let key = key.to_openssh(LineEnding::LF)?;
|
||||||
.to_openssh(LineEnding::LF)
|
|
||||||
.expect("encodig should not fail");
|
|
||||||
|
|
||||||
args.output
|
args.output
|
||||||
.write(key.as_bytes())
|
.write(key.as_bytes())
|
||||||
|
|
14
src/helper.rs
Normal file
14
src/helper.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
use bytes::Bytes;
|
||||||
|
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
||||||
|
use hyper::{Response, StatusCode};
|
||||||
|
|
||||||
|
pub fn response(
|
||||||
|
status_code: StatusCode,
|
||||||
|
body: impl Into<String>,
|
||||||
|
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||||
|
Response::builder()
|
||||||
|
.status(status_code)
|
||||||
|
.body(Full::new(Bytes::from(body.into())))
|
||||||
|
.expect("all configuration should be valid")
|
||||||
|
.map(|b| b.map_err(|never| match never {}).boxed())
|
||||||
|
}
|
|
@ -2,5 +2,6 @@
|
||||||
#![feature(let_chains)]
|
#![feature(let_chains)]
|
||||||
pub mod animals;
|
pub mod animals;
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
pub mod helper;
|
||||||
pub mod ssh;
|
pub mod ssh;
|
||||||
pub mod tunnel;
|
pub mod tunnel;
|
||||||
|
|
21
src/main.rs
21
src/main.rs
|
@ -1,29 +1,30 @@
|
||||||
use std::{net::SocketAddr, path::Path};
|
use std::{net::SocketAddr, path::Path};
|
||||||
|
|
||||||
|
use color_eyre::eyre::Context;
|
||||||
use dotenvy::dotenv;
|
use dotenvy::dotenv;
|
||||||
use hyper::server::conn::http1::{self};
|
use hyper::server::conn::http1::{self};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tracing::{info, warn};
|
use tracing::{error, info};
|
||||||
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
use tunnel_rs::{ssh::Server, tunnel::Tunnels};
|
use tunnel_rs::{ssh::Server, tunnel::Tunnels};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() -> color_eyre::Result<()> {
|
||||||
|
color_eyre::install()?;
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
|
||||||
.or_else(|_| EnvFilter::try_new("info"))
|
|
||||||
.expect("Fallback should be valid");
|
|
||||||
|
|
||||||
let logger = tracing_subscriber::fmt::layer().compact();
|
let logger = tracing_subscriber::fmt::layer().compact();
|
||||||
Registry::default().with(logger).with(env_filter).init();
|
Registry::default().with(logger).with(env_filter).init();
|
||||||
|
|
||||||
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
|
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
|
||||||
russh::keys::PrivateKey::read_openssh_file(Path::new(&path)).unwrap()
|
russh::keys::PrivateKey::read_openssh_file(Path::new(&path))
|
||||||
|
.wrap_err_with(|| format!("failed to read ssh key: {path}"))?
|
||||||
} else {
|
} else {
|
||||||
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap()
|
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let port = 3000;
|
let port = 3000;
|
||||||
|
@ -38,12 +39,12 @@ async fn main() {
|
||||||
info!("SSH is available on {addr}");
|
info!("SSH is available on {addr}");
|
||||||
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], port));
|
let addr = SocketAddr::from(([0, 0, 0, 0], port));
|
||||||
let listener = TcpListener::bind(addr).await.unwrap();
|
let listener = TcpListener::bind(addr).await?;
|
||||||
info!("HTTP is available on {addr}");
|
info!("HTTP is available on {addr}");
|
||||||
|
|
||||||
// TODO: Graceful shutdown
|
// TODO: Graceful shutdown
|
||||||
loop {
|
loop {
|
||||||
let (stream, _) = listener.accept().await.unwrap();
|
let (stream, _) = listener.accept().await?;
|
||||||
let io = TokioIo::new(stream);
|
let io = TokioIo::new(stream);
|
||||||
|
|
||||||
let tunnels = tunnels.clone();
|
let tunnels = tunnels.clone();
|
||||||
|
@ -55,7 +56,7 @@ async fn main() {
|
||||||
.with_upgrades()
|
.with_upgrades()
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
warn!("Failed to serve connection: {err:?}");
|
error!("Failed to serve connection: {err:?}");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -143,13 +143,11 @@ impl russh::server::Handler for Handler {
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
trace!("Sending error/help message and disconnecting");
|
trace!("Sending error/help message and disconnecting");
|
||||||
session
|
session.disconnect(
|
||||||
.disconnect(
|
|
||||||
russh::Disconnect::ByApplication,
|
russh::Disconnect::ByApplication,
|
||||||
&format!("\n\r{err}"),
|
&format!("\n\r{err}"),
|
||||||
"EN",
|
"EN",
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
session.channel_failure(channel)
|
session.channel_failure(channel)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
|
||||||
use hyper::{
|
use hyper::{
|
||||||
Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST,
|
Request, Response, StatusCode,
|
||||||
|
body::Incoming,
|
||||||
|
client::conn::http1::Builder,
|
||||||
|
header::{self, HOST},
|
||||||
service::Service,
|
service::Service,
|
||||||
};
|
};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
|
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, error, trace, warn};
|
||||||
|
|
||||||
use russh::{
|
use russh::{
|
||||||
Channel,
|
Channel,
|
||||||
|
@ -17,10 +20,8 @@ use tokio::sync::RwLock;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
animals::get_animal_name,
|
animals::get_animal_name,
|
||||||
auth::{
|
auth::{AuthStatus, ForwardAuth},
|
||||||
AuthStatus::{Authenticated, Unauthenticated},
|
helper::response,
|
||||||
ForwardAuth,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -120,18 +121,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||||
fn response(
|
trace!("{:#?}", req);
|
||||||
status_code: StatusCode,
|
|
||||||
body: impl Into<String>,
|
|
||||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
|
||||||
Response::builder()
|
|
||||||
.status(status_code)
|
|
||||||
.body(Full::new(Bytes::from(body.into())))
|
|
||||||
.unwrap()
|
|
||||||
.map(|b| b.map_err(|never| match never {}).boxed())
|
|
||||||
}
|
|
||||||
|
|
||||||
trace!(?req);
|
|
||||||
|
|
||||||
let Some(authority) = req
|
let Some(authority) = req
|
||||||
.uri()
|
.uri()
|
||||||
|
@ -141,15 +131,18 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
req.headers()
|
req.headers()
|
||||||
.get(HOST)
|
.get(HOST)
|
||||||
.map(|h| h.to_str().unwrap().to_owned())
|
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
|
||||||
})
|
})
|
||||||
else {
|
else {
|
||||||
let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header");
|
let resp = response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Missing or invalid authority or host header",
|
||||||
|
);
|
||||||
|
|
||||||
return Box::pin(async { Ok(resp) });
|
return Box::pin(async { Ok(resp) });
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!(tunnel = authority, "Request");
|
debug!(tunnel = authority, "Tunnel request");
|
||||||
|
|
||||||
let s = self.clone();
|
let s = self.clone();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
|
@ -163,13 +156,43 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||||
|
|
||||||
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
|
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
|
||||||
let user = match s.forward_auth.check_auth(req.headers()).await {
|
let user = match s.forward_auth.check_auth(req.headers()).await {
|
||||||
Authenticated(user) => user,
|
Ok(AuthStatus::Authenticated(user)) => user,
|
||||||
Unauthenticated(response) => return Ok(response),
|
Ok(AuthStatus::Unauthenticated(location)) => {
|
||||||
|
let resp = Response::builder()
|
||||||
|
.status(StatusCode::FOUND)
|
||||||
|
.header(header::LOCATION, location)
|
||||||
|
.body(
|
||||||
|
Empty::new()
|
||||||
|
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
|
||||||
|
.map_err(|never| match never {})
|
||||||
|
.boxed(),
|
||||||
|
)
|
||||||
|
.expect("configuration should be valid");
|
||||||
|
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
Ok(AuthStatus::Unauthorized) => {
|
||||||
|
let resp = response(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"You do not have permission to access this tunnel",
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!("Unexpected error during authentication: {err}");
|
||||||
|
let resp = response(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"Unexpected error during authentication",
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("Tunnel owned by {owner} is getting accessed by {user}");
|
trace!("Tunnel owned by {owner} is getting accessed by {user:?}");
|
||||||
|
|
||||||
if !user.eq(owner) {
|
if !user.is(owner) {
|
||||||
let resp = response(
|
let resp = response(
|
||||||
StatusCode::FORBIDDEN,
|
StatusCode::FORBIDDEN,
|
||||||
"You do not have permission to access this tunnel",
|
"You do not have permission to access this tunnel",
|
||||||
|
@ -202,7 +225,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let resp = sender.send_request(req).await.unwrap();
|
let resp = sender.send_request(req).await?;
|
||||||
Ok(resp.map(|b| b.boxed()))
|
Ok(resp.map(|b| b.boxed()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user