Refactor tunnel

This commit is contained in:
Dreaded_X 2025-04-04 16:56:47 +02:00
parent 2bd26f1db2
commit bb50a59316
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
4 changed files with 63 additions and 44 deletions

View File

@ -1,3 +1,4 @@
#![feature(impl_trait_in_fn_trait_return)]
pub mod animals;
pub mod ssh;
pub mod tunnel;

View File

@ -70,7 +70,7 @@ async fn main() {
debug!("Request for {authority:?}");
let Some(tunnel) = tunnels.read().await.get(&authority).cloned() else {
let Some(tunnel) = tunnels.get_tunnel(&authority).await else {
let mut resp = Response::new(full(format!("Unknown tunnel: {authority}")));
*resp.status_mut() = StatusCode::NOT_FOUND;

View File

@ -11,10 +11,7 @@ use tokio::{
};
use tracing::{debug, error};
use crate::{
animals::get_animal_name,
tunnel::{self, Tunnel, Tunnels},
};
use crate::tunnel::{Tunnel, Tunnels};
pub struct Handler {
tx: UnboundedSender<Vec<u8>>,
@ -28,26 +25,6 @@ impl Handler {
fn send(&self, data: &str) {
let _ = self.tx.send(data.as_bytes().to_vec());
}
async fn full_address(&self, address: &str) -> Option<String> {
let all_tunnels = self.all_tunnels.read().await;
let address = if address == "localhost" {
loop {
let address = get_animal_name();
if !all_tunnels.contains_key(address) {
break address;
}
}
} else {
if all_tunnels.contains_key(address) {
return None;
}
address
};
Some(format!("{address}.tunnel.huizinga.dev"))
}
}
impl russh::server::Handler for Handler {
@ -126,18 +103,13 @@ impl russh::server::Handler for Handler {
) -> Result<bool, Self::Error> {
debug!("{address}:{port}");
let Some(full_address) = self.full_address(address).await else {
let tunnel = Tunnel::new(session.handle(), address, *port);
let Some(address) = self.all_tunnels.add_tunnel(address, tunnel).await else {
self.send(&format!("{port} => FAILED ({address} already in use)\r\n"));
return Ok(false);
};
self.tunnels.insert(full_address.clone());
self.all_tunnels.write().await.insert(
full_address.clone(),
Tunnel::new(session.handle(), address, *port),
);
self.send(&format!("{port} => https://{full_address}\r\n"));
self.send(&format!("{port} => https://{address}\r\n"));
self.tunnels.insert(address);
Ok(true)
}
@ -146,13 +118,10 @@ impl russh::server::Handler for Handler {
impl Drop for Handler {
fn drop(&mut self) {
let tunnels = self.tunnels.clone();
let all_tunnels = self.all_tunnels.clone();
let mut all_tunnels = self.all_tunnels.clone();
tokio::spawn(async move {
let mut all_tunnels = all_tunnels.write().await;
for tunnel in tunnels {
all_tunnels.remove(&tunnel);
}
all_tunnels.remove_tunnels(tunnels.clone()).await;
debug!("{all_tunnels:?}");
});
@ -166,7 +135,7 @@ pub struct Server {
impl Server {
pub fn new() -> Self {
Server {
tunnels: tunnel::new(),
tunnels: Tunnels::new(),
}
}

View File

@ -1,4 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use russh::{
Channel,
@ -6,6 +9,8 @@ use russh::{
};
use tokio::sync::RwLock;
use crate::animals::get_animal_name;
#[derive(Debug, Clone)]
pub struct Tunnel {
handle: Handle,
@ -29,8 +34,52 @@ impl Tunnel {
}
}
pub type Tunnels = Arc<RwLock<HashMap<String, Tunnel>>>;
#[derive(Debug, Clone)]
pub struct Tunnels(Arc<RwLock<HashMap<String, Tunnel>>>);
pub fn new() -> Tunnels {
Arc::new(RwLock::new(HashMap::new()))
impl Tunnels {
pub fn new() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
pub async fn add_tunnel(&mut self, address: &str, tunnel: Tunnel) -> Option<String> {
let mut all_tunnels = self.0.write().await;
let address = if address == "localhost" {
loop {
let address = get_animal_name();
if !all_tunnels.contains_key(address) {
break address;
}
}
} else {
if all_tunnels.contains_key(address) {
return None;
}
address
};
let address = format!("{address}.tunnel.huizinga.dev");
all_tunnels.insert(address.clone(), tunnel);
Some(address)
}
pub async fn remove_tunnels(&mut self, tunnels: HashSet<String>) {
let mut all_tunnels = self.0.write().await;
for tunnel in tunnels {
all_tunnels.remove(&tunnel);
}
}
pub async fn get_tunnel(&self, address: &str) -> Option<Tunnel> {
self.0.read().await.get(address).cloned()
}
}
impl Default for Tunnels {
fn default() -> Self {
Self::new()
}
}