Store tunnel access in Arc<RwLock> and keep a copy in the handler
This should make it easier to access a tunnel in the handler and to modify its access
This commit is contained in:
17
src/ssh.rs
17
src/ssh.rs
@@ -1,6 +1,7 @@
|
||||
use std::{collections::HashSet, iter::once, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use std::{iter::once, net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use clap::Parser;
|
||||
use indexmap::IndexMap;
|
||||
use russh::{
|
||||
ChannelId,
|
||||
keys::PrivateKey,
|
||||
@@ -19,7 +20,7 @@ pub struct Handler {
|
||||
rx: Option<UnboundedReceiver<Vec<u8>>>,
|
||||
|
||||
all_tunnels: Tunnels,
|
||||
tunnels: HashSet<String>,
|
||||
tunnels: IndexMap<String, Tunnel>,
|
||||
|
||||
access: Option<TunnelAccess>,
|
||||
}
|
||||
@@ -36,8 +37,8 @@ impl Handler {
|
||||
async fn set_access(&mut self, access: TunnelAccess) {
|
||||
self.access = Some(access.clone());
|
||||
|
||||
for tunnel in &self.tunnels {
|
||||
self.all_tunnels.set_access(tunnel, access.clone()).await;
|
||||
for (_address, tunnel) in &self.tunnels {
|
||||
tunnel.set_access(access.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -156,7 +157,7 @@ impl russh::server::Handler for Handler {
|
||||
};
|
||||
|
||||
let tunnel = Tunnel::new(session.handle(), address, *port, access);
|
||||
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
|
||||
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else {
|
||||
self.sendln(format!("FAILED: ({address} already in use)"));
|
||||
return Ok(false);
|
||||
};
|
||||
@@ -164,7 +165,7 @@ impl russh::server::Handler for Handler {
|
||||
// NOTE: The port we receive might not be the port that is getting forwarded from the
|
||||
// client, we could include it in the message we send
|
||||
self.sendln(format!("http://{address}"));
|
||||
self.tunnels.insert(address);
|
||||
self.tunnels.insert(address, tunnel);
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
@@ -176,7 +177,7 @@ impl Drop for Handler {
|
||||
let mut all_tunnels = self.all_tunnels.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
all_tunnels.remove_tunnels(tunnels.clone()).await;
|
||||
all_tunnels.remove_tunnels(&tunnels).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -227,7 +228,7 @@ impl russh::server::Server for Server {
|
||||
tx,
|
||||
rx: Some(rx),
|
||||
all_tunnels: self.tunnels.clone(),
|
||||
tunnels: HashSet::new(),
|
||||
tunnels: IndexMap::new(),
|
||||
access: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,8 @@ use hyper::{
|
||||
service::Service,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use russh::{
|
||||
@@ -37,7 +34,7 @@ pub struct Tunnel {
|
||||
handle: Handle,
|
||||
address: String,
|
||||
port: u32,
|
||||
access: TunnelAccess,
|
||||
access: Arc<RwLock<TunnelAccess>>,
|
||||
}
|
||||
|
||||
impl Tunnel {
|
||||
@@ -51,7 +48,7 @@ impl Tunnel {
|
||||
handle,
|
||||
address: address.into(),
|
||||
port,
|
||||
access,
|
||||
access: Arc::new(RwLock::new(access)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +58,10 @@ impl Tunnel {
|
||||
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_access(&self, access: TunnelAccess) {
|
||||
*self.access.write().await = access;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -104,19 +105,13 @@ impl Tunnels {
|
||||
Some(address)
|
||||
}
|
||||
|
||||
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) {
|
||||
pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Tunnel>) {
|
||||
let mut all_tunnels = self.tunnels.write().await;
|
||||
for tunnel in tunnels {
|
||||
trace!(tunnel, "Removing tunnel");
|
||||
all_tunnels.remove(&tunnel);
|
||||
for (address, _tunnel) in tunnels {
|
||||
trace!(address, "Removing tunnel");
|
||||
all_tunnels.remove(address);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_access(&mut self, tunnel: &str, access: TunnelAccess) {
|
||||
if let Some(tunnel) = self.tunnels.write().await.get_mut(tunnel) {
|
||||
tunnel.access = access;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Request<Incoming>> for Tunnels {
|
||||
@@ -166,7 +161,7 @@ impl Service<Request<Incoming>> for Tunnels {
|
||||
return Ok(resp);
|
||||
};
|
||||
|
||||
if let TunnelAccess::Private(owner) = &tunnel.access {
|
||||
if let TunnelAccess::Private(owner) = tunnel.access.read().await.deref() {
|
||||
let user = match s.forward_auth.check_auth(req.headers()).await {
|
||||
Authenticated(user) => user,
|
||||
Unauthenticated(response) => return Ok(response),
|
||||
|
||||
Reference in New Issue
Block a user