Bläddra i källkod

Refactor MumblePacket

Sergey Chushin 3 år sedan
förälder
incheckning
7650a4bac7
4 ändrade filer med 222 tillägg och 271 borttagningar
  1. 24 25
      src/client.rs
  2. 99 31
      src/connection.rs
  3. 97 213
      src/protocol.rs
  4. 2 2
      src/server.rs

+ 24 - 25
src/client.rs

@@ -1,14 +1,21 @@
 use std::sync::Arc;
 
-use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::sync::mpsc;
 use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
 use tokio::task::JoinHandle;
 
-use crate::connection::Connection;
+use crate::connection::{Connection, ControlChannelWriter};
 use crate::db::{Db, User};
 use crate::proto::mumble::{Ping, UserRemove, UserState};
-use crate::protocol::{AudioData, MumblePacket, MumblePacketWriter, VoicePacket};
+use crate::protocol::{AudioData, MumblePacket, VoicePacket};
+use crate::client::Error::StreamError;
+
+pub struct Client {
+    pub session_id: u32,
+    inner_sender: UnboundedSender<InnerMessage>,
+    handler_task: JoinHandle<()>,
+    packet_task: JoinHandle<()>,
+}
 
 pub enum Message {
     UserConnected(u32),
@@ -21,20 +28,13 @@ pub enum ResponseMessage {
     Talking(AudioData),
 }
 
-pub struct Client {
-    pub session_id: u32,
-    inner_sender: UnboundedSender<InnerMessage>,
-    handler_task: JoinHandle<()>,
-    packet_task: JoinHandle<()>,
-}
-
 pub enum Error {
-    StreamError(crate::protocol::Error),
+    StreamError,
 }
 
-struct Handler<W> {
+struct Handler {
     db: Arc<Db>,
-    writer: MumblePacketWriter<W>,
+    writer: ControlChannelWriter,
     session_id: u32,
     response_sender: UnboundedSender<ResponseMessage>,
 }
@@ -48,14 +48,11 @@ enum InnerMessage {
 type ResponseReceiver = UnboundedReceiver<ResponseMessage>;
 
 impl Client {
-    pub async fn new<S>(connection: Connection<S>, db: Arc<Db>) -> (Client, ResponseReceiver)
-    where
-        S: 'static + AsyncRead + AsyncWrite + Unpin + Send,
-    {
+    pub async fn new(connection: Connection, db: Arc<Db>) -> (Client, ResponseReceiver) {
         let (sender, mut receiver) = mpsc::unbounded_channel();
         let (response_sender, response_receiver) = mpsc::unbounded_channel();
 
-        let writer = connection.writer;
+        let (mut reader, writer) = connection.control_channel.split();
         let session_id = connection.session_id;
         let handler_task = tokio::spawn(async move {
             let mut handler = Handler {
@@ -92,7 +89,6 @@ impl Client {
         });
 
         let inner_sender = sender.clone();
-        let mut reader = connection.reader;
         let packet_task = tokio::spawn(async move {
             loop {
                 match reader.read().await {
@@ -128,10 +124,7 @@ impl Drop for Client {
     }
 }
 
-impl<W> Handler<W>
-where
-    W: AsyncWrite + Unpin + Send,
-{
+impl Handler {
     async fn handle_packet(&mut self, packet: MumblePacket) -> Result<(), Error> {
         match packet {
             MumblePacket::Ping(ping) => {
@@ -215,7 +208,13 @@ impl From<User> for MumblePacket {
 }
 
 impl From<crate::protocol::Error> for Error {
-    fn from(err: crate::protocol::Error) -> Self {
-        Error::StreamError(err)
+    fn from(_: crate::protocol::Error) -> Self {
+        StreamError
+    }
+}
+
+impl From<crate::connection::Error> for Error {
+    fn from(_: crate::connection::Error) -> Self {
+        StreamError
     }
 }

+ 99 - 31
src/connection.rs

@@ -1,21 +1,36 @@
 use std::sync::Arc;
 
-use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
-
 use crate::db::Db;
 use crate::proto::mumble::{
-    ChannelState, CodecVersion, PermissionQuery, ServerConfig, ServerSync, UserState, Version,
-};
-use crate::protocol::{
-    MumblePacket, MumblePacketReader, MumblePacketWriter, MUMBLE_PROTOCOL_VERSION,
+    ChannelState, CodecVersion, CryptSetup, PermissionQuery, ServerConfig, ServerSync, UserState,
+    Version,
 };
+use crate::protocol::{MumblePacket, MUMBLE_PROTOCOL_VERSION};
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
 
-pub struct Connection<S> {
-    pub reader: MumblePacketReader<ReadHalf<S>>,
-    pub writer: MumblePacketWriter<WriteHalf<S>>,
+use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
+use tokio::net::TcpStream;
+use tokio_rustls::TlsStream;
+
+pub struct Connection {
+    pub control_channel: ControlChannel,
     pub session_id: u32,
 }
 
+pub struct ControlChannel {
+    reader: ControlChannelReader,
+    writer: ControlChannelWriter,
+}
+
+pub struct ControlChannelReader {
+    reader: ReadHalf<TlsStream<TcpStream>>,
+}
+
+pub struct ControlChannelWriter {
+    writer: WriteHalf<TlsStream<TcpStream>>,
+}
+
 pub struct ConnectionConfig {
     pub max_bandwidth: u32,
     pub welcome_text: String,
@@ -24,31 +39,30 @@ pub struct ConnectionConfig {
 pub enum Error {
     ConnectionSetupError,
     AuthenticationError,
-    StreamError(crate::protocol::Error),
+    StreamError,
 }
 
-impl<S> Connection<S>
-where
-    S: AsyncRead + AsyncWrite + Unpin + Send,
-{
+impl Connection {
     pub async fn setup_connection(
         db: Arc<Db>,
-        stream: S,
+        stream: TlsStream<TcpStream>,
         config: ConnectionConfig,
-    ) -> Result<Connection<S>, Error> {
-        let (mut reader, mut writer) = crate::protocol::new(stream);
+    ) -> Result<Connection, Error> {
+        let mut control_channel = ControlChannel::new(stream);
 
         //Version exchange
-        let _ = match reader.read().await? {
+        let _ = match control_channel.read().await? {
             MumblePacket::Version(version) => version,
             _ => return Err(Error::ConnectionSetupError),
         };
         let mut version = Version::new();
         version.set_version(MUMBLE_PROTOCOL_VERSION);
-        writer.write(MumblePacket::Version(version)).await?;
+        control_channel
+            .write(MumblePacket::Version(version))
+            .await?;
 
         //Authentication
-        let mut auth = match reader.read().await? {
+        let mut auth = match control_channel.read().await? {
             MumblePacket::Authenticate(auth) => auth,
             _ => return Err(Error::ConnectionSetupError),
         };
@@ -57,7 +71,7 @@ where
         }
         let session_id = db.add_new_user(auth.take_username()).await;
 
-        //Crypt setup TODO
+        //Crypt setup
 
         //CodecVersion
         let mut codec_version = CodecVersion::new();
@@ -65,7 +79,7 @@ where
         codec_version.set_beta(0);
         codec_version.set_prefer_alpha(true);
         codec_version.set_opus(true);
-        writer
+        control_channel
             .write(MumblePacket::CodecVersion(codec_version))
             .await?;
 
@@ -75,7 +89,7 @@ where
             let mut channel_state = ChannelState::new();
             channel_state.set_channel_id(channel.id);
             channel_state.set_name(channel.name);
-            writer
+            control_channel
                 .write(MumblePacket::ChannelState(channel_state))
                 .await?;
         }
@@ -84,7 +98,7 @@ where
         let mut permission_query = PermissionQuery::new();
         permission_query.set_permissions(134743822);
         permission_query.set_channel_id(0);
-        writer
+        control_channel
             .write(MumblePacket::PermissionQuery(permission_query))
             .await?;
 
@@ -95,7 +109,9 @@ where
             user_state.set_name(user.username);
             user_state.set_session(user.session_id);
             user_state.set_channel_id(user.channel_id);
-            writer.write(MumblePacket::UserState(user_state)).await?;
+            control_channel
+                .write(MumblePacket::UserState(user_state))
+                .await?;
         }
 
         //Server sync
@@ -104,7 +120,9 @@ where
         server_sync.set_welcome_text(config.welcome_text);
         server_sync.set_max_bandwidth(config.max_bandwidth);
         server_sync.set_permissions(134743822);
-        writer.write(MumblePacket::ServerSync(server_sync)).await?;
+        control_channel
+            .write(MumblePacket::ServerSync(server_sync))
+            .await?;
 
         //ServerConfig
         let mut server_config = ServerConfig::new();
@@ -112,20 +130,70 @@ where
         server_config.set_allow_html(true);
         server_config.set_message_length(5000);
         server_config.set_image_message_length(131072);
-        writer
+        control_channel
             .write(MumblePacket::ServerConfig(server_config))
             .await?;
 
         Ok(Connection {
-            reader,
-            writer,
+            control_channel,
             session_id,
         })
     }
 }
 
+impl ControlChannel {
+    pub async fn read(&mut self) -> Result<MumblePacket, Error> {
+        self.reader.read().await
+    }
+
+    pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
+        self.writer.write(packet).await
+    }
+
+    pub fn split(self) -> (ControlChannelReader, ControlChannelWriter) {
+        (self.reader, self.writer)
+    }
+
+    fn new(stream: TlsStream<TcpStream>) -> Self {
+        let (reader, writer) = tokio::io::split(stream);
+        ControlChannel {
+            reader: ControlChannelReader { reader },
+            writer: ControlChannelWriter { writer },
+        }
+    }
+}
+
+impl ControlChannelReader {
+    pub async fn read(&mut self) -> Result<MumblePacket, Error> {
+        let mut packet_type = [0; 2];
+        let mut length = [0; 4];
+        self.reader.read_exact(&mut packet_type).await?;
+        self.reader.read_exact(&mut length).await?;
+        let (packet_type, length) = MumblePacket::parse_prefix(packet_type, length);
+
+        let mut payload = vec![0; length as usize];
+        self.reader.read_exact(&mut payload).await?;
+        Ok(MumblePacket::parse_payload(packet_type, &payload)?)
+    }
+}
+
+impl ControlChannelWriter {
+    pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
+        let bytes = packet.serialize();
+        self.writer.write_all(&bytes).await?;
+        self.writer.flush().await?;
+        Ok(())
+    }
+}
+
 impl From<crate::protocol::Error> for Error {
-    fn from(err: crate::protocol::Error) -> Self {
-        Error::StreamError(err)
+    fn from(_: crate::protocol::Error) -> Self {
+        Error::StreamError
+    }
+}
+
+impl From<std::io::Error> for Error {
+    fn from(_: std::io::Error) -> Self {
+        Error::StreamError
     }
 }

+ 97 - 213
src/protocol.rs

@@ -1,5 +1,4 @@
 use protobuf::{Message, ProtobufError};
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
 
 use crate::proto::mumble::{
     Authenticate, BanList, ChannelRemove, ChannelState, CodecVersion, ContextAction,
@@ -36,7 +35,9 @@ const USER_STATS: u16 = 22;
 const REQUEST_BLOB: u16 = 23;
 const SERVER_CONFIG: u16 = 24;
 const SUGGEST_CONFIG: u16 = 25;
-const MAX_AUDIO_PACKET_SIZE: usize = 1020;
+
+const TYPE_SIZE: usize = 2;
+const LENGTH_SIZE: usize = 4;
 
 pub enum MumblePacket {
     Version(Version),
@@ -74,18 +75,9 @@ pub enum VoicePacket {
 
 pub enum Error {
     UnknownPacketType,
-    ConnectionError(std::io::Error),
     ParsingError,
 }
 
-pub struct MumblePacketReader<R> {
-    reader: R,
-}
-
-pub struct MumblePacketWriter<W> {
-    writer: W,
-}
-
 pub struct VoicePing {
     bytes: Vec<u8>,
 }
@@ -96,100 +88,75 @@ pub struct AudioData {
     bytes: Vec<u8>,
 }
 
-pub fn new<S>(
-    stream: S,
-) -> (
-    MumblePacketReader<ReadHalf<S>>,
-    MumblePacketWriter<WriteHalf<S>>,
-)
-where
-    S: AsyncRead + AsyncWrite + Unpin + Send,
-{
-    let (reader, writer) = tokio::io::split(stream);
-    (
-        MumblePacketReader::new(reader),
-        MumblePacketWriter::new(writer),
-    )
-}
-
-impl<R> MumblePacketReader<R>
-where
-    R: AsyncRead + Unpin + Send,
-{
-    pub fn new(reader: R) -> Self {
-        MumblePacketReader { reader }
+impl MumblePacket {
+    pub fn parse_prefix(packet_type: [u8; TYPE_SIZE], length: [u8; LENGTH_SIZE]) -> (u16, u32) {
+        (u16::from_be_bytes(packet_type), u32::from_be_bytes(length))
     }
 
-    pub async fn read(&mut self) -> Result<MumblePacket, Error> {
-        let packet_type = self.reader.read_u16().await?;
-        let payload_length = self.reader.read_u32().await?;
-        let payload = self.read_payload(payload_length).await?;
-
+    pub fn parse_payload(packet_type: u16, payload: &[u8]) -> Result<Self, Error> {
         match packet_type {
-            VERSION => Ok(MumblePacket::Version(Version::parse_from_bytes(&payload)?)),
-            UDP_TUNNEL => Ok(MumblePacket::UdpTunnel(VoicePacket::parse_from_bytes(
-                payload,
+            VERSION => Ok(MumblePacket::Version(Version::parse_from_bytes(payload)?)),
+            UDP_TUNNEL => Ok(MumblePacket::UdpTunnel(VoicePacket::parse(
+                payload.to_vec(),
             )?)),
             AUTHENTICATE => Ok(MumblePacket::Authenticate(Authenticate::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
-            PING => Ok(MumblePacket::Ping(Ping::parse_from_bytes(&payload)?)),
-            REJECT => Ok(MumblePacket::Reject(Reject::parse_from_bytes(&payload)?)),
+            PING => Ok(MumblePacket::Ping(Ping::parse_from_bytes(payload)?)),
+            REJECT => Ok(MumblePacket::Reject(Reject::parse_from_bytes(payload)?)),
             SERVER_SYNC => Ok(MumblePacket::ServerSync(ServerSync::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             CHANNEL_REMOVE => Ok(MumblePacket::ChannelRemove(
-                ChannelRemove::parse_from_bytes(&payload)?,
+                ChannelRemove::parse_from_bytes(payload)?,
             )),
             CHANNEL_STATE => Ok(MumblePacket::ChannelState(ChannelState::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             USER_REMOVE => Ok(MumblePacket::UserRemove(UserRemove::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             USER_STATE => Ok(MumblePacket::UserState(UserState::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
-            BAN_LIST => Ok(MumblePacket::BanList(BanList::parse_from_bytes(&payload)?)),
+            BAN_LIST => Ok(MumblePacket::BanList(BanList::parse_from_bytes(payload)?)),
             TEXT_MESSAGE => Ok(MumblePacket::TextMessage(TextMessage::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             PERMISSION_DENIED => Ok(MumblePacket::PermissionDenied(
-                PermissionDenied::parse_from_bytes(&payload)?,
+                PermissionDenied::parse_from_bytes(payload)?,
             )),
-            ACL => Ok(MumblePacket::Acl(Acl::parse_from_bytes(&payload)?)),
+            ACL => Ok(MumblePacket::Acl(Acl::parse_from_bytes(payload)?)),
             QUERY_USERS => Ok(MumblePacket::QueryUsers(QueryUsers::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             CRYPT_SETUP => Ok(MumblePacket::CryptSetup(CryptSetup::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             CONTEXT_ACTION_MODIFY => Ok(MumblePacket::ContextActionModify(
-                ContextActionModify::parse_from_bytes(&payload)?,
+                ContextActionModify::parse_from_bytes(payload)?,
             )),
             CONTEXT_ACTION => Ok(MumblePacket::ContextAction(
-                ContextAction::parse_from_bytes(&payload)?,
+                ContextAction::parse_from_bytes(payload)?,
             )),
-            USER_LIST => Ok(MumblePacket::UserList(UserList::parse_from_bytes(
-                &payload,
-            )?)),
+            USER_LIST => Ok(MumblePacket::UserList(UserList::parse_from_bytes(payload)?)),
             VOICE_TARGET => Ok(MumblePacket::VoiceTarget(VoiceTarget::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             PERMISSION_QUERY => Ok(MumblePacket::PermissionQuery(
-                PermissionQuery::parse_from_bytes(&payload)?,
+                PermissionQuery::parse_from_bytes(payload)?,
             )),
             CODEC_VERSION => Ok(MumblePacket::CodecVersion(CodecVersion::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             USER_STATS => Ok(MumblePacket::UserStats(UserStats::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             REQUEST_BLOB => Ok(MumblePacket::RequestBlob(RequestBlob::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             SERVER_CONFIG => Ok(MumblePacket::ServerConfig(ServerConfig::parse_from_bytes(
-                &payload,
+                payload,
             )?)),
             SUGGEST_CONFIG => Ok(MumblePacket::SuggestConfig(
                 SuggestConfig::parse_from_bytes(&payload)?,
@@ -198,156 +165,95 @@ where
         }
     }
 
-    async fn read_varint(&mut self) -> Result<u64, Error> {
-        //TODO negative number decode
-        let header = self.reader.read_u8().await?;
-
-        //7-bit number
-        if (header & 0b1000_0000) == 0b0000_0000 {
-            return Ok(header as u64);
-        }
-        //14-bit number
-        if (header & 0b1100_0000) == 0b1000_0000 {
-            let first_number_byte = header ^ 0b1000_0000;
-            return Ok(((first_number_byte as u64) << 8) | (self.reader.read_u8().await? as u64));
-        }
-        //21-bit number
-        if (header & 0b1110_0000) == 0b1100_0000 {
-            let first_number_byte = header ^ 0b1100_0000;
-            return Ok(((first_number_byte as u64) << 16)
-                | ((self.reader.read_u8().await? as u64) << 8)
-                | (self.reader.read_u8().await? as u64));
-        }
-        //28-bit number
-        if (header & 0b1111_0000) == 0b1110_0000 {
-            let first_number_byte = header ^ 0b1110_0000;
-            return Ok(((first_number_byte as u64) << 24)
-                | ((self.reader.read_u8().await? as u64) << 16)
-                | ((self.reader.read_u8().await? as u64) << 8)
-                | (self.reader.read_u8().await? as u64));
-        }
-        //32-bit number
-        if (header & 0b1111_1100) == 0b1111_0000 {
-            return Ok(self.reader.read_u32().await? as u64);
-        }
-        //64-bit number
-        if (header & 0b1111_1100) == 0b1111_0100 {
-            return Ok(self.reader.read_u64().await?);
-        }
-
-        Err(Error::ParsingError)
-    }
-
-    async fn read_payload(&mut self, payload_length: u32) -> tokio::io::Result<Vec<u8>> {
-        let mut payload = vec![0; payload_length as usize];
-        self.reader.read_exact(&mut payload).await?;
-        Ok(payload)
-    }
-}
-
-impl<W> MumblePacketWriter<W>
-where
-    W: AsyncWrite + Unpin + Send,
-{
-    pub fn new(writer: W) -> Self {
-        MumblePacketWriter { writer }
-    }
-
-    pub async fn write(&mut self, packet: MumblePacket) -> Result<(), Error> {
-        match packet {
-            MumblePacket::UdpTunnel(value) => {
-                let bytes = serialize_voice_packet(value);
-                self.writer.write_u16(UDP_TUNNEL).await?;
-                self.writer.write_u32(bytes.len() as u32).await?;
-                self.writer.write_all(&bytes).await?;
+    pub fn serialize(self) -> Vec<u8> {
+        match self {
+            MumblePacket::UdpTunnel(voice_packet) => {
+                let bytes = voice_packet.serialize();
+                return UDP_TUNNEL
+                    .to_be_bytes()
+                    .iter()
+                    .cloned()
+                    .chain((bytes.len() as u32).to_be_bytes().iter().cloned())
+                    .chain(bytes)
+                    .collect();
             }
-            MumblePacket::Version(value) => self.write_protobuf_packet(value, VERSION).await?,
+            MumblePacket::Version(value) => Self::serialize_protobuf_packet(&value, VERSION),
             MumblePacket::Authenticate(value) => {
-                self.write_protobuf_packet(value, AUTHENTICATE).await?
-            }
-            MumblePacket::Ping(value) => self.write_protobuf_packet(value, PING).await?,
-            MumblePacket::Reject(value) => self.write_protobuf_packet(value, REJECT).await?,
-            MumblePacket::ServerSync(value) => {
-                self.write_protobuf_packet(value, SERVER_SYNC).await?
+                Self::serialize_protobuf_packet(&value, AUTHENTICATE)
             }
+            MumblePacket::Ping(value) => Self::serialize_protobuf_packet(&value, PING),
+            MumblePacket::Reject(value) => Self::serialize_protobuf_packet(&value, REJECT),
+            MumblePacket::ServerSync(value) => Self::serialize_protobuf_packet(&value, SERVER_SYNC),
             MumblePacket::ChannelRemove(value) => {
-                self.write_protobuf_packet(value, CHANNEL_REMOVE).await?
+                Self::serialize_protobuf_packet(&value, CHANNEL_REMOVE)
             }
             MumblePacket::ChannelState(value) => {
-                self.write_protobuf_packet(value, CHANNEL_STATE).await?
+                Self::serialize_protobuf_packet(&value, CHANNEL_STATE)
             }
-            MumblePacket::UserRemove(value) => {
-                self.write_protobuf_packet(value, USER_REMOVE).await?
-            }
-            MumblePacket::UserState(value) => self.write_protobuf_packet(value, USER_STATE).await?,
-            MumblePacket::BanList(value) => self.write_protobuf_packet(value, BAN_LIST).await?,
+            MumblePacket::UserRemove(value) => Self::serialize_protobuf_packet(&value, USER_REMOVE),
+            MumblePacket::UserState(value) => Self::serialize_protobuf_packet(&value, USER_STATE),
+            MumblePacket::BanList(value) => Self::serialize_protobuf_packet(&value, BAN_LIST),
             MumblePacket::TextMessage(value) => {
-                self.write_protobuf_packet(value, TEXT_MESSAGE).await?
+                Self::serialize_protobuf_packet(&value, TEXT_MESSAGE)
             }
             MumblePacket::PermissionDenied(value) => {
-                self.write_protobuf_packet(value, PERMISSION_DENIED).await?
-            }
-            MumblePacket::Acl(value) => self.write_protobuf_packet(value, ACL).await?,
-            MumblePacket::QueryUsers(value) => {
-                self.write_protobuf_packet(value, QUERY_USERS).await?
-            }
-            MumblePacket::CryptSetup(value) => {
-                self.write_protobuf_packet(value, CRYPT_SETUP).await?
+                Self::serialize_protobuf_packet(&value, PERMISSION_DENIED)
             }
+            MumblePacket::Acl(value) => Self::serialize_protobuf_packet(&value, ACL),
+            MumblePacket::QueryUsers(value) => Self::serialize_protobuf_packet(&value, QUERY_USERS),
+            MumblePacket::CryptSetup(value) => Self::serialize_protobuf_packet(&value, CRYPT_SETUP),
             MumblePacket::ContextActionModify(value) => {
-                self.write_protobuf_packet(value, CONTEXT_ACTION_MODIFY)
-                    .await?
+                Self::serialize_protobuf_packet(&value, CONTEXT_ACTION_MODIFY)
             }
             MumblePacket::ContextAction(value) => {
-                self.write_protobuf_packet(value, CONTEXT_ACTION).await?
+                Self::serialize_protobuf_packet(&value, CONTEXT_ACTION)
             }
-            MumblePacket::UserList(value) => self.write_protobuf_packet(value, USER_LIST).await?,
+            MumblePacket::UserList(value) => Self::serialize_protobuf_packet(&value, USER_LIST),
             MumblePacket::VoiceTarget(value) => {
-                self.write_protobuf_packet(value, VOICE_TARGET).await?
+                Self::serialize_protobuf_packet(&value, VOICE_TARGET)
             }
             MumblePacket::PermissionQuery(value) => {
-                self.write_protobuf_packet(value, PERMISSION_QUERY).await?
+                Self::serialize_protobuf_packet(&value, PERMISSION_QUERY)
             }
             MumblePacket::CodecVersion(value) => {
-                self.write_protobuf_packet(value, CODEC_VERSION).await?
+                Self::serialize_protobuf_packet(&value, CODEC_VERSION)
             }
-            MumblePacket::UserStats(value) => self.write_protobuf_packet(value, USER_STATS).await?,
+            MumblePacket::UserStats(value) => Self::serialize_protobuf_packet(&value, USER_STATS),
             MumblePacket::RequestBlob(value) => {
-                self.write_protobuf_packet(value, REQUEST_BLOB).await?
+                Self::serialize_protobuf_packet(&value, REQUEST_BLOB)
             }
             MumblePacket::ServerConfig(value) => {
-                self.write_protobuf_packet(value, SERVER_CONFIG).await?
+                Self::serialize_protobuf_packet(&value, SERVER_CONFIG)
             }
             MumblePacket::SuggestConfig(value) => {
-                self.write_protobuf_packet(value, SUGGEST_CONFIG).await?
+                Self::serialize_protobuf_packet(&value, SUGGEST_CONFIG)
             }
         }
-
-        self.writer.flush().await?;
-        Ok(())
     }
 
-    async fn write_protobuf_packet<T>(&mut self, packet: T, packet_type: u16) -> Result<(), Error>
+    fn serialize_protobuf_packet<T>(packet: &T, packet_type: u16) -> Vec<u8>
     where
         T: Message,
     {
-        let bytes = packet.write_to_bytes()?;
-        self.writer.write_u16(packet_type).await?;
-        self.writer.write_u32(bytes.len() as u32).await?;
-        self.writer.write_all(&bytes).await?;
-
-        Ok(())
+        let bytes = packet.write_to_bytes().unwrap();
+        return packet_type
+            .to_be_bytes()
+            .iter()
+            .cloned()
+            .chain((bytes.len() as u32).to_be_bytes().iter().cloned())
+            .chain(bytes)
+            .collect();
     }
 }
 
 impl VoicePacket {
-    fn parse_from_bytes(bytes: Vec<u8>) -> Result<VoicePacket, Error> {
+    pub fn parse(bytes: Vec<u8>) -> Result<Self, Error> {
         if bytes.is_empty() {
             return Err(Error::ParsingError);
         }
 
         let header = bytes.first().unwrap();
-        let (packet_type, _) = decode_header(header.clone());
+        let (packet_type, _) = decode_header(*header);
         if packet_type == 1 {
             return Ok(VoicePacket::Ping(VoicePing { bytes }));
         }
@@ -357,6 +263,23 @@ impl VoicePacket {
             bytes,
         }))
     }
+
+    pub fn serialize(self) -> Vec<u8> {
+        match self {
+            VoicePacket::Ping(ping) => ping.bytes,
+            VoicePacket::AudioData(audio_data) => {
+                if let Some(session_id) = audio_data.session_id {
+                    let mut bytes = audio_data.bytes;
+                    let varint = encode_varint(session_id as u64);
+                    return std::iter::once(bytes.remove(0))
+                        .chain(varint)
+                        .chain(bytes)
+                        .collect();
+                }
+                audio_data.bytes
+            }
+        }
+    }
 }
 
 fn decode_header(header: u8) -> (u8, u8) {
@@ -365,12 +288,7 @@ fn decode_header(header: u8) -> (u8, u8) {
     (packet_type, target)
 }
 
-fn encode_header(packet_type: u8, target: u8) -> u8 {
-    (packet_type << 5) | target
-}
-
 fn encode_varint(number: u64) -> Vec<u8> {
-    //TODO negative number encode
     let mut result = vec![];
 
     if number < 0x80 {
@@ -414,31 +332,6 @@ fn encode_varint(number: u64) -> Vec<u8> {
     result
 }
 
-fn serialize_voice_packet(packet: VoicePacket) -> Vec<u8> {
-    match packet {
-        VoicePacket::Ping(ping) => ping.bytes,
-        VoicePacket::AudioData(audio_data) => {
-            if let Some(session_id) = audio_data.session_id {
-                let mut bytes = audio_data.bytes;
-                let mut varint = encode_varint(session_id as u64);
-                let mut result = Vec::with_capacity(bytes.len() + varint.len());
-                let header = bytes.remove(0);
-                result.push(header);
-                result.append(&mut varint);
-                result.append(&mut bytes);
-                return result;
-            }
-            audio_data.bytes
-        }
-    }
-}
-
-impl From<std::io::Error> for Error {
-    fn from(err: std::io::Error) -> Self {
-        Error::ConnectionError(err)
-    }
-}
-
 impl From<ProtobufError> for Error {
     fn from(_: ProtobufError) -> Self {
         Error::ParsingError
@@ -447,8 +340,6 @@ impl From<ProtobufError> for Error {
 
 #[cfg(test)]
 mod tests {
-    use tokio::net::TcpStream;
-
     use super::*;
 
     #[test]
@@ -458,13 +349,6 @@ mod tests {
         assert_eq!(decode_header(0b1000_0000), (4, 0));
     }
 
-    #[test]
-    fn test_encode_header() {
-        assert_eq!(encode_header(2, 8), 0b0100_1000);
-        assert_eq!(encode_header(3, 31), 0b0111_1111);
-        assert_eq!(encode_header(4, 0), 0b1000_0000);
-    }
-
     #[test]
     fn test_encode_varint() {
         let varint_7bit_positive = vec![0b0000_1000];

+ 2 - 2
src/server.rs

@@ -5,7 +5,7 @@ use std::sync::Arc;
 use tokio::net::{TcpListener, TcpStream};
 use tokio::sync::RwLock;
 use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
-use tokio_rustls::{server::TlsStream, TlsAcceptor};
+use tokio_rustls::{TlsAcceptor, TlsStream};
 
 use crate::client::{Client, Message, ResponseMessage};
 use crate::connection::{Connection, ConnectionConfig};
@@ -42,7 +42,7 @@ pub async fn run(config: Config) -> std::io::Result<()> {
         tokio::spawn(async move {
             let stream = acceptor.accept(stream).await;
             if let Ok(stream) = stream {
-                process(db, stream, clients).await;
+                process(db, TlsStream::from(stream), clients).await;
             }
         });
     }