소스 검색

Use a single Mutex per AudioState

Frans Bergman 3 년 전
부모
커밋
402ea25fd7
4개의 변경된 파일91개의 추가작업 그리고 119개의 파일을 삭제
  1. 23 11
      src/audio/audio.rs
  2. 51 86
      src/audio/audio_state.rs
  3. 1 1
      src/audio/song.rs
  4. 16 21
      src/audio/song_queue.rs

+ 23 - 11
src/audio/audio.rs

@@ -19,10 +19,11 @@ use crate::util::{message_react, send_embed};
 struct Audio;
 
 lazy_static! {
-    static ref AUDIO_STATES: Mutex<HashMap<GuildId, Arc<AudioState>>> = Mutex::new(HashMap::new());
+    static ref AUDIO_STATES: Mutex<HashMap<GuildId, Arc<Mutex<AudioState>>>> =
+        Mutex::new(HashMap::new());
 }
 
-async fn get_audio_state(ctx: &Context, msg: &Message) -> Option<Arc<AudioState>> {
+async fn get_audio_state(ctx: &Context, msg: &Message) -> Option<Arc<Mutex<AudioState>>> {
     let guild = msg.guild(&ctx.cache).await.unwrap();
     let guild_id = guild.id;
 
@@ -80,13 +81,18 @@ async fn get_audio_state(ctx: &Context, msg: &Message) -> Option<Arc<AudioState>
     }
 
     match audio_states.get(&guild_id) {
-        Some(state) => {
-            let state = state.clone();
-            AudioState::set_context(state.clone(), ctx, msg).await;
-            Some(state)
+        Some(state_mutex) => {
+            let mut state = state_mutex.lock().await;
+            state.set_context(ctx, msg).await;
+            drop(state);
+            Some(state_mutex.clone())
         }
         None => {
             let audio_state = AudioState::new(manager.get(guild_id).unwrap(), ctx, msg);
+            let cloned_as = audio_state.clone();
+            tokio::spawn(async {
+                AudioState::play_audio(cloned_as).await;
+            });
             audio_states.insert(guild_id, audio_state.clone());
             Some(audio_state)
         }
@@ -157,8 +163,9 @@ async fn skip(ctx: &Context, msg: &Message) -> CommandResult {
         Some(audio_state) => audio_state,
         None => return Ok(()),
     };
+    let mut audio_state = audio_state.lock().await;
 
-    if let Err(why) = AudioState::send_track_command(audio_state, TrackCommand::Stop).await {
+    if let Err(why) = audio_state.send_track_command(TrackCommand::Stop).await {
         send_embed(ctx, msg, &format!("Error: {}", why)).await;
     } else {
         message_react(ctx, msg, "↪").await;
@@ -173,8 +180,9 @@ async fn pause(ctx: &Context, msg: &Message) -> CommandResult {
         Some(audio_state) => audio_state,
         None => return Ok(()),
     };
+    let mut audio_state = audio_state.lock().await;
 
-    if let Err(why) = AudioState::send_track_command(audio_state, TrackCommand::Pause).await {
+    if let Err(why) = audio_state.send_track_command(TrackCommand::Pause).await {
         send_embed(ctx, msg, &format!("Error: {}", why)).await;
     } else {
         message_react(ctx, msg, "⏸").await;
@@ -189,8 +197,9 @@ async fn resume(ctx: &Context, msg: &Message) -> CommandResult {
         Some(audio_state) => audio_state,
         None => return Ok(()),
     };
+    let mut audio_state = audio_state.lock().await;
 
-    if let Err(why) = AudioState::send_track_command(audio_state, TrackCommand::Play).await {
+    if let Err(why) = audio_state.send_track_command(TrackCommand::Play).await {
         send_embed(ctx, msg, &format!("Error: {}", why)).await;
     } else {
         message_react(ctx, msg, "▶").await;
@@ -205,8 +214,9 @@ async fn clear(ctx: &Context, msg: &Message) -> CommandResult {
         Some(audio_state) => audio_state,
         None => return Ok(()),
     };
+    let mut audio_state = audio_state.lock().await;
 
-    if let Err(why) = AudioState::clear(audio_state.clone()).await {
+    if let Err(why) = audio_state.clear().await {
         send_embed(ctx, msg, &format!("Error: {}", why)).await;
     } else {
         message_react(ctx, msg, "🗑").await;
@@ -223,7 +233,9 @@ async fn queue(ctx: &Context, msg: &Message) -> CommandResult {
         None => return Ok(()),
     };
 
-    send_embed(ctx, msg, &AudioState::get_string(audio_state).await).await;
+    let audio_state = audio_state.lock().await;
+    let text = audio_state.get_string();
+    send_embed(ctx, msg, &text).await;
 
     Ok(())
 }

+ 51 - 86
src/audio/audio_state.rs

@@ -12,42 +12,44 @@ use songbird::{
     tracks::{TrackCommand, TrackHandle},
     Call, Event, EventContext, EventHandler as VoiceEventHandler, TrackEvent,
 };
+use std::sync::Arc;
 use std::time::Duration;
-use std::{mem::drop, sync::Arc};
 use tokio::sync::Mutex;
 use tokio::time::sleep;
 
 pub struct AudioState {
     queue: SongQueue,
     handler: Arc<SerenityMutex<Call>>,
-    current_song: Mutex<Option<Song>>,
-    track_handle: Mutex<Option<TrackHandle>>,
-    is_looping: Mutex<bool>,
+    current_song: Option<Song>,
+    track_handle: Option<TrackHandle>,
 
-    channel_id: Mutex<ChannelId>,
-    http: Mutex<Arc<Http>>,
+    channel_id: ChannelId,
+    http: Arc<Http>,
 }
 
 impl AudioState {
-    pub fn new(handler: Arc<SerenityMutex<Call>>, ctx: &Context, msg: &Message) -> Arc<AudioState> {
-        let audio_state = AudioState {
+    pub fn new(
+        handler: Arc<SerenityMutex<Call>>,
+        ctx: &Context,
+        msg: &Message,
+    ) -> Arc<Mutex<AudioState>> {
+        let audio_state = Arc::new(Mutex::new(AudioState {
             queue: SongQueue::new(),
             handler,
-            current_song: Mutex::new(None),
-            track_handle: Mutex::new(None),
-            is_looping: Mutex::new(false),
+            current_song: None,
+            track_handle: None,
 
-            channel_id: Mutex::new(msg.channel_id),
-            http: Mutex::new(ctx.http.clone()),
-        };
-        let audio_state = Arc::new(audio_state);
+            channel_id: msg.channel_id,
+            http: ctx.http.clone(),
+        }));
         let my_audio_state = audio_state.clone();
         tokio::spawn(async move {
             // Leave if no music is playing within 1 minute
             sleep(Duration::from_secs(60)).await;
-            let current_song = my_audio_state.current_song.lock().await;
+            let audio_state = my_audio_state.lock().await;
+            let current_song = audio_state.current_song.as_ref();
             if current_song.is_none() {
-                let mut handler = my_audio_state.handler.lock().await;
+                let mut handler = audio_state.handler.lock().await;
                 if let Err(e) = handler.leave().await {
                     println!("Automatic leave failed: {:?}", e);
                 }
@@ -56,31 +58,19 @@ impl AudioState {
         audio_state
     }
 
-    pub async fn set_context(audio_state: Arc<AudioState>, ctx: &Context, msg: &Message) {
-        {
-            let mut channel_id = audio_state.channel_id.lock().await;
-            *channel_id = msg.channel_id;
-        }
-        {
-            let mut http = audio_state.http.lock().await;
-            *http = ctx.http.clone();
-        }
+    pub async fn set_context(self: &mut AudioState, ctx: &Context, msg: &Message) {
+        self.channel_id = msg.channel_id;
+        self.http = ctx.http.clone();
     }
 
-    async fn play_audio(audio_state: Arc<AudioState>) {
-        let is_looping = audio_state.is_looping.lock().await;
-        let song = if *is_looping {
-            let mut current_song = audio_state.current_song.lock().await;
-            current_song.take()
-        } else {
-            audio_state.queue.pop().await
-        };
-        drop(is_looping);
+    pub async fn play_audio(state_mutex: Arc<Mutex<AudioState>>) {
+        let mut state = state_mutex.lock().await;
+        let song = state.queue.pop().await;
 
         let song = match song {
             Some(song) => song,
             None => {
-                let mut handler = audio_state.handler.lock().await;
+                let mut handler = state.handler.lock().await;
                 if let Err(e) = handler.leave().await {
                     println!("Error leaving channel: {:?}", e);
                 }
@@ -99,36 +89,30 @@ impl AudioState {
         let reader = Reader::Extension(source);
         let source = input::Input::float_pcm(true, reader);
 
-        let mut handler = audio_state.handler.lock().await;
+        let mut handler = state.handler.lock().await;
 
         let handle = handler.play_source(source);
+        drop(handler);
 
         if let Err(why) = handle.add_event(
             Event::Track(TrackEvent::End),
             SongEndNotifier {
-                audio_state: audio_state.clone(),
+                audio_state: state_mutex.clone(),
             },
         ) {
             panic!("Err AudioState::play_audio: {:?}", why);
         }
-        {
-            let text = song.get_string().await;
-            let channel_id = audio_state.channel_id.lock().await;
-            let http = audio_state.http.lock().await;
-            send_embed_http(
-                *channel_id,
-                http.clone(),
-                &format!("Now playing:\n\n {}", text),
-            )
-            .await;
-        }
-        let mut current_song = audio_state.current_song.lock().await;
-        *current_song = Some(song);
-        let mut track_handle = audio_state.track_handle.lock().await;
-        *track_handle = Some(handle);
+        send_embed_http(
+            state.channel_id,
+            state.http.clone(),
+            &format!("Now playing:\n\n {}", song.get_string()),
+        )
+        .await;
+        state.current_song = Some(song);
+        state.track_handle = Some(handle);
     }
 
-    pub async fn add_audio(audio_state: Arc<AudioState>, query: &str) {
+    pub async fn add_audio(audio_state: Arc<Mutex<AudioState>>, query: &str) {
         let song = match Song::from_query(query).await {
             Ok(song) => song,
             Err(why) => {
@@ -136,23 +120,14 @@ impl AudioState {
                 return;
             }
         };
-        audio_state.queue.push(vec![song]).await;
-        let current_song = audio_state.current_song.lock().await;
-        if current_song.is_none() {
-            let audio_state = audio_state.clone();
-            tokio::spawn(async {
-                AudioState::play_audio(audio_state).await;
-            });
-        }
+        let mut state = audio_state.lock().await;
+        state.queue.push(vec![song]).await;
+        if state.current_song.is_none() {}
     }
 
-    pub async fn send_track_command(
-        audio_state: Arc<AudioState>,
-        cmd: TrackCommand,
-    ) -> Result<(), String> {
-        let track_handle = audio_state.track_handle.lock().await;
-        match track_handle.as_ref() {
-            Some(track_handle) => match track_handle.send(cmd) {
+    pub async fn send_track_command(&mut self, cmd: TrackCommand) -> Result<(), String> {
+        match self.track_handle.as_ref() {
+            Some(th) => match th.send(cmd) {
                 Ok(()) => Ok(()),
                 Err(why) => Err(format!("{:?}", why)),
             },
@@ -160,35 +135,25 @@ impl AudioState {
         }
     }
 
-    // on success, returns a bool that specifies whether the queue is now being looped
-    pub async fn change_looping(audio_state: Arc<AudioState>) -> Result<bool, String> {
-        {
-            let current_song = audio_state.current_song.lock().await;
-            if current_song.is_none() {
-                return Err("no song is playing".to_string());
-            }
-        }
-        let mut is_looping = audio_state.is_looping.lock().await;
-        *is_looping = !*is_looping;
-        Ok(*is_looping)
+    pub async fn clear(&mut self) -> Result<(), String> {
+        self.queue.clear().await
     }
 
-    pub async fn get_string(audio_state: Arc<AudioState>) -> String {
-        let current_song = audio_state.current_song.lock().await;
-        let current_song = match &*current_song {
-            Some(song) => song.get_string().await,
+    pub fn get_string(&self) -> String {
+        let current_song = match self.current_song.as_ref() {
+            Some(song) => song.get_string(),
             None => "*Not playing*\n".to_string(),
         };
         format!(
             "**Current Song:**\n{}\n\n**Queue:**\n{}",
             current_song,
-            audio_state.queue.get_string().await
+            self.queue.get_string()
         )
     }
 }
 
 struct SongEndNotifier {
-    audio_state: Arc<AudioState>,
+    audio_state: Arc<Mutex<AudioState>>,
 }
 
 #[async_trait]

+ 1 - 1
src/audio/song.rs

@@ -31,7 +31,7 @@ impl Song {
         Ok(song)
     }
 
-    pub async fn get_string(&self) -> String {
+    pub fn get_string(&self) -> String {
         let artist = self
             .metadata
             .artist

+ 16 - 21
src/audio/song_queue.rs

@@ -1,48 +1,43 @@
 use super::song::Song;
-use std::{cmp::min, collections::VecDeque, sync::Arc};
-use tokio::sync::Mutex;
+use std::{cmp::min, collections::VecDeque};
 pub struct SongQueue {
-    queue: Arc<Mutex<VecDeque<Song>>>,
+    queue: VecDeque<Song>,
 }
 
 impl SongQueue {
     pub fn new() -> SongQueue {
         SongQueue {
-            queue: Arc::new(Mutex::new(VecDeque::new())),
+            queue: VecDeque::new(),
         }
     }
-    pub async fn push(&self, songs: Vec<Song>) {
-        let mut queue = self.queue.lock().await;
+    pub async fn push(&mut self, songs: Vec<Song>) {
         for item in songs.into_iter() {
-            queue.push_back(item);
+            self.queue.push_back(item);
         }
     }
-    pub async fn pop(&self) -> Option<Song> {
-        let mut queue = self.queue.lock().await;
-        queue.pop_front()
+    pub async fn pop(&mut self) -> Option<Song> {
+        self.queue.pop_front()
     }
-    pub async fn clear(&self) -> Result<(), String> {
-        let mut queue = self.queue.lock().await;
-        if queue.len() == 0 {
+    pub async fn clear(&mut self) -> Result<(), String> {
+        if self.queue.len() == 0 {
             return Err("queue is empty".to_string());
         };
-        queue.clear();
+        self.queue.clear();
         Ok(())
     }
-    pub async fn get_string(&self) -> String {
-        let queue = self.queue.lock().await;
-        if queue.len() == 0 {
+    pub fn get_string(&self) -> String {
+        if self.queue.len() == 0 {
             return "*empty*".to_string();
         };
         let mut s = String::new();
         s.push_str(&format!(
             "*Showing {} of {} songs*\n",
-            min(20, queue.len()),
-            queue.len()
+            min(20, self.queue.len()),
+            self.queue.len()
         ));
-        for (i, song) in queue.iter().take(20).enumerate() {
+        for (i, song) in self.queue.iter().take(20).enumerate() {
             s += &format!("{}: ", i);
-            s += &song.get_string().await;
+            s += &song.get_string();
             s += "\n";
         }
         s