Refactor tunnel

This commit is contained in:
Dreaded_X 2025-04-04 16:56:47 +02:00
parent 2bd26f1db2
commit bb50a59316
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
4 changed files with 63 additions and 44 deletions

View File

@ -1,3 +1,4 @@
#![feature(impl_trait_in_fn_trait_return)]
pub mod animals; pub mod animals;
pub mod ssh; pub mod ssh;
pub mod tunnel; pub mod tunnel;

View File

@ -70,7 +70,7 @@ async fn main() {
debug!("Request for {authority:?}"); debug!("Request for {authority:?}");
let Some(tunnel) = tunnels.read().await.get(&authority).cloned() else { let Some(tunnel) = tunnels.get_tunnel(&authority).await else {
let mut resp = Response::new(full(format!("Unknown tunnel: {authority}"))); let mut resp = Response::new(full(format!("Unknown tunnel: {authority}")));
*resp.status_mut() = StatusCode::NOT_FOUND; *resp.status_mut() = StatusCode::NOT_FOUND;

View File

@ -11,10 +11,7 @@ use tokio::{
}; };
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{ use crate::tunnel::{Tunnel, Tunnels};
animals::get_animal_name,
tunnel::{self, Tunnel, Tunnels},
};
pub struct Handler { pub struct Handler {
tx: UnboundedSender<Vec<u8>>, tx: UnboundedSender<Vec<u8>>,
@ -28,26 +25,6 @@ impl Handler {
fn send(&self, data: &str) { fn send(&self, data: &str) {
let _ = self.tx.send(data.as_bytes().to_vec()); let _ = self.tx.send(data.as_bytes().to_vec());
} }
async fn full_address(&self, address: &str) -> Option<String> {
let all_tunnels = self.all_tunnels.read().await;
let address = if address == "localhost" {
loop {
let address = get_animal_name();
if !all_tunnels.contains_key(address) {
break address;
}
}
} else {
if all_tunnels.contains_key(address) {
return None;
}
address
};
Some(format!("{address}.tunnel.huizinga.dev"))
}
} }
impl russh::server::Handler for Handler { impl russh::server::Handler for Handler {
@ -126,18 +103,13 @@ impl russh::server::Handler for Handler {
) -> Result<bool, Self::Error> { ) -> Result<bool, Self::Error> {
debug!("{address}:{port}"); debug!("{address}:{port}");
let Some(full_address) = self.full_address(address).await else { let tunnel = Tunnel::new(session.handle(), address, *port);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
self.send(&format!("{port} => FAILED ({address} already in use)\r\n")); self.send(&format!("{port} => FAILED ({address} already in use)\r\n"));
return Ok(false); return Ok(false);
}; };
self.send(&format!("{port} => https://{address}\r\n"));
self.tunnels.insert(full_address.clone()); self.tunnels.insert(address);
self.all_tunnels.write().await.insert(
full_address.clone(),
Tunnel::new(session.handle(), address, *port),
);
self.send(&format!("{port} => https://{full_address}\r\n"));
Ok(true) Ok(true)
} }
@ -146,13 +118,10 @@ impl russh::server::Handler for Handler {
impl Drop for Handler { impl Drop for Handler {
fn drop(&mut self) { fn drop(&mut self) {
let tunnels = self.tunnels.clone(); let tunnels = self.tunnels.clone();
let all_tunnels = self.all_tunnels.clone(); let mut all_tunnels = self.all_tunnels.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut all_tunnels = all_tunnels.write().await; all_tunnels.remove_tunnels(tunnels.clone()).await;
for tunnel in tunnels {
all_tunnels.remove(&tunnel);
}
debug!("{all_tunnels:?}"); debug!("{all_tunnels:?}");
}); });
@ -166,7 +135,7 @@ pub struct Server {
impl Server { impl Server {
pub fn new() -> Self { pub fn new() -> Self {
Server { Server {
tunnels: tunnel::new(), tunnels: Tunnels::new(),
} }
} }

View File

@ -1,4 +1,7 @@
use std::{collections::HashMap, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use russh::{ use russh::{
Channel, Channel,
@ -6,6 +9,8 @@ use russh::{
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::animals::get_animal_name;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Tunnel { pub struct Tunnel {
handle: Handle, handle: Handle,
@ -29,8 +34,52 @@ impl Tunnel {
} }
} }
pub type Tunnels = Arc<RwLock<HashMap<String, Tunnel>>>; #[derive(Debug, Clone)]
pub struct Tunnels(Arc<RwLock<HashMap<String, Tunnel>>>);
pub fn new() -> Tunnels { impl Tunnels {
Arc::new(RwLock::new(HashMap::new())) pub fn new() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
pub async fn add_tunnel(&mut self, address: &str, tunnel: Tunnel) -> Option<String> {
let mut all_tunnels = self.0.write().await;
let address = if address == "localhost" {
loop {
let address = get_animal_name();
if !all_tunnels.contains_key(address) {
break address;
}
}
} else {
if all_tunnels.contains_key(address) {
return None;
}
address
};
let address = format!("{address}.tunnel.huizinga.dev");
all_tunnels.insert(address.clone(), tunnel);
Some(address)
}
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) {
let mut all_tunnels = self.0.write().await;
for tunnel in tunnels {
all_tunnels.remove(&tunnel);
}
}
pub async fn get_tunnel(&self, address: &str) -> Option<Tunnel> {
self.0.read().await.get(address).cloned()
}
}
impl Default for Tunnels {
fn default() -> Self {
Self::new()
}
} }