siranga/src/io/stats.rs
Dreaded_X d4bd0ef1ca
All checks were successful
Build and deploy / Build container and manifests (push) Successful in 2m48s
Use store instead of fetch_add for atomics
2025-04-18 14:58:15 +02:00

134 lines
3.3 KiB
Rust

use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use russh::ChannelStream;
use russh::server::Msg;
use crate::helper::Unit;
#[derive(Debug, Default)]
pub struct Stats {
connections: AtomicUsize,
rx: AtomicUsize,
tx: AtomicUsize,
failed: AtomicBool,
}
impl Stats {
pub fn add_connection(&self) {
self.connections.store(1, Ordering::Relaxed);
}
pub fn add_rx_bytes(&self, n: usize) {
self.rx.store(n, Ordering::Relaxed);
}
pub fn add_tx_bytes(&self, n: usize) {
self.tx.store(n, Ordering::Relaxed);
}
pub fn connections(&self) -> usize {
self.connections.load(Ordering::Relaxed)
}
pub fn failed(&self) -> bool {
self.failed.load(Ordering::Relaxed)
}
pub fn set_failed(&self, failed: bool) {
self.failed.store(failed, Ordering::Relaxed);
}
pub fn rx(&self) -> Unit {
Unit::new(self.rx.load(Ordering::Relaxed), "B")
}
pub fn tx(&self) -> Unit {
Unit::new(self.tx.load(Ordering::Relaxed), "B")
}
}
pin_project! {
pub struct TrackStats {
#[pin]
inner: ChannelStream<Msg>,
stats: Arc<Stats>,
}
}
impl TrackStats {
pub fn new(inner: ChannelStream<Msg>, stats: Arc<Stats>) -> Self {
Self { inner, stats }
}
}
impl hyper::rt::Read for TrackStats {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let project = self.project();
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(project.inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};
project.stats.add_tx_bytes(n);
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl hyper::rt::Write for TrackStats {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let project = self.project();
tokio::io::AsyncWrite::poll_write(project.inner, cx, buf).map(|res| {
res.inspect(|n| {
project.stats.add_rx_bytes(*n);
})
})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let project = self.project();
tokio::io::AsyncWrite::poll_write_vectored(project.inner, cx, bufs).map(|res| {
res.inspect(|n| {
project.stats.add_rx_bytes(*n);
})
})
}
}