Răsfoiți Sursa

Use UDP for audio channels

Sergey Chushin 3 ani în urmă
părinte
comite
7b9b8aca3d
8 a modificat fișierele cu 640 adăugiri și 275 ștergeri
  1. 1 0
      Cargo.lock
  2. 1 0
      Cargo.toml
  3. 261 44
      src/client.rs
  4. 107 142
      src/connection.rs
  5. 17 17
      src/crypto.rs
  6. 3 1
      src/main.rs
  7. 13 11
      src/protocol.rs
  8. 237 60
      src/server.rs

+ 1 - 0
Cargo.lock

@@ -500,6 +500,7 @@ dependencies = [
  "clap",
  "protobuf",
  "protoc-rust",
+ "rand",
  "serde",
  "sled",
  "tokio",

+ 1 - 0
Cargo.toml

@@ -16,6 +16,7 @@ sled = "0.34.6"
 serde = { version = "1.0.126", features = ["derive"] }
 bincode = "1.3.3"
 aes = "0.7.3"
+rand = "0.8.3"
 
 [build-dependencies]
 protoc-rust = "2.23.0"

+ 261 - 44
src/client.rs

@@ -1,68 +1,231 @@
 use std::sync::Arc;
 
 use tokio::sync::mpsc;
-use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
+use tokio::sync::mpsc::{Receiver, Sender};
 use tokio::task::JoinHandle;
 
-use crate::connection::{Connection, ControlChannelWriter};
-use crate::db::{Db, User};
-use crate::proto::mumble::{Ping, UserRemove, UserState};
-use crate::protocol::{AudioData, MumblePacket, VoicePacket};
 use crate::client::Error::StreamError;
+use crate::connection::{AudioChannel, AudioChannelSender, ControlChannel, ControlChannelSender};
+use crate::db::{Db, User};
+use crate::proto::mumble::{
+    ChannelState, CodecVersion, CryptSetup, Ping, ServerConfig, ServerSync, UserRemove, UserState,
+    Version,
+};
+use crate::protocol::{AudioData, AudioPacket, MumblePacket, MUMBLE_PROTOCOL_VERSION};
 
 pub struct Client {
     pub session_id: u32,
-    inner_sender: UnboundedSender<InnerMessage>,
+    inner_sender: Sender<InnerMessage>,
     handler_task: JoinHandle<()>,
     packet_task: JoinHandle<()>,
+    audio_task: Option<JoinHandle<()>>,
+}
+
+pub struct Config {
+    pub crypto_key: [u8; 16],
+    pub server_nonce: [u8; 16],
+    pub client_nonce: [u8; 16],
+    pub alpha_codec_version: i32,
+    pub beta_codec_version: i32,
+    pub prefer_alpha: bool,
+    pub opus_support: bool,
+    pub welcome_text: String,
+    pub max_bandwidth: u32,
+    pub max_users: u32,
+    pub allow_html: bool,
+    pub max_message_length: u32,
+    pub max_image_message_length: u32,
 }
 
+// from other connected users
 pub enum Message {
     UserConnected(u32),
     UserDisconnected(u32),
     UserTalking(AudioData),
 }
 
+// to other connected users
 pub enum ResponseMessage {
     Disconnected,
     Talking(AudioData),
 }
 
 pub enum Error {
+    AuthenticationError,
     StreamError,
+    WrongPacket,
 }
 
 struct Handler {
-    db: Arc<Db>,
-    writer: ControlChannelWriter,
     session_id: u32,
-    response_sender: UnboundedSender<ResponseMessage>,
+    db: Arc<Db>,
+    control_channel_sender: ControlChannelSender,
+    audio_channel_sender: Option<AudioChannelSender>,
+    response_sender: Sender<ResponseMessage>,
+    is_audio_tunneling: bool,
 }
 
 enum InnerMessage {
     Message(Message),
-    Packet(MumblePacket),
-    Disconnected,
+    Packet(Box<MumblePacket>),
+    Audio(AudioPacket),
+    AudioChannel(AudioChannelSender),
+    SelfDisconnected,
 }
 
-type ResponseReceiver = UnboundedReceiver<ResponseMessage>;
+type Responder = Receiver<ResponseMessage>;
 
 impl Client {
-    pub async fn new(connection: Connection, db: Arc<Db>) -> (Client, ResponseReceiver) {
-        let (sender, mut receiver) = mpsc::unbounded_channel();
-        let (response_sender, response_receiver) = mpsc::unbounded_channel();
+    pub async fn establish_connection(
+        db: Arc<Db>,
+        mut control_channel: ControlChannel,
+        config: Config,
+    ) -> Result<(Self, Responder), Error> {
+        match control_channel.receive().await? {
+            MumblePacket::Version(version) => version,
+            _ => return Err(Error::WrongPacket),
+        };
+        let mut auth = match control_channel.receive().await? {
+            MumblePacket::Authenticate(auth) => auth,
+            _ => return Err(Error::WrongPacket),
+        };
+        if !auth.has_username() {
+            return Err(Error::AuthenticationError);
+        }
+        let session_id = db.add_new_user(auth.take_username()).await;
+
+        let version = {
+            let mut version = Version::new();
+            version.set_version(MUMBLE_PROTOCOL_VERSION);
+            MumblePacket::Version(version)
+        };
+        let crypt_setup = {
+            let key = config.crypto_key;
+            let server_nonce = config.server_nonce;
+            let client_nonce = config.client_nonce;
+            let mut crypt_setup = CryptSetup::new();
+            crypt_setup.set_key(Vec::from(key));
+            crypt_setup.set_server_nonce(Vec::from(server_nonce));
+            crypt_setup.set_client_nonce(Vec::from(client_nonce));
+            MumblePacket::CryptSetup(crypt_setup)
+        };
+        let codec_version = {
+            let mut codec_version = CodecVersion::new();
+            codec_version.set_alpha(config.alpha_codec_version);
+            codec_version.set_beta(config.beta_codec_version);
+            codec_version.set_prefer_alpha(config.prefer_alpha);
+            codec_version.set_opus(config.opus_support);
+            MumblePacket::CodecVersion(codec_version)
+        };
+        let channel_states: Vec<MumblePacket> = {
+            db.get_channels()
+                .await
+                .into_iter()
+                .map(|channel| {
+                    let mut channel_state = ChannelState::new();
+                    channel_state.set_channel_id(channel.id);
+                    channel_state.set_name(channel.name);
+                    MumblePacket::ChannelState(channel_state)
+                })
+                .collect()
+        };
+        let user_states: Vec<MumblePacket> = {
+            db.get_connected_users()
+                .await
+                .into_iter()
+                .map(|user| {
+                    let mut user_state = UserState::new();
+                    user_state.set_name(user.username);
+                    user_state.set_session(user.session_id);
+                    user_state.set_channel_id(user.channel_id);
+                    MumblePacket::UserState(user_state)
+                })
+                .collect()
+        };
+        let server_sync = {
+            let mut server_sync = ServerSync::new();
+            server_sync.set_session(session_id);
+            server_sync.set_welcome_text(config.welcome_text);
+            server_sync.set_max_bandwidth(config.max_bandwidth);
+            MumblePacket::ServerSync(server_sync)
+        };
+        let server_config = {
+            let mut server_config = ServerConfig::new();
+            server_config.set_max_users(config.max_users);
+            server_config.set_allow_html(config.allow_html);
+            server_config.set_message_length(config.max_message_length);
+            server_config.set_image_message_length(config.max_image_message_length);
+            MumblePacket::ServerConfig(server_config)
+        };
+
+        control_channel.send(version).await?;
+        control_channel.send(crypt_setup).await?;
+        control_channel.send(codec_version).await?;
+        for channel_state in channel_states {
+            control_channel.send(channel_state).await?;
+        }
+        for user_state in user_states {
+            control_channel.send(user_state).await?;
+        }
+        control_channel.send(server_sync).await?;
+        control_channel.send(server_config).await?;
+
+        let (client, response_receiver) = Client::new(control_channel, db, session_id).await;
+        Ok((client, response_receiver))
+    }
 
-        let (mut reader, writer) = connection.control_channel.split();
-        let session_id = connection.session_id;
+    pub async fn set_audio_channel(&mut self, audio_channel: AudioChannel) {
+        let (mut receiver, sender) = audio_channel.split();
+        let inner_sender = self.inner_sender.clone();
+        self.audio_task = Some(tokio::spawn(async move {
+            loop {
+                match receiver.receive().await {
+                    Ok(packet) => {
+                        if inner_sender.try_send(InnerMessage::Audio(packet)).is_err() {
+                            return;
+                        }
+                    }
+                    Err(_) => return,
+                }
+            }
+        }));
+
+        self.inner_sender
+            .send(InnerMessage::AudioChannel(sender))
+            .await;
+    }
+
+    pub async fn send_message(&self, message: Message) {
+        match message {
+            Message::UserTalking(_) => {
+                self.inner_sender.try_send(InnerMessage::Message(message));
+            }
+            _ => {
+                self.inner_sender.send(InnerMessage::Message(message)).await;
+            }
+        }
+    }
+
+    async fn new(
+        control_channel: ControlChannel,
+        db: Arc<Db>,
+        session_id: u32,
+    ) -> (Client, Responder) {
+        let (inner_sender, mut inner_receiver) = mpsc::channel(2);
+        let (response_sender, response_receiver) = mpsc::channel(2);
+
+        let (mut control_channel_receiver, control_channel_sender) = control_channel.split();
         let handler_task = tokio::spawn(async move {
             let mut handler = Handler {
-                db,
-                writer,
                 session_id,
+                db,
+                control_channel_sender,
+                audio_channel_sender: None,
                 response_sender,
+                is_audio_tunneling: false,
             };
             loop {
-                let message = match receiver.recv().await {
+                let message = match inner_receiver.recv().await {
                     Some(msg) => msg,
                     None => return,
                 };
@@ -75,26 +238,35 @@ impl Client {
                         }
                     }
                     InnerMessage::Packet(packet) => {
-                        let result = handler.handle_packet(packet).await;
+                        let result = handler.handle_mumble_packet(*packet).await;
                         if result.is_err() {
                             return;
                         }
                     }
-                    InnerMessage::Disconnected => {
+                    InnerMessage::SelfDisconnected => {
                         handler.self_disconnected().await;
                         return;
                     }
+                    InnerMessage::Audio(audio) => {
+                        let result = handler.handle_audio_packet(audio).await;
+                        if result.is_err() {
+                            return;
+                        }
+                    }
+                    InnerMessage::AudioChannel(sender) => {
+                        handler.audio_channel_sender = Some(sender)
+                    }
                 }
             }
         });
 
-        let inner_sender = sender.clone();
+        let sender = inner_sender.clone();
         let packet_task = tokio::spawn(async move {
             loop {
-                match reader.read().await {
-                    Ok(packet) => sender.send(InnerMessage::Packet(packet)),
+                match control_channel_receiver.receive().await {
+                    Ok(packet) => sender.send(InnerMessage::Packet(Box::from(packet))).await,
                     Err(_) => {
-                        sender.send(InnerMessage::Disconnected);
+                        sender.send(InnerMessage::SelfDisconnected).await;
                         return;
                     }
                 };
@@ -107,41 +279,45 @@ impl Client {
                 inner_sender,
                 handler_task,
                 packet_task,
+                audio_task: None,
             },
             response_receiver,
         );
     }
-
-    pub fn post_message(&self, message: Message) {
-        self.inner_sender.send(InnerMessage::Message(message));
-    }
 }
 
 impl Drop for Client {
     fn drop(&mut self) {
         self.handler_task.abort();
         self.packet_task.abort();
+        if let Some(audio_task) = self.audio_task.as_ref() {
+            audio_task.abort();
+        }
     }
 }
 
 impl Handler {
-    async fn handle_packet(&mut self, packet: MumblePacket) -> Result<(), Error> {
+    async fn handle_mumble_packet(&mut self, packet: MumblePacket) -> Result<(), Error> {
         match packet {
             MumblePacket::Ping(ping) => {
                 if ping.has_timestamp() {
                     let mut ping = Ping::new();
                     ping.set_timestamp(ping.get_timestamp());
-                    self.writer.write(MumblePacket::Ping(ping)).await?;
+                    self.control_channel_sender
+                        .send(MumblePacket::Ping(ping))
+                        .await?;
                 }
             }
             MumblePacket::UdpTunnel(voice) => match voice {
-                VoicePacket::Ping(_) => {
-                    self.writer.write(MumblePacket::UdpTunnel(voice)).await;
+                AudioPacket::Ping(_) => {
+                    self.control_channel_sender
+                        .send(MumblePacket::UdpTunnel(voice))
+                        .await?;
                 }
-                VoicePacket::AudioData(mut audio_data) => {
+                AudioPacket::AudioData(mut audio_data) => {
                     audio_data.session_id = Some(self.session_id);
                     self.response_sender
-                        .send(ResponseMessage::Talking(audio_data));
+                        .try_send(ResponseMessage::Talking(audio_data));
                 }
             },
             _ => println!("unimplemented!"),
@@ -159,9 +335,37 @@ impl Handler {
         Ok(())
     }
 
+    async fn handle_audio_packet(&mut self, packet: AudioPacket) -> Result<(), Error> {
+        match packet {
+            AudioPacket::Ping(_) => {
+                if !self.is_audio_tunneling && self.audio_channel_sender.is_some() {
+                    self.audio_channel_sender
+                        .as_mut()
+                        .unwrap()
+                        .send(packet)
+                        .await?;
+                } else {
+                    self.control_channel_sender
+                        .send(MumblePacket::UdpTunnel(packet))
+                        .await?;
+                }
+            }
+            AudioPacket::AudioData(mut audio_data) => {
+                audio_data.session_id = Some(self.session_id);
+                // It isn't critical to lost some audio packets
+                self.response_sender
+                    .try_send(ResponseMessage::Talking(audio_data));
+            }
+        }
+
+        Ok(())
+    }
+
     async fn new_user_connected(&mut self, session_id: u32) -> Result<(), Error> {
         if let Some(user) = self.db.get_user_by_session_id(session_id).await {
-            self.writer.write(MumblePacket::from(user)).await?;
+            self.control_channel_sender
+                .send(MumblePacket::from(user))
+                .await?;
         }
         Ok(())
     }
@@ -170,21 +374,34 @@ impl Handler {
         let mut user_remove = UserRemove::new();
         user_remove.set_session(session_id);
         Ok(self
-            .writer
-            .write(MumblePacket::UserRemove(user_remove))
+            .control_channel_sender
+            .send(MumblePacket::UserRemove(user_remove))
             .await?)
     }
 
     async fn self_disconnected(&mut self) {
         self.db.remove_connected_user(self.session_id).await;
-        self.response_sender.send(ResponseMessage::Disconnected);
+        self.response_sender
+            .send(ResponseMessage::Disconnected)
+            .await;
     }
 
     async fn user_talking(&mut self, audio_data: AudioData) -> Result<(), Error> {
-        Ok(self
-            .writer
-            .write(MumblePacket::UdpTunnel(VoicePacket::AudioData(audio_data)))
-            .await?)
+        let audio_packet = AudioPacket::AudioData(audio_data);
+
+        if !self.is_audio_tunneling && self.audio_channel_sender.is_some() {
+            self.audio_channel_sender
+                .as_mut()
+                .unwrap()
+                .send(audio_packet)
+                .await?;
+        } else {
+            self.control_channel_sender
+                .send(MumblePacket::UdpTunnel(audio_packet))
+                .await?;
+        }
+
+        Ok(())
     }
 }
 

+ 107 - 142
src/connection.rs

@@ -1,170 +1,102 @@
 use std::sync::Arc;
 
-use crate::db::Db;
-use crate::proto::mumble::{
-    ChannelState, CodecVersion, CryptSetup, PermissionQuery, ServerConfig, ServerSync, UserState,
-    Version,
-};
-use crate::protocol::{MumblePacket, MUMBLE_PROTOCOL_VERSION};
-use rand::rngs::StdRng;
-use rand::{Rng, SeedableRng};
+use crate::protocol::{AudioPacket, MumblePacket};
 
+use crate::crypto::Ocb2Aes128Crypto;
+
+use std::net::SocketAddr;
 use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
-use tokio::net::TcpStream;
+use tokio::net::{TcpStream, UdpSocket};
+use tokio::sync::Mutex;
+
+use tokio::sync::mpsc::Receiver;
 use tokio_rustls::TlsStream;
 
-pub struct Connection {
-    pub control_channel: ControlChannel,
-    pub session_id: u32,
+pub struct ControlChannel {
+    receiver: ControlChannelReceiver,
+    sender: ControlChannelSender,
 }
 
-pub struct ControlChannel {
-    reader: ControlChannelReader,
-    writer: ControlChannelWriter,
+pub struct AudioChannel {
+    receiver: AudioChannelReceiver,
+    sender: AudioChannelSender,
 }
 
-pub struct ControlChannelReader {
+pub struct ControlChannelReceiver {
     reader: ReadHalf<TlsStream<TcpStream>>,
 }
 
-pub struct ControlChannelWriter {
+pub struct ControlChannelSender {
     writer: WriteHalf<TlsStream<TcpStream>>,
 }
 
-pub struct ConnectionConfig {
-    pub max_bandwidth: u32,
-    pub welcome_text: String,
+pub struct AudioChannelReceiver {
+    raw_bytes_receiver: Receiver<Vec<u8>>,
+    crypto: Arc<Mutex<Ocb2Aes128Crypto>>,
+}
+
+pub struct AudioChannelSender {
+    socket: Arc<UdpSocket>,
+    crypto: Arc<Mutex<Ocb2Aes128Crypto>>,
+    destination: SocketAddr,
 }
 
 pub enum Error {
-    ConnectionSetupError,
-    AuthenticationError,
-    StreamError,
-}
-
-impl Connection {
-    pub async fn setup_connection(
-        db: Arc<Db>,
-        stream: TlsStream<TcpStream>,
-        config: ConnectionConfig,
-    ) -> Result<Connection, Error> {
-        let mut control_channel = ControlChannel::new(stream);
-
-        //Version exchange
-        let _ = match control_channel.read().await? {
-            MumblePacket::Version(version) => version,
-            _ => return Err(Error::ConnectionSetupError),
-        };
-        let mut version = Version::new();
-        version.set_version(MUMBLE_PROTOCOL_VERSION);
-        control_channel
-            .write(MumblePacket::Version(version))
-            .await?;
-
-        //Authentication
-        let mut auth = match control_channel.read().await? {
-            MumblePacket::Authenticate(auth) => auth,
-            _ => return Err(Error::ConnectionSetupError),
-        };
-        if !auth.has_username() {
-            return Err(Error::AuthenticationError);
-        }
-        let session_id = db.add_new_user(auth.take_username()).await;
-
-        //Crypt setup
-
-        //CodecVersion
-        let mut codec_version = CodecVersion::new();
-        codec_version.set_alpha(-2147483632);
-        codec_version.set_beta(0);
-        codec_version.set_prefer_alpha(true);
-        codec_version.set_opus(true);
-        control_channel
-            .write(MumblePacket::CodecVersion(codec_version))
-            .await?;
-
-        //Channel state
-        let channels = db.get_channels().await;
-        for channel in channels {
-            let mut channel_state = ChannelState::new();
-            channel_state.set_channel_id(channel.id);
-            channel_state.set_name(channel.name);
-            control_channel
-                .write(MumblePacket::ChannelState(channel_state))
-                .await?;
-        }
+    IOError(std::io::Error),
+    ParsingError(crate::protocol::Error),
+    CryptError(crate::crypto::Error),
+}
 
-        //PermissionQuery
-        let mut permission_query = PermissionQuery::new();
-        permission_query.set_permissions(134743822);
-        permission_query.set_channel_id(0);
-        control_channel
-            .write(MumblePacket::PermissionQuery(permission_query))
-            .await?;
-
-        //User states
-        let connected_users = db.get_connected_users().await;
-        for user in connected_users {
-            let mut user_state = UserState::new();
-            user_state.set_name(user.username);
-            user_state.set_session(user.session_id);
-            user_state.set_channel_id(user.channel_id);
-            control_channel
-                .write(MumblePacket::UserState(user_state))
-                .await?;
-        }
+impl ControlChannel {
+    pub fn new(stream: TlsStream<TcpStream>) -> Self {
+        let (reader, writer) = tokio::io::split(stream);
+        let receiver = ControlChannelReceiver { reader };
+        let sender = ControlChannelSender { writer };
 
-        //Server sync
-        let mut server_sync = ServerSync::new();
-        server_sync.set_session(session_id);
-        server_sync.set_welcome_text(config.welcome_text);
-        server_sync.set_max_bandwidth(config.max_bandwidth);
-        server_sync.set_permissions(134743822);
-        control_channel
-            .write(MumblePacket::ServerSync(server_sync))
-            .await?;
-
-        //ServerConfig
-        let mut server_config = ServerConfig::new();
-        server_config.set_max_users(10);
-        server_config.set_allow_html(true);
-        server_config.set_message_length(5000);
-        server_config.set_image_message_length(131072);
-        control_channel
-            .write(MumblePacket::ServerConfig(server_config))
-            .await?;
-
-        Ok(Connection {
-            control_channel,
-            session_id,
-        })
+        ControlChannel { receiver, sender }
     }
-}
 
-impl ControlChannel {
-    pub async fn read(&mut self) -> Result<MumblePacket, Error> {
-        self.reader.read().await
+    pub async fn receive(&mut self) -> Result<MumblePacket, Error> {
+        self.receiver.receive().await
     }
 
-    pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
-        self.writer.write(packet).await
+    pub async fn send(&mut self, packet: MumblePacket) -> Result<(), Error> {
+        self.sender.send(packet).await
     }
 
-    pub fn split(self) -> (ControlChannelReader, ControlChannelWriter) {
-        (self.reader, self.writer)
+    pub fn split(self) -> (ControlChannelReceiver, ControlChannelSender) {
+        (self.receiver, self.sender)
     }
+}
 
-    fn new(stream: TlsStream<TcpStream>) -> Self {
-        let (reader, writer) = tokio::io::split(stream);
-        ControlChannel {
-            reader: ControlChannelReader { reader },
-            writer: ControlChannelWriter { writer },
-        }
+impl AudioChannel {
+    pub fn new(
+        incoming_bytes_receiver: Receiver<Vec<u8>>,
+        socket: Arc<UdpSocket>,
+        crypto: Ocb2Aes128Crypto,
+        destination: SocketAddr,
+    ) -> Self {
+        let crypto = Arc::new(Mutex::new(crypto));
+        let receiver = AudioChannelReceiver {
+            raw_bytes_receiver: incoming_bytes_receiver,
+            crypto: Arc::clone(&crypto),
+        };
+        let sender = AudioChannelSender {
+            socket,
+            crypto: Arc::clone(&crypto),
+            destination,
+        };
+
+        AudioChannel { receiver, sender }
+    }
+
+    pub fn split(self) -> (AudioChannelReceiver, AudioChannelSender) {
+        (self.receiver, self.sender)
     }
 }
 
-impl ControlChannelReader {
-    pub async fn read(&mut self) -> Result<MumblePacket, Error> {
+impl ControlChannelReceiver {
+    pub async fn receive(&mut self) -> Result<MumblePacket, Error> {
         let mut packet_type = [0; 2];
         let mut length = [0; 4];
         self.reader.read_exact(&mut packet_type).await?;
@@ -177,8 +109,8 @@ impl ControlChannelReader {
     }
 }
 
-impl ControlChannelWriter {
-    pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
+impl ControlChannelSender {
+    pub async fn send(&mut self, packet: MumblePacket) -> Result<(), Error> {
         let bytes = packet.serialize();
         self.writer.write_all(&bytes).await?;
         self.writer.flush().await?;
@@ -186,14 +118,47 @@ impl ControlChannelWriter {
     }
 }
 
-impl From<crate::protocol::Error> for Error {
-    fn from(_: crate::protocol::Error) -> Self {
-        Error::StreamError
+impl AudioChannelSender {
+    pub async fn send(&mut self, packet: AudioPacket) -> Result<(), Error> {
+        let bytes = packet.serialize();
+        let encrypted = {
+            let mut crypto = self.crypto.lock().await;
+            crypto.encrypt(&bytes)?
+        };
+        self.socket.send_to(&encrypted, self.destination).await?;
+        Ok(())
+    }
+}
+
+impl AudioChannelReceiver {
+    pub async fn receive(&mut self) -> Result<AudioPacket, Error> {
+        match self.raw_bytes_receiver.recv().await {
+            Some(bytes) => {
+                let decrypted = {
+                    let mut crypto = self.crypto.lock().await;
+                    crypto.decrypt(&bytes)?
+                };
+                Ok(AudioPacket::parse(decrypted)?)
+            }
+            None => unimplemented!(),
+        }
     }
 }
 
 impl From<std::io::Error> for Error {
-    fn from(_: std::io::Error) -> Self {
-        Error::StreamError
+    fn from(error: std::io::Error) -> Self {
+        Error::IOError(error)
+    }
+}
+
+impl From<crate::protocol::Error> for Error {
+    fn from(error: crate::protocol::Error) -> Self {
+        Error::ParsingError(error)
+    }
+}
+
+impl From<crate::crypto::Error> for Error {
+    fn from(error: crate::crypto::Error) -> Self {
+        Error::CryptError(error)
     }
 }

+ 17 - 17
src/crypto.rs

@@ -10,11 +10,7 @@ type Key = [u8; 16];
 type Nonce = [u8; 16];
 type Tag = [u8; 16];
 
-enum Error {
-    Fail,
-}
-
-struct CryptState {
+pub struct Ocb2Aes128Crypto {
     cipher: Aes128,
     encrypt_iv: Nonce,
     decrypt_iv: Nonce,
@@ -24,11 +20,15 @@ struct CryptState {
     lost: u32,
 }
 
+pub enum Error {
+    Fail,
+}
+
 // Based on the official Mumble project CryptState implementation
 // TODO refactor this mess
-impl CryptState {
-    pub fn new(key: Key, encrypt_iv: Nonce, decrypt_iv: Nonce) -> CryptState {
-        CryptState {
+impl Ocb2Aes128Crypto {
+    pub fn new(key: Key, encrypt_iv: Nonce, decrypt_iv: Nonce) -> Ocb2Aes128Crypto {
+        Ocb2Aes128Crypto {
             cipher: Aes128::new(&GenericArray::from(key)),
             encrypt_iv,
             decrypt_iv,
@@ -346,7 +346,7 @@ fn swapped(value: u8) -> u8 {
 
 #[cfg(test)]
 mod tests {
-    use crate::crypto::{CryptState, AES_BLOCK_SIZE};
+    use crate::crypto::{Ocb2Aes128Crypto, AES_BLOCK_SIZE};
 
     #[test]
     fn test_reverse_recovery() {
@@ -359,8 +359,8 @@ mod tests {
             0x9d, 0xb0, 0xcd, 0xf8, 0x80, 0xf7, 0x3e, 0x3e, 0x10, 0xd4, 0xeb, 0x32, 0x17, 0x76,
             0x66, 0x88,
         ];
-        let mut encryption = CryptState::new(key, encrypt_iv, decrypt_iv);
-        let mut decryption = CryptState::new(key, decrypt_iv, encrypt_iv);
+        let mut encryption = Ocb2Aes128Crypto::new(key, encrypt_iv, decrypt_iv);
+        let mut decryption = Ocb2Aes128Crypto::new(key, decrypt_iv, encrypt_iv);
         let secret = b"MyVerySecret".to_vec();
         let mut encrypted = vec![vec![]; 512];
 
@@ -399,8 +399,8 @@ mod tests {
             0x9d, 0xb0, 0xcd, 0xf8, 0x80, 0xf7, 0x3e, 0x3e, 0x10, 0xd4, 0xeb, 0x32, 0x17, 0x76,
             0x66, 0x88,
         ];
-        let mut encryption = CryptState::new(key, encrypt_iv, decrypt_iv);
-        let mut decryption = CryptState::new(key, decrypt_iv, encrypt_iv);
+        let mut encryption = Ocb2Aes128Crypto::new(key, encrypt_iv, decrypt_iv);
+        let mut decryption = Ocb2Aes128Crypto::new(key, decrypt_iv, encrypt_iv);
         let secret = b"MyVerySecret".to_vec();
 
         let mut encrypted = encryption.encrypt(&secret).ok().unwrap();
@@ -442,7 +442,7 @@ mod tests {
             0x0e, 0x0f,
         ];
         let mut tag = [0; 16];
-        let crypt_state = CryptState::new(key, [0; 16], [0; 16]);
+        let crypt_state = Ocb2Aes128Crypto::new(key, [0; 16], [0; 16]);
 
         assert!(crypt_state.ocb_encrypt(&source, &mut destination, key, &mut tag, true));
 
@@ -482,7 +482,7 @@ mod tests {
             0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22,
             0x11, 0x00,
         ];
-        let crypt_state = CryptState::new(key, [0; 16], [0; 16]);
+        let crypt_state = Ocb2Aes128Crypto::new(key, [0; 16], [0; 16]);
 
         for len in 0..128 {
             let mut src = Vec::with_capacity(len);
@@ -523,7 +523,7 @@ mod tests {
             0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22,
             0x11, 0x00,
         ];
-        let crypt = CryptState::new(key, nonce, nonce);
+        let crypt = Ocb2Aes128Crypto::new(key, nonce, nonce);
         let mut src = [0; AES_BLOCK_SIZE * 2];
         src[AES_BLOCK_SIZE - 1] = (AES_BLOCK_SIZE * 8) as u8;
         src.split_at_mut(AES_BLOCK_SIZE).1.fill(42);
@@ -562,7 +562,7 @@ mod tests {
             0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22,
             0x11, 0x00,
         ];
-        let mut crypt = CryptState::new(key, nonce, nonce);
+        let mut crypt = Ocb2Aes128Crypto::new(key, nonce, nonce);
         let message = b"It was a funky funky town!";
         let mut encrypted = crypt.encrypt(message).ok().unwrap();
 

+ 3 - 1
src/main.rs

@@ -1,6 +1,7 @@
 use std::fs::File;
 use std::io::BufReader;
 
+use crate::server::Server;
 use clap::{App, Arg};
 use tokio::runtime::Builder;
 use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey};
@@ -64,9 +65,10 @@ fn main() {
         path_to_db_file: path,
     };
 
+    let server = Server::new(config);
     let tokio_rt = Builder::new_multi_thread().enable_all().build().unwrap();
     tokio_rt.block_on(async {
-        server::run(config).await.unwrap();
+        server.run().await;
     });
 }
 

+ 13 - 11
src/protocol.rs

@@ -8,6 +8,7 @@ use crate::proto::mumble::{
 };
 
 pub const MUMBLE_PROTOCOL_VERSION: u32 = 0b0000_0001_0011_0100;
+pub const MAX_AUDIO_PACKET_SIZE: usize = 1020;
 
 const VERSION: u16 = 0;
 const UDP_TUNNEL: u16 = 1;
@@ -41,7 +42,7 @@ const LENGTH_SIZE: usize = 4;
 
 pub enum MumblePacket {
     Version(Version),
-    UdpTunnel(VoicePacket),
+    UdpTunnel(AudioPacket),
     Authenticate(Authenticate),
     Ping(Ping),
     Reject(Reject),
@@ -68,8 +69,8 @@ pub enum MumblePacket {
     SuggestConfig(SuggestConfig),
 }
 
-pub enum VoicePacket {
-    Ping(VoicePing),
+pub enum AudioPacket {
+    Ping(AudioPing),
     AudioData(AudioData),
 }
 
@@ -78,7 +79,7 @@ pub enum Error {
     ParsingError,
 }
 
-pub struct VoicePing {
+pub struct AudioPing {
     bytes: Vec<u8>,
 }
 
@@ -96,7 +97,7 @@ impl MumblePacket {
     pub fn parse_payload(packet_type: u16, payload: &[u8]) -> Result<Self, Error> {
         match packet_type {
             VERSION => Ok(MumblePacket::Version(Version::parse_from_bytes(payload)?)),
-            UDP_TUNNEL => Ok(MumblePacket::UdpTunnel(VoicePacket::parse(
+            UDP_TUNNEL => Ok(MumblePacket::UdpTunnel(AudioPacket::parse(
                 payload.to_vec(),
             )?)),
             AUTHENTICATE => Ok(MumblePacket::Authenticate(Authenticate::parse_from_bytes(
@@ -246,19 +247,19 @@ impl MumblePacket {
     }
 }
 
-impl VoicePacket {
+impl AudioPacket {
     pub fn parse(bytes: Vec<u8>) -> Result<Self, Error> {
-        if bytes.is_empty() {
+        if bytes.is_empty() || bytes.len() > MAX_AUDIO_PACKET_SIZE {
             return Err(Error::ParsingError);
         }
 
         let header = bytes.first().unwrap();
         let (packet_type, _) = decode_header(*header);
         if packet_type == 1 {
-            return Ok(VoicePacket::Ping(VoicePing { bytes }));
+            return Ok(AudioPacket::Ping(AudioPing { bytes }));
         }
 
-        Ok(VoicePacket::AudioData(AudioData {
+        Ok(AudioPacket::AudioData(AudioData {
             session_id: None,
             bytes,
         }))
@@ -266,8 +267,8 @@ impl VoicePacket {
 
     pub fn serialize(self) -> Vec<u8> {
         match self {
-            VoicePacket::Ping(ping) => ping.bytes,
-            VoicePacket::AudioData(audio_data) => {
+            AudioPacket::Ping(ping) => ping.bytes,
+            AudioPacket::AudioData(audio_data) => {
                 if let Some(session_id) = audio_data.session_id {
                     let mut bytes = audio_data.bytes;
                     let varint = encode_varint(session_id as u64);
@@ -338,6 +339,7 @@ impl From<ProtobufError> for Error {
     }
 }
 
+//TODO write more tests
 #[cfg(test)]
 mod tests {
     use super::*;

+ 237 - 60
src/server.rs

@@ -2,14 +2,22 @@ use std::collections::HashMap;
 use std::net::{IpAddr, SocketAddr};
 use std::sync::Arc;
 
-use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::RwLock;
+use tokio::net::{TcpListener, TcpStream, UdpSocket};
+use tokio::sync::{mpsc, Mutex, RwLock};
 use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
 use tokio_rustls::{TlsAcceptor, TlsStream};
 
 use crate::client::{Client, Message, ResponseMessage};
-use crate::connection::{Connection, ConnectionConfig};
+use crate::connection::{AudioChannel, ControlChannel};
+use crate::crypto::Ocb2Aes128Crypto;
 use crate::db::Db;
+use crate::protocol::AudioData;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+
+use tokio::sync::mpsc::{Receiver, Sender};
+
+pub const MAX_UDP_DATAGRAM_SIZE: usize = 1024;
 
 pub struct Config {
     pub ip_address: IpAddr,
@@ -19,83 +27,252 @@ pub struct Config {
     pub path_to_db_file: String,
 }
 
-type Clients = Arc<RwLock<HashMap<u32, Client>>>;
+pub struct Server {
+    config: Config,
+    db: Arc<Db>,
+    clients: RwLock<HashMap<SessionId, Client>>,
+    waiting_for_audio_channel: Mutex<Vec<(SessionId, IpAddr, Ocb2Aes128Crypto)>>,
+    address_to_channel: RwLock<HashMap<SocketAddr, Sender<Vec<u8>>>>,
+}
 
-pub async fn run(config: Config) -> std::io::Result<()> {
-    let db = Arc::new(Db::open(&config.path_to_db_file));
+type SessionId = u32;
 
-    let mut tls_config = ServerConfig::new(NoClientAuth::new());
-    tls_config
-        .set_single_cert(vec![config.certificate], config.private_key)
-        .expect("Invalid private key");
+impl Server {
+    pub fn new(config: Config) -> Arc<Self> {
+        let path_to_db_file = config.path_to_db_file.clone();
 
-    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
-    let listener = TcpListener::bind(SocketAddr::new(config.ip_address, config.port)).await?;
+        Arc::new(Server {
+            config,
+            clients: RwLock::new(HashMap::new()),
+            db: Arc::new(Db::open(&path_to_db_file)),
+            waiting_for_audio_channel: Mutex::new(vec![]),
+            address_to_channel: RwLock::new(HashMap::new()),
+        })
+    }
 
-    let clients = Arc::new(RwLock::new(HashMap::new()));
-    loop {
-        let (stream, _) = listener.accept().await?;
-        let acceptor = acceptor.clone();
-        let db = Arc::clone(&db);
-        let clients = Arc::clone(&clients);
+    pub async fn run(self: Arc<Self>) {
+        let mut tls_config = ServerConfig::new(NoClientAuth::new());
+        tls_config
+            .set_single_cert(
+                vec![self.config.certificate.clone()],
+                self.config.private_key.clone(),
+            )
+            .expect("Invalid private key");
+
+        let socket_address = SocketAddr::new(self.config.ip_address, self.config.port);
+        let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
+        let tcp_listener = TcpListener::bind(socket_address).await.unwrap();
+        let udp_socket = UdpSocket::bind(socket_address).await.unwrap();
+
+        Arc::clone(&self).run_udp_task(udp_socket).await;
+        Arc::clone(&self)
+            .listen_for_new_connections(tcp_listener, tls_acceptor)
+            .await;
+    }
 
+    async fn run_udp_task(self: Arc<Self>, socket: UdpSocket) {
+        let socket = Arc::new(socket);
         tokio::spawn(async move {
-            let stream = acceptor.accept(stream).await;
-            if let Ok(stream) = stream {
-                process(db, TlsStream::from(stream), clients).await;
+            let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
+            loop {
+                if let Ok((len, socket_address)) = socket.recv_from(&mut buf).await {
+                    if !Arc::clone(&self)
+                        .send_to_audio_channel(&buf[..len], &socket_address)
+                        .await
+                    {
+                        // TODO Move to a separate task
+                        Arc::clone(&self)
+                            .match_address_to_channel(
+                                &buf[..len],
+                                socket_address,
+                                Arc::clone(&socket),
+                            )
+                            .await;
+                    }
+                }
             }
         });
     }
-}
 
-async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
-    let connection_config = ConnectionConfig {
-        max_bandwidth: 128000,
-        welcome_text: "Welcome!".to_string(),
-    };
-    let connection = match Connection::setup_connection(db.clone(), stream, connection_config).await
-    {
-        Ok(connection) => connection,
-        Err(_) => {
-            eprintln!("Error establishing a connection");
-            return;
+    async fn send_to_audio_channel(self: &Arc<Self>, buf: &[u8], address: &SocketAddr) -> bool {
+        let connected = self.address_to_channel.read().await;
+        if let Some(sender) = connected.get(address) {
+            sender.try_send(Vec::from(buf));
+            return true;
         }
-    };
-    let session_id = connection.session_id;
-    let (client, mut response_receiver) = Client::new(connection, db).await;
 
-    {
-        let mut clients = clients.write().await;
-        for client in clients.values() {
-            client.post_message(Message::UserConnected(session_id))
-        }
-        clients.insert(session_id, client);
+        false
     }
 
-    loop {
-        let message = match response_receiver.recv().await {
-            Some(msg) => msg,
+    async fn match_address_to_channel(
+        self: &Arc<Self>,
+        buf: &[u8],
+        address: SocketAddr,
+        udp_socket: Arc<UdpSocket>,
+    ) {
+        let mut waiting = self.waiting_for_audio_channel.lock().await;
+        let index = match waiting
+            .iter_mut()
+            .position(|(_, ip, crypto)| &address.ip() == ip && crypto.decrypt(buf).is_ok())
+        {
+            Some(index) => index,
             None => return,
         };
+        let (session_id, _, crypto) = waiting.remove(index);
+        drop(waiting);
 
-        match message {
-            ResponseMessage::Disconnected => {
-                let mut clients = clients.write().await;
-                clients.remove(&session_id);
-                for client in clients.values() {
-                    client.post_message(Message::UserDisconnected(session_id))
+        let (sender, receiver) = mpsc::channel(1);
+        let mut clients = self.clients.write().await;
+        if let Some(client) = clients.get_mut(&session_id) {
+            let audio_channel = AudioChannel::new(receiver, udp_socket, crypto, address);
+            client.set_audio_channel(audio_channel).await;
+        }
+        drop(clients);
+
+        let mut address_to_channel = self.address_to_channel.write().await;
+        address_to_channel.insert(address, sender);
+    }
+
+    async fn listen_for_new_connections(
+        self: Arc<Self>,
+        listener: TcpListener,
+        acceptor: TlsAcceptor,
+    ) {
+        loop {
+            let (stream, _) = match listener.accept().await {
+                Ok(stream) => stream,
+                Err(_) => continue,
+            };
+            let acceptor = acceptor.clone();
+            let server = Arc::clone(&self);
+
+            tokio::spawn(async move {
+                let stream = acceptor.accept(stream).await;
+                if let Ok(stream) = stream {
+                    server.process_new_connection(TlsStream::from(stream)).await;
                 }
-                return;
-            }
-            ResponseMessage::Talking(audio_data) => {
-                let clients = clients.read().await;
-                for client in clients
-                    .values()
-                    .filter(|client| client.session_id != session_id)
-                {
-                    client.post_message(Message::UserTalking(audio_data.clone()));
+            });
+        }
+    }
+
+    async fn process_new_connection(self: Arc<Self>, stream: TlsStream<TcpStream>) {
+        let (session_id, mut responder) = match self.new_client(stream).await {
+            Ok(id) => id,
+            Err(_) => unimplemented!(),
+        };
+
+        loop {
+            let message = match responder.recv().await {
+                Some(msg) => msg,
+                None => return,
+            };
+
+            match message {
+                ResponseMessage::Disconnected => {
+                    self.client_disconnected(session_id).await;
+                    return;
+                }
+                ResponseMessage::Talking(audio_data) => {
+                    self.client_talking(session_id, audio_data).await;
                 }
             }
         }
     }
+
+    async fn client_disconnected(&self, session_id: SessionId) {
+        let mut clients = self.clients.write().await;
+        clients.remove(&session_id);
+        for client in clients.values() {
+            client
+                .send_message(Message::UserDisconnected(session_id))
+                .await;
+        }
+        drop(clients);
+
+        //TODO optimize
+        let mut waiting = self.waiting_for_audio_channel.lock().await;
+        if let Some(index) = waiting.iter().position(|(id, _, _)| session_id == *id) {
+            waiting.remove(index);
+        } else {
+            drop(waiting);
+
+            let mut address_to_channel = self.address_to_channel.write().await;
+            if let Some(key) = address_to_channel
+                .keys()
+                .find(|key| address_to_channel.get(key).unwrap().is_closed())
+                .cloned()
+            {
+                address_to_channel.remove(&key);
+            }
+        }
+    }
+
+    async fn client_talking(&self, session_id: SessionId, audio: AudioData) {
+        let clients = self.clients.read().await;
+        for client in clients
+            .values()
+            .filter(|client| client.session_id != session_id)
+        {
+            client
+                .send_message(Message::UserTalking(audio.clone()))
+                .await;
+        }
+    }
+
+    async fn new_client(
+        self: &Arc<Self>,
+        stream: TlsStream<TcpStream>,
+    ) -> Result<(SessionId, Receiver<ResponseMessage>), crate::client::Error> {
+        let ip = stream.get_ref().0.peer_addr().unwrap().ip();
+        let config = self.create_client_config();
+        let crypto =
+            Ocb2Aes128Crypto::new(config.crypto_key, config.server_nonce, config.client_nonce);
+        let (client, receiver) =
+            Client::establish_connection(Arc::clone(&self.db), ControlChannel::new(stream), config)
+                .await?;
+
+        let session_id = client.session_id;
+        let mut clients = self.clients.write().await;
+        for client in clients.values() {
+            client
+                .send_message(Message::UserConnected(session_id))
+                .await;
+        }
+        clients.insert(session_id, client);
+        drop(clients);
+
+        let mut waiting = self.waiting_for_audio_channel.lock().await;
+        waiting.push((session_id, ip, crypto));
+        drop(waiting);
+
+        Ok((session_id, receiver))
+    }
+
+    fn create_client_config(&self) -> crate::client::Config {
+        let crypto_key = self.generate_key();
+        let server_nonce = self.generate_key();
+        let client_nonce = self.generate_key();
+        crate::client::Config {
+            crypto_key,
+            server_nonce,
+            client_nonce,
+            alpha_codec_version: 0,
+            beta_codec_version: 0,
+            prefer_alpha: true,
+            opus_support: true,
+            welcome_text: "Welcome".to_string(),
+            max_bandwidth: 128000,
+            max_users: 10,
+            allow_html: true,
+            max_message_length: 512,
+            max_image_message_length: 100000,
+        }
+    }
+
+    fn generate_key(&self) -> [u8; 16] {
+        let mut buffer = [0; 16];
+        let mut rng = StdRng::from_entropy();
+        rng.fill(&mut buffer);
+        buffer
+    }
 }