db.rs 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. use std::collections::HashMap;
  2. use serde::{Deserialize, Serialize};
  3. use std::sync::atomic::{AtomicU32, Ordering};
  4. use tokio::sync::RwLock;
  5. const ROOT_CHANNEL_ID: u32 = 0;
  6. const USER_TREE_NAME: &[u8] = b"users";
  7. const CHANNEL_TREE_NAME: &[u8] = b"channels";
  8. type SessionId = u32;
  9. pub struct Db {
  10. db: sled::Db,
  11. users: sled::Tree,
  12. channels: sled::Tree,
  13. connected_users: RwLock<HashMap<SessionId, User>>,
  14. next_session_id: AtomicU32,
  15. }
  16. #[derive(Clone)]
  17. pub struct User {
  18. pub id: Option<u32>,
  19. pub username: String,
  20. pub channel_id: u32,
  21. pub session_id: SessionId,
  22. }
  23. #[derive(Serialize, Deserialize)]
  24. struct PersistentUserData {
  25. id: u32,
  26. username: String,
  27. channel_id: u32,
  28. }
  29. #[derive(Serialize, Deserialize)]
  30. pub struct Channel {
  31. pub id: u32,
  32. pub name: String,
  33. }
  34. impl Db {
  35. pub fn open(path_to_db_file: &str) -> Self {
  36. let db = sled::open(path_to_db_file).expect("Unable to open database");
  37. let users = db.open_tree(USER_TREE_NAME).unwrap();
  38. let channels = db.open_tree(CHANNEL_TREE_NAME).unwrap();
  39. let root_channel = bincode::serialize(&Channel {
  40. id: 0,
  41. name: "Root".to_string(),
  42. })
  43. .unwrap();
  44. channels
  45. .compare_and_swap(
  46. ROOT_CHANNEL_ID.to_be_bytes(),
  47. Option::<&[u8]>::None,
  48. Some(root_channel),
  49. )
  50. .unwrap();
  51. Db {
  52. db,
  53. users,
  54. channels,
  55. connected_users: RwLock::new(HashMap::new()),
  56. next_session_id: AtomicU32::new(0),
  57. }
  58. }
  59. pub async fn add_new_user(&self, username: String) -> u32 {
  60. let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
  61. let mut connected_users = self.connected_users.write().await;
  62. connected_users.insert(
  63. session_id,
  64. User {
  65. id: None,
  66. username,
  67. channel_id: ROOT_CHANNEL_ID,
  68. session_id,
  69. },
  70. );
  71. session_id
  72. }
  73. pub async fn get_channels(&self) -> Vec<Channel> {
  74. self.channels
  75. .iter()
  76. .values()
  77. .map(|channel| bincode::deserialize(&channel.unwrap()).unwrap())
  78. .collect()
  79. }
  80. pub async fn get_connected_users(&self) -> Vec<User> {
  81. let users = self.connected_users.read().await;
  82. users.values().cloned().collect()
  83. }
  84. pub async fn get_user_by_session_id(&self, session_id: u32) -> Option<User> {
  85. let connected_users = self.connected_users.read().await;
  86. if let Some(user) = connected_users.get(&session_id) {
  87. return Some(user.clone());
  88. }
  89. None
  90. }
  91. pub async fn remove_connected_user(&self, session_id: u32) {
  92. let mut connected_users = self.connected_users.write().await;
  93. connected_users.remove(&session_id);
  94. }
  95. }