Tunnels can now by default only be accessed by the owner, a flag is provided to make the tunnel public

This commit is contained in:
Dreaded_X 2025-04-07 17:05:21 +02:00
parent 9a1eeb9b69
commit 01457a185f
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
7 changed files with 1229 additions and 62 deletions

1106
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ http-body-util = { version = "0.1.3", features = ["full"] }
hyper = { version = "1.6.0", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] }
hyper-util = { version = "0.1.11", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] }
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1" russh = "0.51.1"
tokio = { version = "1.44.1", features = ["full"] } tokio = { version = "1.44.1", features = ["full"] }
tracing = "0.1.41" tracing = "0.1.41"

86
src/auth.rs Normal file
View File

@ -0,0 +1,86 @@
use bytes::Bytes;
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
use hyper::{
HeaderMap, Response,
header::{self, HeaderValue},
};
use reqwest::redirect::Policy;
use tracing::debug;
pub enum AuthStatus {
Authenticated(String),
Unauthenticated(Response<BoxBody<Bytes, hyper::Error>>),
}
#[derive(Debug, Clone)]
pub struct ForwardAuth {
address: String,
}
impl ForwardAuth {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
address: endpoint.into(),
}
}
pub async fn check_auth(&self, headers: &HeaderMap<HeaderValue>) -> AuthStatus {
let client = reqwest::ClientBuilder::new()
.redirect(Policy::none())
.build()
.unwrap();
let headers = headers
.clone()
.into_iter()
.filter_map(|(key, value)| {
if let Some(key) = key
&& key != header::CONTENT_LENGTH
&& key != header::HOST
{
Some((key, value))
} else {
None
}
})
.collect();
debug!("{headers:#?}");
let resp = client
.get(&self.address)
.headers(headers)
.send()
.await
.unwrap();
let status_code = resp.status();
if !status_code.is_success() {
debug!("{:#?}", resp.headers());
let location = resp.headers().get(header::LOCATION).unwrap().clone();
let body = resp.bytes().await.unwrap();
let resp = Response::builder()
.status(status_code)
.header(header::LOCATION, location)
.body(Full::new(body))
.unwrap()
.map(|b| b.map_err(|never| match never {}).boxed());
return AuthStatus::Unauthenticated(resp);
}
debug!("{:#?}", resp.headers());
let user = resp
.headers()
.get("remote-user")
.unwrap()
.to_str()
.unwrap()
.to_owned();
debug!("{}", resp.text().await.unwrap());
debug!("Logged in as user: {user}");
AuthStatus::Authenticated(user)
}
}

View File

@ -1,4 +1,6 @@
#![feature(impl_trait_in_fn_trait_return)] #![feature(impl_trait_in_fn_trait_return)]
#![feature(let_chains)]
pub mod animals; pub mod animals;
pub mod auth;
pub mod ssh; pub mod ssh;
pub mod tunnel; pub mod tunnel;

View File

@ -7,7 +7,7 @@ use rand::rngs::OsRng;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{info, warn}; use tracing::{info, warn};
use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt};
use tunnel_rs::ssh::Server; use tunnel_rs::{ssh::Server, tunnel::Tunnels};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -28,9 +28,11 @@ async fn main() {
let port = 3000; let port = 3000;
let domain = std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{port}")); let domain = std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{port}"));
let authz_address = std::env::var("AUTHZ_ENDPOINT")
.unwrap_or("http://localhost:9091/api/authz/forward-auth".into());
let mut ssh = Server::new(domain); let tunnels = Tunnels::new(domain, authz_address);
let tunnels = ssh.tunnels(); let mut ssh = Server::new(tunnels.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], 2222)); let addr = SocketAddr::from(([0, 0, 0, 0], 2222));
tokio::spawn(async move { ssh.run(key, addr).await }); tokio::spawn(async move { ssh.run(key, addr).await });
info!("SSH is available on {addr}"); info!("SSH is available on {addr}");

View File

@ -12,7 +12,7 @@ use tokio::{
}; };
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::tunnel::{Tunnel, Tunnels}; use crate::tunnel::{Tunnel, TunnelAccess, Tunnels};
pub struct Handler { pub struct Handler {
tx: UnboundedSender<Vec<u8>>, tx: UnboundedSender<Vec<u8>>,
@ -20,6 +20,8 @@ pub struct Handler {
all_tunnels: Tunnels, all_tunnels: Tunnels,
tunnels: HashSet<String>, tunnels: HashSet<String>,
access: Option<TunnelAccess>,
} }
impl Handler { impl Handler {
@ -30,12 +32,22 @@ impl Handler {
fn sendln(&self, data: impl AsRef<str>) { fn sendln(&self, data: impl AsRef<str>) {
self.send(format!("{}\n\r", data.as_ref())); self.send(format!("{}\n\r", data.as_ref()));
} }
async fn set_access(&mut self, access: TunnelAccess) {
self.access = Some(access.clone());
for tunnel in &self.tunnels {
self.all_tunnels.set_access(tunnel, access.clone()).await;
}
}
} }
/// Quickly create http tunnels for development /// Quickly create http tunnels for development
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
struct Args { struct Args {
#[arg(short, long)]
public: bool,
} }
impl russh::server::Handler for Handler { impl russh::server::Handler for Handler {
@ -76,6 +88,8 @@ impl russh::server::Handler for Handler {
) -> Result<Auth, Self::Error> { ) -> Result<Auth, Self::Error> {
debug!("Login from {user}"); debug!("Login from {user}");
self.set_access(TunnelAccess::Private(user.into())).await;
// TODO: Get ssh keys associated with user from ldap // TODO: Get ssh keys associated with user from ldap
Ok(Auth::Accept) Ok(Auth::Accept)
} }
@ -108,6 +122,10 @@ impl russh::server::Handler for Handler {
match Args::try_parse_from(cmd) { match Args::try_parse_from(cmd) {
Ok(args) => { Ok(args) => {
debug!("{args:?}"); debug!("{args:?}");
if args.public {
trace!("Making tunnels public");
self.set_access(TunnelAccess::Public).await;
}
} }
Err(err) => { Err(err) => {
self.send(format!("{err}")); self.send(format!("{err}"));
@ -128,7 +146,11 @@ impl russh::server::Handler for Handler {
) -> Result<bool, Self::Error> { ) -> Result<bool, Self::Error> {
trace!(address, port, "tcpip_forward"); trace!(address, port, "tcpip_forward");
let tunnel = Tunnel::new(session.handle(), address, *port); let Some(access) = self.access.clone() else {
return Err(russh::Error::Inconsistent);
};
let tunnel = Tunnel::new(session.handle(), address, *port, access);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else { let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
self.sendln(format!("FAILED: ({address} already in use)")); self.sendln(format!("FAILED: ({address} already in use)"));
return Ok(false); return Ok(false);
@ -159,10 +181,8 @@ pub struct Server {
} }
impl Server { impl Server {
pub fn new(domain: impl Into<String>) -> Self { pub fn new(tunnels: Tunnels) -> Self {
Server { Server { tunnels }
tunnels: Tunnels::new(domain),
}
} }
pub fn tunnels(&self) -> Tunnels { pub fn tunnels(&self) -> Tunnels {
@ -203,6 +223,7 @@ impl russh::server::Server for Server {
rx: Some(rx), rx: Some(rx),
all_tunnels: self.tunnels.clone(), all_tunnels: self.tunnels.clone(),
tunnels: HashSet::new(), tunnels: HashSet::new(),
access: None,
} }
} }

View File

@ -18,21 +18,40 @@ use russh::{
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::animals::get_animal_name; use crate::{
animals::get_animal_name,
auth::{
AuthStatus::{Authenticated, Unauthenticated},
ForwardAuth,
},
};
#[derive(Debug, Clone)]
pub enum TunnelAccess {
Private(String),
Public,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Tunnel { pub struct Tunnel {
handle: Handle, handle: Handle,
address: String, address: String,
port: u32, port: u32,
access: TunnelAccess,
} }
impl Tunnel { impl Tunnel {
pub fn new(handle: Handle, address: impl Into<String>, port: u32) -> Self { pub fn new(
handle: Handle,
address: impl Into<String>,
port: u32,
access: TunnelAccess,
) -> Self {
Self { Self {
handle, handle,
address: address.into(), address: address.into(),
port, port,
access,
} }
} }
@ -48,13 +67,15 @@ impl Tunnel {
pub struct Tunnels { pub struct Tunnels {
tunnels: Arc<RwLock<HashMap<String, Tunnel>>>, tunnels: Arc<RwLock<HashMap<String, Tunnel>>>,
domain: String, domain: String,
forward_auth: ForwardAuth,
} }
impl Tunnels { impl Tunnels {
pub fn new(domain: impl Into<String>) -> Self { pub fn new(domain: impl Into<String>, endpoint: impl Into<String>) -> Self {
Self { Self {
tunnels: Arc::new(RwLock::new(HashMap::new())), tunnels: Arc::new(RwLock::new(HashMap::new())),
domain: domain.into(), domain: domain.into(),
forward_auth: ForwardAuth::new(endpoint),
} }
} }
@ -90,6 +111,12 @@ impl Tunnels {
all_tunnels.remove(&tunnel); all_tunnels.remove(&tunnel);
} }
} }
pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) {
if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) {
tunnel.access = access;
};
}
} }
impl Service<Request<Incoming>> for Tunnels { impl Service<Request<Incoming>> for Tunnels {
@ -129,9 +156,9 @@ impl Service<Request<Incoming>> for Tunnels {
debug!(tunnel = authority, "Request"); debug!(tunnel = authority, "Request");
let tunnels = self.tunnels.clone(); let s = self.clone();
Box::pin(async move { Box::pin(async move {
let tunnels = tunnels.read().await; let tunnels = s.tunnels.read().await;
let Some(tunnel) = tunnels.get(&authority) else { let Some(tunnel) = tunnels.get(&authority) else {
debug!(tunnel = authority, "Unknown tunnel"); debug!(tunnel = authority, "Unknown tunnel");
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
@ -139,6 +166,24 @@ impl Service<Request<Incoming>> for Tunnels {
return Ok(resp); return Ok(resp);
}; };
if let TunnelAccess::Private(owner) = &tunnel.access {
let user = match s.forward_auth.check_auth(req.headers()).await {
Authenticated(user) => user,
Unauthenticated(response) => return Ok(response),
};
trace!("Tunnel owned by {owner} is getting accessed by {user}");
if !user.eq(owner) {
let resp = response(
StatusCode::FORBIDDEN,
"You do not have permission to access this tunnel",
);
return Ok(resp);
}
}
let channel = match tunnel.open_tunnel().await { let channel = match tunnel.open_tunnel().await {
Ok(channel) => channel, Ok(channel) => channel,
Err(err) => { Err(err) => {