diff --git a/src/main.rs b/src/main.rs index 14aca68..03c521f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,27 +1,13 @@ use std::{net::SocketAddr, path::Path}; -use bytes::Bytes; -use http_body_util::{BodyExt, Full, combinators::BoxBody}; -use hyper::{ - Method, Request, Response, StatusCode, - client::conn::http1::Builder, - header::HOST, - server::conn::http1::{self}, - service::service_fn, -}; +use hyper::server::conn::http1::{self}; use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use tokio::net::TcpListener; -use tracing::{debug, trace, warn}; +use tracing::warn; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use tunnel_rs::ssh::Server; -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} - #[tokio::main] async fn main() { let env_filter = EnvFilter::try_from_default_env() @@ -41,85 +27,18 @@ async fn main() { let tunnels = ssh.tunnels(); tokio::spawn(async move { ssh.run(key, ("0.0.0.0", 2222)).await }); - let service = service_fn(move |req: Request<_>| { - let tunnels = tunnels.clone(); - async move { - if req.method() == Method::CONNECT { - let mut resp = Response::new(full("CONNECT not supported")); - *resp.status_mut() = StatusCode::BAD_REQUEST; - - Ok::<_, hyper::Error>(resp) - } else { - trace!(?req); - let Some(authority) = req - .uri() - .authority() - .as_ref() - .map(|a| a.to_string()) - .or_else(|| { - req.headers() - .get(HOST) - .map(|h| h.to_str().unwrap().to_owned()) - }) - else { - let mut resp = Response::new(full("Missing authority or host header")); - *resp.status_mut() = StatusCode::BAD_REQUEST; - - return Ok::<_, hyper::Error>(resp); - }; - - debug!("Request for {authority:?}"); - - let Some(tunnel) = tunnels.get_tunnel(&authority).await else { - let mut resp = Response::new(full(format!("Unknown tunnel: {authority}"))); - *resp.status_mut() = StatusCode::NOT_FOUND; - - return Ok::<_, hyper::Error>(resp); - }; - - debug!("Opening channel"); - let channel = match tunnel.open_tunnel().await { - Ok(channel) => channel, - Err(err) => { - warn!("Failed to tunnel: {err}"); - let mut resp = Response::new(full("Failed to open tunnel")); - *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - - return Ok::<_, hyper::Error>(resp); - } - }; - let io = TokioIo::new(channel.into_stream()); - - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - - tokio::spawn(async move { - if let Err(err) = conn.await { - warn!("Connection failed: {err}"); - } - }); - - let resp = sender.send_request(req).await.unwrap(); - Ok(resp.map(|b| b.boxed())) - } - } - }); - let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); let listener = TcpListener::bind(addr).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); let io = TokioIo::new(stream); - let service = service.clone(); + let tunnels = tunnels.clone(); tokio::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(io, service) + .serve_connection(io, tunnels) .with_upgrades() .await { diff --git a/src/ssh.rs b/src/ssh.rs index b77ab92..0044bbe 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -1,5 +1,12 @@ -use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration}; +use std::{collections::HashSet, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; +use bytes::Bytes; +use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; +use hyper::{ + Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST, + service::Service, +}; +use hyper_util::rt::TokioIo; use russh::{ ChannelId, keys::PrivateKey, @@ -9,7 +16,7 @@ use tokio::{ net::ToSocketAddrs, sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}, }; -use tracing::{debug, error}; +use tracing::{debug, error, trace, warn}; use crate::tunnel::{Tunnel, Tunnels}; @@ -188,3 +195,76 @@ impl russh::server::Server for Server { error!("Session error: {error:#?}"); } } + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +impl Service> for Tunnels { + type Response = Response>; + type Error = hyper::Error; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + trace!(?req); + + let Some(authority) = req + .uri() + .authority() + .as_ref() + .map(|a| a.to_string()) + .or_else(|| { + req.headers() + .get(HOST) + .map(|h| h.to_str().unwrap().to_owned()) + }) + else { + let mut resp = Response::new(full("Missing authority or host header")); + *resp.status_mut() = StatusCode::BAD_REQUEST; + + return Box::pin(async { Ok(resp) }); + }; + + debug!("Request for {authority:?}"); + + let tunnels = self.clone(); + Box::pin(async move { + let Some(tunnel) = tunnels.get_tunnel(&authority).await else { + let mut resp = Response::new(full(format!("Unknown tunnel: {authority}"))); + *resp.status_mut() = StatusCode::NOT_FOUND; + + return Ok::<_, hyper::Error>(resp); + }; + + debug!("Opening channel"); + let channel = match tunnel.open_tunnel().await { + Ok(channel) => channel, + Err(err) => { + warn!("Failed to tunnel: {err}"); + let mut resp = Response::new(full("Failed to open tunnel")); + *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + + return Ok::<_, hyper::Error>(resp); + } + }; + let io = TokioIo::new(channel.into_stream()); + + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!("Connection failed: {err}"); + } + }); + + let resp = sender.send_request(req).await.unwrap(); + Ok(resp.map(|b| b.boxed())) + }) + } +}