diff --git a/src/lib.rs b/src/lib.rs index 43f4421..c9dfe47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(impl_trait_in_fn_trait_return)] pub mod animals; pub mod ssh; pub mod tunnel; diff --git a/src/main.rs b/src/main.rs index cc036ce..14aca68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -70,7 +70,7 @@ async fn main() { debug!("Request for {authority:?}"); - let Some(tunnel) = tunnels.read().await.get(&authority).cloned() else { + let Some(tunnel) = tunnels.get_tunnel(&authority).await else { let mut resp = Response::new(full(format!("Unknown tunnel: {authority}"))); *resp.status_mut() = StatusCode::NOT_FOUND; diff --git a/src/ssh.rs b/src/ssh.rs index f285edd..b77ab92 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -11,10 +11,7 @@ use tokio::{ }; use tracing::{debug, error}; -use crate::{ - animals::get_animal_name, - tunnel::{self, Tunnel, Tunnels}, -}; +use crate::tunnel::{Tunnel, Tunnels}; pub struct Handler { tx: UnboundedSender>, @@ -28,26 +25,6 @@ impl Handler { fn send(&self, data: &str) { let _ = self.tx.send(data.as_bytes().to_vec()); } - - async fn full_address(&self, address: &str) -> Option { - let all_tunnels = self.all_tunnels.read().await; - - let address = if address == "localhost" { - loop { - let address = get_animal_name(); - if !all_tunnels.contains_key(address) { - break address; - } - } - } else { - if all_tunnels.contains_key(address) { - return None; - } - address - }; - - Some(format!("{address}.tunnel.huizinga.dev")) - } } impl russh::server::Handler for Handler { @@ -126,18 +103,13 @@ impl russh::server::Handler for Handler { ) -> Result { debug!("{address}:{port}"); - let Some(full_address) = self.full_address(address).await else { + let tunnel = Tunnel::new(session.handle(), address, *port); + let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else { self.send(&format!("{port} => FAILED ({address} already in use)\r\n")); return Ok(false); }; - - self.tunnels.insert(full_address.clone()); - self.all_tunnels.write().await.insert( - full_address.clone(), - Tunnel::new(session.handle(), address, *port), - ); - - self.send(&format!("{port} => https://{full_address}\r\n")); + self.send(&format!("{port} => https://{address}\r\n")); + self.tunnels.insert(address); Ok(true) } @@ -146,13 +118,10 @@ impl russh::server::Handler for Handler { impl Drop for Handler { fn drop(&mut self) { let tunnels = self.tunnels.clone(); - let all_tunnels = self.all_tunnels.clone(); + let mut all_tunnels = self.all_tunnels.clone(); tokio::spawn(async move { - let mut all_tunnels = all_tunnels.write().await; - for tunnel in tunnels { - all_tunnels.remove(&tunnel); - } + all_tunnels.remove_tunnels(tunnels.clone()).await; debug!("{all_tunnels:?}"); }); @@ -166,7 +135,7 @@ pub struct Server { impl Server { pub fn new() -> Self { Server { - tunnels: tunnel::new(), + tunnels: Tunnels::new(), } } diff --git a/src/tunnel.rs b/src/tunnel.rs index 68400dc..07a6440 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use russh::{ Channel, @@ -6,6 +9,8 @@ use russh::{ }; use tokio::sync::RwLock; +use crate::animals::get_animal_name; + #[derive(Debug, Clone)] pub struct Tunnel { handle: Handle, @@ -29,8 +34,52 @@ impl Tunnel { } } -pub type Tunnels = Arc>>; +#[derive(Debug, Clone)] +pub struct Tunnels(Arc>>); -pub fn new() -> Tunnels { - Arc::new(RwLock::new(HashMap::new())) +impl Tunnels { + pub fn new() -> Self { + Self(Arc::new(RwLock::new(HashMap::new()))) + } + + pub async fn add_tunnel(&mut self, address: &str, tunnel: Tunnel) -> Option { + let mut all_tunnels = self.0.write().await; + + let address = if address == "localhost" { + loop { + let address = get_animal_name(); + if !all_tunnels.contains_key(address) { + break address; + } + } + } else { + if all_tunnels.contains_key(address) { + return None; + } + address + }; + + let address = format!("{address}.tunnel.huizinga.dev"); + + all_tunnels.insert(address.clone(), tunnel); + + Some(address) + } + + pub async fn remove_tunnels(&mut self, tunnels: HashSet) { + let mut all_tunnels = self.0.write().await; + for tunnel in tunnels { + all_tunnels.remove(&tunnel); + } + } + + pub async fn get_tunnel(&self, address: &str) -> Option { + self.0.read().await.get(address).cloned() + } +} + +impl Default for Tunnels { + fn default() -> Self { + Self::new() + } }