Fixed duplicate detection and store failed tunnels in handler

This commit is contained in:
Dreaded_X 2025-04-08 02:20:58 +02:00
parent 5213aee232
commit 8d98431542
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
2 changed files with 34 additions and 25 deletions

View File

@ -20,7 +20,7 @@ pub struct Handler {
rx: Option<UnboundedReceiver<Vec<u8>>>, rx: Option<UnboundedReceiver<Vec<u8>>>,
all_tunnels: Tunnels, all_tunnels: Tunnels,
tunnels: IndexMap<String, Tunnel>, tunnels: IndexMap<String, Option<Tunnel>>,
access: Option<TunnelAccess>, access: Option<TunnelAccess>,
} }
@ -38,7 +38,9 @@ impl Handler {
self.access = Some(access.clone()); self.access = Some(access.clone());
for (_address, tunnel) in &self.tunnels { for (_address, tunnel) in &self.tunnels {
tunnel.set_access(access.clone()).await; if let Some(tunnel) = tunnel {
tunnel.set_access(access.clone()).await;
}
} }
} }
} }
@ -79,6 +81,16 @@ impl russh::server::Handler for Handler {
} }
}); });
// NOTE: I believe this happens as the final step when opening a session.
// At this point all the tunnels should be populated
for (address, tunnel) in &self.tunnels {
if tunnel.is_some() {
self.sendln(format!("http://{address}"));
} else {
self.sendln(format!("Failed to open {address}, address already in use"));
}
}
Ok(true) Ok(true)
} }
@ -158,14 +170,11 @@ impl russh::server::Handler for Handler {
let tunnel = Tunnel::new(session.handle(), address, *port, access); let tunnel = Tunnel::new(session.handle(), address, *port, access);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else { let Some(address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await else {
self.sendln(format!("FAILED: ({address} already in use)")); self.tunnels.insert(address.into(), None);
return Ok(false); return Ok(false);
}; };
// NOTE: The port we receive might not be the port that is getting forwarded from the self.tunnels.insert(address, Some(tunnel));
// client, we could include it in the message we send
self.sendln(format!("http://{address}"));
self.tunnels.insert(address, tunnel);
Ok(true) Ok(true)
} }

View File

@ -32,30 +32,25 @@ pub enum TunnelAccess {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Tunnel { pub struct Tunnel {
handle: Handle, handle: Handle,
address: String, name: String,
port: u32, port: u32,
access: Arc<RwLock<TunnelAccess>>, access: Arc<RwLock<TunnelAccess>>,
} }
impl Tunnel { impl Tunnel {
pub fn new( pub fn new(handle: Handle, name: impl Into<String>, port: u32, access: TunnelAccess) -> Self {
handle: Handle,
address: impl Into<String>,
port: u32,
access: TunnelAccess,
) -> Self {
Self { Self {
handle, handle,
address: address.into(), name: name.into(),
port, port,
access: Arc::new(RwLock::new(access)), access: Arc::new(RwLock::new(access)),
} }
} }
pub async fn open_tunnel(&self) -> Result<Channel<Msg>, russh::Error> { pub async fn open_tunnel(&self) -> Result<Channel<Msg>, russh::Error> {
trace!(tunnel = self.address, "Opening tunnel"); trace!(tunnel = self.name, "Opening tunnel");
self.handle self.handle
.channel_open_forwarded_tcpip(&self.address, self.port, &self.address, self.port) .channel_open_forwarded_tcpip(&self.name, self.port, &self.name, self.port)
.await .await
} }
@ -84,32 +79,37 @@ impl Tunnels {
let mut all_tunnels = self.tunnels.write().await; let mut all_tunnels = self.tunnels.write().await;
let address = if address == "localhost" { let address = if address == "localhost" {
// NOTE: It is technically possible to become stuck in this loop.
// However, that really only becomes a concern if a (very) high
// number of tunnels is open at the same time.
loop { loop {
let address = get_animal_name(); let address = get_animal_name();
if !all_tunnels.contains_key(address) { let address = format!("{address}.{}", self.domain);
if !all_tunnels.contains_key(&address) {
break address; break address;
} }
} }
} else { } else {
if all_tunnels.contains_key(address) { let address = format!("{address}.{}", self.domain);
if all_tunnels.contains_key(&address) {
return None; return None;
} }
address address
}; };
let address = format!("{address}.{}", self.domain);
trace!(tunnel = address, "Adding tunnel"); trace!(tunnel = address, "Adding tunnel");
all_tunnels.insert(address.clone(), tunnel); all_tunnels.insert(address.clone(), tunnel);
Some(address) Some(address)
} }
pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Tunnel>) { pub async fn remove_tunnels(&mut self, tunnels: &IndexMap<String, Option<Tunnel>>) {
let mut all_tunnels = self.tunnels.write().await; let mut all_tunnels = self.tunnels.write().await;
for (address, _tunnel) in tunnels { for (address, tunnel) in tunnels {
trace!(address, "Removing tunnel"); if tunnel.is_some() {
all_tunnels.remove(address); trace!(address, "Removing tunnel");
all_tunnels.remove(address);
}
} }
} }
} }