diff --git a/src/config.rs b/src/config.rs index 609ccc2..a2ea570 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,13 +2,14 @@ use std::{ collections::HashMap, fs, net::{Ipv4Addr, SocketAddr}, + time::Duration, }; use async_recursion::async_recursion; use eui48::MacAddress; use regex::{Captures, Regex}; -use rumqttc::{has_wildcards, AsyncClient}; -use serde::Deserialize; +use rumqttc::{has_wildcards, AsyncClient, MqttOptions, Transport}; +use serde::{Deserialize, Deserializer}; use tracing::{debug, trace}; use crate::{ @@ -19,7 +20,8 @@ use crate::{ #[derive(Debug, Deserialize)] pub struct Config { pub openid: OpenIDConfig, - pub mqtt: MqttConfig, + #[serde(deserialize_with = "deserialize_mqtt_options")] + pub mqtt: MqttOptions, #[serde(default)] pub fullfillment: FullfillmentConfig, pub ntfy: Option, @@ -47,6 +49,27 @@ pub struct MqttConfig { pub tls: bool, } +fn deserialize_mqtt_options<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + Ok(MqttOptions::from(MqttConfig::deserialize(deserializer)?)) +} + +impl From for MqttOptions { + fn from(value: MqttConfig) -> Self { + let mut mqtt_options = MqttOptions::new(value.client_name, value.host, value.port); + mqtt_options.set_credentials(value.username, value.password); + mqtt_options.set_keep_alive(Duration::from_secs(5)); + + if value.tls { + mqtt_options.set_transport(Transport::tls_with_default_config()); + } + + mqtt_options + } +} + #[derive(Debug, Deserialize)] pub struct FullfillmentConfig { #[serde(default = "default_fullfillment_ip")] diff --git a/src/main.rs b/src/main.rs index 1c04b6c..b23ca9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ #![feature(async_closure)] -use std::{process, time::Duration}; +use std::process; use axum::{ extract::FromRef, http::StatusCode, response::IntoResponse, routing::post, Json, Router, @@ -16,7 +16,7 @@ use automation::{ }; use dotenvy::dotenv; use futures::future::join_all; -use rumqttc::{AsyncClient, MqttOptions, Transport}; +use rumqttc::AsyncClient; use tracing::{debug, error, info}; use google_home::{GoogleHome, Request}; @@ -52,21 +52,12 @@ async fn app() -> anyhow::Result<()> { info!("Starting automation_rs..."); - let config = std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.toml".to_owned()); - let config = Config::parse_file(&config)?; - - // Configure MQTT - let mqtt = config.mqtt.clone(); - let mut mqttoptions = MqttOptions::new(mqtt.client_name, mqtt.host, mqtt.port); - mqttoptions.set_credentials(mqtt.username, mqtt.password); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - - if mqtt.tls { - mqttoptions.set_transport(Transport::tls_with_default_config()); - } + let config_filename = + std::env::var("AUTOMATION_CONFIG").unwrap_or("./config/config.toml".to_owned()); + let config = Config::parse_file(&config_filename)?; // Create a mqtt client and wrap the eventloop - let (client, eventloop) = AsyncClient::new(mqttoptions, 10); + let (client, eventloop) = AsyncClient::new(config.mqtt.clone(), 10); let mqtt = Mqtt::new(eventloop); let presence = presence::start(config.presence.clone(), mqtt.subscribe(), client.clone()).await?;