Improved error handling
This commit is contained in:
parent
3937e06660
commit
31532493cb
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2909,6 +2909,7 @@ dependencies = [
|
|||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"russh",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
|
|
@ -17,6 +17,7 @@ indexmap = "2.9.0"
|
|||
rand = "0.8.5"
|
||||
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
|
||||
russh = "0.51.1"
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1.44.1", features = ["full"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
|
||||
|
|
107
src/auth.rs
107
src/auth.rs
|
@ -1,22 +1,50 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
||||
use hyper::{
|
||||
HeaderMap, Response,
|
||||
header::{self, HeaderValue},
|
||||
HeaderMap, StatusCode,
|
||||
header::{self, HeaderName, HeaderValue, ToStrError},
|
||||
};
|
||||
use reqwest::redirect::Policy;
|
||||
use tracing::debug;
|
||||
|
||||
pub enum AuthStatus {
|
||||
Authenticated(String),
|
||||
Unauthenticated(Response<BoxBody<Bytes, hyper::Error>>),
|
||||
}
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ForwardAuth {
|
||||
address: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct User {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn is(&self, username: impl AsRef<str>) -> bool {
|
||||
self.username.eq(username.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthStatus {
|
||||
// Contains the value of the location header that will redirect the user to the login page
|
||||
Unauthenticated(HeaderValue),
|
||||
Authenticated(User),
|
||||
Unauthorized,
|
||||
}
|
||||
|
||||
const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user");
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("Reqwest error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("Http error: {0}")]
|
||||
Http(#[from] hyper::http::Error),
|
||||
#[error("Header '{0}' is missing from auth endpoint response")]
|
||||
MissingHeader(HeaderName),
|
||||
#[error("Header '{0}' received from auth endpoint is invalid: {1}")]
|
||||
InvalidHeader(HeaderName, ToStrError),
|
||||
#[error("Unexpected response from auth endpoint: {0:?}")]
|
||||
UnexpectedResponse(reqwest::Response),
|
||||
}
|
||||
|
||||
impl ForwardAuth {
|
||||
pub fn new(endpoint: impl Into<String>) -> Self {
|
||||
Self {
|
||||
|
@ -24,11 +52,13 @@ impl ForwardAuth {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn check_auth(&self, headers: &HeaderMap<HeaderValue>) -> AuthStatus {
|
||||
pub async fn check_auth(
|
||||
&self,
|
||||
headers: &HeaderMap<HeaderValue>,
|
||||
) -> Result<AuthStatus, AuthError> {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.redirect(Policy::none())
|
||||
.build()
|
||||
.unwrap();
|
||||
.build()?;
|
||||
|
||||
let headers = headers
|
||||
.clone()
|
||||
|
@ -45,42 +75,33 @@ impl ForwardAuth {
|
|||
})
|
||||
.collect();
|
||||
|
||||
debug!("{headers:#?}");
|
||||
|
||||
let resp = client
|
||||
.get(&self.address)
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let resp = client.get(&self.address).headers(headers).send().await?;
|
||||
|
||||
let status_code = resp.status();
|
||||
if !status_code.is_success() {
|
||||
debug!("{:#?}", resp.headers());
|
||||
let location = resp.headers().get(header::LOCATION).unwrap().clone();
|
||||
let body = resp.bytes().await.unwrap();
|
||||
let resp = Response::builder()
|
||||
.status(status_code)
|
||||
.header(header::LOCATION, location)
|
||||
.body(Full::new(body))
|
||||
.unwrap()
|
||||
.map(|b| b.map_err(|never| match never {}).boxed());
|
||||
|
||||
return AuthStatus::Unauthenticated(resp);
|
||||
}
|
||||
|
||||
debug!("{:#?}", resp.headers());
|
||||
let user = resp
|
||||
if status_code == StatusCode::FOUND {
|
||||
let location = resp
|
||||
.headers()
|
||||
.get("remote-user")
|
||||
.unwrap()
|
||||
.get(header::LOCATION)
|
||||
.cloned()
|
||||
.ok_or(AuthError::MissingHeader(header::LOCATION))?;
|
||||
|
||||
return Ok(AuthStatus::Unauthenticated(location));
|
||||
} else if status_code == StatusCode::FORBIDDEN {
|
||||
return Ok(AuthStatus::Unauthorized);
|
||||
} else if !status_code.is_success() {
|
||||
return Err(AuthError::UnexpectedResponse(resp));
|
||||
}
|
||||
|
||||
let username = resp
|
||||
.headers()
|
||||
.get(REMOTE_USER)
|
||||
.ok_or(AuthError::MissingHeader(REMOTE_USER))?
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))?
|
||||
.to_owned();
|
||||
debug!("{}", resp.text().await.unwrap());
|
||||
|
||||
debug!("Logged in as user: {user}");
|
||||
debug!("Connected user is: {username}");
|
||||
|
||||
AuthStatus::Authenticated(user)
|
||||
Ok(AuthStatus::Authenticated(User { username }))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,12 +18,9 @@ fn main() -> color_eyre::Result<()> {
|
|||
|
||||
color_eyre::install()?;
|
||||
|
||||
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)
|
||||
.expect("algorithm should be supported");
|
||||
let key = russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?;
|
||||
|
||||
let key = key
|
||||
.to_openssh(LineEnding::LF)
|
||||
.expect("encodig should not fail");
|
||||
let key = key.to_openssh(LineEnding::LF)?;
|
||||
|
||||
args.output
|
||||
.write(key.as_bytes())
|
||||
|
|
14
src/helper.rs
Normal file
14
src/helper.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
||||
use hyper::{Response, StatusCode};
|
||||
|
||||
pub fn response(
|
||||
status_code: StatusCode,
|
||||
body: impl Into<String>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status_code)
|
||||
.body(Full::new(Bytes::from(body.into())))
|
||||
.expect("all configuration should be valid")
|
||||
.map(|b| b.map_err(|never| match never {}).boxed())
|
||||
}
|
|
@ -2,5 +2,6 @@
|
|||
#![feature(let_chains)]
|
||||
pub mod animals;
|
||||
pub mod auth;
|
||||
pub mod helper;
|
||||
pub mod ssh;
|
||||
pub mod tunnel;
|
||||
|
|
21
src/main.rs
21
src/main.rs
|
@ -1,29 +1,30 @@
|
|||
use std::{net::SocketAddr, path::Path};
|
||||
|
||||
use color_eyre::eyre::Context;
|
||||
use dotenvy::dotenv;
|
||||
use hyper::server::conn::http1::{self};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rand::rngs::OsRng;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use tunnel_rs::{ssh::Server, tunnel::Tunnels};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
async fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
dotenv().ok();
|
||||
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.or_else(|_| EnvFilter::try_new("info"))
|
||||
.expect("Fallback should be valid");
|
||||
let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
|
||||
|
||||
let logger = tracing_subscriber::fmt::layer().compact();
|
||||
Registry::default().with(logger).with(env_filter).init();
|
||||
|
||||
let key = if let Ok(path) = std::env::var("PRIVATE_KEY_FILE") {
|
||||
russh::keys::PrivateKey::read_openssh_file(Path::new(&path)).unwrap()
|
||||
russh::keys::PrivateKey::read_openssh_file(Path::new(&path))
|
||||
.wrap_err_with(|| format!("failed to read ssh key: {path}"))?
|
||||
} else {
|
||||
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap()
|
||||
russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)?
|
||||
};
|
||||
|
||||
let port = 3000;
|
||||
|
@ -38,12 +39,12 @@ async fn main() {
|
|||
info!("SSH is available on {addr}");
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], port));
|
||||
let listener = TcpListener::bind(addr).await.unwrap();
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
info!("HTTP is available on {addr}");
|
||||
|
||||
// TODO: Graceful shutdown
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let tunnels = tunnels.clone();
|
||||
|
@ -55,7 +56,7 @@ async fn main() {
|
|||
.with_upgrades()
|
||||
.await
|
||||
{
|
||||
warn!("Failed to serve connection: {err:?}");
|
||||
error!("Failed to serve connection: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -143,13 +143,11 @@ impl russh::server::Handler for Handler {
|
|||
}
|
||||
Err(err) => {
|
||||
trace!("Sending error/help message and disconnecting");
|
||||
session
|
||||
.disconnect(
|
||||
session.disconnect(
|
||||
russh::Disconnect::ByApplication,
|
||||
&format!("\n\r{err}"),
|
||||
"EN",
|
||||
)
|
||||
.unwrap();
|
||||
)?;
|
||||
|
||||
session.channel_failure(channel)
|
||||
}
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
||||
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
|
||||
use hyper::{
|
||||
Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST,
|
||||
Request, Response, StatusCode,
|
||||
body::Incoming,
|
||||
client::conn::http1::Builder,
|
||||
header::{self, HOST},
|
||||
service::Service,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use indexmap::IndexMap;
|
||||
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
|
||||
use tracing::{debug, trace, warn};
|
||||
use tracing::{debug, error, trace, warn};
|
||||
|
||||
use russh::{
|
||||
Channel,
|
||||
|
@ -17,10 +20,8 @@ use tokio::sync::RwLock;
|
|||
|
||||
use crate::{
|
||||
animals::get_animal_name,
|
||||
auth::{
|
||||
AuthStatus::{Authenticated, Unauthenticated},
|
||||
ForwardAuth,
|
||||
},
|
||||
auth::{AuthStatus, ForwardAuth},
|
||||
helper::response,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -120,18 +121,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
|||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||
fn response(
|
||||
status_code: StatusCode,
|
||||
body: impl Into<String>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status_code)
|
||||
.body(Full::new(Bytes::from(body.into())))
|
||||
.unwrap()
|
||||
.map(|b| b.map_err(|never| match never {}).boxed())
|
||||
}
|
||||
|
||||
trace!(?req);
|
||||
trace!("{:#?}", req);
|
||||
|
||||
let Some(authority) = req
|
||||
.uri()
|
||||
|
@ -141,15 +131,18 @@ impl Service<Request<Incoming>> for Tunnels {
|
|||
.or_else(|| {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.map(|h| h.to_str().unwrap().to_owned())
|
||||
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
|
||||
})
|
||||
else {
|
||||
let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header");
|
||||
let resp = response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing or invalid authority or host header",
|
||||
);
|
||||
|
||||
return Box::pin(async { Ok(resp) });
|
||||
};
|
||||
|
||||
debug!(tunnel = authority, "Request");
|
||||
debug!(tunnel = authority, "Tunnel request");
|
||||
|
||||
let s = self.clone();
|
||||
Box::pin(async move {
|
||||
|
@ -163,13 +156,43 @@ impl Service<Request<Incoming>> for Tunnels {
|
|||
|
||||
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
|
||||
let user = match s.forward_auth.check_auth(req.headers()).await {
|
||||
Authenticated(user) => user,
|
||||
Unauthenticated(response) => return Ok(response),
|
||||
Ok(AuthStatus::Authenticated(user)) => user,
|
||||
Ok(AuthStatus::Unauthenticated(location)) => {
|
||||
let resp = Response::builder()
|
||||
.status(StatusCode::FOUND)
|
||||
.header(header::LOCATION, location)
|
||||
.body(
|
||||
Empty::new()
|
||||
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
|
||||
.map_err(|never| match never {})
|
||||
.boxed(),
|
||||
)
|
||||
.expect("configuration should be valid");
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
Ok(AuthStatus::Unauthorized) => {
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"You do not have permission to access this tunnel",
|
||||
);
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Unexpected error during authentication: {err}");
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Unexpected error during authentication",
|
||||
);
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Tunnel owned by {owner} is getting accessed by {user}");
|
||||
trace!("Tunnel owned by {owner} is getting accessed by {user:?}");
|
||||
|
||||
if !user.eq(owner) {
|
||||
if !user.is(owner) {
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"You do not have permission to access this tunnel",
|
||||
|
@ -202,7 +225,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
|||
}
|
||||
});
|
||||
|
||||
let resp = sender.send_request(req).await.unwrap();
|
||||
let resp = sender.send_request(req).await?;
|
||||
Ok(resp.map(|b| b.boxed()))
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user