Reorganized project

This commit is contained in:
2024-12-08 00:15:03 +01:00
parent 42f391cde6
commit 8877b24e84
36 changed files with 505 additions and 797 deletions

28
automation_lib/Cargo.toml Normal file
View File

@@ -0,0 +1,28 @@
[package]
name = "automation_lib"
version = "0.1.0"
edition = "2021"
[dependencies]
automation_macro = { workspace = true }
automation_cast = { workspace = true }
google_home = { workspace = true }
rumqttc = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
reqwest = { workspace = true }
serde_repr = { workspace = true }
tracing = { workspace = true }
bytes = { workspace = true }
pollster = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
indexmap = { workspace = true }
tokio-cron-scheduler = { workspace = true }
mlua = { workspace = true }
tokio-util = { workspace = true }
uuid = { workspace = true }
dyn-clone = { workspace = true }
impls = { workspace = true }

View File

@@ -0,0 +1,60 @@
use std::marker::PhantomData;
use mlua::{FromLua, IntoLuaMulti};
#[derive(Debug, Clone)]
struct Internal {
uuid: uuid::Uuid,
lua: mlua::Lua,
}
#[derive(Debug, Clone)]
pub struct ActionCallback<T> {
internal: Option<Internal>,
phantom: PhantomData<T>,
}
impl<T> Default for ActionCallback<T> {
fn default() -> Self {
Self {
internal: None,
phantom: PhantomData::<T>,
}
}
}
impl<T> FromLua for ActionCallback<T> {
fn from_lua(value: mlua::Value, lua: &mlua::Lua) -> mlua::Result<Self> {
let uuid = uuid::Uuid::new_v4();
lua.set_named_registry_value(&uuid.to_string(), value)?;
Ok(ActionCallback {
internal: Some(Internal {
uuid,
lua: lua.clone(),
}),
phantom: PhantomData::<T>,
})
}
}
// TODO: Return proper error here
impl<T> ActionCallback<T>
where
T: IntoLuaMulti + Sync + Send + Clone + 'static,
{
pub async fn call(&self, state: T) {
let Some(internal) = self.internal.as_ref() else {
return;
};
let callback: mlua::Value = internal
.lua
.named_registry_value(&internal.uuid.to_string())
.unwrap();
match callback {
mlua::Value::Function(f) => f.call_async::<()>(state).await.unwrap(),
_ => todo!("Only functions are currently supported"),
}
}
}

View File

@@ -0,0 +1,74 @@
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use rumqttc::{MqttOptions, Transport};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct MqttConfig {
pub host: String,
pub port: u16,
pub client_name: String,
pub username: String,
pub password: String,
#[serde(default)]
pub tls: bool,
}
impl From<MqttConfig> for MqttOptions {
fn from(value: MqttConfig) -> Self {
let mut mqtt_options = MqttOptions::new(value.client_name, value.host, value.port);
mqtt_options.set_credentials(value.username, value.password);
mqtt_options.set_keep_alive(Duration::from_secs(5));
if value.tls {
mqtt_options.set_transport(Transport::tls_with_default_config());
}
mqtt_options
}
}
#[derive(Debug, Deserialize)]
pub struct FulfillmentConfig {
pub openid_url: String,
#[serde(default = "default_fulfillment_ip")]
pub ip: Ipv4Addr,
#[serde(default = "default_fulfillment_port")]
pub port: u16,
}
impl From<FulfillmentConfig> for SocketAddr {
fn from(fulfillment: FulfillmentConfig) -> Self {
(fulfillment.ip, fulfillment.port).into()
}
}
fn default_fulfillment_ip() -> Ipv4Addr {
[0, 0, 0, 0].into()
}
fn default_fulfillment_port() -> u16 {
7878
}
#[derive(Debug, Clone, Deserialize)]
pub struct InfoConfig {
pub name: String,
pub room: Option<String>,
}
impl InfoConfig {
pub fn identifier(&self) -> String {
(if let Some(room) = &self.room {
room.to_ascii_lowercase().replace(' ', "_") + "_"
} else {
String::new()
}) + &self.name.to_ascii_lowercase().replace(' ', "_")
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct MqttDeviceConfig {
pub topic: String,
}

View File

@@ -0,0 +1,99 @@
use std::fmt::Debug;
use automation_cast::Cast;
use dyn_clone::DynClone;
use google_home::traits::OnOff;
use mlua::ObjectLike;
use crate::event::{OnDarkness, OnMqtt, OnNotification, OnPresence};
// TODO: Make this a proper macro
macro_rules! impl_device {
($device:ty) => {
impl mlua::UserData for $device {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_async_function("new", |_lua, config| async {
let device: $device = LuaDeviceCreate::create(config)
.await
.map_err(mlua::ExternalError::into_lua_err)?;
Ok(device)
});
methods.add_method("__box", |_lua, this, _: ()| {
let b: Box<dyn Device> = Box::new(this.clone());
Ok(b)
});
methods.add_async_method("get_id", |_lua, this, _: ()| async move { Ok(this.get_id()) });
if impls::impls!($device: google_home::traits::OnOff) {
methods.add_async_method("set_on", |_lua, this, on: bool| async move {
(this.deref().cast() as Option<&dyn google_home::traits::OnOff>)
.expect("Cast should be valid")
.set_on(on)
.await
.unwrap();
Ok(())
});
methods.add_async_method("is_on", |_lua, this, _: ()| async move {
Ok((this.deref().cast() as Option<&dyn google_home::traits::OnOff>)
.expect("Cast should be valid")
.on()
.await
.unwrap())
});
}
}
}
};
}
pub(crate) use impl_device;
#[async_trait::async_trait]
pub trait LuaDeviceCreate {
type Config;
type Error;
async fn create(config: Self::Config) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait Device:
Debug
+ DynClone
+ Sync
+ Send
+ Cast<dyn google_home::Device>
+ Cast<dyn OnMqtt>
+ Cast<dyn OnPresence>
+ Cast<dyn OnDarkness>
+ Cast<dyn OnNotification>
+ Cast<dyn OnOff>
{
fn get_id(&self) -> String;
}
impl mlua::FromLua for Box<dyn Device> {
fn from_lua(value: mlua::Value, _lua: &mlua::Lua) -> mlua::Result<Self> {
match value {
mlua::Value::UserData(ud) => {
let ud = if ud.is::<Box<dyn Device>>() {
ud
} else {
ud.call_method::<_>("__box", ())?
};
let b = ud.borrow::<Self>()?.clone();
Ok(b)
}
_ => Err(mlua::Error::RuntimeError("Expected user data".into())),
}
}
}
impl mlua::UserData for Box<dyn Device> {}
dyn_clone::clone_trait_object!(Device);

View File

@@ -0,0 +1,189 @@
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use futures::future::join_all;
use futures::Future;
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{debug, instrument, trace};
use crate::device::Device;
use crate::event::{Event, EventChannel, OnDarkness, OnMqtt, OnNotification, OnPresence};
pub type DeviceMap = HashMap<String, Box<dyn Device>>;
#[derive(Clone)]
pub struct DeviceManager {
devices: Arc<RwLock<DeviceMap>>,
event_channel: EventChannel,
scheduler: JobScheduler,
}
impl DeviceManager {
pub async fn new() -> Self {
let (event_channel, mut event_rx) = EventChannel::new();
let device_manager = Self {
devices: Arc::new(RwLock::new(HashMap::new())),
event_channel,
scheduler: JobScheduler::new().await.unwrap(),
};
tokio::spawn({
let device_manager = device_manager.clone();
async move {
loop {
if let Some(event) = event_rx.recv().await {
device_manager.handle_event(event).await;
} else {
todo!("Handle errors with the event channel properly")
}
}
}
});
device_manager.scheduler.start().await.unwrap();
device_manager
}
pub async fn add(&self, device: Box<dyn Device>) {
let id = device.get_id();
debug!(id, "Adding device");
self.devices.write().await.insert(id, device);
}
pub fn event_channel(&self) -> EventChannel {
self.event_channel.clone()
}
pub async fn get(&self, name: &str) -> Option<Box<dyn Device>> {
self.devices.read().await.get(name).cloned()
}
pub async fn devices(&self) -> RwLockReadGuard<DeviceMap> {
self.devices.read().await
}
#[instrument(skip(self))]
async fn handle_event(&self, event: Event) {
match event {
Event::MqttMessage(message) => {
let devices = self.devices.read().await;
let iter = devices.iter().map(|(id, device)| {
let message = message.clone();
async move {
let device: Option<&dyn OnMqtt> = device.cast();
if let Some(device) = device {
// let subscribed = device
// .topics()
// .iter()
// .any(|topic| matches(&message.topic, topic));
//
// if subscribed {
trace!(id, "Handling");
device.on_mqtt(message).await;
trace!(id, "Done");
// }
}
}
});
join_all(iter).await;
}
Event::Darkness(dark) => {
let devices = self.devices.read().await;
let iter = devices.iter().map(|(id, device)| async move {
let device: Option<&dyn OnDarkness> = device.cast();
if let Some(device) = device {
trace!(id, "Handling");
device.on_darkness(dark).await;
trace!(id, "Done");
}
});
join_all(iter).await;
}
Event::Presence(presence) => {
let devices = self.devices.read().await;
let iter = devices.iter().map(|(id, device)| async move {
let device: Option<&dyn OnPresence> = device.cast();
if let Some(device) = device {
trace!(id, "Handling");
device.on_presence(presence).await;
trace!(id, "Done");
}
});
join_all(iter).await;
}
Event::Ntfy(notification) => {
let devices = self.devices.read().await;
let iter = devices.iter().map(|(id, device)| {
let notification = notification.clone();
async move {
let device: Option<&dyn OnNotification> = device.cast();
if let Some(device) = device {
trace!(id, "Handling");
device.on_notification(notification).await;
trace!(id, "Done");
}
}
});
join_all(iter).await;
}
}
}
}
impl mlua::UserData for DeviceManager {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_async_method("add", |_lua, this, device: Box<dyn Device>| async move {
this.add(device).await;
Ok(())
});
methods.add_async_method(
"schedule",
|lua, this, (schedule, f): (String, mlua::Function)| async move {
debug!("schedule = {schedule}");
// This creates a function, that returns the actual job we want to run
let create_job = {
let lua = lua.clone();
move |uuid: uuid::Uuid,
_: tokio_cron_scheduler::JobScheduler|
-> Pin<Box<dyn Future<Output = ()> + Send>> {
let lua = lua.clone();
// Create the actual function we want to run on a schedule
let future = async move {
let f: mlua::Function =
lua.named_registry_value(uuid.to_string().as_str()).unwrap();
f.call_async::<()>(()).await.unwrap();
};
Box::pin(future)
}
};
let job = Job::new_async(schedule.as_str(), create_job).unwrap();
let uuid = this.scheduler.add(job).await.unwrap();
// Store the function in the registry
lua.set_named_registry_value(uuid.to_string().as_str(), f)
.unwrap();
Ok(())
},
);
methods.add_method("event_channel", |_lua, this, ()| Ok(this.event_channel()))
}
}

100
automation_lib/src/error.rs Normal file
View File

@@ -0,0 +1,100 @@
use std::{error, fmt, result};
use bytes::Bytes;
use rumqttc::ClientError;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct MissingEnv {
keys: Vec<String>,
}
// TODO: Would be nice to somehow get the line number of the missing keys
impl MissingEnv {
pub fn new() -> Self {
Self { keys: Vec::new() }
}
pub fn add_missing(&mut self, key: &str) {
self.keys.push(key.into());
}
pub fn has_missing(self) -> result::Result<(), Self> {
if !self.keys.is_empty() {
Err(self)
} else {
Ok(())
}
}
}
impl Default for MissingEnv {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for MissingEnv {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Missing environment variable")?;
if self.keys.is_empty() {
unreachable!("This error should only be returned if there are actually missing environment variables");
}
if self.keys.len() == 1 {
write!(f, " '{}'", self.keys[0])?;
} else {
write!(f, "s '{}'", self.keys[0])?;
self.keys
.iter()
.skip(1)
.try_for_each(|key| write!(f, ", '{key}'"))?;
}
Ok(())
}
}
impl error::Error for MissingEnv {}
#[derive(Debug, Error)]
pub enum ParseError {
#[error("Invalid message payload received: {0:?}")]
InvalidPayload(Bytes),
}
// TODO: Would be nice to somehow get the line number of the expected wildcard topic
#[derive(Debug, Error)]
#[error("Topic '{topic}' is expected to be a wildcard topic")]
pub struct MissingWildcard {
topic: String,
}
impl MissingWildcard {
pub fn new(topic: &str) -> Self {
Self {
topic: topic.into(),
}
}
}
#[derive(Debug, Error)]
pub enum DeviceConfigError {
#[error("Device '{0}' does not implement expected trait '{1}'")]
MissingTrait(String, String),
#[error(transparent)]
MqttClientError(#[from] rumqttc::ClientError),
}
#[derive(Debug, Error)]
pub enum PresenceError {
#[error(transparent)]
SubscribeError(#[from] ClientError),
#[error(transparent)]
MissingWildcard(#[from] MissingWildcard),
}
#[derive(Debug, Error)]
pub enum LightSensorError {
#[error(transparent)]
SubscribeError(#[from] ClientError),
}

View File

@@ -0,0 +1,55 @@
use async_trait::async_trait;
use mlua::FromLua;
use rumqttc::Publish;
use tokio::sync::mpsc;
use crate::ntfy::Notification;
#[derive(Debug, Clone)]
pub enum Event {
MqttMessage(Publish),
Darkness(bool),
Presence(bool),
Ntfy(Notification),
}
pub type Sender = mpsc::Sender<Event>;
pub type Receiver = mpsc::Receiver<Event>;
#[derive(Clone, Debug, FromLua)]
pub struct EventChannel(Sender);
impl EventChannel {
pub fn new() -> (Self, Receiver) {
let (tx, rx) = mpsc::channel(100);
(Self(tx), rx)
}
pub fn get_tx(&self) -> Sender {
self.0.clone()
}
}
impl mlua::UserData for EventChannel {}
#[async_trait]
pub trait OnMqtt: Sync + Send {
// fn topics(&self) -> Vec<&str>;
async fn on_mqtt(&self, message: Publish);
}
#[async_trait]
pub trait OnPresence: Sync + Send {
async fn on_presence(&self, presence: bool);
}
#[async_trait]
pub trait OnDarkness: Sync + Send {
async fn on_darkness(&self, dark: bool);
}
#[async_trait]
pub trait OnNotification: Sync + Send {
async fn on_notification(&self, notification: Notification);
}

View File

@@ -0,0 +1,10 @@
mod timeout;
pub use timeout::Timeout;
pub fn register_with_lua(lua: &mlua::Lua) -> mlua::Result<()> {
lua.globals()
.set("Timeout", lua.create_proxy::<Timeout>()?)?;
Ok(())
}

View File

@@ -0,0 +1,76 @@
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::debug;
use crate::action_callback::ActionCallback;
#[derive(Debug, Default)]
pub struct State {
handle: Option<JoinHandle<()>>,
}
#[derive(Debug, Clone)]
pub struct Timeout {
state: Arc<RwLock<State>>,
}
impl mlua::UserData for Timeout {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_function("new", |_lua, ()| {
let device = Self {
state: Default::default(),
};
Ok(device)
});
methods.add_async_method(
"start",
|_lua, this, (timeout, callback): (u64, ActionCallback<bool>)| async move {
if let Some(handle) = this.state.write().await.handle.take() {
handle.abort();
}
debug!("Running timeout callback after {timeout}s");
let timeout = Duration::from_secs(timeout);
this.state.write().await.handle = Some(tokio::spawn({
async move {
tokio::time::sleep(timeout).await;
callback.call(false).await;
}
}));
Ok(())
},
);
methods.add_async_method("cancel", |_lua, this, ()| async move {
debug!("Canceling timeout callback");
if let Some(handle) = this.state.write().await.handle.take() {
handle.abort();
}
Ok(())
});
methods.add_async_method("is_waiting", |_lua, this, ()| async move {
debug!("Canceling timeout callback");
if let Some(handle) = this.state.read().await.handle.as_ref() {
debug!("Join handle: {}", handle.is_finished());
return Ok(!handle.is_finished());
}
debug!("Join handle: None");
Ok(false)
});
}
}

16
automation_lib/src/lib.rs Normal file
View File

@@ -0,0 +1,16 @@
#![allow(incomplete_features)]
#![feature(specialization)]
#![feature(let_chains)]
pub mod action_callback;
pub mod config;
pub mod device;
pub mod device_manager;
pub mod error;
pub mod event;
pub mod helpers;
pub mod messages;
pub mod mqtt;
pub mod ntfy;
pub mod presence;
pub mod schedule;

View File

@@ -0,0 +1,280 @@
use std::time::{SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use rumqttc::Publish;
use serde::{Deserialize, Serialize};
use crate::error::ParseError;
// Message used to turn on and off devices and receiving their state
#[derive(Debug, Serialize, Deserialize)]
pub struct OnOffMessage {
state: String,
}
impl OnOffMessage {
pub fn new(state: bool) -> Self {
Self {
state: if state { "ON" } else { "OFF" }.into(),
}
}
pub fn state(&self) -> bool {
self.state == "ON"
}
}
impl TryFrom<Publish> for OnOffMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message send to request activating a device
#[derive(Debug, Deserialize)]
pub struct ActivateMessage {
activate: bool,
}
impl ActivateMessage {
pub fn activate(&self) -> bool {
self.activate
}
}
impl TryFrom<Publish> for ActivateMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Actions that can be performed by a remote
#[derive(Debug, Deserialize, Copy, Clone)]
#[serde(rename_all = "snake_case")]
pub enum RemoteAction {
On,
Off,
BrightnessMoveUp,
BrightnessMoveDown,
BrightnessStop,
}
// Message used to report the action performed by a remote
#[derive(Debug, Deserialize)]
pub struct RemoteMessage {
action: RemoteAction,
}
impl RemoteMessage {
pub fn action(&self) -> RemoteAction {
self.action
}
}
impl TryFrom<Publish> for RemoteMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message used to report the current presence state
#[derive(Debug, Deserialize, Serialize)]
pub struct PresenceMessage {
state: bool,
updated: Option<u128>,
}
impl PresenceMessage {
pub fn new(state: bool) -> Self {
Self {
state,
updated: Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time is after UNIX EPOCH")
.as_millis(),
),
}
}
pub fn presence(&self) -> bool {
self.state
}
}
impl TryFrom<Publish> for PresenceMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message used to report the state of a light sensor
#[derive(Debug, Deserialize)]
pub struct BrightnessMessage {
illuminance: isize,
}
impl BrightnessMessage {
pub fn illuminance(&self) -> isize {
self.illuminance
}
}
impl TryFrom<Publish> for BrightnessMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message to report the state of a contact sensor
#[derive(Debug, Deserialize)]
pub struct ContactMessage {
contact: bool,
}
impl ContactMessage {
pub fn is_closed(&self) -> bool {
self.contact
}
}
impl TryFrom<Publish> for ContactMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message used to report the current darkness state
#[derive(Debug, Deserialize, Serialize)]
pub struct DarknessMessage {
state: bool,
updated: Option<u128>,
}
impl DarknessMessage {
pub fn new(state: bool) -> Self {
Self {
state,
updated: Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time is after UNIX EPOCH")
.as_millis(),
),
}
}
pub fn is_dark(&self) -> bool {
self.state
}
}
impl TryFrom<Publish> for DarknessMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message used to report the power draw a smart plug
#[derive(Debug, Deserialize)]
pub struct PowerMessage {
power: f32,
}
impl PowerMessage {
pub fn power(&self) -> f32 {
self.power
}
}
impl TryFrom<Publish> for PowerMessage {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}
// Message used to report the power state of a hue light
#[derive(Debug, Deserialize)]
pub struct HueState {
on: bool,
}
#[derive(Debug, Deserialize)]
pub struct HueMessage {
state: HueState,
}
impl HueMessage {
pub fn is_on(&self) -> bool {
self.state.on
}
}
impl TryFrom<Bytes> for HueMessage {
type Error = ParseError;
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
serde_json::from_slice(&bytes).or(Err(ParseError::InvalidPayload(bytes.clone())))
}
}
// TODO: Import this from the air_filter code itself instead of copying
#[derive(PartialEq, Eq, Debug, Clone, Copy, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AirFilterFanState {
Off,
Low,
Medium,
High,
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
pub struct SetAirFilterFanState {
state: AirFilterFanState,
}
#[derive(PartialEq, Debug, Clone, Copy, Deserialize, Serialize)]
pub struct AirFilterState {
pub state: AirFilterFanState,
pub humidity: f32,
pub temperature: f32,
}
impl SetAirFilterFanState {
pub fn new(state: AirFilterFanState) -> Self {
Self { state }
}
}
impl TryFrom<Publish> for AirFilterState {
type Error = ParseError;
fn try_from(message: Publish) -> Result<Self, Self::Error> {
serde_json::from_slice(&message.payload)
.or(Err(ParseError::InvalidPayload(message.payload.clone())))
}
}

View File

@@ -0,0 +1,48 @@
use std::ops::{Deref, DerefMut};
use mlua::FromLua;
use rumqttc::{AsyncClient, Event, EventLoop, Incoming};
use tracing::{debug, warn};
use crate::event::{self, EventChannel};
#[derive(Debug, Clone, FromLua)]
pub struct WrappedAsyncClient(pub AsyncClient);
impl Deref for WrappedAsyncClient {
type Target = AsyncClient;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for WrappedAsyncClient {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl mlua::UserData for WrappedAsyncClient {}
pub fn start(mut eventloop: EventLoop, event_channel: &EventChannel) {
let tx = event_channel.get_tx();
tokio::spawn(async move {
debug!("Listening for MQTT events");
loop {
let notification = eventloop.poll().await;
match notification {
Ok(Event::Incoming(Incoming::Publish(p))) => {
tx.send(event::Event::MqttMessage(p)).await.ok();
}
Ok(..) => continue,
Err(err) => {
// Something has gone wrong
// We stay in the loop as that will attempt to reconnect
warn!("{}", err);
}
}
}
});
}

211
automation_lib/src/ntfy.rs Normal file
View File

@@ -0,0 +1,211 @@
use std::collections::HashMap;
use std::convert::Infallible;
use std::ops::Deref;
use async_trait::async_trait;
use automation_cast::Cast;
use automation_macro::LuaDeviceConfig;
use serde::Serialize;
use serde_repr::*;
use tracing::{error, trace, warn};
use crate::device::{impl_device, Device, LuaDeviceCreate};
use crate::event::{self, Event, EventChannel, OnNotification, OnPresence};
#[derive(Debug, Serialize_repr, Clone, Copy)]
#[repr(u8)]
pub enum Priority {
Min = 1,
Low,
Default,
High,
Max,
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "action")]
pub enum ActionType {
Broadcast {
#[serde(skip_serializing_if = "HashMap::is_empty")]
extras: HashMap<String, String>,
},
// View,
// Http
}
#[derive(Debug, Serialize, Clone)]
pub struct Action {
#[serde(flatten)]
pub action: ActionType,
pub label: String,
pub clear: Option<bool>,
}
#[derive(Serialize)]
struct NotificationFinal {
topic: String,
#[serde(flatten)]
inner: Notification,
}
#[derive(Debug, Serialize, Clone)]
pub struct Notification {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
message: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tags: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
priority: Option<Priority>,
#[serde(skip_serializing_if = "Vec::is_empty")]
actions: Vec<Action>,
}
impl Notification {
pub fn new() -> Self {
Self {
title: None,
message: None,
tags: Vec::new(),
priority: None,
actions: Vec::new(),
}
}
pub fn set_title(mut self, title: &str) -> Self {
self.title = Some(title.into());
self
}
pub fn set_message(mut self, message: &str) -> Self {
self.message = Some(message.into());
self
}
pub fn add_tag(mut self, tag: &str) -> Self {
self.tags.push(tag.into());
self
}
pub fn set_priority(mut self, priority: Priority) -> Self {
self.priority = Some(priority);
self
}
pub fn add_action(mut self, action: Action) -> Self {
self.actions.push(action);
self
}
fn finalize(self, topic: &str) -> NotificationFinal {
NotificationFinal {
topic: topic.into(),
inner: self,
}
}
}
impl Default for Notification {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, LuaDeviceConfig)]
pub struct Config {
#[device_config(default("https://ntfy.sh".into()))]
pub url: String,
pub topic: String,
#[device_config(rename("event_channel"), from_lua, with(|ec: EventChannel| ec.get_tx()))]
pub tx: event::Sender,
}
#[derive(Debug, Clone)]
pub struct Ntfy {
config: Config,
}
impl_device!(Ntfy);
#[async_trait]
impl LuaDeviceCreate for Ntfy {
type Config = Config;
type Error = Infallible;
async fn create(config: Self::Config) -> Result<Self, Self::Error> {
trace!(id = "ntfy", "Setting up Ntfy");
Ok(Self { config })
}
}
impl Device for Ntfy {
fn get_id(&self) -> String {
"ntfy".to_string()
}
}
impl Ntfy {
async fn send(&self, notification: Notification) {
let notification = notification.finalize(&self.config.topic);
// Create the request
let res = reqwest::Client::new()
.post(self.config.url.clone())
.json(&notification)
.send()
.await;
if let Err(err) = res {
error!("Something went wrong while sending the notification: {err}");
} else if let Ok(res) = res {
let status = res.status();
if !status.is_success() {
warn!("Received status {status} when sending notification");
}
}
}
}
#[async_trait]
impl OnPresence for Ntfy {
async fn on_presence(&self, presence: bool) {
// Setup extras for the broadcast
let extras = HashMap::from([
("cmd".into(), "presence".into()),
("state".into(), if presence { "0" } else { "1" }.into()),
]);
// Create broadcast action
let action = Action {
action: ActionType::Broadcast { extras },
label: if presence { "Set away" } else { "Set home" }.into(),
clear: Some(true),
};
// Create the notification
let notification = Notification::new()
.set_title("Presence")
.set_message(if presence { "Home" } else { "Away" })
.add_tag("house")
.add_action(action)
.set_priority(Priority::Low);
if self
.config
.tx
.send(Event::Ntfy(notification))
.await
.is_err()
{
warn!("There are no receivers on the event channel");
}
}
}
#[async_trait]
impl OnNotification for Ntfy {
async fn on_notification(&self, notification: Notification) {
self.send(notification).await;
}
}

View File

@@ -0,0 +1,132 @@
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use automation_cast::Cast;
use automation_macro::LuaDeviceConfig;
use rumqttc::Publish;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tracing::{debug, trace, warn};
use crate::config::MqttDeviceConfig;
use crate::device::{impl_device, Device, LuaDeviceCreate};
use crate::event::{self, Event, EventChannel, OnMqtt};
use crate::messages::PresenceMessage;
use crate::mqtt::WrappedAsyncClient;
#[derive(Debug, Clone, LuaDeviceConfig)]
pub struct Config {
#[device_config(flatten)]
pub mqtt: MqttDeviceConfig,
#[device_config(from_lua, rename("event_channel"), with(|ec: EventChannel| ec.get_tx()))]
pub tx: event::Sender,
#[device_config(from_lua)]
pub client: WrappedAsyncClient,
}
pub const DEFAULT_PRESENCE: bool = false;
#[derive(Debug)]
pub struct State {
devices: HashMap<String, bool>,
current_overall_presence: bool,
}
#[derive(Debug, Clone)]
pub struct Presence {
config: Config,
state: Arc<RwLock<State>>,
}
impl Presence {
async fn state(&self) -> RwLockReadGuard<State> {
self.state.read().await
}
async fn state_mut(&self) -> RwLockWriteGuard<State> {
self.state.write().await
}
}
impl_device!(Presence);
#[async_trait]
impl LuaDeviceCreate for Presence {
type Config = Config;
type Error = rumqttc::ClientError;
async fn create(config: Self::Config) -> Result<Self, Self::Error> {
trace!(id = "presence", "Setting up Presence");
config
.client
.subscribe(&config.mqtt.topic, rumqttc::QoS::AtLeastOnce)
.await?;
let state = State {
devices: HashMap::new(),
current_overall_presence: DEFAULT_PRESENCE,
};
let state = Arc::new(RwLock::new(state));
Ok(Self { config, state })
}
}
impl Device for Presence {
fn get_id(&self) -> String {
"presence".to_string()
}
}
#[async_trait]
impl OnMqtt for Presence {
async fn on_mqtt(&self, message: Publish) {
if !rumqttc::matches(&message.topic, &self.config.mqtt.topic) {
return;
}
let offset = self
.config
.mqtt
.topic
.find('+')
.or(self.config.mqtt.topic.find('#'))
.expect("Presence::create fails if it does not contain wildcards");
let device_name = message.topic[offset..].into();
if message.payload.is_empty() {
// Remove the device from the map
debug!("State of device [{device_name}] has been removed");
self.state_mut().await.devices.remove(&device_name);
} else {
let present = match PresenceMessage::try_from(message) {
Ok(state) => state.presence(),
Err(err) => {
warn!("Failed to parse message: {err}");
return;
}
};
debug!("State of device [{device_name}] has changed: {}", present);
self.state_mut().await.devices.insert(device_name, present);
}
let overall_presence = self.state().await.devices.iter().any(|(_, v)| *v);
if overall_presence != self.state().await.current_overall_presence {
debug!("Overall presence updated: {overall_presence}");
self.state_mut().await.current_overall_presence = overall_presence;
if self
.config
.tx
.send(Event::Presence(overall_presence))
.await
.is_err()
{
warn!("There are no receivers on the event channel");
}
}
}
}

View File

@@ -0,0 +1,17 @@
use indexmap::IndexMap;
use serde::Deserialize;
#[derive(Debug, Deserialize, Hash, PartialEq, Eq, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum Action {
On,
Off,
}
pub type Schedule = IndexMap<String, IndexMap<Action, Vec<String>>>;
// #[derive(Debug, Deserialize)]
// pub struct Schedule {
// pub when: String,
// pub actions: IndexMap<Action, Vec<String>>,
// }