diff --git a/config.lua b/config.lua index 468055d..39d33eb 100644 --- a/config.lua +++ b/config.lua @@ -1,5 +1,9 @@ print("Hello from lua") +automation.fulfillment = { + openid_url = "https://login.huizinga.dev/api/oidc", +} + local debug, value = pcall(automation.util.get_env, "DEBUG") if debug and value ~= "true" then debug = false @@ -13,7 +17,7 @@ local function mqtt_automation(topic) return "automation/" .. topic end -local mqtt_client = automation.create_mqtt_client({ +local mqtt_client = automation.new_mqtt_client({ host = debug and "olympus.lan.huizinga.dev" or "mosquitto", port = 8883, client_name = debug and "automation-debug" or "automation_rs", @@ -24,13 +28,13 @@ local mqtt_client = automation.create_mqtt_client({ automation.device_manager:add(Ntfy.new({ topic = automation.util.get_env("NTFY_TOPIC"), - event_channel = automation.event_channel, + event_channel = automation.device_manager:event_channel(), })) automation.device_manager:add(Presence.new({ topic = "automation_dev/presence/+/#", client = mqtt_client, - event_channel = automation.event_channel, + event_channel = automation.device_manager:event_channel(), })) automation.device_manager:add(DebugBridge.new({ @@ -58,7 +62,7 @@ automation.device_manager:add(LightSensor.new({ client = mqtt_client, min = 22000, max = 23500, - event_channel = automation.event_channel, + event_channel = automation.device_manager:event_channel(), })) automation.device_manager:add(WakeOnLAN.new({ @@ -110,7 +114,7 @@ automation.device_manager:add(Washer.new({ topic = mqtt_z2m("batchroom/washer"), client = mqtt_client, threshold = 1, - event_channel = automation.event_channel, + event_channel = automation.device_manager:event_channel(), })) automation.device_manager:add(IkeaOutlet.new({ diff --git a/config/config.yml b/config/config.yml deleted file mode 100644 index 2f04777..0000000 --- a/config/config.yml +++ /dev/null @@ -1,2 +0,0 @@ -openid: - base_url: "https://login.huizinga.dev/api/oidc" diff --git a/config/zeus.dev.yml b/config/zeus.dev.yml deleted file mode 100644 index 2f04777..0000000 --- a/config/zeus.dev.yml +++ /dev/null @@ -1,2 +0,0 @@ -openid: - base_url: "https://login.huizinga.dev/api/oidc" diff --git a/src/auth.rs b/src/auth.rs index 75d5dea..5703657 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -6,11 +6,6 @@ use serde::Deserialize; use crate::error::{ApiError, ApiErrorJson}; -#[derive(Debug, Clone, Deserialize)] -pub struct OpenIDConfig { - pub base_url: String, -} - #[derive(Debug, Deserialize)] pub struct User { pub preferred_username: String, @@ -19,18 +14,18 @@ pub struct User { #[async_trait] impl FromRequestParts for User where - OpenIDConfig: FromRef, + String: FromRef, S: Send + Sync, { type Rejection = ApiError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // Get the state - let openid = OpenIDConfig::from_ref(state); + let openid_url = String::from_ref(state); // Create a request to the auth server // TODO: Do some discovery to find the correct url for this instead of assuming - let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid.base_url)); + let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid_url)); // Add auth header to the request if it exists if let Some(auth) = parts.headers.get(axum::http::header::AUTHORIZATION) { diff --git a/src/config.rs b/src/config.rs index 5be2e12..42bee56 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,21 +1,8 @@ -use std::fs; use std::net::{Ipv4Addr, SocketAddr}; use std::time::Duration; -use regex::{Captures, Regex}; use rumqttc::{MqttOptions, Transport}; use serde::Deserialize; -use tracing::debug; - -use crate::auth::OpenIDConfig; -use crate::error::{ConfigParseError, MissingEnv}; - -#[derive(Debug, Deserialize)] -pub struct Config { - pub openid: OpenIDConfig, - #[serde(default)] - pub fullfillment: FullfillmentConfig, -} #[derive(Debug, Clone, Deserialize)] pub struct MqttConfig { @@ -43,33 +30,25 @@ impl From for MqttOptions { } #[derive(Debug, Deserialize)] -pub struct FullfillmentConfig { - #[serde(default = "default_fullfillment_ip")] +pub struct FulfillmentConfig { + pub openid_url: String, + #[serde(default = "default_fulfillment_ip")] pub ip: Ipv4Addr, - #[serde(default = "default_fullfillment_port")] + #[serde(default = "default_fulfillment_port")] pub port: u16, } -impl From for SocketAddr { - fn from(fullfillment: FullfillmentConfig) -> Self { - (fullfillment.ip, fullfillment.port).into() +impl From for SocketAddr { + fn from(fulfillment: FulfillmentConfig) -> Self { + (fulfillment.ip, fulfillment.port).into() } } -impl Default for FullfillmentConfig { - fn default() -> Self { - Self { - ip: default_fullfillment_ip(), - port: default_fullfillment_port(), - } - } -} - -fn default_fullfillment_ip() -> Ipv4Addr { +fn default_fulfillment_ip() -> Ipv4Addr { [0, 0, 0, 0].into() } -fn default_fullfillment_port() -> u16 { +fn default_fulfillment_port() -> u16 { 7878 } @@ -93,31 +72,3 @@ impl InfoConfig { pub struct MqttDeviceConfig { pub topic: String, } - -impl Config { - pub fn parse_file(filename: &str) -> Result { - debug!("Loading config: {filename}"); - let file = fs::read_to_string(filename)?; - - // Substitute in environment variables - let re = Regex::new(r"\$\{(.*)\}").expect("Regex should be valid"); - let mut missing = MissingEnv::new(); - let file = re.replace_all(&file, |caps: &Captures| { - let key = caps.get(1).expect("Capture group should exist").as_str(); - debug!("Substituting '{key}' in config"); - match std::env::var(key) { - Ok(value) => value, - Err(_) => { - missing.add_missing(key); - "".into() - } - } - }); - - missing.has_missing()?; - - let config: Config = serde_yaml::from_str(&file)?; - - Ok(config) - } -} diff --git a/src/device_manager.rs b/src/device_manager.rs index 0ec40d1..668a414 100644 --- a/src/device_manager.rs +++ b/src/device_manager.rs @@ -238,6 +238,8 @@ impl mlua::UserData for DeviceManager { let schedule = lua.from_value(schedule)?; this.add_schedule(schedule).await; Ok(()) - }) + }); + + methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel())) } } diff --git a/src/devices/hue_light.rs b/src/devices/hue_group.rs similarity index 100% rename from src/devices/hue_light.rs rename to src/devices/hue_group.rs diff --git a/src/devices/mod.rs b/src/devices/mod.rs index e062289..6644d5e 100644 --- a/src/devices/mod.rs +++ b/src/devices/mod.rs @@ -3,7 +3,7 @@ mod audio_setup; mod contact_sensor; mod debug_bridge; mod hue_bridge; -mod hue_light; +mod hue_group; mod ikea_outlet; mod kasa_outlet; mod light_sensor; @@ -21,7 +21,7 @@ pub use self::audio_setup::*; pub use self::contact_sensor::*; pub use self::debug_bridge::*; pub use self::hue_bridge::*; -pub use self::hue_light::*; +pub use self::hue_group::*; pub use self::ikea_outlet::*; pub use self::kasa_outlet::*; pub use self::light_sensor::*; @@ -42,6 +42,24 @@ pub trait LuaDeviceCreate { Self: Sized; } +pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> { + AirFilter::register_with_lua(lua)?; + AudioSetup::register_with_lua(lua)?; + ContactSensor::register_with_lua(lua)?; + DebugBridge::register_with_lua(lua)?; + HueBridge::register_with_lua(lua)?; + HueGroup::register_with_lua(lua)?; + IkeaOutlet::register_with_lua(lua)?; + KasaOutlet::register_with_lua(lua)?; + LightSensor::register_with_lua(lua)?; + Ntfy::register_with_lua(lua)?; + Presence::register_with_lua(lua)?; + WakeOnLAN::register_with_lua(lua)?; + Washer::register_with_lua(lua)?; + + Ok(()) +} + #[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + OnNotification + OnOff + Timeout)] pub trait Device: AsGoogleHomeDevice + std::fmt::Debug + Sync + Send { fn get_id(&self) -> String; diff --git a/src/error.rs b/src/error.rs index 0052984..7deded4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,16 +65,6 @@ pub enum ParseError { InvalidPayload(Bytes), } -#[derive(Debug, Error)] -pub enum ConfigParseError { - #[error(transparent)] - MissingEnv(#[from] MissingEnv), - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - YamlError(#[from] serde_yaml::Error), -} - // TODO: Would be nice to somehow get the line number of the expected wildcard topic #[derive(Debug, Error)] #[error("Topic '{topic}' is expected to be a wildcard topic")] diff --git a/src/main.rs b/src/main.rs index f2008de..66a4620 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,12 @@ #![feature(async_closure)] -use std::{fs, process}; +use std::path::Path; +use std::process; -use automation::auth::{OpenIDConfig, User}; -use automation::config::{Config, MqttConfig}; +use anyhow::anyhow; +use automation::auth::User; +use automation::config::{FulfillmentConfig, MqttConfig}; use automation::device_manager::DeviceManager; -use automation::devices::{ - AirFilter, AudioSetup, ContactSensor, DebugBridge, HueBridge, HueGroup, IkeaOutlet, KasaOutlet, - LightSensor, Ntfy, Presence, WakeOnLAN, Washer, -}; +use automation::devices; use automation::error::ApiError; use automation::mqtt::{self, WrappedAsyncClient}; use axum::extract::FromRef; @@ -23,12 +22,12 @@ use tracing::{debug, error, info, warn}; #[derive(Clone)] struct AppState { - pub openid: OpenIDConfig, + pub openid_url: String, } -impl FromRef for OpenIDConfig { +impl FromRef for String { fn from_ref(input: &AppState) -> Self { - input.openid.clone() + input.openid_url.clone() } } @@ -52,77 +51,66 @@ async fn app() -> anyhow::Result<()> { info!("Starting automation_rs..."); - let config_filename = - std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.yml".into()); - let config = Config::parse_file(&config_filename)?; - // Setup the device handler let device_manager = DeviceManager::new(); - // Lua testing - { - let lua = mlua::Lua::new(); + let lua = mlua::Lua::new(); - lua.set_warning_function(|_lua, text, _cont| { - warn!("{text}"); - Ok(()) - }); + lua.set_warning_function(|_lua, text, _cont| { + warn!("{text}"); + Ok(()) + }); - let automation = lua.create_table()?; - let event_channel = device_manager.event_channel(); - let create_mqtt_client = lua.create_function(move |lua, config: mlua::Value| { - let config: MqttConfig = lua.from_value(config)?; + let automation = lua.create_table()?; + let event_channel = device_manager.event_channel(); + let new_mqtt_client = lua.create_function(move |lua, config: mlua::Value| { + let config: MqttConfig = lua.from_value(config)?; - // Create a mqtt client - // TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync - let (client, eventloop) = AsyncClient::new(config.into(), 100); - mqtt::start(eventloop, &event_channel); + // Create a mqtt client + // TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync + let (client, eventloop) = AsyncClient::new(config.into(), 100); + mqtt::start(eventloop, &event_channel); - Ok(WrappedAsyncClient(client)) - })?; + Ok(WrappedAsyncClient(client)) + })?; - automation.set("create_mqtt_client", create_mqtt_client)?; - automation.set("device_manager", device_manager.clone())?; - automation.set("event_channel", device_manager.event_channel())?; + automation.set("new_mqtt_client", new_mqtt_client)?; + automation.set("device_manager", device_manager.clone())?; - let util = lua.create_table()?; - let get_env = lua.create_function(|_lua, name: String| { - std::env::var(name).map_err(mlua::ExternalError::into_lua_err) - })?; - util.set("get_env", get_env)?; - automation.set("util", util)?; + let util = lua.create_table()?; + let get_env = lua.create_function(|_lua, name: String| { + std::env::var(name).map_err(mlua::ExternalError::into_lua_err) + })?; + util.set("get_env", get_env)?; + automation.set("util", util)?; - lua.globals().set("automation", automation)?; + lua.globals().set("automation", automation)?; - // Register all the device types - Ntfy::register_with_lua(&lua)?; - Presence::register_with_lua(&lua)?; - AirFilter::register_with_lua(&lua)?; - AudioSetup::register_with_lua(&lua)?; - ContactSensor::register_with_lua(&lua)?; - DebugBridge::register_with_lua(&lua)?; - HueBridge::register_with_lua(&lua)?; - HueGroup::register_with_lua(&lua)?; - IkeaOutlet::register_with_lua(&lua)?; - KasaOutlet::register_with_lua(&lua)?; - LightSensor::register_with_lua(&lua)?; - WakeOnLAN::register_with_lua(&lua)?; - Washer::register_with_lua(&lua)?; + devices::register_with_lua(&lua)?; - // TODO: Make this not hardcoded - let filename = "config.lua"; - let file = fs::read_to_string(filename)?; - match lua.load(file).set_name(filename).exec_async().await { - Err(error) => { - println!("{error}"); - Err(error) - } - result => result, - }?; - } + // TODO: Make this not hardcoded + let config_filename = std::env::var("AUTOMATION_CONFIG").unwrap_or("./config.lua".into()); + let config_path = Path::new(&config_filename); + match lua.load(config_path).exec_async().await { + Err(error) => { + println!("{error}"); + Err(error) + } + result => result, + }?; - // Create google home fullfillment route - let fullfillment = Router::new().route( + let automation: mlua::Table = lua.globals().get("automation")?; + let fulfillment_config: Option = automation.get("fulfillment")?; + let fulfillment_config = if let Some(fulfillment_config) = fulfillment_config { + let fulfillment_config: FulfillmentConfig = lua.from_value(fulfillment_config)?; + debug!("automation.fulfillment = {fulfillment_config:?}"); + fulfillment_config + } else { + return Err(anyhow!("Fulfillment is not configured")); + }; + + // Create google home fulfillment route + let fulfillment = Router::new().route( "/google_home", post(async move |user: User, Json(payload): Json| { debug!(username = user.preferred_username, "{payload:#?}"); @@ -144,13 +132,13 @@ async fn app() -> anyhow::Result<()> { // Combine together all the routes let app = Router::new() - .nest("/fullfillment", fullfillment) + .nest("/fulfillment", fulfillment) .with_state(AppState { - openid: config.openid, + openid_url: fulfillment_config.openid_url.clone(), }); // Start the web server - let addr = config.fullfillment.into(); + let addr = fulfillment_config.into(); info!("Server started on http://{addr}"); axum::Server::try_bind(&addr)? .serve(app.into_make_service())