Преглед изворни кода

Do not leave when a song is loading

Frans Bergman пре 1 година
родитељ
комит
23060fe13c
3 измењених фајлова са 51 додато и 23 уклоњено
  1. 21 11
      src/audio/audio.rs
  2. 29 11
      src/audio/audio_state.rs
  3. 1 1
      src/audio/song.rs

+ 21 - 11
src/audio/audio.rs

@@ -147,7 +147,7 @@ async fn disconnect(ctx: &Context, msg: &Message) -> CommandResult {
 
 #[command]
 async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
-    let query = args.rest();
+    let query = args.rest().to_string();
 
     message_react(ctx, msg, "🎶").await;
 
@@ -157,16 +157,26 @@ async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
         None => return Ok(()),
     };
 
-    match Song::from_query(msg.author.clone(), query).await {
-        Ok(song) => {
-            AudioState::add_audio(audio_state, song).await;
-            message_react(ctx, msg, "✅").await;
-        }
-        Err(why) => {
-            message_react(ctx, msg, "❎").await;
-            send_embed(ctx, msg, &format!("Error: {}", why)).await;
-        }
-    }
+    let song = {
+        let ctx = ctx.clone();
+        let msg = msg.clone();
+        tokio::spawn(async move {
+            let song = Song::from_query(msg.author.clone(), query).await;
+            match song {
+                Ok(song) => {
+                    message_react(&ctx, &msg, "✅").await;
+                    Some(song)
+                }
+                Err(why) => {
+                    message_react(&ctx, &msg, "❎").await;
+                    send_embed(&ctx, &msg, &format!("Error: {}", why)).await;
+                    None
+                }
+            }
+        })
+    };
+
+    AudioState::add_audio(audio_state, song).await;
 
     Ok(())
 }

+ 29 - 11
src/audio/audio_state.rs

@@ -16,9 +16,11 @@ use std::time::Duration;
 use std::sync::Arc;
 use tokio::sync::Mutex;
 use tokio::time::sleep;
+use tokio::task::JoinHandle;
 
 pub struct AudioState {
     queue: SongQueue,
+    in_flight: usize,
     handler: Arc<SerenityMutex<Call>>,
     current_song: Option<Song>,
     track_handle: Option<TrackHandle>,
@@ -33,6 +35,7 @@ impl AudioState {
     pub fn new(handler: Arc<SerenityMutex<Call>>, ctx: &Context, msg: &Message) -> Arc<Mutex<AudioState>> {
         let audio_state = AudioState {
             queue: SongQueue::new(),
+            in_flight: 0,
             handler,
             current_song: None,
             track_handle: None,
@@ -76,9 +79,12 @@ impl AudioState {
         let song = match song {
             Some(song) => song,
             None => {
-                let mut handler = state.handler.lock().await;
-                if let Err(e) = handler.leave().await {
-                    println!("Error leaving channel: {:?}", e);
+                state.current_song = None;
+                if state.in_flight == 0 {
+                    let mut handler = state.handler.lock().await;
+                    if let Err(e) = handler.leave().await {
+                        println!("Error leaving channel: {:?}", e);
+                    }
                 }
                 return;
             }
@@ -124,14 +130,26 @@ impl AudioState {
         state.track_handle = Some(handle);
     }
 
-    pub async fn add_audio(audio_state: Arc<Mutex<AudioState>>, song: Song) {
-        let mut state = audio_state.lock().await;
-        state.queue.push(vec![song]);
-        if state.current_song.is_none() {
-            let audio_state = audio_state.clone();
-            tokio::spawn(async {
-                AudioState::play_audio(audio_state).await;
-            });
+    pub async fn add_audio(audio_state: Arc<Mutex<AudioState>>, song: JoinHandle<Option<Song>>) {
+        {
+            let mut state = audio_state.lock().await;
+            state.in_flight += 1;
+        }
+        {
+            if let Ok(Some(song)) = song.await {
+                let mut state = audio_state.lock().await;
+                state.queue.push(vec![song]);
+                if state.current_song.is_none() {
+                    let audio_state = audio_state.clone();
+                    tokio::spawn(async {
+                        AudioState::play_audio(audio_state).await;
+                    });
+                }
+            }
+        }
+        {
+            let mut state = audio_state.lock().await;
+            state.in_flight -= 1;
         }
     }
 

+ 1 - 1
src/audio/song.rs

@@ -11,7 +11,7 @@ pub struct Song {
 }
 
 impl Song {
-    pub async fn from_query(user: User, query: &str) -> Result<Song, std::io::Error> {
+    pub async fn from_query(user: User, query: String) -> Result<Song, std::io::Error> {
         let query = if query.contains("watch?v=") {
             query.to_string()
         } else {