diff --git a/Cargo.lock b/Cargo.lock index ae81346..91edff1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3221,7 +3221,6 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", - "indexmap", "rand 0.8.5", "ratatui", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index d5a1f1b..d98c99d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ futures = "0.3.31" http-body-util = { version = "0.1.3", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] } hyper-util = { version = "0.1.11", features = ["full"] } -indexmap = "2.9.0" rand = "0.8.5" ratatui = { version = "0.29.0", features = ["unstable-backend-writer"] } reqwest = { version = "0.12.15", features = ["rustls-tls"] } diff --git a/src/handler.rs b/src/handler.rs index 03c4c46..97f534a 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,7 +1,6 @@ use std::{io::Write, iter::once}; use clap::Parser as _; -use indexmap::IndexMap; use ratatui::{Terminal, TerminalOptions, Viewport, layout::Rect, prelude::CrosstermBackend}; use russh::{ ChannelId, @@ -19,7 +18,7 @@ use crate::{ pub struct Handler { all_tunnels: Tunnels, - tunnels: IndexMap>, + tunnels: Vec, user: Option, pty_channel: Option, @@ -33,7 +32,7 @@ impl Handler { pub fn new(all_tunnels: Tunnels) -> Self { Self { all_tunnels, - tunnels: IndexMap::new(), + tunnels: Default::default(), user: None, pty_channel: None, terminal: None, @@ -43,10 +42,8 @@ impl Handler { } async fn set_access_all(&mut self, access: TunnelAccess) { - for (_address, tunnel) in &self.tunnels { - if let Some(tunnel) = tunnel { - tunnel.set_access(access.clone()).await; - } + for tunnel in &self.tunnels { + tunnel.set_access(access.clone()).await; } } @@ -92,7 +89,7 @@ impl Handler { async fn set_access_selection(&mut self, access: TunnelAccess) { if let Some(selected) = self.selected { - if let Some((_, Some(tunnel))) = self.tunnels.get_index_mut(selected) { + if let Some(tunnel) = self.tunnels.get_mut(selected) { tunnel.set_access(access).await; } } else { @@ -267,18 +264,16 @@ impl russh::server::Handler for Handler { return Err(russh::Error::Inconsistent); }; - let tunnel = Tunnel::new( - session.handle(), - address, - *port, - TunnelAccess::Private(user), - ); - let (success, address) = self.all_tunnels.add_tunnel(address, tunnel.clone()).await; + let tunnel = self + .all_tunnels + .add_tunnel(session.handle(), address, *port, user) + .await; - let tunnel = if success { Some(tunnel) } else { None }; - self.tunnels.insert(address, tunnel); + self.tunnels.push(tunnel); - Ok(success) + // Technically forwarding has failed if tunnel.domain = None, however by lying to the ssh + // client we can retry in the future + Ok(true) } async fn window_change_request( diff --git a/src/tui.rs b/src/tui.rs index 7b5299d..41ed26d 100644 --- a/src/tui.rs +++ b/src/tui.rs @@ -1,7 +1,6 @@ use std::cmp; use futures::StreamExt; -use indexmap::IndexMap; use ratatui::{ Frame, layout::{Constraint, Flex, Layout, Rect}, @@ -25,12 +24,8 @@ fn command<'c>(key: &'c str, text: &'c str) -> Vec> { impl Renderer { // NOTE: This needs to be a separate function as the render functions can not be async - pub async fn update( - &mut self, - tunnels: &IndexMap>, - index: Option, - ) { - self.table_rows = futures::stream::iter(tunnels.iter()) + pub async fn update(&mut self, tunnels: &[Tunnel], index: Option) { + self.table_rows = futures::stream::iter(tunnels) .then(tunnel::tui::to_row) .collect::>() .await; diff --git a/src/tunnel.rs b/src/tunnel.rs index e5f3663..d37caa0 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -8,8 +8,12 @@ use hyper::{ service::Service, }; use hyper_util::rt::TokioIo; -use indexmap::IndexMap; -use std::{collections::HashMap, ops::Deref, pin::Pin, sync::Arc}; +use std::{ + collections::{HashMap, hash_map::Entry}, + ops::Deref, + pin::Pin, + sync::Arc, +}; use tracing::{debug, error, trace, warn}; use russh::{ @@ -37,20 +41,12 @@ pub enum TunnelAccess { pub struct Tunnel { handle: Handle, name: String, + domain: Option, port: u32, access: Arc>, } impl Tunnel { - pub fn new(handle: Handle, name: impl Into, port: u32, access: TunnelAccess) -> Self { - Self { - handle, - name: name.into(), - port, - access: Arc::new(RwLock::new(access)), - } - } - pub async fn open_tunnel(&self) -> Result, russh::Error> { trace!(tunnel = self.name, "Opening tunnel"); self.handle @@ -65,6 +61,12 @@ impl Tunnel { pub async fn is_public(&self) -> bool { matches!(*self.access.read().await, TunnelAccess::Public) } + + pub fn get_address(&self) -> Option { + self.domain + .clone() + .map(|domain| format!("{}.{domain}", self.name)) + } } #[derive(Debug, Clone)] @@ -83,40 +85,57 @@ impl Tunnels { } } - pub async fn add_tunnel(&mut self, address: &str, tunnel: Tunnel) -> (bool, String) { - let mut all_tunnels = self.tunnels.write().await; + pub async fn add_tunnel( + &mut self, + handle: Handle, + name: impl Into, + port: u32, + user: impl Into, + ) -> Tunnel { + let mut tunnel = Tunnel { + handle, + name: name.into(), + domain: Some(self.domain.clone()), + port, + access: Arc::new(RwLock::new(TunnelAccess::Private(user.into()))), + }; - let address = if address == "localhost" { + if tunnel.name == "localhost" { // NOTE: It is technically possible to become stuck in this loop. // However, that really only becomes a concern if a (very) high // number of tunnels is open at the same time. loop { - let address = get_animal_name(); - let address = format!("{address}.{}", self.domain); - if !all_tunnels.contains_key(&address) { - break address; + tunnel.name = get_animal_name().into(); + if !self + .tunnels + .read() + .await + .contains_key(&tunnel.get_address().expect("domain is set")) + { + break; } + trace!(tunnel = tunnel.name, "Already in use, picking new name"); } - } else { - let address = format!("{address}.{}", self.domain); - if all_tunnels.contains_key(&address) { - return (false, address); - } - address }; + let address = tunnel.get_address().expect("domain is set"); - trace!(tunnel = address, "Adding tunnel"); - all_tunnels.insert(address.clone(), tunnel); + if let Entry::Vacant(e) = self.tunnels.write().await.entry(address) { + trace!(tunnel = tunnel.name, "Adding tunnel"); + e.insert(tunnel.clone()); + } else { + trace!("Address already in use"); + tunnel.domain = None + } - (true, address) + tunnel } - pub async fn remove_tunnels(&mut self, tunnels: &IndexMap>) { + pub async fn remove_tunnels(&mut self, tunnels: &[Tunnel]) { let mut all_tunnels = self.tunnels.write().await; - for (address, tunnel) in tunnels { - if tunnel.is_some() { - trace!(address, "Removing tunnel"); - all_tunnels.remove(address); + for tunnel in tunnels { + if let Some(address) = tunnel.get_address() { + trace!(tunnel.name, "Removing tunnel"); + all_tunnels.remove(&address); } } } diff --git a/src/tunnel/tui.rs b/src/tunnel/tui.rs index 602eb17..295f051 100644 --- a/src/tunnel/tui.rs +++ b/src/tunnel/tui.rs @@ -6,22 +6,30 @@ use ratatui::text::Span; use super::{Tunnel, TunnelAccess}; pub fn header() -> Vec> { - vec!["Access".into(), "Port".into(), "Address".into()] + vec![ + "Name".into(), + "Access".into(), + "Port".into(), + "Address".into(), + ] } -pub async fn to_row((address, tunnel): (&String, &Option)) -> Vec> { - let (access, port) = if let Some(tunnel) = tunnel { - let access = match tunnel.access.read().await.deref() { - TunnelAccess::Private(owner) => owner.clone().yellow(), - TunnelAccess::Protected => "PROTECTED".blue(), - TunnelAccess::Public => "PUBLIC".green(), - }; - - (access, tunnel.port.to_string().into()) - } else { - ("FAILED".red(), "".into()) +pub async fn to_row(tunnel: &Tunnel) -> Vec> { + let access = match tunnel.access.read().await.deref() { + TunnelAccess::Private(owner) => owner.clone().yellow(), + TunnelAccess::Protected => "PROTECTED".blue(), + TunnelAccess::Public => "PUBLIC".green(), }; - let address = format!("http://{address}").into(); - vec![access, port, address] + let address = tunnel + .get_address() + .map(|address| format!("http://{address}").into()) + .unwrap_or("FAILED".red()); + + vec![ + tunnel.name.clone().into(), + access, + tunnel.port.to_string().into(), + address, + ] }