Compare commits

...

3 Commits

Author SHA1 Message Date
67ed13463a
Started work on reimplementing schedules
All checks were successful
Build and deploy automation_rs / Build Docker image (push) Successful in 40s
Build and deploy automation_rs / Deploy Docker container (push) Has been skipped
Build and deploy automation_rs / Build automation_rs (push) Successful in 5m22s
2024-04-29 04:55:39 +02:00
b16f2ae420
Fixed spelling mistakes 2024-04-29 04:55:39 +02:00
96f260492b
Moved last config items to lua + small cleanup 2024-04-29 04:55:30 +02:00
20 changed files with 153 additions and 232 deletions

5
Cargo.lock generated
View File

@ -88,6 +88,7 @@ dependencies = [
"impl_cast", "impl_cast",
"indexmap 2.0.0", "indexmap 2.0.0",
"mlua", "mlua",
"once_cell",
"paste", "paste",
"pollster", "pollster",
"regex", "regex",
@ -1108,9 +1109,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.18.0" version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]] [[package]]
name = "openssl-probe" name = "openssl-probe"

View File

@ -42,13 +42,8 @@ enum_dispatch = "0.3.12"
indexmap = { version = "2.0.0", features = ["serde"] } indexmap = { version = "2.0.0", features = ["serde"] }
serde_yaml = "0.9.27" serde_yaml = "0.9.27"
tokio-cron-scheduler = "0.9.4" tokio-cron-scheduler = "0.9.4"
mlua = { version = "0.9.7", features = [ mlua = { version = "0.9.7", features = ["lua54", "vendored", "macros", "serialize", "async", "send"] }
"lua54", once_cell = "1.19.0"
"vendored",
"macros",
"serialize",
"async",
] }
[patch.crates-io] [patch.crates-io]
wakey = { git = "https://git.huizinga.dev/Dreaded_X/wakey" } wakey = { git = "https://git.huizinga.dev/Dreaded_X/wakey" }

View File

@ -1,7 +1,7 @@
FROM gcr.io/distroless/cc-debian12:nonroot FROM gcr.io/distroless/cc-debian12:nonroot
ENV AUTOMATION_CONFIG=/app/config.yml ENV AUTOMATION_CONFIG=/app/config.lua
COPY ./config/config.yml /app/config.yml COPY ./config.lua /app/config.lua
COPY ./build/automation /app/automation COPY ./build/automation /app/automation

View File

@ -1,5 +1,9 @@
print("Hello from lua") print("Hello from lua")
automation.fulfillment = {
openid_url = "https://login.huizinga.dev/api/oidc",
}
local debug, value = pcall(automation.util.get_env, "DEBUG") local debug, value = pcall(automation.util.get_env, "DEBUG")
if debug and value ~= "true" then if debug and value ~= "true" then
debug = false debug = false
@ -13,7 +17,7 @@ local function mqtt_automation(topic)
return "automation/" .. topic return "automation/" .. topic
end end
local mqtt_client = automation.create_mqtt_client({ local mqtt_client = automation.new_mqtt_client({
host = debug and "olympus.lan.huizinga.dev" or "mosquitto", host = debug and "olympus.lan.huizinga.dev" or "mosquitto",
port = 8883, port = 8883,
client_name = debug and "automation-debug" or "automation_rs", client_name = debug and "automation-debug" or "automation_rs",
@ -24,13 +28,13 @@ local mqtt_client = automation.create_mqtt_client({
automation.device_manager:add(Ntfy.new({ automation.device_manager:add(Ntfy.new({
topic = automation.util.get_env("NTFY_TOPIC"), topic = automation.util.get_env("NTFY_TOPIC"),
event_channel = automation.event_channel, event_channel = automation.device_manager:event_channel(),
})) }))
automation.device_manager:add(Presence.new({ automation.device_manager:add(Presence.new({
topic = "automation_dev/presence/+/#", topic = "automation_dev/presence/+/#",
client = mqtt_client, client = mqtt_client,
event_channel = automation.event_channel, event_channel = automation.device_manager:event_channel(),
})) }))
automation.device_manager:add(DebugBridge.new({ automation.device_manager:add(DebugBridge.new({
@ -58,7 +62,7 @@ automation.device_manager:add(LightSensor.new({
client = mqtt_client, client = mqtt_client,
min = 22000, min = 22000,
max = 23500, max = 23500,
event_channel = automation.event_channel, event_channel = automation.device_manager:event_channel(),
})) }))
automation.device_manager:add(WakeOnLAN.new({ automation.device_manager:add(WakeOnLAN.new({
@ -110,7 +114,7 @@ automation.device_manager:add(Washer.new({
topic = mqtt_z2m("batchroom/washer"), topic = mqtt_z2m("batchroom/washer"),
client = mqtt_client, client = mqtt_client,
threshold = 1, threshold = 1,
event_channel = automation.event_channel, event_channel = automation.device_manager:event_channel(),
})) }))
automation.device_manager:add(IkeaOutlet.new({ automation.device_manager:add(IkeaOutlet.new({
@ -156,23 +160,14 @@ automation.device_manager:add(ContactSensor.new({
}, },
})) }))
local bedroom_air_filter = automation.device_manager:add(AirFilter.new({ local bedroom_air_filter = AirFilter.new({
name = "Air Filter", name = "Air Filter",
room = "Bedroom", room = "Bedroom",
topic = "pico/filter/bedroom", topic = "pico/filter/bedroom",
client = mqtt_client, client = mqtt_client,
}))
-- TODO: Use the wrapped device bedroom_air_filter instead of the string
automation.device_manager:add_schedule({
["0 0 19 * * *"] = {
on = {
"bedroom_air_filter",
},
},
["0 0 20 * * *"] = {
off = {
"bedroom_air_filter",
},
},
}) })
automation.device_manager:add(bedroom_air_filter)
automation.device_manager:schedule("0/1 * * * * *", function()
print("Device: " .. bedroom_air_filter:get_id())
end)

View File

@ -1,2 +0,0 @@
openid:
base_url: "https://login.huizinga.dev/api/oidc"

View File

@ -1,2 +0,0 @@
openid:
base_url: "https://login.huizinga.dev/api/oidc"

View File

@ -49,7 +49,7 @@ pub trait GoogleHomeDevice: AsGoogleHomeDevice + Sync + Send + 'static {
fn get_id(&self) -> String; fn get_id(&self) -> String;
fn is_online(&self) -> bool; fn is_online(&self) -> bool;
// Default values that can optionally be overriden // Default values that can optionally be overridden
fn will_report_state(&self) -> bool { fn will_report_state(&self) -> bool {
false false
} }

View File

@ -17,7 +17,7 @@ pub struct GoogleHome {
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum FullfillmentError { pub enum FulfillmentError {
#[error("Expected at least one ResponsePayload")] #[error("Expected at least one ResponsePayload")]
ExpectedOnePayload, ExpectedOnePayload,
} }
@ -33,7 +33,7 @@ impl GoogleHome {
&self, &self,
request: Request, request: Request,
devices: &HashMap<String, Arc<RwLock<Box<T>>>>, devices: &HashMap<String, Arc<RwLock<Box<T>>>>,
) -> Result<Response, FullfillmentError> { ) -> Result<Response, FulfillmentError> {
// TODO: What do we do if we actually get more then one thing in the input array, right now // TODO: What do we do if we actually get more then one thing in the input array, right now
// we only respond to the first thing // we only respond to the first thing
let intent = request.inputs.into_iter().next(); let intent = request.inputs.into_iter().next();
@ -54,7 +54,7 @@ impl GoogleHome {
payload payload
.await .await
.ok_or(FullfillmentError::ExpectedOnePayload) .ok_or(FulfillmentError::ExpectedOnePayload)
.map(|payload| Response::new(&request.request_id, payload)) .map(|payload| Response::new(&request.request_id, payload))
} }

View File

@ -2,7 +2,7 @@
#![feature(specialization)] #![feature(specialization)]
#![feature(let_chains)] #![feature(let_chains)]
pub mod device; pub mod device;
mod fullfillment; mod fulfillment;
mod request; mod request;
mod response; mod response;
@ -13,6 +13,6 @@ pub mod traits;
pub mod types; pub mod types;
pub use device::GoogleHomeDevice; pub use device::GoogleHomeDevice;
pub use fullfillment::{FullfillmentError, GoogleHome}; pub use fulfillment::{FulfillmentError, GoogleHome};
pub use request::Request; pub use request::Request;
pub use response::Response; pub use response::Response;

View File

@ -6,11 +6,6 @@ use serde::Deserialize;
use crate::error::{ApiError, ApiErrorJson}; use crate::error::{ApiError, ApiErrorJson};
#[derive(Debug, Clone, Deserialize)]
pub struct OpenIDConfig {
pub base_url: String,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct User { pub struct User {
pub preferred_username: String, pub preferred_username: String,
@ -19,18 +14,18 @@ pub struct User {
#[async_trait] #[async_trait]
impl<S> FromRequestParts<S> for User impl<S> FromRequestParts<S> for User
where where
OpenIDConfig: FromRef<S>, String: FromRef<S>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ApiError; type Rejection = ApiError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Get the state // Get the state
let openid = OpenIDConfig::from_ref(state); let openid_url = String::from_ref(state);
// Create a request to the auth server // Create a request to the auth server
// TODO: Do some discovery to find the correct url for this instead of assuming // TODO: Do some discovery to find the correct url for this instead of assuming
let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid.base_url)); let mut req = reqwest::Client::new().get(format!("{}/userinfo", openid_url));
// Add auth header to the request if it exists // Add auth header to the request if it exists
if let Some(auth) = parts.headers.get(axum::http::header::AUTHORIZATION) { if let Some(auth) = parts.headers.get(axum::http::header::AUTHORIZATION) {

View File

@ -1,21 +1,8 @@
use std::fs;
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration; use std::time::Duration;
use regex::{Captures, Regex};
use rumqttc::{MqttOptions, Transport}; use rumqttc::{MqttOptions, Transport};
use serde::Deserialize; use serde::Deserialize;
use tracing::debug;
use crate::auth::OpenIDConfig;
use crate::error::{ConfigParseError, MissingEnv};
#[derive(Debug, Deserialize)]
pub struct Config {
pub openid: OpenIDConfig,
#[serde(default)]
pub fullfillment: FullfillmentConfig,
}
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct MqttConfig { pub struct MqttConfig {
@ -43,33 +30,25 @@ impl From<MqttConfig> for MqttOptions {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct FullfillmentConfig { pub struct FulfillmentConfig {
#[serde(default = "default_fullfillment_ip")] pub openid_url: String,
#[serde(default = "default_fulfillment_ip")]
pub ip: Ipv4Addr, pub ip: Ipv4Addr,
#[serde(default = "default_fullfillment_port")] #[serde(default = "default_fulfillment_port")]
pub port: u16, pub port: u16,
} }
impl From<FullfillmentConfig> for SocketAddr { impl From<FulfillmentConfig> for SocketAddr {
fn from(fullfillment: FullfillmentConfig) -> Self { fn from(fulfillment: FulfillmentConfig) -> Self {
(fullfillment.ip, fullfillment.port).into() (fulfillment.ip, fulfillment.port).into()
} }
} }
impl Default for FullfillmentConfig { fn default_fulfillment_ip() -> Ipv4Addr {
fn default() -> Self {
Self {
ip: default_fullfillment_ip(),
port: default_fullfillment_port(),
}
}
}
fn default_fullfillment_ip() -> Ipv4Addr {
[0, 0, 0, 0].into() [0, 0, 0, 0].into()
} }
fn default_fullfillment_port() -> u16 { fn default_fulfillment_port() -> u16 {
7878 7878
} }
@ -93,31 +72,3 @@ impl InfoConfig {
pub struct MqttDeviceConfig { pub struct MqttDeviceConfig {
pub topic: String, pub topic: String,
} }
impl Config {
pub fn parse_file(filename: &str) -> Result<Self, ConfigParseError> {
debug!("Loading config: {filename}");
let file = fs::read_to_string(filename)?;
// Substitute in environment variables
let re = Regex::new(r"\$\{(.*)\}").expect("Regex should be valid");
let mut missing = MissingEnv::new();
let file = re.replace_all(&file, |caps: &Captures| {
let key = caps.get(1).expect("Capture group should exist").as_str();
debug!("Substituting '{key}' in config");
match std::env::var(key) {
Ok(value) => value,
Err(_) => {
missing.add_missing(key);
"".into()
}
}
});
missing.has_missing()?;
let config: Config = serde_yaml::from_str(&file)?;
Ok(config)
}
}

View File

@ -3,15 +3,14 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc; use std::sync::Arc;
use futures::future::join_all; use futures::future::join_all;
use google_home::traits::OnOff; use mlua::FromLua;
use mlua::{FromLua, LuaSerdeExt};
use tokio::sync::{RwLock, RwLockReadGuard}; use tokio::sync::{RwLock, RwLockReadGuard};
use tokio_cron_scheduler::{Job, JobScheduler}; use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{debug, instrument, trace}; use tracing::{debug, instrument, trace};
use crate::devices::{As, Device}; use crate::devices::{As, Device};
use crate::event::{Event, EventChannel, OnDarkness, OnMqtt, OnNotification, OnPresence}; use crate::event::{Event, EventChannel, OnDarkness, OnMqtt, OnNotification, OnPresence};
use crate::schedule::{Action, Schedule}; use crate::LUA;
#[derive(Debug, FromLua, Clone)] #[derive(Debug, FromLua, Clone)]
pub struct WrappedDevice(Arc<RwLock<Box<dyn Device>>>); pub struct WrappedDevice(Arc<RwLock<Box<dyn Device>>>);
@ -35,23 +34,31 @@ impl DerefMut for WrappedDevice {
&mut self.0 &mut self.0
} }
} }
impl mlua::UserData for WrappedDevice {} impl mlua::UserData for WrappedDevice {
fn add_methods<'lua, M: mlua::prelude::LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("get_id", |_lua, this, _: ()| async {
Ok(crate::devices::Device::get_id(this.0.read().await.as_ref()))
});
}
}
pub type DeviceMap = HashMap<String, Arc<RwLock<Box<dyn Device>>>>; pub type DeviceMap = HashMap<String, Arc<RwLock<Box<dyn Device>>>>;
#[derive(Debug, Clone)] #[derive(Clone)]
pub struct DeviceManager { pub struct DeviceManager {
devices: Arc<RwLock<DeviceMap>>, devices: Arc<RwLock<DeviceMap>>,
event_channel: EventChannel, event_channel: EventChannel,
scheduler: JobScheduler,
} }
impl DeviceManager { impl DeviceManager {
pub fn new() -> Self { pub async fn new() -> Self {
let (event_channel, mut event_rx) = EventChannel::new(); let (event_channel, mut event_rx) = EventChannel::new();
let device_manager = Self { let device_manager = Self {
devices: Arc::new(RwLock::new(HashMap::new())), devices: Arc::new(RwLock::new(HashMap::new())),
event_channel, event_channel,
scheduler: JobScheduler::new().await.unwrap(),
}; };
tokio::spawn({ tokio::spawn({
@ -67,58 +74,11 @@ impl DeviceManager {
} }
}); });
device_manager.scheduler.start().await.unwrap();
device_manager device_manager
} }
// TODO: This function is currently extremely cursed...
pub async fn add_schedule(&self, schedule: Schedule) {
let sched = JobScheduler::new().await.unwrap();
for (when, actions) in schedule {
let manager = self.clone();
sched
.add(
Job::new_async(when.as_str(), move |_uuid, _l| {
let actions = actions.clone();
let manager = manager.clone();
Box::pin(async move {
for (action, targets) in actions {
for target in targets {
let device = manager.get(&target).await.unwrap();
match action {
Action::On => {
As::<dyn OnOff>::cast_mut(
device.write().await.as_mut(),
)
.unwrap()
.set_on(true)
.await
.unwrap();
}
Action::Off => {
As::<dyn OnOff>::cast_mut(
device.write().await.as_mut(),
)
.unwrap()
.set_on(false)
.await
.unwrap();
}
}
}
}
})
})
.unwrap(),
)
.await
.unwrap();
}
sched.start().await.unwrap();
}
pub async fn add(&self, device: &WrappedDevice) { pub async fn add(&self, device: &WrappedDevice) {
let id = device.read().await.get_id(); let id = device.read().await.get_id();
@ -220,12 +180,6 @@ impl DeviceManager {
} }
} }
impl Default for DeviceManager {
fn default() -> Self {
Self::new()
}
}
impl mlua::UserData for DeviceManager { impl mlua::UserData for DeviceManager {
fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("add", |_lua, this, device: WrappedDevice| async move { methods.add_async_method("add", |_lua, this, device: WrappedDevice| async move {
@ -234,10 +188,41 @@ impl mlua::UserData for DeviceManager {
Ok(()) Ok(())
}); });
methods.add_async_method("add_schedule", |lua, this, schedule| async { methods.add_async_method(
let schedule = lua.from_value(schedule)?; "schedule",
this.add_schedule(schedule).await; |lua, this, (schedule, f): (String, mlua::Function)| async move {
Ok(()) debug!("schedule = {schedule}");
}) let uuid = this
.scheduler
.add(
Job::new_async(schedule.as_str(), |uuid, _lock| {
Box::pin(async move {
let lua = LUA.lock().await;
let f: mlua::Function =
lua.named_registry_value(uuid.to_string().as_str()).unwrap();
f.call::<_, ()>(()).unwrap();
})
})
.unwrap(),
)
.await
.unwrap();
// Store the function in the registry
lua.set_named_registry_value(uuid.to_string().as_str(), f)
.unwrap();
Ok(())
},
);
// methods.add_async_method("add_schedule", |lua, this, schedule| async {
// let schedule = lua.from_value(schedule)?;
// this.add_schedule(schedule).await;
// Ok(())
// });
methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel()))
} }
} }

View File

@ -233,7 +233,7 @@ impl crate::traits::Timeout for IkeaOutlet {
tokio::time::sleep(timeout).await; tokio::time::sleep(timeout).await;
debug!(id, "Turning outlet off!"); debug!(id, "Turning outlet off!");
// TODO: Idealy we would call self.set_on(false), however since we want to do // TODO: Idealy we would call self.set_on(false), however since we want to do
// it after a timeout we have to put it in a seperate task. // it after a timeout we have to put it in a separate task.
// I don't think we can really get around calling outside function // I don't think we can really get around calling outside function
set_on(client, &topic, false).await; set_on(client, &topic, false).await;
})); }));

View File

@ -3,7 +3,7 @@ mod audio_setup;
mod contact_sensor; mod contact_sensor;
mod debug_bridge; mod debug_bridge;
mod hue_bridge; mod hue_bridge;
mod hue_light; mod hue_group;
mod ikea_outlet; mod ikea_outlet;
mod kasa_outlet; mod kasa_outlet;
mod light_sensor; mod light_sensor;
@ -21,7 +21,7 @@ pub use self::audio_setup::*;
pub use self::contact_sensor::*; pub use self::contact_sensor::*;
pub use self::debug_bridge::*; pub use self::debug_bridge::*;
pub use self::hue_bridge::*; pub use self::hue_bridge::*;
pub use self::hue_light::*; pub use self::hue_group::*;
pub use self::ikea_outlet::*; pub use self::ikea_outlet::*;
pub use self::kasa_outlet::*; pub use self::kasa_outlet::*;
pub use self::light_sensor::*; pub use self::light_sensor::*;
@ -42,6 +42,24 @@ pub trait LuaDeviceCreate {
Self: Sized; Self: Sized;
} }
pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> {
AirFilter::register_with_lua(lua)?;
AudioSetup::register_with_lua(lua)?;
ContactSensor::register_with_lua(lua)?;
DebugBridge::register_with_lua(lua)?;
HueBridge::register_with_lua(lua)?;
HueGroup::register_with_lua(lua)?;
IkeaOutlet::register_with_lua(lua)?;
KasaOutlet::register_with_lua(lua)?;
LightSensor::register_with_lua(lua)?;
Ntfy::register_with_lua(lua)?;
Presence::register_with_lua(lua)?;
WakeOnLAN::register_with_lua(lua)?;
Washer::register_with_lua(lua)?;
Ok(())
}
#[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + OnNotification + OnOff + Timeout)] #[impl_cast::device(As: OnMqtt + OnPresence + OnDarkness + OnNotification + OnOff + Timeout)]
pub trait Device: AsGoogleHomeDevice + std::fmt::Debug + Sync + Send { pub trait Device: AsGoogleHomeDevice + std::fmt::Debug + Sync + Send {
fn get_id(&self) -> String; fn get_id(&self) -> String;

View File

@ -154,7 +154,7 @@ impl Ntfy {
.await; .await;
if let Err(err) = res { if let Err(err) = res {
error!("Something went wrong while sending the notifcation: {err}"); error!("Something went wrong while sending the notification: {err}");
} else if let Ok(res) = res { } else if let Ok(res) = res {
let status = res.status(); let status = res.status();
if !status.is_success() { if !status.is_success() {

View File

@ -54,7 +54,7 @@ impl Device for Washer {
} }
} }
// The washer needs to have a power draw above the theshold multiple times before the washer is // The washer needs to have a power draw above the threshold multiple times before the washer is
// actually marked as running // actually marked as running
// This helps prevent false positives // This helps prevent false positives
const HYSTERESIS: isize = 10; const HYSTERESIS: isize = 10;

View File

@ -65,16 +65,6 @@ pub enum ParseError {
InvalidPayload(Bytes), InvalidPayload(Bytes),
} }
#[derive(Debug, Error)]
pub enum ConfigParseError {
#[error(transparent)]
MissingEnv(#[from] MissingEnv),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
YamlError(#[from] serde_yaml::Error),
}
// TODO: Would be nice to somehow get the line number of the expected wildcard topic // TODO: Would be nice to somehow get the line number of the expected wildcard topic
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("Topic '{topic}' is expected to be a wildcard topic")] #[error("Topic '{topic}' is expected to be a wildcard topic")]

View File

@ -1,6 +1,9 @@
#![allow(incomplete_features)] #![allow(incomplete_features)]
#![feature(specialization)] #![feature(specialization)]
#![feature(let_chains)] #![feature(let_chains)]
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
pub mod auth; pub mod auth;
pub mod config; pub mod config;
pub mod device_manager; pub mod device_manager;
@ -11,3 +14,5 @@ pub mod messages;
pub mod mqtt; pub mod mqtt;
pub mod schedule; pub mod schedule;
pub mod traits; pub mod traits;
pub static LUA: Lazy<Mutex<mlua::Lua>> = Lazy::new(|| Mutex::new(mlua::Lua::new()));

View File

@ -1,15 +1,14 @@
#![feature(async_closure)] #![feature(async_closure)]
use std::{fs, process}; use std::path::Path;
use std::process;
use automation::auth::{OpenIDConfig, User}; use anyhow::anyhow;
use automation::config::{Config, MqttConfig}; use automation::auth::User;
use automation::config::{FulfillmentConfig, MqttConfig};
use automation::device_manager::DeviceManager; use automation::device_manager::DeviceManager;
use automation::devices::{
AirFilter, AudioSetup, ContactSensor, DebugBridge, HueBridge, HueGroup, IkeaOutlet, KasaOutlet,
LightSensor, Ntfy, Presence, WakeOnLAN, Washer,
};
use automation::error::ApiError; use automation::error::ApiError;
use automation::mqtt::{self, WrappedAsyncClient}; use automation::mqtt::{self, WrappedAsyncClient};
use automation::{devices, LUA};
use axum::extract::FromRef; use axum::extract::FromRef;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::IntoResponse; use axum::response::IntoResponse;
@ -23,12 +22,12 @@ use tracing::{debug, error, info, warn};
#[derive(Clone)] #[derive(Clone)]
struct AppState { struct AppState {
pub openid: OpenIDConfig, pub openid_url: String,
} }
impl FromRef<AppState> for OpenIDConfig { impl FromRef<AppState> for String {
fn from_ref(input: &AppState) -> Self { fn from_ref(input: &AppState) -> Self {
input.openid.clone() input.openid_url.clone()
} }
} }
@ -52,16 +51,11 @@ async fn app() -> anyhow::Result<()> {
info!("Starting automation_rs..."); info!("Starting automation_rs...");
let config_filename =
std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.yml".into());
let config = Config::parse_file(&config_filename)?;
// Setup the device handler // Setup the device handler
let device_manager = DeviceManager::new(); let device_manager = DeviceManager::new().await;
// Lua testing let fulfillment_config = {
{ let lua = LUA.lock().await;
let lua = mlua::Lua::new();
lua.set_warning_function(|_lua, text, _cont| { lua.set_warning_function(|_lua, text, _cont| {
warn!("{text}"); warn!("{text}");
@ -70,7 +64,7 @@ async fn app() -> anyhow::Result<()> {
let automation = lua.create_table()?; let automation = lua.create_table()?;
let event_channel = device_manager.event_channel(); let event_channel = device_manager.event_channel();
let create_mqtt_client = lua.create_function(move |lua, config: mlua::Value| { let new_mqtt_client = lua.create_function(move |lua, config: mlua::Value| {
let config: MqttConfig = lua.from_value(config)?; let config: MqttConfig = lua.from_value(config)?;
// Create a mqtt client // Create a mqtt client
@ -81,9 +75,8 @@ async fn app() -> anyhow::Result<()> {
Ok(WrappedAsyncClient(client)) Ok(WrappedAsyncClient(client))
})?; })?;
automation.set("create_mqtt_client", create_mqtt_client)?; automation.set("new_mqtt_client", new_mqtt_client)?;
automation.set("device_manager", device_manager.clone())?; automation.set("device_manager", device_manager.clone())?;
automation.set("event_channel", device_manager.event_channel())?;
let util = lua.create_table()?; let util = lua.create_table()?;
let get_env = lua.create_function(|_lua, name: String| { let get_env = lua.create_function(|_lua, name: String| {
@ -94,35 +87,32 @@ async fn app() -> anyhow::Result<()> {
lua.globals().set("automation", automation)?; lua.globals().set("automation", automation)?;
// Register all the device types devices::register_with_lua(&lua)?;
Ntfy::register_with_lua(&lua)?;
Presence::register_with_lua(&lua)?;
AirFilter::register_with_lua(&lua)?;
AudioSetup::register_with_lua(&lua)?;
ContactSensor::register_with_lua(&lua)?;
DebugBridge::register_with_lua(&lua)?;
HueBridge::register_with_lua(&lua)?;
HueGroup::register_with_lua(&lua)?;
IkeaOutlet::register_with_lua(&lua)?;
KasaOutlet::register_with_lua(&lua)?;
LightSensor::register_with_lua(&lua)?;
WakeOnLAN::register_with_lua(&lua)?;
Washer::register_with_lua(&lua)?;
// TODO: Make this not hardcoded // TODO: Make this not hardcoded
let filename = "config.lua"; let config_filename = std::env::var("AUTOMATION_CONFIG").unwrap_or("./config.lua".into());
let file = fs::read_to_string(filename)?; let config_path = Path::new(&config_filename);
match lua.load(file).set_name(filename).exec_async().await { match lua.load(config_path).exec_async().await {
Err(error) => { Err(error) => {
println!("{error}"); println!("{error}");
Err(error) Err(error)
} }
result => result, result => result,
}?; }?;
}
// Create google home fullfillment route let automation: mlua::Table = lua.globals().get("automation")?;
let fullfillment = Router::new().route( let fulfillment_config: Option<mlua::Value> = automation.get("fulfillment")?;
if let Some(fulfillment_config) = fulfillment_config {
let fulfillment_config: FulfillmentConfig = lua.from_value(fulfillment_config)?;
debug!("automation.fulfillment = {fulfillment_config:?}");
fulfillment_config
} else {
return Err(anyhow!("Fulfillment is not configured"));
}
};
// Create google home fulfillment route
let fulfillment = Router::new().route(
"/google_home", "/google_home",
post(async move |user: User, Json(payload): Json<Request>| { post(async move |user: User, Json(payload): Json<Request>| {
debug!(username = user.preferred_username, "{payload:#?}"); debug!(username = user.preferred_username, "{payload:#?}");
@ -144,13 +134,13 @@ async fn app() -> anyhow::Result<()> {
// Combine together all the routes // Combine together all the routes
let app = Router::new() let app = Router::new()
.nest("/fullfillment", fullfillment) .nest("/fulfillment", fulfillment)
.with_state(AppState { .with_state(AppState {
openid: config.openid, openid_url: fulfillment_config.openid_url.clone(),
}); });
// Start the web server // Start the web server
let addr = config.fullfillment.into(); let addr = fulfillment_config.into();
info!("Server started on http://{addr}"); info!("Server started on http://{addr}");
axum::Server::try_bind(&addr)? axum::Server::try_bind(&addr)?
.serve(app.into_make_service()) .serve(app.into_make_service())