diff --git a/src/main.rs b/src/main.rs index 03c521f..735f4a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,11 +23,15 @@ async fn main() { russh::keys::PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519).unwrap() }; - let mut ssh = Server::new(); + let port = 3000; + let domain = std::env::var("TUNNEL_DOMAIN").unwrap_or_else(|_| format!("localhost:{port}")); + + let mut ssh = Server::new(domain); + let tunnels = ssh.tunnels(); tokio::spawn(async move { ssh.run(key, ("0.0.0.0", 2222)).await }); - let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); + let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); diff --git a/src/ssh.rs b/src/ssh.rs index b77ab92..775260d 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -105,10 +105,13 @@ impl russh::server::Handler for Handler { let tunnel = Tunnel::new(session.handle(), address, *port); let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else { - self.send(&format!("{port} => FAILED ({address} already in use)\r\n")); + self.send(&format!("FAILED: ({address} already in use)\r\n")); return Ok(false); }; - self.send(&format!("{port} => https://{address}\r\n")); + + // NOTE: The port we receive might not be the port that is getting forwarded from the + // client, we could include it in the message we send + self.send(&format!("http://{address}\r\n")); self.tunnels.insert(address); Ok(true) @@ -133,9 +136,9 @@ pub struct Server { } impl Server { - pub fn new() -> Self { + pub fn new(domain: impl Into) -> Self { Server { - tunnels: Tunnels::new(), + tunnels: Tunnels::new(domain), } } @@ -164,12 +167,6 @@ impl Server { } } -impl Default for Server { - fn default() -> Self { - Self::new() - } -} - impl russh::server::Server for Server { type Handler = Handler; diff --git a/src/tunnel.rs b/src/tunnel.rs index 6865b66..5a9be4f 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -44,15 +44,21 @@ impl Tunnel { } #[derive(Debug, Clone)] -pub struct Tunnels(Arc>>); +pub struct Tunnels { + tunnels: Arc>>, + domain: String, +} impl Tunnels { - pub fn new() -> Self { - Self(Arc::new(RwLock::new(HashMap::new()))) + pub fn new(domain: impl Into) -> Self { + Self { + tunnels: Arc::new(RwLock::new(HashMap::new())), + domain: domain.into(), + } } pub async fn add_tunnel(&mut self, address: &str, tunnel: Tunnel) -> Option { - let mut all_tunnels = self.0.write().await; + let mut all_tunnels = self.tunnels.write().await; let address = if address == "localhost" { loop { @@ -68,7 +74,7 @@ impl Tunnels { address }; - let address = format!("{address}.tunnel.huizinga.dev"); + let address = format!("{address}.{}", self.domain); all_tunnels.insert(address.clone(), tunnel); @@ -76,20 +82,14 @@ impl Tunnels { } pub async fn remove_tunnels(&mut self, tunnels: HashSet) { - let mut all_tunnels = self.0.write().await; + let mut all_tunnels = self.tunnels.write().await; for tunnel in tunnels { all_tunnels.remove(&tunnel); } } pub async fn get_tunnel(&self, address: &str) -> Option { - self.0.read().await.get(address).cloned() - } -} - -impl Default for Tunnels { - fn default() -> Self { - Self::new() + self.tunnels.read().await.get(address).cloned() } }