diff --git a/src/auth.rs b/src/auth.rs index 840d8f8..ab45dc1 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,5 +1,5 @@ use hyper::{ - HeaderMap, StatusCode, + HeaderMap, Method, StatusCode, header::{self, HeaderName, HeaderValue, ToStrError}, }; use reqwest::redirect::Policy; @@ -30,6 +30,7 @@ pub enum AuthStatus { } const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user"); +const X_FORWARDED_METHOD: HeaderName = HeaderName::from_static("x-forwarded-method"); #[derive(Debug, thiserror::Error)] pub enum AuthError { @@ -54,13 +55,14 @@ impl ForwardAuth { pub async fn check_auth( &self, + methods: &Method, headers: &HeaderMap, ) -> Result { let client = reqwest::ClientBuilder::new() .redirect(Policy::none()) .build()?; - let headers = headers + let mut headers: HeaderMap = headers .clone() .into_iter() .filter_map(|(key, value)| { @@ -75,6 +77,11 @@ impl ForwardAuth { }) .collect(); + headers.insert( + X_FORWARDED_METHOD, + HeaderValue::from_str(methods.as_str()).expect("method should convert to valid ascii"), + ); + let resp = client.get(&self.address).headers(headers).send().await?; let status_code = resp.status(); diff --git a/src/tunnel.rs b/src/tunnel.rs index fc7e4f3..6f9a55b 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -215,7 +215,7 @@ impl Service> for Tunnels { }; if !matches!(tunnel.access.read().await.deref(), TunnelAccess::Public) { - let user = match s.forward_auth.check_auth(req.headers()).await { + let user = match s.forward_auth.check_auth(req.method(), req.headers()).await { Ok(AuthStatus::Authenticated(user)) => user, Ok(AuthStatus::Unauthenticated(location)) => { let resp = Response::builder()