Store tunnel access in Arc<RwLock> and keep a copy in the handler
This should make it easier to access a tunnel in the handler and to modify its access
This commit is contained in:
parent
12398b8e5a
commit
5213aee232
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -1192,9 +1192,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "2.8.0"
|
version = "2.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
|
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown",
|
"hashbrown",
|
||||||
|
@ -1800,7 +1800,7 @@ dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tracing",
|
"tracing",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2788,6 +2788,7 @@ dependencies = [
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
|
"indexmap",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"russh",
|
"russh",
|
||||||
|
|
|
@ -11,6 +11,7 @@ dotenvy = "0.15.7"
|
||||||
http-body-util = { version = "0.1.3", features = ["full"] }
|
http-body-util = { version = "0.1.3", features = ["full"] }
|
||||||
hyper = { version = "1.6.0", features = ["full"] }
|
hyper = { version = "1.6.0", features = ["full"] }
|
||||||
hyper-util = { version = "0.1.11", features = ["full"] }
|
hyper-util = { version = "0.1.11", features = ["full"] }
|
||||||
|
indexmap = "2.9.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
|
reqwest = { version = "0.12.15", features = ["rustls-tls"] }
|
||||||
russh = "0.51.1"
|
russh = "0.51.1"
|
||||||
|
|
17
src/ssh.rs
17
src/ssh.rs
|
@ -1,6 +1,7 @@
|
||||||
use std::{collections::HashSet, iter::once, net::SocketAddr, sync::Arc, time::Duration};
|
use std::{iter::once, net::SocketAddr, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use indexmap::IndexMap;
|
||||||
use russh::{
|
use russh::{
|
||||||
ChannelId,
|
ChannelId,
|
||||||
keys::PrivateKey,
|
keys::PrivateKey,
|
||||||
|
@ -19,7 +20,7 @@ pub struct Handler {
|
||||||
rx: Option<UnboundedReceiver<Vec<u8>>>,
|
rx: Option<UnboundedReceiver<Vec<u8>>>,
|
||||||
|
|
||||||
all_tunnels: Tunnels,
|
all_tunnels: Tunnels,
|
||||||
tunnels: HashSet<String>,
|
tunnels: IndexMap<String, Tunnel>,
|
||||||
|
|
||||||
access: Option<TunnelAccess>,
|
access: Option<TunnelAccess>,
|
||||||
}
|
}
|
||||||
|
@ -36,8 +37,8 @@ impl Handler {
|
||||||
async fn set_access(&mut self, access: TunnelAccess) {
|
async fn set_access(&mut self, access: TunnelAccess) {
|
||||||
self.access = Some(access.clone());
|
self.access = Some(access.clone());
|
||||||
|
|
||||||
for tunnel in &self.tunnels {
|
for (_address, tunnel) in &self.tunnels {
|
||||||
self.all_tunnels.set_access(tunnel, access.clone()).await;
|
tunnel.set_access(access.clone()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,7 +157,7 @@ impl russh::server::Handler for Handler {
|
||||||
};
|
};
|
||||||
|
|
||||||
let tunnel = Tunnel::new(session.handle(), address, *port, access);
|
let tunnel = Tunnel::new(session.handle(), address, *port, access);
|
||||||
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
|
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else {
|
||||||
self.sendln(format!("FAILED: ({address} already in use)"));
|
self.sendln(format!("FAILED: ({address} already in use)"));
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
};
|
};
|
||||||
|
@ -164,7 +165,7 @@ impl russh::server::Handler for Handler {
|
||||||
// NOTE: The port we receive might not be the port that is getting forwarded from the
|
// NOTE: The port we receive might not be the port that is getting forwarded from the
|
||||||
// client, we could include it in the message we send
|
// client, we could include it in the message we send
|
||||||
self.sendln(format!("http://{address}"));
|
self.sendln(format!("http://{address}"));
|
||||||
self.tunnels.insert(address);
|
self.tunnels.insert(address, tunnel);
|
||||||
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ impl Drop for Handler {
|
||||||
let mut all_tunnels = self.all_tunnels.clone();
|
let mut all_tunnels = self.all_tunnels.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
all_tunnels.remove_tunnels(tunnels.clone()).await;
|
all_tunnels.remove_tunnels(&tunnels).await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,7 +228,7 @@ impl russh::server::Server for Server {
|
||||||
tx,
|
tx,
|
||||||
rx: Some(rx),
|
rx: Some(rx),
|
||||||
all_tunnels: self.tunnels.clone(),
|
all_tunnels: self.tunnels.clone(),
|
||||||
tunnels: HashSet::new(),
|
tunnels: IndexMap::new(),
|
||||||
access: None,
|
access: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,11 +5,8 @@ use hyper::{
|
||||||
service::Service,
|
service::Service,
|
||||||
};
|
};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use std::{
|
use indexmap::IndexMap;
|
||||||
collections::{HashMap, HashSet},
|
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
|
||||||
pin::Pin,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
use russh::{
|
use russh::{
|
||||||
|
@ -37,7 +34,7 @@ pub struct Tunnel {
|
||||||
handle: Handle,
|
handle: Handle,
|
||||||
address: String,
|
address: String,
|
||||||
port: u32,
|
port: u32,
|
||||||
access: TunnelAccess,
|
access: Arc<RwLock<TunnelAccess>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tunnel {
|
impl Tunnel {
|
||||||
|
@ -51,7 +48,7 @@ impl Tunnel {
|
||||||
handle,
|
handle,
|
||||||
address: address.into(),
|
address: address.into(),
|
||||||
port,
|
port,
|
||||||
access,
|
access: Arc::new(RwLock::new(access)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +58,10 @@ impl Tunnel {
|
||||||
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
|
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn set_access(&self, access: TunnelAccess) {
|
||||||
|
*self.access.write().await = access;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -104,19 +105,13 @@ impl Tunnels {
|
||||||
Some(address)
|
Some(address)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) {
|
pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Tunnel>) {
|
||||||
let mut all_tunnels = self.tunnels.write().await;
|
let mut all_tunnels = self.tunnels.write().await;
|
||||||
for tunnel in tunnels {
|
for (address, _tunnel) in tunnels {
|
||||||
trace!(tunnel, "Removing tunnel");
|
trace!(address, "Removing tunnel");
|
||||||
all_tunnels.remove(&tunnel);
|
all_tunnels.remove(address);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) {
|
|
||||||
if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) {
|
|
||||||
tunnel.access = access;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service<Request<Incoming>> for Tunnels {
|
impl Service<Request<Incoming>> for Tunnels {
|
||||||
|
@ -166,7 +161,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
};
|
};
|
||||||
|
|
||||||
if let TunnelAccess::Private(owner) = &tunnel.access {
|
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
|
||||||
let user = match s.forward_auth.check_auth(req.headers()).await {
|
let user = match s.forward_auth.check_auth(req.headers()).await {
|
||||||
Authenticated(user) => user,
|
Authenticated(user) => user,
|
||||||
Unauthenticated(response) => return Ok(response),
|
Unauthenticated(response) => return Ok(response),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user