From cfd10a7dafd862be8b248b74af28123d6926f3a2 Mon Sep 17 00:00:00 2001 From: Dreaded_X Date: Tue, 3 Jan 2023 05:26:00 +0100 Subject: [PATCH] Refactored how we deal with blocking code and added AudioSetup --- Cargo.lock | 2 + Cargo.toml | 2 + config/zeus.dev.toml | 6 + google-home/src/errors.rs | 4 + src/config.rs | 11 +- src/devices.rs | 3 + src/devices/audio_setup.rs | 277 +++++++++++++++++++++++++++++++++++++ src/devices/ikea_outlet.rs | 52 ++----- src/devices/wake_on_lan.rs | 53 +++---- src/main.rs | 18 ++- src/mqtt.rs | 99 +++++++++++-- src/ntfy.rs | 19 ++- src/presence.rs | 36 ++--- 13 files changed, 464 insertions(+), 118 deletions(-) create mode 100644 src/devices/audio_setup.rs diff --git a/Cargo.lock b/Cargo.lock index c089407..b5a997e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,10 +31,12 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bytes", "dotenv", "google-home", "impl_cast", "paste", + "pollster", "reqwest", "rumqttc", "serde", diff --git a/Cargo.toml b/Cargo.toml index 36bc300..3d249ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ axum = "0.6.1" serde_repr = "0.1.10" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +bytes = "1.3.0" +pollster = "0.2.5" [profile.release] lto=true diff --git a/config/zeus.dev.toml b/config/zeus.dev.toml index c747c63..b064ab1 100644 --- a/config/zeus.dev.toml +++ b/config/zeus.dev.toml @@ -25,3 +25,9 @@ type = "WakeOnLAN" info = { name = "Zeus", room = "Living Room" } mqtt = { topic = "automation/appliance/living_room/zeus" } mac_address = "30:9c:23:60:9c:13" + +[devices.audio] +type = "AudioSetup" +mqtt = { topic = "zigbee2mqtt/living/remote" } +mixer = [10, 0, 0, 49] +speakers = [10, 0, 0, 182] diff --git a/google-home/src/errors.rs b/google-home/src/errors.rs index a710899..12aec14 100644 --- a/google-home/src/errors.rs +++ b/google-home/src/errors.rs @@ -6,8 +6,12 @@ use serde::Serialize; pub enum DeviceError { #[error("deviceNotFound")] DeviceNotFound, + #[error("deviceOffline")] + DeviceOffline, #[error("actionNotAvailable")] ActionNotAvailable, + #[error("transientError")] + TransientError, } #[derive(Debug, Hash, PartialEq, Eq, Copy, Clone, Serialize, Error)] diff --git a/src/config.rs b/src/config.rs index 32b04bd..e6935fd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use tracing::{debug, trace}; use rumqttc::AsyncClient; use serde::Deserialize; -use crate::devices::{DeviceBox, IkeaOutlet, WakeOnLAN}; +use crate::devices::{DeviceBox, IkeaOutlet, WakeOnLAN, AudioSetup}; // @TODO Configure more defaults @@ -83,6 +83,11 @@ pub enum Device { info: InfoConfig, mqtt: MqttDeviceConfig, mac_address: String, + }, + AudioSetup { + mqtt: MqttDeviceConfig, + mixer: [u8; 4], + speakers: [u8; 4], } } @@ -110,6 +115,10 @@ impl Device { trace!(id = identifier, "WakeOnLan [{} in {:?}]", info.name, info.room); Box::new(WakeOnLAN::new(identifier, info, mqtt, mac_address, client)) }, + Device::AudioSetup { mqtt, mixer, speakers } => { + trace!(id = identifier, "AudioSetup [{}]", identifier); + Box::new(AudioSetup::new(identifier, mqtt, mixer, speakers, client)) + }, } } } diff --git a/src/devices.rs b/src/devices.rs index 5eea5ef..52e9a91 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -4,6 +4,9 @@ pub use self::ikea_outlet::IkeaOutlet; mod wake_on_lan; pub use self::wake_on_lan::WakeOnLAN; +mod audio_setup; +pub use self::audio_setup::AudioSetup; + use std::collections::HashMap; use google_home::{GoogleHomeDevice, traits::OnOff}; diff --git a/src/devices/audio_setup.rs b/src/devices/audio_setup.rs new file mode 100644 index 0000000..db12fb2 --- /dev/null +++ b/src/devices/audio_setup.rs @@ -0,0 +1,277 @@ +use std::io::{Write, Read}; +use std::net::{TcpStream, SocketAddr}; + +use bytes::{BufMut, Buf}; +use google_home::errors::{ErrorCode, DeviceError}; +use google_home::traits::{self, OnOff}; +use rumqttc::AsyncClient; +use serde::{Deserialize, Serialize}; +use tracing::warn; +use pollster::FutureExt as _; + +use crate::config::MqttDeviceConfig; +use crate::mqtt::{OnMqtt, RemoteMessage, RemoteAction}; + +use super::Device; + +struct TPLinkOutlet { + addr: SocketAddr, +} + +impl TPLinkOutlet { + pub fn new(ip: [u8; 4]) -> Self { + // @TODO Get the current state of the outlet + Self { addr: (ip, 9999).into() } + } + + pub fn encrypt(data: bytes::Bytes) -> bytes::Bytes { + let mut key: u8 = 171; + let mut encrypted = bytes::BytesMut::with_capacity(data.len() + 4); + + encrypted.put_u32(data.len() as u32); + + for c in data { + key = key ^ c; + encrypted.put_u8(key); + } + + return encrypted.freeze(); + } +} +#[derive(Debug, Serialize)] +struct RequestRelayState { + state: isize, +} + +#[derive(Debug, Serialize)] +struct RequestSysinfo; + +#[derive(Debug, Serialize)] +struct RequestSystem { + #[serde(skip_serializing_if = "Option::is_none")] + get_sysinfo: Option, + #[serde(skip_serializing_if = "Option::is_none")] + set_relay_state: Option, +} + +#[derive(Debug, Serialize)] +struct Request { + system: RequestSystem, +} + +impl Request { + fn get_sysinfo() -> Self { + Self { + system: RequestSystem { + get_sysinfo: Some(RequestSysinfo{}), + set_relay_state: None + } + } + } + + fn set_relay_state(on: bool) -> Self { + Self { + system: RequestSystem { + get_sysinfo: None, + set_relay_state: Some(RequestRelayState { + state: if on { 1 } else { 0 } + }) + } + } + } + + + fn encrypt(&self) -> bytes::Bytes { + let data: bytes::Bytes = serde_json::to_string(self).unwrap().into(); + + let mut key: u8 = 171; + let mut encrypted = bytes::BytesMut::with_capacity(data.len() + 4); + + encrypted.put_u32(data.len() as u32); + + for c in data { + key = key ^ c; + encrypted.put_u8(key); + } + + return encrypted.freeze(); + } +} + +#[derive(Debug, Deserialize)] +struct ResponseSetRelayState { + err_code: isize, +} + +#[derive(Debug, Deserialize)] +struct ResponseGetSysinfo { + err_code: isize, + relay_state: isize, +} + +#[derive(Debug, Deserialize)] +struct ResponseSystem { + set_relay_state: Option, + get_sysinfo: Option, +} + +#[derive(Debug, Deserialize)] +struct Response { + system: ResponseSystem, +} + +impl Response { + fn get_current_relay_state(&self) -> Result { + if let Some(sysinfo) = &self.system.get_sysinfo { + if sysinfo.err_code != 0 { + return Err(anyhow::anyhow!("Error code: {}", sysinfo.err_code)); + } + return Ok(sysinfo.relay_state == 1); + } + + return Err(anyhow::anyhow!("No sysinfo found in response")); + } + + fn check_set_relay_success(&self) -> Result<(), anyhow::Error> { + if let Some(set_relay_state) = &self.system.set_relay_state { + if set_relay_state.err_code != 0 { + return Err(anyhow::anyhow!("Error code: {}", set_relay_state.err_code)); + } + return Ok(()); + } + + return Err(anyhow::anyhow!("No relay_state found in response")); + } + + fn decrypt(mut data: bytes::Bytes) -> Result { + let mut key: u8 = 171; + if data.len() < 4 { + return Err(anyhow::anyhow!("Expected a minimun data length of 4")); + } + + let length = data.get_u32(); + let mut decrypted = bytes::BytesMut::with_capacity(length as usize); + + for c in data { + decrypted.put_u8(key ^ c); + key = c; + } + + let decrypted = std::str::from_utf8(&decrypted)?; + Ok(serde_json::from_str(decrypted)?) + } +} + +impl traits::OnOff for TPLinkOutlet { + fn is_on(&self) -> Result { + let mut stream = TcpStream::connect(self.addr).or::(Err(DeviceError::DeviceOffline.into()))?; + + let body = Request::get_sysinfo().encrypt(); + stream.write_all(&body).and(stream.flush()).or::(Err(DeviceError::TransientError.into()))?; + + let mut received = Vec::new(); + let mut rx_bytes = [0; 1024]; + loop { + let read = stream.read(&mut rx_bytes).or::(Err(DeviceError::TransientError.into()))?; + + received.extend_from_slice(&rx_bytes[..read]); + + if read < rx_bytes.len() { + break; + } + } + + let resp = Response::decrypt(received.into()).or::(Err(DeviceError::TransientError.into()))?; + + resp.get_current_relay_state().or(Err(DeviceError::TransientError.into())) + } + + fn set_on(&mut self, on: bool) -> Result<(), ErrorCode> { + let mut stream = TcpStream::connect(self.addr).or::(Err(DeviceError::DeviceOffline.into()))?; + + let body = Request::set_relay_state(on).encrypt(); + stream.write_all(&body).and(stream.flush()).or::(Err(DeviceError::TransientError.into()))?; + + let mut received = Vec::new(); + let mut rx_bytes = [0; 1024]; + loop { + let read = match stream.read(&mut rx_bytes) { + Ok(read) => read, + Err(_) => return Err(DeviceError::TransientError.into()), + }; + + received.extend_from_slice(&rx_bytes[..read]); + + if read < rx_bytes.len() { + break; + } + } + + let resp = Response::decrypt(received.into()).or::(Err(DeviceError::TransientError.into()))?; + + resp.check_set_relay_success().or(Err(DeviceError::TransientError.into())) + } +} + +pub struct AudioSetup { + identifier: String, + mqtt: MqttDeviceConfig, + mixer: TPLinkOutlet, + speakers: TPLinkOutlet, +} + +impl AudioSetup { + pub fn new(identifier: String, mqtt: MqttDeviceConfig, mixer_ip: [u8; 4], speakers_ip: [u8; 4], client: AsyncClient) -> Self { + let mixer = TPLinkOutlet::new(mixer_ip); + let speakers = TPLinkOutlet::new(speakers_ip); + + client.subscribe(mqtt.topic.clone(), rumqttc::QoS::AtLeastOnce).block_on().unwrap(); + + Self { identifier, mqtt, mixer, speakers } + } +} + +impl Device for AudioSetup { + fn get_id(&self) -> String { + self.identifier.clone() + } +} + +impl OnMqtt for AudioSetup { + fn on_mqtt(&mut self, message: &rumqttc::Publish) { + if message.topic != self.mqtt.topic { + return; + } + + let action = match RemoteMessage::try_from(message) { + Ok(message) => message.action(), + Err(err) => { + warn!(id = self.identifier, "Failed to parse message: {err}"); + return; + } + }; + + match action { + RemoteAction::On => { + if self.mixer.is_on().unwrap() { + self.speakers.set_on(false).unwrap(); + self.mixer.set_on(false).unwrap(); + } else { + self.speakers.set_on(true).unwrap(); + self.mixer.set_on(true).unwrap(); + } + }, + RemoteAction::BrightnessMoveUp => { + if !self.mixer.is_on().unwrap() { + self.mixer.set_on(true).unwrap(); + } else if self.speakers.is_on().unwrap() { + self.speakers.set_on(false).unwrap(); + } else { + self.speakers.set_on(true).unwrap(); + } + }, + RemoteAction::BrightnessStop => { /* Ignore this action */ }, + _ => warn!("Expected ikea shortcut button which only supports 'on' and 'brightness_move_up', got: {action:?}") + } + } +} diff --git a/src/devices/ikea_outlet.rs b/src/devices/ikea_outlet.rs index 888e1d3..a82a751 100644 --- a/src/devices/ikea_outlet.rs +++ b/src/devices/ikea_outlet.rs @@ -3,13 +3,13 @@ use std::time::Duration; use google_home::errors::ErrorCode; use google_home::{GoogleHomeDevice, device, types::Type, traits}; use rumqttc::{AsyncClient, Publish}; -use serde::{Deserialize, Serialize}; use tracing::{debug, trace, warn}; use tokio::task::JoinHandle; +use pollster::FutureExt as _; use crate::config::{KettleConfig, InfoConfig, MqttDeviceConfig}; use crate::devices::Device; -use crate::mqtt::OnMqtt; +use crate::mqtt::{OnMqtt, OnOffMessage}; use crate::presence::OnPresence; pub struct IkeaOutlet { @@ -28,22 +28,14 @@ impl IkeaOutlet { let c = client.clone(); let t = mqtt.topic.clone(); // @TODO Handle potential errors here - tokio::spawn(async move { - c.subscribe(t, rumqttc::QoS::AtLeastOnce).await.unwrap(); - }); + c.subscribe(t, rumqttc::QoS::AtLeastOnce).block_on().unwrap(); Self{ identifier, info, mqtt, kettle, client, last_known_state: false, handle: None } } } async fn set_on(client: AsyncClient, topic: String, on: bool) { - let message = StateMessage{ - state: if on { - "ON".to_owned() - } else { - "OFF".to_owned() - } - }; + let message = OnOffMessage::new(on); // @TODO Handle potential errors here client.publish(topic + "/set", rumqttc::QoS::AtLeastOnce, false, serde_json::to_string(&message).unwrap()).await.unwrap(); @@ -55,20 +47,6 @@ impl Device for IkeaOutlet { } } -#[derive(Debug, Serialize, Deserialize)] -struct StateMessage { - state: String -} - -impl TryFrom<&Publish> for StateMessage { - type Error = anyhow::Error; - - fn try_from(message: &Publish) -> Result { - serde_json::from_slice(&message.payload) - .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) - } -} - impl OnMqtt for IkeaOutlet { fn on_mqtt(&mut self, message: &Publish) { // Update the internal state based on what the device has reported @@ -76,16 +54,16 @@ impl OnMqtt for IkeaOutlet { return; } - let new_state = match StateMessage::try_from(message) { - Ok(state) => state, + let state = match OnOffMessage::try_from(message) { + Ok(state) => state.state(), Err(err) => { warn!(id = self.identifier, "Failed to parse message: {err}"); return; } - }.state == "ON"; + }; // No need to do anything if the state has not changed - if new_state == self.last_known_state { + if state == self.last_known_state { return; } @@ -94,11 +72,11 @@ impl OnMqtt for IkeaOutlet { handle.abort(); } - debug!(id = self.identifier, "Updating state to {new_state}"); - self.last_known_state = new_state; + debug!(id = self.identifier, "Updating state to {state}"); + self.last_known_state = state; // If this is a kettle start a timeout for turning it of again - if new_state { + if state { let kettle = match &self.kettle { Some(kettle) => kettle, None => return, @@ -139,9 +117,7 @@ impl OnPresence for IkeaOutlet { debug!(id = self.identifier, "Turning device off"); let client = self.client.clone(); let topic = self.mqtt.topic.clone(); - tokio::spawn(async move { - set_on(client, topic, false).await; - }); + set_on(client, topic, false).block_on(); } } } @@ -185,9 +161,7 @@ impl traits::OnOff for IkeaOutlet { fn set_on(&mut self, on: bool) -> Result<(), ErrorCode> { let client = self.client.clone(); let topic = self.mqtt.topic.clone(); - tokio::spawn(async move { - set_on(client, topic, on).await; - }); + set_on(client, topic, on).block_on(); Ok(()) } diff --git a/src/devices/wake_on_lan.rs b/src/devices/wake_on_lan.rs index 2e3bf4f..a4edb55 100644 --- a/src/devices/wake_on_lan.rs +++ b/src/devices/wake_on_lan.rs @@ -1,9 +1,9 @@ use google_home::{GoogleHomeDevice, types::Type, device, traits::{self, Scene}, errors::{ErrorCode, DeviceError}}; use tracing::{debug, warn}; use rumqttc::{AsyncClient, Publish}; -use serde::Deserialize; +use pollster::FutureExt as _; -use crate::{config::{InfoConfig, MqttDeviceConfig}, mqtt::OnMqtt}; +use crate::{config::{InfoConfig, MqttDeviceConfig}, mqtt::{OnMqtt, ActivateMessage}}; use super::Device; @@ -18,9 +18,7 @@ impl WakeOnLAN { pub fn new(identifier: String, info: InfoConfig, mqtt: MqttDeviceConfig, mac_address: String, client: AsyncClient) -> Self { let t = mqtt.topic.clone(); // @TODO Handle potential errors here - tokio::spawn(async move { - client.subscribe(t, rumqttc::QoS::AtLeastOnce).await.unwrap(); - }); + client.subscribe(t, rumqttc::QoS::AtLeastOnce).block_on().unwrap(); Self { identifier, info, mqtt, mac_address } } @@ -32,20 +30,6 @@ impl Device for WakeOnLAN { } } -#[derive(Debug, Deserialize)] -struct StateMessage { - activate: bool -} - -impl TryFrom<&Publish> for StateMessage { - type Error = anyhow::Error; - - fn try_from(message: &Publish) -> Result { - serde_json::from_slice(&message.payload) - .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) - } -} - impl OnMqtt for WakeOnLAN { fn on_mqtt(&mut self, message: &Publish) { @@ -53,15 +37,15 @@ impl OnMqtt for WakeOnLAN { return; } - let payload = match StateMessage::try_from(message) { - Ok(state) => state, + let activate = match ActivateMessage::try_from(message) { + Ok(message) => message.activate(), Err(err) => { warn!(id = self.identifier, "Failed to parse message: {err}"); return; } }; - self.set_active(payload.activate).ok(); + self.set_active(activate).ok(); } } @@ -97,19 +81,20 @@ impl traits::Scene for WakeOnLAN { // if we are inside of docker, so for now just call a webhook that does it for us let mac_address = self.mac_address.clone(); let id = self.identifier.clone(); - tokio::spawn(async move { - debug!(id, "Activating Computer: {}", mac_address); - let req = match reqwest::get(format!("http://10.0.0.2:9000/start-pc?mac={mac_address}")).await { - Ok(req) => req, - Err(err) => { - warn!(id, "Failed to call webhook: {err}"); - return; - } - }; - if req.status() != 200 { - warn!(id, "Failed to call webhook: {}", req.status()); + + debug!(id, "Activating Computer: {}", mac_address); + let req = match reqwest::get(format!("http://10.0.0.2:9000/start-pc?mac={mac_address}")).block_on() { + Ok(req) => req, + Err(err) => { + warn!(id, "Failed to call webhook: {err}"); + // @TODO Handle error + return Ok(()); } - }); + }; + + if req.status() != 200 { + warn!(id, "Failed to call webhook: {}", req.status()); + } Ok(()) } else { diff --git a/src/main.rs b/src/main.rs index b04b652..c622405 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use dotenv::dotenv; use rumqttc::{MqttOptions, Transport, AsyncClient}; use tracing::{error, info, metadata::LevelFilter}; -use automation::{devices::Devices, mqtt::Mqtt}; +use automation::{devices::{Devices}, mqtt::Mqtt}; use google_home::{GoogleHome, Request}; use tracing_subscriber::EnvFilter; @@ -58,7 +58,6 @@ async fn main() { let presence = Arc::new(RwLock::new(presence)); mqtt.add_listener(Arc::downgrade(&presence)); - // Start mqtt, this spawns a seperate async task mqtt.start(); @@ -66,6 +65,8 @@ async fn main() { config.devices .into_iter() .map(|(identifier, device_config)| { + // This can technically block, but this only happens during start-up, so should not be + // a problem device_config.into(identifier, client.clone()) }) .for_each(|device| { @@ -75,12 +76,15 @@ async fn main() { // Create google home fullfillment route let fullfillment = Router::new() .route("/google_home", post(async move |Json(payload): Json| { - // @TODO Verify that we are actually logged in - // Might also be smart to get the username from here - let gc = GoogleHome::new(&config.fullfillment.username); - let result = gc.handle_request(payload, &mut devices.write().unwrap().as_google_home_devices()).unwrap(); + // Handle request might block, so we need to spawn a blocking task + tokio::task::spawn_blocking(move || { + // @TODO Verify that we are actually logged in + // Might also be smart to get the username from here + let gc = GoogleHome::new(&config.fullfillment.username); + let result = gc.handle_request(payload, &mut devices.write().unwrap().as_google_home_devices()).unwrap(); - (StatusCode::OK, Json(result)) + return (StatusCode::OK, Json(result)); + }).await.unwrap() })); // Combine together all the routes diff --git a/src/mqtt.rs b/src/mqtt.rs index 65ab094..849b46f 100644 --- a/src/mqtt.rs +++ b/src/mqtt.rs @@ -1,5 +1,6 @@ use std::sync::{Weak, RwLock}; -use tracing::{error, debug, trace, span, Level}; +use serde::{Serialize, Deserialize}; +use tracing::{error, debug, span, Level}; use rumqttc::{Publish, Event, Incoming, EventLoop}; use tokio::task::JoinHandle; @@ -19,16 +20,12 @@ impl Mqtt { return Self { listeners: Vec::new(), eventloop } } - fn notify(&mut self, message: Publish) { - self.listeners.retain(|listener| { + fn notify(message: Publish, listeners: Vec>>) { + let _span = span!(Level::TRACE, "mqtt_message").entered(); + listeners.into_iter().for_each(|listener| { if let Some(listener) = listener.upgrade() { listener.write().unwrap().on_mqtt(&message); - return true; - } else { - trace!("Removing listener..."); } - - return false; }) } @@ -43,9 +40,15 @@ impl Mqtt { let notification = self.eventloop.poll().await; match notification { Ok(Event::Incoming(Incoming::Publish(p))) => { - // Could cause problems in async - let _span = span!(Level::TRACE, "mqtt_message").entered(); - self.notify(p); + // Remove non-existing listeners + self.listeners.retain(|listener| listener.strong_count() > 0); + // Clone the listeners + let listeners = self.listeners.clone(); + + // Notify might block, so we spawn a blocking task + tokio::task::spawn_blocking(move || { + Mqtt::notify(p, listeners); + }); }, Ok(..) => continue, Err(err) => { @@ -59,3 +62,77 @@ impl Mqtt { }) } } + +#[derive(Debug, Serialize, Deserialize)] +pub struct OnOffMessage { + state: String +} + +impl OnOffMessage { + pub fn new(state: bool) -> Self { + Self { state: if state {"ON"} else {"OFF"}.into() } + } + + pub fn state(&self) -> bool { + self.state == "ON" + } +} + +impl TryFrom<&Publish> for OnOffMessage { + type Error = anyhow::Error; + + fn try_from(message: &Publish) -> Result { + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) + } +} + +#[derive(Debug, Deserialize)] +pub struct ActivateMessage { + activate: bool +} + +impl ActivateMessage { + pub fn activate(&self) -> bool { + self.activate + } +} + +impl TryFrom<&Publish> for ActivateMessage { + type Error = anyhow::Error; + + fn try_from(message: &Publish) -> Result { + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) + } +} + +#[derive(Debug, Deserialize, Copy, Clone)] +#[serde(rename_all = "snake_case")] +pub enum RemoteAction { + On, + Off, + BrightnessMoveUp, + BrightnessMoveDown, + BrightnessStop, +} + +#[derive(Debug, Deserialize)] +pub struct RemoteMessage { + action: RemoteAction +} + +impl RemoteMessage { + pub fn action(&self) -> RemoteAction { + self.action + } +} + +impl TryFrom<&Publish> for RemoteMessage { + type Error = anyhow::Error; + + fn try_from(message: &Publish) -> Result { + serde_json::from_slice(&message.payload) + .or(Err(anyhow::anyhow!("Invalid message payload received: {:?}", message.payload))) + } +} diff --git a/src/ntfy.rs b/src/ntfy.rs index c8258d1..52b39d7 100644 --- a/src/ntfy.rs +++ b/src/ntfy.rs @@ -4,6 +4,7 @@ use tracing::{warn, error}; use reqwest::StatusCode; use serde::Serialize; use serde_repr::*; +use pollster::FutureExt as _; use crate::{presence::OnPresence, config::NtfyConfig}; @@ -122,16 +123,14 @@ impl OnPresence for Ntfy { .body(serde_json::to_string(¬ification).unwrap()); // Send the notification - tokio::spawn(async move { - let res = req.send().await; - if let Err(err) = res { - error!("Something went wrong while sending the notifcation: {err}"); - } else if let Ok(res) = res { - let status = res.status(); - if status != StatusCode::OK { - warn!("Received status {status} when sending notification"); - } + let res = req.send().block_on(); + if let Err(err) = res { + error!("Something went wrong while sending the notifcation: {err}"); + } else if let Ok(res) = res { + let status = res.status(); + if status != StatusCode::OK { + warn!("Received status {status} when sending notification"); } - }); + } } } diff --git a/src/presence.rs b/src/presence.rs index 63c496b..774e669 100644 --- a/src/presence.rs +++ b/src/presence.rs @@ -1,8 +1,9 @@ use std::{sync::{Weak, RwLock}, collections::HashMap}; -use tracing::{debug, warn, trace, span, Level}; +use tracing::{debug, warn, span, Level}; use rumqttc::{AsyncClient, Publish}; use serde::{Serialize, Deserialize}; +use pollster::FutureExt as _; use crate::{mqtt::OnMqtt, config::MqttDeviceConfig}; @@ -21,9 +22,7 @@ impl Presence { pub fn new(mqtt: MqttDeviceConfig, client: AsyncClient) -> Self { // @TODO Handle potential errors here let topic = mqtt.topic.clone() + "/+"; - tokio::spawn(async move { - client.subscribe(topic, rumqttc::QoS::AtLeastOnce).await.unwrap(); - }); + client.subscribe(topic, rumqttc::QoS::AtLeastOnce).block_on().unwrap(); Self { listeners: Vec::new(), devices: HashMap::new(), overall_presence: false, mqtt } } @@ -31,6 +30,15 @@ impl Presence { pub fn add_listener(&mut self, listener: Weak>) { self.listeners.push(listener); } + + pub fn notify(presence: bool, listeners: Vec>>) { + let _span = span!(Level::TRACE, "presence_update").entered(); + listeners.into_iter().for_each(|listener| { + if let Some(listener) = listener.upgrade() { + listener.write().unwrap().on_presence(presence); + } + }) + } } #[derive(Debug, Serialize, Deserialize)] @@ -75,19 +83,15 @@ impl OnMqtt for Presence { debug!("Overall presence updated: {overall_presence}"); self.overall_presence = overall_presence; - // This has problems in async - let _span = span!(Level::TRACE, "presence_update").entered(); + // Remove non-existing listeners + self.listeners.retain(|listener| listener.strong_count() > 0); + // Clone the listeners + let listeners = self.listeners.clone(); - self.listeners.retain(|listener| { - if let Some(listener) = listener.upgrade() { - listener.write().unwrap().on_presence(overall_presence); - return true; - } else { - trace!("Removing listener..."); - } - - return false; - }) + // Notify might block, so we spawn a blocking task + tokio::task::spawn_blocking(move || { + Presence::notify(overall_presence, listeners); + }); } } }