diff --git a/Cargo.lock b/Cargo.lock index 50143f9..8505c37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1192,9 +1192,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown", @@ -1800,7 +1800,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2788,6 +2788,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "indexmap", "rand 0.8.5", "reqwest", "russh", diff --git a/Cargo.toml b/Cargo.toml index 03fe798..8c339f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ dotenvy = "0.15.7" http-body-util = { version = "0.1.3", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] } +indexmap = "2.9.0" rand = "0.8.5" reqwest = { version = "0.12.15", features = ["rustls-tls"] } russh = "0.51.1" diff --git a/src/ssh.rs b/src/ssh.rs index f595901..eabe668 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -1,6 +1,7 @@ -use std::{collections::HashSet, iter::once, net::SocketAddr, sync::Arc, time::Duration}; +use std::{iter::once, net::SocketAddr, sync::Arc, time::Duration}; use clap::Parser; +use indexmap::IndexMap; use russh::{ ChannelId, keys::PrivateKey, @@ -19,7 +20,7 @@ pub struct Handler { rx: Option>>, all_tunnels: Tunnels, - tunnels: HashSet, + tunnels: IndexMap, access: Option, } @@ -36,8 +37,8 @@ impl Handler { async fn set_access(&mut self, access: TunnelAccess) { self.access = Some(access.clone()); - for tunnel in &self.tunnels { - self.all_tunnels.set_access(tunnel, access.clone()).await; + for (_address, tunnel) in &self.tunnels { + tunnel.set_access(access.clone()).await; } } } @@ -156,7 +157,7 @@ impl russh::server::Handler for Handler { }; let tunnel = Tunnel::new(session.handle(), address, *port, access); - let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else { + let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else { self.sendln(format!("FAILED: ({address} already in use)")); return Ok(false); }; @@ -164,7 +165,7 @@ impl russh::server::Handler for Handler { // NOTE: The port we receive might not be the port that is getting forwarded from the // client, we could include it in the message we send self.sendln(format!("http://{address}")); - self.tunnels.insert(address); + self.tunnels.insert(address, tunnel); Ok(true) } @@ -176,7 +177,7 @@ impl Drop for Handler { let mut all_tunnels = self.all_tunnels.clone(); tokio::spawn(async move { - all_tunnels.remove_tunnels(tunnels.clone()).await; + all_tunnels.remove_tunnels(&tunnels).await; }); } } @@ -227,7 +228,7 @@ impl russh::server::Server for Server { tx, rx: Some(rx), all_tunnels: self.tunnels.clone(), - tunnels: HashSet::new(), + tunnels: IndexMap::new(), access: None, } } diff --git a/src/tunnel.rs b/src/tunnel.rs index ea6a7d5..ba772bd 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -5,11 +5,8 @@ use hyper::{ service::Service, }; use hyper_util::rt::TokioIo; -use std::{ - collections::{HashMap, HashSet}, - pin::Pin, - sync::Arc, -}; +use indexmap::IndexMap; +use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc}; use tracing::{debug, trace, warn}; use russh::{ @@ -37,7 +34,7 @@ pub struct Tunnel { handle: Handle, address: String, port: u32, - access: TunnelAccess, + access: Arc>, } impl Tunnel { @@ -51,7 +48,7 @@ impl Tunnel { handle, address: address.into(), port, - access, + access: Arc::new(RwLock::new(access)), } } @@ -61,6 +58,10 @@ impl Tunnel { .channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port) .await } + + pub async fn set_access(&self, access: TunnelAccess) { + *self.access.write().await = access; + } } #[derive(Debug, Clone)] @@ -104,19 +105,13 @@ impl Tunnels { Some(address) } - pub async fn remove_tunnels(&mut self, tunnels: HashSet) { + pub async fn remove_tunnels(&mut self, tunnels: &IndexMap) { let mut all_tunnels = self.tunnels.write().await; - for tunnel in tunnels { - trace!(tunnel, "Removing tunnel"); - all_tunnels.remove(&tunnel); + for (address, _tunnel) in tunnels { + trace!(address, "Removing tunnel"); + all_tunnels.remove(address); } } - - pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) { - if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) { - tunnel.access = access; - }; - } } impl Service> for Tunnels { @@ -166,7 +161,7 @@ impl Service> for Tunnels { return Ok(resp); }; - if let TunnelAccess::Private(owner) = &tunnel.access { + if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() { let user = match s.forward_auth.check_auth(req.headers()).await { Authenticated(user) => user, Unauthenticated(response) => return Ok(response),