Moved last config items to lua + small cleanup

This commit is contained in:
Dreaded_X 2024-04-29 03:38:30 +02:00
parent 2a3b14267b
commit 2ff59872b2
Signed by: Dreaded_X
GPG Key ID: FA5F485356B0D2D4
11 changed files with 105 additions and 161 deletions

View File

@ -1,7 +1,7 @@
FROM gcr.io/distroless/cc-debian12:nonroot
ENV AUTOMATION_CONFIG=/app/config.yml
COPY ./config/config.yml /app/config.yml
ENV AUTOMATION_CONFIG=/app/config.lua
COPY ./config.lua /app/config.lua
COPY ./build/automation /app/automation

View File

@ -1,5 +1,9 @@
print("Hello from lua")
automation.fulfillment = {
openid_url = "https://login.huizinga.dev/api/oidc",
}
local debug, value = pcall(automation.util.get_env, "DEBUG")
if debug and value ~= "true" then
debug = false
@ -13,7 +17,7 @@ local function mqtt_automation(topic)
return "automation/" .. topic
end
local mqtt_client = automation.create_mqtt_client({
local mqtt_client = automation.new_mqtt_client({
host = debug and "olympus.lan.huizinga.dev" or "mosquitto",
port = 8883,
client_name = debug and "automation-debug" or "automation_rs",
@ -24,13 +28,13 @@ local mqtt_client = automation.create_mqtt_client({
automation.device_manager:add(Ntfy.new({
topic = automation.util.get_env("NTFY_TOPIC"),
event_channel = automation.event_channel,
event_channel = automation.device_manager:event_channel(),
}))
automation.device_manager:add(Presence.new({
topic = "automation_dev/presence/+/#",
client = mqtt_client,
event_channel = automation.event_channel,
event_channel = automation.device_manager:event_channel(),
}))
automation.device_manager:add(DebugBridge.new({
@ -58,7 +62,7 @@ automation.device_manager:add(LightSensor.new({
client = mqtt_client,
min = 22000,
max = 23500,
event_channel = automation.event_channel,
event_channel = automation.device_manager:event_channel(),
}))
automation.device_manager:add(WakeOnLAN.new({
@ -110,7 +114,7 @@ automation.device_manager:add(Washer.new({
topic = mqtt_z2m("batchroom/washer"),
client = mqtt_client,
threshold = 1,
event_channel = automation.event_channel,
event_channel = automation.device_manager:event_channel(),
}))
automation.device_manager:add(IkeaOutlet.new({

View File

@ -1,2 +0,0 @@
openid:
base_url: "https://login.huizinga.dev/api/oidc"

View File

@ -1,2 +0,0 @@
openid:
base_url: "https://login.huizinga.dev/api/oidc"

View File

@ -6,11 +6,6 @@ use serde::Deserialize;
use crate::error::{ApiError, ApiErrorJson};
#[derive(Debug, Clone, Deserialize)]
pub struct OpenIDConfig {
pub base_url: String,
}
#[derive(Debug, Deserialize)]
pub struct User {
pub preferred_username: String,
@ -19,18 +14,18 @@ pub struct User {
#[async_trait]
impl<S> FromRequestParts<S> for User
where
OpenIDConfig: FromRef<S>,
String: FromRef<S>,
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Get the state
let openid = OpenIDConfig::from_ref(state);
let openid_url = String::from_ref(state);
// Create a request to the auth server
// TODO: Do some discovery to find the correct url for this instead of assuming
let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid.base_url));
let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid_url));
// Add auth header to the request if it exists
if let Some(auth) = parts.headers.get(axum::http::header::AUTHORIZATION) {

View File

@ -1,21 +1,8 @@
use std::fs;
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use regex::{Captures, Regex};
use rumqttc::{MqttOptions, Transport};
use serde::Deserialize;
use tracing::debug;
use crate::auth::OpenIDConfig;
use crate::error::{ConfigParseError, MissingEnv};
#[derive(Debug, Deserialize)]
pub struct Config {
pub openid: OpenIDConfig,
#[serde(default)]
pub fullfillment: FullfillmentConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MqttConfig {
@ -43,33 +30,25 @@ impl From<MqttConfig> for MqttOptions {
}
#[derive(Debug, Deserialize)]
pub struct FullfillmentConfig {
#[serde(default = "default_fullfillment_ip")]
pub struct FulfillmentConfig {
pub openid_url: String,
#[serde(default = "default_fulfillment_ip")]
pub ip: Ipv4Addr,
#[serde(default = "default_fullfillment_port")]
#[serde(default = "default_fulfillment_port")]
pub port: u16,
}
impl From<FullfillmentConfig> for SocketAddr {
fn from(fullfillment: FullfillmentConfig) -> Self {
(fullfillment.ip, fullfillment.port).into()
impl From<FulfillmentConfig> for SocketAddr {
fn from(fulfillment: FulfillmentConfig) -> Self {
(fulfillment.ip, fulfillment.port).into()
}
}
impl Default for FullfillmentConfig {
fn default() -> Self {
Self {
ip: default_fullfillment_ip(),
port: default_fullfillment_port(),
}
}
}
fn default_fullfillment_ip() -> Ipv4Addr {
fn default_fulfillment_ip() -> Ipv4Addr {
[0, 0, 0, 0].into()
}
fn default_fullfillment_port() -> u16 {
fn default_fulfillment_port() -> u16 {
7878
}
@ -93,31 +72,3 @@ impl InfoConfig {
pub struct MqttDeviceConfig {
pub topic: String,
}
impl Config {
pub fn parse_file(filename: &str) -> Result<Self, ConfigParseError> {
debug!("Loading config: {filename}");
let file = fs::read_to_string(filename)?;
// Substitute in environment variables
let re = Regex::new(r"\$\{(.*)\}").expect("Regex should be valid");
let mut missing = MissingEnv::new();
let file = re.replace_all(&file, |caps: &Captures| {
let key = caps.get(1).expect("Capture group should exist").as_str();
debug!("Substituting '{key}' in config");
match std::env::var(key) {
Ok(value) => value,
Err(_) => {
missing.add_missing(key);
"".into()
}
}
});
missing.has_missing()?;
let config: Config = serde_yaml::from_str(&file)?;
Ok(config)
}
}

View File

@ -238,6 +238,8 @@ impl mlua::UserData for DeviceManager {
let schedule = lua.from_value(schedule)?;
this.add_schedule(schedule).await;
Ok(())
})
});
methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel()))
}
}

View File

@ -3,7 +3,7 @@ mod audio_setup;
mod contact_sensor;
mod debug_bridge;
mod hue_bridge;
mod hue_light;
mod hue_group;
mod ikea_outlet;
mod kasa_outlet;
mod light_sensor;
@ -24,7 +24,7 @@ pub use self::audio_setup::*;
pub use self::contact_sensor::*;
pub use self::debug_bridge::*;
pub use self::hue_bridge::*;
pub use self::hue_light::*;
pub use self::hue_group::*;
pub use self::ikea_outlet::*;
pub use self::kasa_outlet::*;
pub use self::light_sensor::*;
@ -45,6 +45,24 @@ pub trait LuaDeviceCreate {
Self: Sized;
}
pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> {
AirFilter::register_with_lua(lua)?;
AudioSetup::register_with_lua(lua)?;
ContactSensor::register_with_lua(lua)?;
DebugBridge::register_with_lua(lua)?;
HueBridge::register_with_lua(lua)?;
HueGroup::register_with_lua(lua)?;
IkeaOutlet::register_with_lua(lua)?;
KasaOutlet::register_with_lua(lua)?;
LightSensor::register_with_lua(lua)?;
Ntfy::register_with_lua(lua)?;
Presence::register_with_lua(lua)?;
WakeOnLAN::register_with_lua(lua)?;
Washer::register_with_lua(lua)?;
Ok(())
}
pub trait Device:
Debug
+ Sync

View File

@ -65,16 +65,6 @@ pub enum ParseError {
InvalidPayload(Bytes),
}
#[derive(Debug, Error)]
pub enum ConfigParseError {
#[error(transparent)]
MissingEnv(#[from] MissingEnv),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
YamlError(#[from] serde_yaml::Error),
}
// TODO: Would be nice to somehow get the line number of the expected wildcard topic
#[derive(Debug, Error)]
#[error("Topic '{topic}' is expected to be a wildcard topic")]

View File

@ -1,13 +1,12 @@
#![feature(async_closure)]
use std::{fs, process};
use std::path::Path;
use std::process;
use automation::auth::{OpenIDConfig, User};
use automation::config::{Config, MqttConfig};
use anyhow::anyhow;
use automation::auth::User;
use automation::config::{FulfillmentConfig, MqttConfig};
use automation::device_manager::DeviceManager;
use automation::devices::{
AirFilter, AudioSetup, ContactSensor, DebugBridge, HueBridge, HueGroup, IkeaOutlet, KasaOutlet,
LightSensor, Ntfy, Presence, WakeOnLAN, Washer,
};
use automation::devices;
use automation::error::ApiError;
use automation::mqtt::{self, WrappedAsyncClient};
use axum::extract::FromRef;
@ -23,12 +22,12 @@ use tracing::{debug, error, info, warn};
#[derive(Clone)]
struct AppState {
pub openid: OpenIDConfig,
pub openid_url: String,
}
impl FromRef<AppState> for OpenIDConfig {
impl FromRef<AppState> for String {
fn from_ref(input: &AppState) -> Self {
input.openid.clone()
input.openid_url.clone()
}
}
@ -53,77 +52,66 @@ async fn app() -> anyhow::Result<()> {
info!("Starting automation_rs...");
let config_filename =
std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.yml".into());
let config = Config::parse_file(&config_filename)?;
// Setup the device handler
let device_manager = DeviceManager::new();
// Lua testing
{
let lua = mlua::Lua::new();
let lua = mlua::Lua::new();
lua.set_warning_function(|_lua, text, _cont| {
warn!("{text}");
Ok(())
});
lua.set_warning_function(|_lua, text, _cont| {
warn!("{text}");
Ok(())
});
let automation = lua.create_table()?;
let event_channel = device_manager.event_channel();
let create_mqtt_client = lua.create_function(move |lua, config: mlua::Value| {
let config: MqttConfig = lua.from_value(config)?;
let automation = lua.create_table()?;
let event_channel = device_manager.event_channel();
let new_mqtt_client = lua.create_function(move |lua, config: mlua::Value| {
let config: MqttConfig = lua.from_value(config)?;
// Create a mqtt client
// TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync
let (client, eventloop) = AsyncClient::new(config.into(), 100);
mqtt::start(eventloop, &event_channel);
// Create a mqtt client
// TODO: When starting up, the devices are not yet created, this could lead to a device being out of sync
let (client, eventloop) = AsyncClient::new(config.into(), 100);
mqtt::start(eventloop, &event_channel);
Ok(WrappedAsyncClient(client))
})?;
Ok(WrappedAsyncClient(client))
})?;
automation.set("create_mqtt_client", create_mqtt_client)?;
automation.set("device_manager", device_manager.clone())?;
automation.set("event_channel", device_manager.event_channel())?;
automation.set("new_mqtt_client", new_mqtt_client)?;
automation.set("device_manager", device_manager.clone())?;
let util = lua.create_table()?;
let get_env = lua.create_function(|_lua, name: String| {
std::env::var(name).map_err(mlua::ExternalError::into_lua_err)
})?;
util.set("get_env", get_env)?;
automation.set("util", util)?;
let util = lua.create_table()?;
let get_env = lua.create_function(|_lua, name: String| {
std::env::var(name).map_err(mlua::ExternalError::into_lua_err)
})?;
util.set("get_env", get_env)?;
automation.set("util", util)?;
lua.globals().set("automation", automation)?;
lua.globals().set("automation", automation)?;
// Register all the device types
Ntfy::register_with_lua(&lua)?;
Presence::register_with_lua(&lua)?;
AirFilter::register_with_lua(&lua)?;
AudioSetup::register_with_lua(&lua)?;
ContactSensor::register_with_lua(&lua)?;
DebugBridge::register_with_lua(&lua)?;
HueBridge::register_with_lua(&lua)?;
HueGroup::register_with_lua(&lua)?;
IkeaOutlet::register_with_lua(&lua)?;
KasaOutlet::register_with_lua(&lua)?;
LightSensor::register_with_lua(&lua)?;
WakeOnLAN::register_with_lua(&lua)?;
Washer::register_with_lua(&lua)?;
devices::register_with_lua(&lua)?;
// TODO: Make this not hardcoded
let filename = "config.lua";
let file = fs::read_to_string(filename)?;
match lua.load(file).set_name(filename).exec_async().await {
Err(error) => {
println!("{error}");
Err(error)
}
result => result,
}?;
}
// TODO: Make this not hardcoded
let config_filename = std::env::var("AUTOMATION_CONFIG").unwrap_or("./config.lua".into());
let config_path = Path::new(&config_filename);
match lua.load(config_path).exec_async().await {
Err(error) => {
println!("{error}");
Err(error)
}
result => result,
}?;
// Create google home fullfillment route
let fullfillment = Router::new().route(
let automation: mlua::Table = lua.globals().get("automation")?;
let fulfillment_config: Option<mlua::Value> = automation.get("fulfillment")?;
let fulfillment_config = if let Some(fulfillment_config) = fulfillment_config {
let fulfillment_config: FulfillmentConfig = lua.from_value(fulfillment_config)?;
debug!("automation.fulfillment = {fulfillment_config:?}");
fulfillment_config
} else {
return Err(anyhow!("Fulfillment is not configured"));
};
// Create google home fulfillment route
let fulfillment = Router::new().route(
"/google_home",
post(async move |user: User, Json(payload): Json<Request>| {
debug!(username = user.preferred_username, "{payload:#?}");
@ -145,13 +133,13 @@ async fn app() -> anyhow::Result<()> {
// Combine together all the routes
let app = Router::new()
.nest("/fullfillment", fullfillment)
.nest("/fulfillment", fulfillment)
.with_state(AppState {
openid: config.openid,
openid_url: fulfillment_config.openid_url.clone(),
});
// Start the web server
let addr = config.fullfillment.into();
let addr = fulfillment_config.into();
info!("Server started on http://{addr}");
axum::Server::try_bind(&addr)?
.serve(app.into_make_service())