diff --git a/src/ssh.rs b/src/ssh.rs index 8a24251..b77ab92 100644 --- a/src/ssh.rs +++ b/src/ssh.rs @@ -1,12 +1,5 @@ -use std::{collections::HashSet, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; +use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration}; -use bytes::Bytes; -use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; -use hyper::{ - Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST, - service::Service, -}; -use hyper_util::rt::TokioIo; use russh::{ ChannelId, keys::PrivateKey, @@ -16,7 +9,7 @@ use tokio::{ net::ToSocketAddrs, sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}, }; -use tracing::{debug, error, trace, warn}; +use tracing::{debug, error}; use crate::tunnel::{Tunnel, Tunnels}; @@ -195,78 +188,3 @@ impl russh::server::Server for Server { error!("Session error: {error:#?}"); } } - -impl Service> for Tunnels { - type Response = Response>; - type Error = hyper::Error; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - fn response( - status_code: StatusCode, - body: impl Into, - ) -> Response> { - Response::builder() - .status(status_code) - .body(Full::new(Bytes::from(body.into()))) - .unwrap() - .map(|b| b.map_err(|never| match never {}).boxed()) - } - - trace!(?req); - - let Some(authority) = req - .uri() - .authority() - .as_ref() - .map(|a| a.to_string()) - .or_else(|| { - req.headers() - .get(HOST) - .map(|h| h.to_str().unwrap().to_owned()) - }) - else { - let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header"); - - return Box::pin(async { Ok(resp) }); - }; - - debug!("Request for {authority:?}"); - - let tunnels = self.clone(); - Box::pin(async move { - let Some(tunnel) = tunnels.get_tunnel(&authority).await else { - let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); - - return Ok::<_, hyper::Error>(resp); - }; - - debug!("Opening channel"); - let channel = match tunnel.open_tunnel().await { - Ok(channel) => channel, - Err(err) => { - warn!("Failed to open tunnel: {err}"); - let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); - - return Ok::<_, hyper::Error>(resp); - } - }; - let io = TokioIo::new(channel.into_stream()); - - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - - tokio::spawn(async move { - if let Err(err) = conn.await { - warn!("Connection failed: {err}"); - } - }); - - let resp = sender.send_request(req).await.unwrap(); - Ok(resp.map(|b| b.boxed())) - }) - } -} diff --git a/src/tunnel.rs b/src/tunnel.rs index 07a6440..6865b66 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,7 +1,16 @@ +use bytes::Bytes; +use http_body_util::{BodyExt as _, Full, combinators::BoxBody}; +use hyper::{ + Request, Response, StatusCode, body::Incoming, client::conn::http1::Builder, header::HOST, + service::Service, +}; +use hyper_util::rt::TokioIo; use std::{ collections::{HashMap, HashSet}, + pin::Pin, sync::Arc, }; +use tracing::{debug, trace, warn}; use russh::{ Channel, @@ -83,3 +92,78 @@ impl Default for Tunnels { Self::new() } } + +impl Service> for Tunnels { + type Response = Response>; + type Error = hyper::Error; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + fn response( + status_code: StatusCode, + body: impl Into, + ) -> Response> { + Response::builder() + .status(status_code) + .body(Full::new(Bytes::from(body.into()))) + .unwrap() + .map(|b| b.map_err(|never| match never {}).boxed()) + } + + trace!(?req); + + let Some(authority) = req + .uri() + .authority() + .as_ref() + .map(|a| a.to_string()) + .or_else(|| { + req.headers() + .get(HOST) + .map(|h| h.to_str().unwrap().to_owned()) + }) + else { + let resp = response(StatusCode::BAD_REQUEST, "Missing authority or host header"); + + return Box::pin(async { Ok(resp) }); + }; + + debug!("Request for {authority:?}"); + + let tunnels = self.clone(); + Box::pin(async move { + let Some(tunnel) = tunnels.get_tunnel(&authority).await else { + let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel"); + + return Ok::<_, hyper::Error>(resp); + }; + + debug!("Opening channel"); + let channel = match tunnel.open_tunnel().await { + Ok(channel) => channel, + Err(err) => { + warn!("Failed to open tunnel: {err}"); + let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel"); + + return Ok::<_, hyper::Error>(resp); + } + }; + let io = TokioIo::new(channel.into_stream()); + + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!("Connection failed: {err}"); + } + }); + + let resp = sender.send_request(req).await.unwrap(); + Ok(resp.map(|b| b.boxed())) + }) + } +}