udp_audio_channel.rs 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. use crate::crypto;
  2. use crate::crypto::Ocb2Aes128Crypto;
  3. use crate::protocol::connection::{AudioChannel, AudioChannelStats, Error};
  4. use crate::protocol::parser::AudioPacket;
  5. use async_trait::async_trait;
  6. use std::convert::TryInto;
  7. use std::net::{IpAddr, SocketAddr};
  8. use std::sync::atomic::{AtomicU32, Ordering};
  9. use std::sync::Arc;
  10. use tokio::net::UdpSocket;
  11. use tokio::sync::broadcast::error::RecvError;
  12. use tokio::sync::broadcast::{Receiver, Sender};
  13. use tokio::sync::{broadcast, Mutex};
  14. use tokio::task::JoinHandle;
  15. const MAX_AUDIO_PACKET_SIZE: usize = 1020;
  16. const ENCRYPTION_OVERHEAD: usize = 4;
  17. const MAX_DATAGRAM_SIZE: usize = MAX_AUDIO_PACKET_SIZE + ENCRYPTION_OVERHEAD;
  18. const INFO_PING_SIZE: usize = 12;
  19. const RESPONSE_SIZE: usize = 4 + 8 + 4 + 4 + 4;
  20. type Data = Arc<(Vec<u8>, SocketAddr)>;
  21. pub struct UdpWorker {
  22. sender: Sender<Data>,
  23. socket: Arc<UdpSocket>,
  24. task: JoinHandle<()>,
  25. }
  26. pub struct ServerInfo {
  27. pub version: u32,
  28. pub connected_users: Arc<AtomicU32>,
  29. pub max_users: u32,
  30. pub max_bandwidth: u32,
  31. }
  32. pub struct UdpAudioChannel {
  33. good: AtomicU32,
  34. late: AtomicU32,
  35. lost: AtomicU32,
  36. received: AtomicU32,
  37. receiver: Mutex<Receiver<Data>>,
  38. crypto: Mutex<Ocb2Aes128Crypto>,
  39. socket: Arc<UdpSocket>,
  40. destination: SocketAddr,
  41. }
  42. impl UdpWorker {
  43. pub async fn start(socket: UdpSocket, info: ServerInfo) -> Self {
  44. let (sender, _) = broadcast::channel(8);
  45. let socket = Arc::new(socket);
  46. let udp_socket = Arc::clone(&socket);
  47. let broadcast_sender = sender.clone();
  48. let task = tokio::spawn(async move {
  49. let mut buf = [0; MAX_DATAGRAM_SIZE];
  50. loop {
  51. if let Ok((len, address)) = udp_socket.recv_from(&mut buf).await {
  52. if len == INFO_PING_SIZE {
  53. Self::response_to_ping(
  54. &buf[..12].try_into().unwrap(),
  55. &udp_socket,
  56. address,
  57. &info,
  58. )
  59. .await;
  60. } else {
  61. broadcast_sender.send(Arc::new((Vec::from(&buf[..len]), address)));
  62. }
  63. }
  64. }
  65. });
  66. UdpWorker {
  67. sender,
  68. socket,
  69. task,
  70. }
  71. }
  72. pub async fn new_audio_channel(
  73. &self,
  74. ip: IpAddr,
  75. mut crypto: Ocb2Aes128Crypto,
  76. ) -> UdpAudioChannel {
  77. loop {
  78. let mut receiver = self.sender.subscribe();
  79. let data = match receiver.recv().await {
  80. Ok(data) => data,
  81. Err(RecvError::Lagged(_)) => receiver.recv().await.unwrap(),
  82. Err(_) => panic!(),
  83. };
  84. let (bytes, address) = data.as_ref();
  85. if address.ip() == ip && crypto.decrypt(bytes).is_ok() {
  86. return UdpAudioChannel {
  87. good: AtomicU32::new(1),
  88. late: AtomicU32::new(0),
  89. lost: AtomicU32::new(0),
  90. received: AtomicU32::new(1),
  91. receiver: Mutex::new(receiver),
  92. crypto: Mutex::new(crypto),
  93. socket: Arc::clone(&self.socket),
  94. destination: address.clone(),
  95. };
  96. }
  97. }
  98. }
  99. async fn response_to_ping(
  100. ping: &[u8; INFO_PING_SIZE],
  101. socket: &UdpSocket,
  102. origin: SocketAddr,
  103. info: &ServerInfo,
  104. ) {
  105. let bytes = Self::create_response(ping, info);
  106. socket.send_to(&bytes, origin).await;
  107. }
  108. fn create_response(ping: &[u8; INFO_PING_SIZE], info: &ServerInfo) -> [u8; RESPONSE_SIZE] {
  109. let mut response = [0u8; RESPONSE_SIZE];
  110. response[..4].copy_from_slice(&info.version.to_be_bytes());
  111. response[4..12].copy_from_slice(&ping[4..]);
  112. response[12..16]
  113. .copy_from_slice(&info.connected_users.load(Ordering::Acquire).to_be_bytes());
  114. response[16..20].copy_from_slice(&info.max_users.to_be_bytes());
  115. response[20..24].copy_from_slice(&info.max_bandwidth.to_be_bytes());
  116. response
  117. }
  118. }
  119. #[async_trait]
  120. impl AudioChannel for UdpAudioChannel {
  121. async fn send(&self, packet: AudioPacket) -> Result<(), Error> {
  122. let bytes = packet.serialize();
  123. let encrypted = {
  124. let mut crypto = self.crypto.lock().await;
  125. crypto.encrypt(&bytes)?
  126. };
  127. self.socket.send_to(&encrypted, self.destination).await?;
  128. Ok(())
  129. }
  130. async fn receive(&self) -> Result<AudioPacket, Error> {
  131. let mut receiver = self.receiver.lock().await;
  132. let data = loop {
  133. let data = match receiver.recv().await {
  134. Ok(data) => data,
  135. Err(RecvError::Lagged(_)) => receiver.recv().await.unwrap(),
  136. Err(_) => panic!(),
  137. };
  138. if data.1 == self.destination {
  139. break data;
  140. }
  141. };
  142. drop(receiver);
  143. let mut crypto = self.crypto.lock().await;
  144. let decrypted = crypto.decrypt(&data.0)?;
  145. self.good.swap(crypto.good, Ordering::Release);
  146. self.late.swap(crypto.late, Ordering::Release);
  147. self.lost.swap(crypto.lost, Ordering::Release);
  148. drop(crypto);
  149. let packet = AudioPacket::parse(decrypted)?;
  150. self.received.fetch_add(1, Ordering::Relaxed);
  151. Ok(packet)
  152. }
  153. fn get_stats(&self) -> AudioChannelStats {
  154. AudioChannelStats {
  155. good: self.good.load(Ordering::Acquire),
  156. late: self.late.load(Ordering::Acquire),
  157. lost: self.lost.load(Ordering::Acquire),
  158. received: self.received.load(Ordering::Acquire),
  159. }
  160. }
  161. }
  162. impl Drop for UdpWorker {
  163. fn drop(&mut self) {
  164. self.task.abort();
  165. }
  166. }
  167. impl From<crypto::Error> for Error {
  168. fn from(_: crypto::Error) -> Self {
  169. Error::IO(std::io::Error::new(
  170. std::io::ErrorKind::InvalidData,
  171. "crypto fail",
  172. ))
  173. }
  174. }
  175. #[cfg(test)]
  176. mod tests {
  177. use super::*;
  178. #[test]
  179. fn test_create_response() {
  180. let ping = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0];
  181. let info = ServerInfo {
  182. version: 0x0123,
  183. connected_users: Arc::new(AtomicU32::from(42)),
  184. max_users: 100,
  185. max_bandwidth: 100000,
  186. };
  187. let expected: [u8; RESPONSE_SIZE] = [
  188. 0, 0, 1, 35, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 42, 0, 0, 0, 100, 0, 1, 134, 160,
  189. ];
  190. let response = UdpWorker::create_response(&ping, &info);
  191. assert_eq!(response, expected);
  192. }
  193. }