From 82859d8e46daa91af1157c8885d88a9ec11f7398 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Fri, 6 Jan 2023 03:34:33 +0100 Subject: [PATCH] Added authentication to fullfillment endpoint --- config/zeus.dev.toml | 6 ++--- src/auth.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 28 +++++++++++++++++++-- src/lib.rs | 1 + src/main.rs | 28 +++++++++++++++------ 5 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 src/auth.rs diff --git a/config/zeus.dev.toml b/config/zeus.dev.toml index 7976b62..58fa33f 100644 --- a/config/zeus.dev.toml +++ b/config/zeus.dev.toml @@ -1,11 +1,11 @@ +[openid] +base_url = "https://login.huizinga.dev/api/oidc" + [mqtt] host="olympus.lan.huizinga.dev" port=8883 username="mqtt" -[fullfillment] -username="Dreaded_X" - [presence] topic = "automation_dev/presence" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..16a0bbd --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,60 @@ +use axum::{ + async_trait, + extract::{FromRequestParts, FromRef}, + http::{StatusCode, request::Parts}, + response::{IntoResponse, Response}, +}; +use serde::Deserialize; + +use crate::config::OpenIDConfig; + +#[derive(Debug, Deserialize)] +pub struct User { + pub preferred_username: String, +} + +#[async_trait] +impl FromRequestParts for User +where + OpenIDConfig: FromRef, + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // Get the state + let openid = OpenIDConfig::from_ref(state); + + // Create a request to the auth server + // @TODO Do some discovery to find the correct url for this instead of assuming + let mut req = reqwest::Client::new() + .get(format!("{}/userinfo", openid.base_url)); + + // Add auth header to the request if it exists + if let Some(auth) = parts.headers.get(axum::http::header::AUTHORIZATION) { + req = req.header(reqwest::header::AUTHORIZATION, auth); + } + + // Send the request + let res = req.send() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + + // If the request is success full the auth token is valid and we are given userinfo + let status = res.status(); + if status.is_success() { + let user = res.json() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + + return Ok(user); + } else { + let err = res + .text() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + + return Err((status, err).into_response()); + } + } +} diff --git a/src/config.rs b/src/config.rs index 6d3bccf..36c7180 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use std::{fs, error::Error, collections::HashMap, net::Ipv4Addr}; +use std::{fs, error::Error, collections::HashMap, net::{Ipv4Addr, SocketAddr}}; use tracing::{debug, trace}; use rumqttc::AsyncClient; @@ -10,7 +10,9 @@ use crate::devices::{DeviceBox, IkeaOutlet, WakeOnLAN, AudioSetup, ContactSensor #[derive(Debug, Deserialize)] pub struct Config { + pub openid: OpenIDConfig, pub mqtt: MqttConfig, + #[serde(default)] pub fullfillment: FullfillmentConfig, #[serde(default)] pub ntfy: NtfyConfig, @@ -21,6 +23,11 @@ pub struct Config { pub devices: HashMap } +#[derive(Debug, Clone, Deserialize)] +pub struct OpenIDConfig { + pub base_url: String +} + #[derive(Debug, Deserialize)] pub struct MqttConfig { pub host: String, @@ -31,9 +38,26 @@ pub struct MqttConfig { #[derive(Debug, Deserialize)] pub struct FullfillmentConfig { + #[serde(default = "default_fullfillment_ip")] + pub ip: Ipv4Addr, #[serde(default = "default_fullfillment_port")] pub port: u16, - pub username: String, +} + +impl From for SocketAddr { + fn from(fullfillment: FullfillmentConfig) -> Self { + (fullfillment.ip, fullfillment.port).into() + } +} + +impl Default for FullfillmentConfig { + fn default() -> Self { + Self { ip: default_fullfillment_ip(), port: default_fullfillment_port() } + } +} + +fn default_fullfillment_ip() -> Ipv4Addr { + [127, 0, 0, 1].into() } fn default_fullfillment_port() -> u16 { diff --git a/src/lib.rs b/src/lib.rs index 4138364..770d471 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,4 @@ pub mod presence; pub mod ntfy; pub mod light_sensor; pub mod hue_bridge; +pub mod auth; diff --git a/src/main.rs b/src/main.rs index 9161426..4118cb8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,9 @@ #![feature(async_closure)] use std::{time::Duration, sync::{Arc, RwLock}, process, net::SocketAddr}; -use axum::{Router, Json, routing::post, http::StatusCode}; +use axum::{Router, Json, routing::post, http::StatusCode, extract::FromRef}; -use automation::{config::Config, presence::Presence, ntfy::Ntfy, light_sensor::{self, LightSensor}, hue_bridge::HueBridge}; +use automation::{config::{Config, OpenIDConfig}, presence::Presence, ntfy::Ntfy, light_sensor::LightSensor, hue_bridge::HueBridge, auth::User}; use dotenv::dotenv; use rumqttc::{MqttOptions, Transport, AsyncClient}; use tracing::{error, info, metadata::LevelFilter}; @@ -12,6 +12,17 @@ use automation::{devices::Devices, mqtt::Mqtt}; use google_home::{GoogleHome, Request}; use tracing_subscriber::EnvFilter; +#[derive(Clone)] +struct AppState { + pub openid: OpenIDConfig +} + +impl FromRef for automation::config::OpenIDConfig { + fn from_ref(input: &AppState) -> Self { + input.openid.clone() + } +} + #[tokio::main] async fn main() { dotenv().ok(); @@ -85,12 +96,10 @@ async fn main() { // Create google home fullfillment route let fullfillment = Router::new() - .route("/google_home", post(async move |Json(payload): Json| { + .route("/google_home", post(async move |user: User, Json(payload): Json| { // Handle request might block, so we need to spawn a blocking task tokio::task::spawn_blocking(move || { - // @TODO Verify that we are actually logged in - // Might also be smart to get the username from here - let gc = GoogleHome::new(&config.fullfillment.username); + let gc = GoogleHome::new(&user.preferred_username); let result = gc.handle_request(payload, &mut devices.write().unwrap().as_google_home_devices()).unwrap(); return (StatusCode::OK, Json(result)); @@ -99,10 +108,13 @@ async fn main() { // Combine together all the routes let app = Router::new() - .nest("/fullfillment", fullfillment); + .nest("/fullfillment", fullfillment) + .with_state(AppState { + openid: config.openid + }); // Start the web server - let addr: SocketAddr = ([127, 0, 0, 1], config.fullfillment.port).into(); + let addr = config.fullfillment.into(); info!("Server started on http://{addr}"); axum::Server::bind(&addr) .serve(app.into_make_service())