Reorganized files
This commit is contained in:
114
src/web/auth.rs
Normal file
114
src/web/auth.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use hyper::{
|
||||
HeaderMap, Method, StatusCode,
|
||||
header::{self, HeaderName, HeaderValue, ToStrError},
|
||||
};
|
||||
use reqwest::redirect::Policy;
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ForwardAuth {
|
||||
address: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct User {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn is(&self, username: impl AsRef<str>) -> bool {
|
||||
self.username.eq(username.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthStatus {
|
||||
// Contains the value of the location header that will redirect the user to the login page
|
||||
Unauthenticated(HeaderValue),
|
||||
Authenticated(User),
|
||||
Unauthorized,
|
||||
}
|
||||
|
||||
const REMOTE_USER: HeaderName = HeaderName::from_static("remote-user");
|
||||
const X_FORWARDED_METHOD: HeaderName = HeaderName::from_static("x-forwarded-method");
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("Reqwest error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("Http error: {0}")]
|
||||
Http(#[from] hyper::http::Error),
|
||||
#[error("Header '{0}' is missing from auth endpoint response")]
|
||||
MissingHeader(HeaderName),
|
||||
#[error("Header '{0}' received from auth endpoint is invalid: {1}")]
|
||||
InvalidHeader(HeaderName, ToStrError),
|
||||
#[error("Unexpected response from auth endpoint: {0:?}")]
|
||||
UnexpectedResponse(reqwest::Response),
|
||||
}
|
||||
|
||||
impl ForwardAuth {
|
||||
pub fn new(endpoint: impl Into<String>) -> Self {
|
||||
Self {
|
||||
address: endpoint.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check(
|
||||
&self,
|
||||
methods: &Method,
|
||||
headers: &HeaderMap<HeaderValue>,
|
||||
) -> Result<AuthStatus, AuthError> {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.redirect(Policy::none())
|
||||
.build()?;
|
||||
|
||||
let mut headers: HeaderMap = headers
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter_map(|(key, value)| {
|
||||
if let Some(key) = key
|
||||
&& key != header::CONTENT_LENGTH
|
||||
&& key != header::HOST
|
||||
{
|
||||
Some((key, value))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
headers.insert(
|
||||
X_FORWARDED_METHOD,
|
||||
HeaderValue::from_str(methods.as_str()).expect("method should convert to valid ascii"),
|
||||
);
|
||||
|
||||
let resp = client.get(&self.address).headers(headers).send().await?;
|
||||
|
||||
let status_code = resp.status();
|
||||
if status_code == StatusCode::FOUND {
|
||||
let location = resp
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.cloned()
|
||||
.ok_or(AuthError::MissingHeader(header::LOCATION))?;
|
||||
|
||||
return Ok(AuthStatus::Unauthenticated(location));
|
||||
} else if status_code == StatusCode::FORBIDDEN {
|
||||
return Ok(AuthStatus::Unauthorized);
|
||||
} else if !status_code.is_success() {
|
||||
return Err(AuthError::UnexpectedResponse(resp));
|
||||
}
|
||||
|
||||
let username = resp
|
||||
.headers()
|
||||
.get(REMOTE_USER)
|
||||
.ok_or(AuthError::MissingHeader(REMOTE_USER))?
|
||||
.to_str()
|
||||
.map_err(|err| AuthError::InvalidHeader(REMOTE_USER, err))?
|
||||
.to_owned();
|
||||
|
||||
debug!("Connected user is: {username}");
|
||||
|
||||
Ok(AuthStatus::Authenticated(User { username }))
|
||||
}
|
||||
}
|
||||
149
src/web/mod.rs
Normal file
149
src/web/mod.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
mod auth;
|
||||
mod response;
|
||||
|
||||
use crate::tunnel::Registry;
|
||||
use std::{ops::Deref, pin::Pin};
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt as _, Empty, combinators::BoxBody};
|
||||
use hyper::{
|
||||
Request, Response, StatusCode,
|
||||
body::Incoming,
|
||||
client::conn::http1::Builder,
|
||||
header::{self, HOST},
|
||||
};
|
||||
use tracing::{debug, error, trace, warn};
|
||||
|
||||
use crate::tunnel::TunnelAccess;
|
||||
use auth::AuthStatus;
|
||||
pub use auth::ForwardAuth;
|
||||
use response::response;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Service {
|
||||
registry: Registry,
|
||||
auth: ForwardAuth,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn new(registry: Registry, auth: ForwardAuth) -> Self {
|
||||
Self { registry, auth }
|
||||
}
|
||||
}
|
||||
|
||||
impl hyper::service::Service<Request<Incoming>> for Service {
|
||||
type Response = Response<BoxBody<Bytes, hyper::Error>>;
|
||||
type Error = hyper::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||
trace!("{:#?}", req);
|
||||
|
||||
let Some(authority) = req
|
||||
.uri()
|
||||
.authority()
|
||||
.as_ref()
|
||||
.map(|a| a.to_string())
|
||||
.or_else(|| {
|
||||
req.headers()
|
||||
.get(HOST)
|
||||
.and_then(|h| h.to_str().ok().map(|s| s.to_owned()))
|
||||
})
|
||||
else {
|
||||
let resp = response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing or invalid authority or host header",
|
||||
);
|
||||
|
||||
return Box::pin(async { Ok(resp) });
|
||||
};
|
||||
|
||||
debug!(authority, "Tunnel request");
|
||||
|
||||
let registry = self.registry.clone();
|
||||
let auth = self.auth.clone();
|
||||
Box::pin(async move {
|
||||
let Some(entry) = registry.get(&authority).await else {
|
||||
debug!(tunnel = authority, "Unknown tunnel");
|
||||
let resp = response(StatusCode::NOT_FOUND, "Unknown tunnel");
|
||||
|
||||
return Ok(resp);
|
||||
};
|
||||
|
||||
if !entry.is_public().await {
|
||||
let user = match auth.check(req.method(), req.headers()).await {
|
||||
Ok(AuthStatus::Authenticated(user)) => user,
|
||||
Ok(AuthStatus::Unauthenticated(location)) => {
|
||||
let resp = Response::builder()
|
||||
.status(StatusCode::FOUND)
|
||||
.header(header::LOCATION, location)
|
||||
.body(
|
||||
Empty::new()
|
||||
// NOTE: I have NO idea why this is able to convert from Innfallible to hyper::Error
|
||||
.map_err(|never| match never {})
|
||||
.boxed(),
|
||||
)
|
||||
.expect("configuration should be valid");
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
Ok(AuthStatus::Unauthorized) => {
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"You do not have permission to access this tunnel",
|
||||
);
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Unexpected error during authentication: {err}");
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Unexpected error during authentication",
|
||||
);
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Tunnel is getting accessed by {user:?}");
|
||||
|
||||
if let TunnelAccess::Private(owner) = entry.get_access().await.deref() {
|
||||
if !user.is(owner) {
|
||||
let resp = response(
|
||||
StatusCode::FORBIDDEN,
|
||||
"You do not have permission to access this tunnel",
|
||||
);
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let io = match entry.open().await {
|
||||
Ok(io) => io,
|
||||
Err(err) => {
|
||||
warn!(tunnel = authority, "Failed to open tunnel: {err}");
|
||||
let resp = response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to open tunnel");
|
||||
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
let (mut sender, conn) = Builder::new()
|
||||
.preserve_header_case(true)
|
||||
.title_case_headers(true)
|
||||
.handshake(io)
|
||||
.await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = conn.await {
|
||||
warn!(runnel = authority, "Connection failed: {err}");
|
||||
}
|
||||
});
|
||||
|
||||
let resp = sender.send_request(req).await?;
|
||||
Ok(resp.map(|b| b.boxed()))
|
||||
})
|
||||
}
|
||||
}
|
||||
14
src/web/response.rs
Normal file
14
src/web/response.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt as _, Full, combinators::BoxBody};
|
||||
use hyper::{Response, StatusCode};
|
||||
|
||||
pub fn response(
|
||||
status_code: StatusCode,
|
||||
body: impl Into<String>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status_code)
|
||||
.body(Full::new(Bytes::from(body.into())))
|
||||
.expect("all configuration should be valid")
|
||||
.map(|b| b.map_err(|never| match never {}).boxed())
|
||||
}
|
||||
Reference in New Issue
Block a user