Store tunnel access in Arc<RwLock> and keep a copy in the handler

This should make it easier to access a tunnel in the handler and to
modify its access
This commit is contained in:
Dreaded_X 2025-04-08 02:01:18 +02:00
parent 12398b8e5a
commit 5213aee232
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
4 changed files with 27 additions and 29 deletions

7
Cargo.lock generated
View File

@ -1192,9 +1192,9 @@ dependencies = [
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.8.0" version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown", "hashbrown",
@ -1800,7 +1800,7 @@ dependencies = [
"once_cell", "once_cell",
"socket2", "socket2",
"tracing", "tracing",
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@ -2788,6 +2788,7 @@ dependencies = [
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-util", "hyper-util",
"indexmap",
"rand 0.8.5", "rand 0.8.5",
"reqwest", "reqwest",
"russh", "russh",

View File

@ -11,6 +11,7 @@ dotenvy = "0.15.7"
http-body-util = { version = "0.1.3", features = ["full"] } http-body-util = { version = "0.1.3", features = ["full"] }
hyper = { version = "1.6.0", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] }
hyper-util = { version = "0.1.11", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] }
indexmap = "2.9.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.15", features = ["rustls-tls"] } reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1" russh = "0.51.1"

View File

@ -1,6 +1,7 @@
use std::{collections::HashSet, iter::once, net::SocketAddr, sync::Arc, time::Duration}; use std::{iter::once, net::SocketAddr, sync::Arc, time::Duration};
use clap::Parser; use clap::Parser;
use indexmap::IndexMap;
use russh::{ use russh::{
ChannelId, ChannelId,
keys::PrivateKey, keys::PrivateKey,
@ -19,7 +20,7 @@ pub struct Handler {
rx: Option<UnboundedReceiver<Vec<u8>>>, rx: Option<UnboundedReceiver<Vec<u8>>>,
all_tunnels: Tunnels, all_tunnels: Tunnels,
tunnels: HashSet<String>, tunnels: IndexMap<String, Tunnel>,
access: Option<TunnelAccess>, access: Option<TunnelAccess>,
} }
@ -36,8 +37,8 @@ impl Handler {
async fn set_access(&mut self, access: TunnelAccess) { async fn set_access(&mut self, access: TunnelAccess) {
self.access = Some(access.clone()); self.access = Some(access.clone());
for tunnel in &self.tunnels { for (_address, tunnel) in &self.tunnels {
self.all_tunnels.set_access(tunnel, access.clone()).await; tunnel.set_access(access.clone()).await;
} }
} }
} }
@ -156,7 +157,7 @@ impl russh::server::Handler for Handler {
}; };
let tunnel = Tunnel::new(session.handle(), address, *port, access); let tunnel = Tunnel::new(session.handle(), address, *port, access);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else { let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else {
self.sendln(format!("FAILED: ({address} already in use)")); self.sendln(format!("FAILED: ({address} already in use)"));
return Ok(false); return Ok(false);
}; };
@ -164,7 +165,7 @@ impl russh::server::Handler for Handler {
// NOTE: The port we receive might not be the port that is getting forwarded from the // NOTE: The port we receive might not be the port that is getting forwarded from the
// client, we could include it in the message we send // client, we could include it in the message we send
self.sendln(format!("http://{address}")); self.sendln(format!("http://{address}"));
self.tunnels.insert(address); self.tunnels.insert(address, tunnel);
Ok(true) Ok(true)
} }
@ -176,7 +177,7 @@ impl Drop for Handler {
let mut all_tunnels = self.all_tunnels.clone(); let mut all_tunnels = self.all_tunnels.clone();
tokio::spawn(async move { tokio::spawn(async move {
all_tunnels.remove_tunnels(tunnels.clone()).await; all_tunnels.remove_tunnels(&tunnels).await;
}); });
} }
} }
@ -227,7 +228,7 @@ impl russh::server::Server for Server {
tx, tx,
rx: Some(rx), rx: Some(rx),
all_tunnels: self.tunnels.clone(), all_tunnels: self.tunnels.clone(),
tunnels: HashSet::new(), tunnels: IndexMap::new(),
access: None, access: None,
} }
} }

View File

@ -5,11 +5,8 @@ use hyper::{
service::Service, service::Service,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use std::{ use indexmap::IndexMap;
collections::{HashMap, HashSet}, use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
pin::Pin,
sync::Arc,
};
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use russh::{ use russh::{
@ -37,7 +34,7 @@ pub struct Tunnel {
handle: Handle, handle: Handle,
address: String, address: String,
port: u32, port: u32,
access: TunnelAccess, access: Arc<RwLock<TunnelAccess>>,
} }
impl Tunnel { impl Tunnel {
@ -51,7 +48,7 @@ impl Tunnel {
handle, handle,
address: address.into(), address: address.into(),
port, port,
access, access: Arc::new(RwLock::new(access)),
} }
} }
@ -61,6 +58,10 @@ impl Tunnel {
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port) .channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
.await .await
} }
pub async fn set_access(&self, access: TunnelAccess) {
*self.access.write().await = access;
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -104,19 +105,13 @@ impl Tunnels {
Some(address) Some(address)
} }
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) { pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Tunnel>) {
let mut all_tunnels = self.tunnels.write().await; let mut all_tunnels = self.tunnels.write().await;
for tunnel in tunnels { for (address, _tunnel) in tunnels {
trace!(tunnel, "Removing tunnel"); trace!(address, "Removing tunnel");
all_tunnels.remove(&tunnel); all_tunnels.remove(address);
} }
} }
pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) {
if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) {
tunnel.access = access;
};
}
} }
impl Service<Request<Incoming>> for Tunnels { impl Service<Request<Incoming>> for Tunnels {
@ -166,7 +161,7 @@ impl Service<Request<Incoming>> for Tunnels {
return Ok(resp); return Ok(resp);
}; };
if let TunnelAccess::Private(owner) = &tunnel.access { if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
let user = match s.forward_auth.check_auth(req.headers()).await { let user = match s.forward_auth.check_auth(req.headers()).await {
Authenticated(user) => user, Authenticated(user) => user,
Unauthenticated(response) => return Ok(response), Unauthenticated(response) => return Ok(response),