Store tunnel access in Arc<RwLock> and keep a copy in the handler

This should make it easier to access a tunnel in the handler and to
modify its access
This commit is contained in:
Dreaded_X 2025-04-08 02:01:18 +02:00
parent 12398b8e5a
commit 5213aee232
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
4 changed files with 27 additions and 29 deletions

7
Cargo.lock generated
View File

@ -1192,9 +1192,9 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.8.0"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [
"equivalent",
"hashbrown",
@ -1800,7 +1800,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -2788,6 +2788,7 @@ dependencies = [
"http-body-util",
"hyper",
"hyper-util",
"indexmap",
"rand 0.8.5",
"reqwest",
"russh",

View File

@ -11,6 +11,7 @@ dotenvy = "0.15.7"
http-body-util = { version = "0.1.3", features = ["full"] }
hyper = { version = "1.6.0", features = ["full"] }
hyper-util = { version = "0.1.11", features = ["full"] }
indexmap = "2.9.0"
rand = "0.8.5"
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
russh = "0.51.1"

View File

@ -1,6 +1,7 @@
use std::{collections::HashSet, iter::once, net::SocketAddr, sync::Arc, time::Duration};
use std::{iter::once, net::SocketAddr, sync::Arc, time::Duration};
use clap::Parser;
use indexmap::IndexMap;
use russh::{
ChannelId,
keys::PrivateKey,
@ -19,7 +20,7 @@ pub struct Handler {
rx: Option<UnboundedReceiver<Vec<u8>>>,
all_tunnels: Tunnels,
tunnels: HashSet<String>,
tunnels: IndexMap<String, Tunnel>,
access: Option<TunnelAccess>,
}
@ -36,8 +37,8 @@ impl Handler {
async fn set_access(&mut self, access: TunnelAccess) {
self.access = Some(access.clone());
for tunnel in &self.tunnels {
self.all_tunnels.set_access(tunnel, access.clone()).await;
for (_address, tunnel) in &self.tunnels {
tunnel.set_access(access.clone()).await;
}
}
}
@ -156,7 +157,7 @@ impl russh::server::Handler for Handler {
};
let tunnel = Tunnel::new(session.handle(), address, *port, access);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else {
self.sendln(format!("FAILED: ({address} already in use)"));
return Ok(false);
};
@ -164,7 +165,7 @@ impl russh::server::Handler for Handler {
// NOTE: The port we receive might not be the port that is getting forwarded from the
// client, we could include it in the message we send
self.sendln(format!("http://{address}"));
self.tunnels.insert(address);
self.tunnels.insert(address, tunnel);
Ok(true)
}
@ -176,7 +177,7 @@ impl Drop for Handler {
let mut all_tunnels = self.all_tunnels.clone();
tokio::spawn(async move {
all_tunnels.remove_tunnels(tunnels.clone()).await;
all_tunnels.remove_tunnels(&tunnels).await;
});
}
}
@ -227,7 +228,7 @@ impl russh::server::Server for Server {
tx,
rx: Some(rx),
all_tunnels: self.tunnels.clone(),
tunnels: HashSet::new(),
tunnels: IndexMap::new(),
access: None,
}
}

View File

@ -5,11 +5,8 @@ use hyper::{
service::Service,
};
use hyper_util::rt::TokioIo;
use std::{
collections::{HashMap, HashSet},
pin::Pin,
sync::Arc,
};
use indexmap::IndexMap;
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
use tracing::{debug, trace, warn};
use russh::{
@ -37,7 +34,7 @@ pub struct Tunnel {
handle: Handle,
address: String,
port: u32,
access: TunnelAccess,
access: Arc<RwLock<TunnelAccess>>,
}
impl Tunnel {
@ -51,7 +48,7 @@ impl Tunnel {
handle,
address: address.into(),
port,
access,
access: Arc::new(RwLock::new(access)),
}
}
@ -61,6 +58,10 @@ impl Tunnel {
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
.await
}
pub async fn set_access(&self, access: TunnelAccess) {
*self.access.write().await = access;
}
}
#[derive(Debug, Clone)]
@ -104,19 +105,13 @@ impl Tunnels {
Some(address)
}
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) {
pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Tunnel>) {
let mut all_tunnels = self.tunnels.write().await;
for tunnel in tunnels {
trace!(tunnel, "Removing tunnel");
all_tunnels.remove(&tunnel);
for (address, _tunnel) in tunnels {
trace!(address, "Removing tunnel");
all_tunnels.remove(address);
}
}
pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) {
if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) {
tunnel.access = access;
};
}
}
impl Service<Request<Incoming>> for Tunnels {
@ -166,7 +161,7 @@ impl Service<Request<Incoming>> for Tunnels {
return Ok(resp);
};
if let TunnelAccess::Private(owner) = &tunnel.access {
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
let user = match s.forward_auth.check_auth(req.headers()).await {
Authenticated(user) => user,
Unauthenticated(response) => return Ok(response),