server.rs 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. use std::collections::HashMap;
  2. use std::net::{IpAddr, SocketAddr};
  3. use std::sync::Arc;
  4. use tokio::net::{TcpListener, TcpStream};
  5. use tokio::sync::RwLock;
  6. use tokio_rustls::rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
  7. use tokio_rustls::{server::TlsStream, TlsAcceptor};
  8. use crate::client::{Client, Message, ResponseMessage};
  9. use crate::connection::{Connection, ConnectionConfig};
  10. use crate::db::Db;
  11. pub struct Config {
  12. pub ip_address: IpAddr,
  13. pub port: u16,
  14. pub certificate: Certificate,
  15. pub private_key: PrivateKey,
  16. pub path_to_db_file: String,
  17. }
  18. type Clients = Arc<RwLock<HashMap<u32, Client>>>;
  19. pub async fn run(config: Config) -> std::io::Result<()> {
  20. let db = Arc::new(Db::open(&config.path_to_db_file));
  21. let mut tls_config = ServerConfig::new(NoClientAuth::new());
  22. tls_config
  23. .set_single_cert(vec![config.certificate], config.private_key)
  24. .expect("Invalid private key");
  25. let acceptor = TlsAcceptor::from(Arc::new(tls_config));
  26. let listener = TcpListener::bind(SocketAddr::new(config.ip_address, config.port)).await?;
  27. let clients = Arc::new(RwLock::new(HashMap::new()));
  28. loop {
  29. let (stream, _) = listener.accept().await?;
  30. let acceptor = acceptor.clone();
  31. let db = Arc::clone(&db);
  32. let clients = Arc::clone(&clients);
  33. tokio::spawn(async move {
  34. let stream = acceptor.accept(stream).await;
  35. if let Ok(stream) = stream {
  36. process(db, stream, clients).await;
  37. }
  38. });
  39. }
  40. }
  41. async fn process(db: Arc<Db>, stream: TlsStream<TcpStream>, clients: Clients) {
  42. let connection_config = ConnectionConfig {
  43. max_bandwidth: 128000,
  44. welcome_text: "Welcome!".to_string(),
  45. };
  46. let connection = match Connection::setup_connection(db.clone(), stream, connection_config).await
  47. {
  48. Ok(connection) => connection,
  49. Err(_) => {
  50. eprintln!("Error establishing a connection");
  51. return;
  52. }
  53. };
  54. let session_id = connection.session_id;
  55. let (client, mut response_receiver) = Client::new(connection, db).await;
  56. {
  57. let mut clients = clients.write().await;
  58. for client in clients.values() {
  59. client.post_message(Message::UserConnected(session_id))
  60. }
  61. clients.insert(session_id, client);
  62. }
  63. loop {
  64. let message = match response_receiver.recv().await {
  65. Some(msg) => msg,
  66. None => return,
  67. };
  68. match message {
  69. ResponseMessage::Disconnected => {
  70. let mut clients = clients.write().await;
  71. clients.remove(&session_id);
  72. for client in clients.values() {
  73. client.post_message(Message::UserDisconnected(session_id))
  74. }
  75. return;
  76. }
  77. ResponseMessage::Talking(audio_data) => {
  78. let clients = clients.read().await;
  79. for client in clients
  80. .values()
  81. .filter(|client| client.session_id != session_id)
  82. {
  83. client.post_message(Message::UserTalking(audio_data.clone()));
  84. }
  85. }
  86. }
  87. }
  88. }