浏览代码

Response to info pings

Sergey Chushin 3 年之前
父节点
当前提交
c97b08bfaf
共有 3 个文件被更改,包括 86 次插入5 次删除
  1. 9 2
      src/server/server.rs
  2. 66 3
      src/server/udp_audio_channel.rs
  3. 11 0
      src/storage.rs

+ 9 - 2
src/server/server.rs

@@ -1,8 +1,9 @@
+use crate::protocol::parser::MUMBLE_PROTOCOL_VERSION;
 use crate::server::client::{Client, Config as ClientConfig};
 use crate::server::client::{Client, Config as ClientConfig};
 use crate::server::connection_worker::ConnectionWorker;
 use crate::server::connection_worker::ConnectionWorker;
 use crate::server::session_pool::SessionPool;
 use crate::server::session_pool::SessionPool;
 use crate::server::tcp_control_channel::TcpControlChannel;
 use crate::server::tcp_control_channel::TcpControlChannel;
-use crate::server::udp_audio_channel::{UdpAudioChannel, UdpWorker};
+use crate::server::udp_audio_channel::{ServerInfo, UdpAudioChannel, UdpWorker};
 use crate::storage::Storage;
 use crate::storage::Storage;
 use dashmap::DashMap;
 use dashmap::DashMap;
 use log::{error, info};
 use log::{error, info};
@@ -72,7 +73,13 @@ impl Server {
                 panic!();
                 panic!();
             }
             }
         };
         };
-        let udp_worker = Arc::new(UdpWorker::start(udp_socket).await);
+        let server_info = ServerInfo {
+            version: MUMBLE_PROTOCOL_VERSION.into(),
+            connected_users: self.storage.watch_connected_count(),
+            max_users: 10,
+            max_bandwidth: 128000,
+        };
+        let udp_worker = Arc::new(UdpWorker::start(udp_socket, server_info).await);
         info!("Server listening on {}", socket_address);
         info!("Server listening on {}", socket_address);
 
 
         loop {
         loop {

+ 66 - 3
src/server/udp_audio_channel.rs

@@ -3,6 +3,7 @@ use crate::crypto::Ocb2Aes128Crypto;
 use crate::protocol::connection::{AudioChannel, AudioChannelStats, Error};
 use crate::protocol::connection::{AudioChannel, AudioChannelStats, Error};
 use crate::protocol::parser::AudioPacket;
 use crate::protocol::parser::AudioPacket;
 use async_trait::async_trait;
 use async_trait::async_trait;
+use std::convert::TryInto;
 use std::net::{IpAddr, SocketAddr};
 use std::net::{IpAddr, SocketAddr};
 use std::sync::atomic::{AtomicU32, Ordering};
 use std::sync::atomic::{AtomicU32, Ordering};
 use std::sync::Arc;
 use std::sync::Arc;
@@ -15,6 +16,8 @@ use tokio::task::JoinHandle;
 const MAX_AUDIO_PACKET_SIZE: usize = 1020;
 const MAX_AUDIO_PACKET_SIZE: usize = 1020;
 const ENCRYPTION_OVERHEAD: usize = 4;
 const ENCRYPTION_OVERHEAD: usize = 4;
 const MAX_DATAGRAM_SIZE: usize = MAX_AUDIO_PACKET_SIZE + ENCRYPTION_OVERHEAD;
 const MAX_DATAGRAM_SIZE: usize = MAX_AUDIO_PACKET_SIZE + ENCRYPTION_OVERHEAD;
+const INFO_PING_SIZE: usize = 12;
+const RESPONSE_SIZE: usize = 4 + 8 + 4 + 4 + 4;
 
 
 type Data = Arc<(Vec<u8>, SocketAddr)>;
 type Data = Arc<(Vec<u8>, SocketAddr)>;
 
 
@@ -24,6 +27,13 @@ pub struct UdpWorker {
     task: JoinHandle<()>,
     task: JoinHandle<()>,
 }
 }
 
 
+pub struct ServerInfo {
+    pub version: u32,
+    pub connected_users: Arc<AtomicU32>,
+    pub max_users: u32,
+    pub max_bandwidth: u32,
+}
+
 pub struct UdpAudioChannel {
 pub struct UdpAudioChannel {
     good: AtomicU32,
     good: AtomicU32,
     late: AtomicU32,
     late: AtomicU32,
@@ -36,7 +46,7 @@ pub struct UdpAudioChannel {
 }
 }
 
 
 impl UdpWorker {
 impl UdpWorker {
-    pub async fn start(socket: UdpSocket) -> Self {
+    pub async fn start(socket: UdpSocket, info: ServerInfo) -> Self {
         let (sender, _) = broadcast::channel(8);
         let (sender, _) = broadcast::channel(8);
         let socket = Arc::new(socket);
         let socket = Arc::new(socket);
         let udp_socket = Arc::clone(&socket);
         let udp_socket = Arc::clone(&socket);
@@ -44,8 +54,18 @@ impl UdpWorker {
         let task = tokio::spawn(async move {
         let task = tokio::spawn(async move {
             let mut buf = [0; MAX_DATAGRAM_SIZE];
             let mut buf = [0; MAX_DATAGRAM_SIZE];
             loop {
             loop {
-                if let Ok((len, socket_address)) = udp_socket.recv_from(&mut buf).await {
-                    broadcast_sender.send(Arc::new((Vec::from(&buf[..len]), socket_address)));
+                if let Ok((len, address)) = udp_socket.recv_from(&mut buf).await {
+                    if len == INFO_PING_SIZE {
+                        Self::response_to_ping(
+                            &buf[..12].try_into().unwrap(),
+                            &udp_socket,
+                            address,
+                            &info,
+                        )
+                        .await;
+                    } else {
+                        broadcast_sender.send(Arc::new((Vec::from(&buf[..len]), address)));
+                    }
                 }
                 }
             }
             }
         });
         });
@@ -84,6 +104,27 @@ impl UdpWorker {
             }
             }
         }
         }
     }
     }
+
+    async fn response_to_ping(
+        ping: &[u8; INFO_PING_SIZE],
+        socket: &UdpSocket,
+        origin: SocketAddr,
+        info: &ServerInfo,
+    ) {
+        let bytes = Self::create_response(ping, info);
+        socket.send_to(&bytes, origin).await;
+    }
+
+    fn create_response(ping: &[u8; INFO_PING_SIZE], info: &ServerInfo) -> [u8; RESPONSE_SIZE] {
+        let mut response = [0u8; RESPONSE_SIZE];
+        response[..4].copy_from_slice(&info.version.to_be_bytes());
+        response[4..12].copy_from_slice(&ping[4..]);
+        response[12..16]
+            .copy_from_slice(&info.connected_users.load(Ordering::Acquire).to_be_bytes());
+        response[16..20].copy_from_slice(&info.max_users.to_be_bytes());
+        response[20..24].copy_from_slice(&info.max_bandwidth.to_be_bytes());
+        response
+    }
 }
 }
 
 
 #[async_trait]
 #[async_trait]
@@ -149,3 +190,25 @@ impl From<crypto::Error> for Error {
         ))
         ))
     }
     }
 }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_create_response() {
+        let ping = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0];
+        let info = ServerInfo {
+            version: 0x0123,
+            connected_users: Arc::new(AtomicU32::from(42)),
+            max_users: 100,
+            max_bandwidth: 100000,
+        };
+
+        let expected: [u8; RESPONSE_SIZE] = [
+            0, 0, 1, 35, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 42, 0, 0, 0, 100, 0, 1, 134, 160,
+        ];
+        let response = UdpWorker::create_response(&ping, &info);
+        assert_eq!(response, expected);
+    }
+}

+ 11 - 0
src/storage.rs

@@ -1,6 +1,8 @@
 use dashmap::DashMap;
 use dashmap::DashMap;
 use serde::{Deserialize, Serialize};
 use serde::{Deserialize, Serialize};
 use std::num::NonZeroU32;
 use std::num::NonZeroU32;
+use std::sync::atomic::{AtomicU32, Ordering};
+use std::sync::Arc;
 
 
 const ROOT_CHANNEL_ID: u32 = 0;
 const ROOT_CHANNEL_ID: u32 = 0;
 const USER_TREE_NAME: &[u8] = b"users";
 const USER_TREE_NAME: &[u8] = b"users";
@@ -18,6 +20,7 @@ pub struct Storage {
     session_data: DashMap<SessionId, SessionData>,
     session_data: DashMap<SessionId, SessionData>,
     guests: DashMap<SessionId, Guest>,
     guests: DashMap<SessionId, Guest>,
     connected_users: DashMap<SessionId, (UserId, Username)>,
     connected_users: DashMap<SessionId, (UserId, Username)>,
+    connected: Arc<AtomicU32>,
 }
 }
 
 
 #[derive(Serialize, Deserialize)]
 #[derive(Serialize, Deserialize)]
@@ -102,6 +105,7 @@ impl Storage {
             session_data: DashMap::new(),
             session_data: DashMap::new(),
             guests: DashMap::new(),
             guests: DashMap::new(),
             connected_users: DashMap::new(),
             connected_users: DashMap::new(),
+            connected: Default::default(),
         }
         }
     }
     }
 
 
@@ -109,12 +113,14 @@ impl Storage {
         let session_id = guest.session_id;
         let session_id = guest.session_id;
         self.guests.insert(session_id, guest);
         self.guests.insert(session_id, guest);
         self.session_data.insert(session_id, SessionData::default());
         self.session_data.insert(session_id, SessionData::default());
+        self.connected.fetch_add(1, Ordering::SeqCst);
     }
     }
 
 
     pub fn add_connected_user(&self, user: User, session_id: SessionId) {
     pub fn add_connected_user(&self, user: User, session_id: SessionId) {
         self.connected_users
         self.connected_users
             .insert(session_id, (user.id, user.username));
             .insert(session_id, (user.id, user.username));
         self.session_data.insert(session_id, SessionData::default());
         self.session_data.insert(session_id, SessionData::default());
+        self.connected.fetch_add(1, Ordering::SeqCst);
     }
     }
 
 
     pub fn get_channels(&self) -> Vec<Channel> {
     pub fn get_channels(&self) -> Vec<Channel> {
@@ -208,10 +214,15 @@ impl Storage {
         }
         }
     }
     }
 
 
+    pub fn watch_connected_count(&self) -> Arc<AtomicU32> {
+        Arc::clone(&self.connected)
+    }
+
     pub fn remove_by_session_id(&self, id: SessionId) {
     pub fn remove_by_session_id(&self, id: SessionId) {
         self.connected_users.remove(&id);
         self.connected_users.remove(&id);
         self.guests.remove(&id);
         self.guests.remove(&id);
         self.session_data.remove(&id);
         self.session_data.remove(&id);
+        self.connected.fetch_sub(1, Ordering::SeqCst);
     }
     }
 }
 }