Browse Source

adding song recommendation feature for spotify

Zhizhou Ma 3 years ago
parent
commit
2d7047b4d2

+ 1 - 1
.spotify_token_cache.json

@@ -1 +1 @@
-{"access_token":"BQDtmxHBq0XMlIKy8C0cp9oFRcxGaA64qwR0CbZd4JS-Sr09JkN6ZGIpwUhjxJrhwlCmFvjCup5hwAlYwCM","expires_in":3600,"expires_at":"2021-06-11T08:05:59.739209300Z","refresh_token":null,"scope":""}
+{"access_token":"BQAniTN-wEbdSkII6q4VyxilAvCembgGe03hLmElnTWYJfKbdmTPWu_e0rEqq3_6iWTzzt7dHpg1XlzInC0","expires_in":3600,"expires_at":"2021-06-23T18:29:14.432292300Z","refresh_token":null,"scope":""}

+ 1 - 0
Cargo.lock

@@ -1030,6 +1030,7 @@ dependencies = [
 name = "octave_rust"
 version = "0.1.0"
 dependencies = [
+ "futures",
  "lazy_static",
  "rand 0.8.3",
  "rspotify",

+ 1 - 0
Cargo.toml

@@ -7,6 +7,7 @@ edition = "2018"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+futures = "0.3.5"
 lazy_static = "1.4.0"
 rand = "0.8.3"
 rspotify = {git = "https://github.com/ramsayleung/rspotify"}

+ 61 - 1
src/audio/audio.rs

@@ -26,7 +26,7 @@ use crate::util::{
 };
 
 #[group]
-#[commands(join,disconnect,play,skip,pause,resume,change_loop,shuffle,clear,splay,queue)]
+#[commands(join,disconnect,play,splay,cure,extend,skip,pause,resume,change_loop,shuffle,clear,queue)]
 struct Audio;
 
 lazy_static! {
@@ -145,6 +145,66 @@ async fn splay(ctx: &Context, msg: &Message, args: Args) -> CommandResult{
     Ok(())
 }
 
+#[command]
+async fn cure(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult{
+    let query= match args.single::<String>(){
+        Ok(query) => query,
+        Err(_) => {
+            send_embed(ctx, msg, "Error: invalid Spotify playlist").await;
+            return Ok(())
+        }
+    };
+    let amount = match args.single::<usize>(){
+        Ok(amount) => amount,
+        Err(_) => {
+            20
+        }
+    };
+
+    let audio_state = get_audio_state(ctx, msg).await;
+    let audio_state = match audio_state{
+        Some(audio_state) => audio_state,
+        None => return Ok(())
+    };
+
+    AudioState::add_recommended_songs(audio_state, &query, amount).await;
+
+    message_react(ctx, msg, "🍻").await;
+    message_react(ctx, msg, "🎶").await;
+
+    Ok(())
+}
+
+#[command]
+async fn extend(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult{
+    let query= match args.single::<String>(){
+        Ok(query) => query,
+        Err(_) => {
+            send_embed(ctx, msg, "Error: invalid Spotify playlist").await;
+            return Ok(())
+        }
+    };
+    let extend_ratio = match args.single::<f64>(){
+        Ok(amount) => amount,
+        Err(_) => {
+            0.5
+        }
+    };
+
+    let audio_state = get_audio_state(ctx, msg).await;
+    let audio_state = match audio_state{
+        Some(audio_state) => audio_state,
+        None => return Ok(())
+    };
+
+    AudioState::extend_songs(audio_state, &query, extend_ratio).await;
+
+    message_react(ctx, msg, "🍻").await;
+    message_react(ctx, msg, "🎶").await;
+
+    Ok(())
+}
+
 #[command]
 async fn skip(ctx: &Context, msg: &Message) -> CommandResult{
     let audio_state = get_audio_state(ctx, msg).await;

+ 40 - 4
src/audio/audio_state.rs

@@ -8,7 +8,10 @@ use super::{
     song::{
         Song,
     },
-    query::process_query,
+    song_searcher::{
+        process_query,
+        song_recommender,
+    },
     subprocess::ffmpeg_pcm,
 };
 use songbird::{Call, Event, EventContext, EventHandler as VoiceEventHandler, TrackEvent, 
@@ -112,6 +115,37 @@ impl AudioState{
         audio_state.queue.push(songs).await;
     }
 
+    pub async fn add_recommended_songs(audio_state: Arc<AudioState>, query: &str, amount: usize){
+        let songs = match song_recommender(query, amount).await{
+            Ok(songs) => songs,
+            Err(why) => {
+                println!("Error add_recommended_songs: {}", why);
+                return;
+            },
+        };
+        audio_state.queue.push(songs).await;
+    }
+
+    pub async fn extend_songs(audio_state: Arc<AudioState>, query: &str, extend_ratio: f64){
+        let mut songs = match process_query(query).await{
+            Ok(songs) => songs,
+            Err(why) => {
+                println!("Error extend_songs: {}", why);
+                return;
+            },
+        };
+        let recommended_songs = match song_recommender(query, (songs.len() as f64 * extend_ratio) as usize).await{
+            Ok(songs) => songs,
+            Err(why) => {
+                println!("Error add_recommended_songs: {}", why);
+                return;
+            },
+        };
+        songs.extend(recommended_songs);
+        songs.shuffle(&mut rand::thread_rng());
+        audio_state.queue.push(songs).await;
+    }
+
     pub async fn send_track_command(audio_state: Arc<AudioState>, cmd: TrackCommand) -> Result<(), String>{
         let track_handle = audio_state.track_handle.lock().await;
         match &*track_handle {
@@ -134,9 +168,11 @@ impl AudioState{
     }
 
     pub async fn change_looping(audio_state: Arc<AudioState>) -> Result<bool, String>{
-        let current_song = audio_state.current_song.lock().await;
-        if current_song.is_none() {
-            return Err("no song is playing".to_string());
+        {
+            let current_song = audio_state.current_song.lock().await;
+            if current_song.is_none() {
+                return Err("no song is playing".to_string());
+            }
         }
         let mut is_looping = audio_state.is_looping.lock().await;
         *is_looping = !*is_looping;

+ 6 - 0
src/audio/config.rs

@@ -0,0 +1,6 @@
+
+pub mod spotify_recommend {
+    pub const SAME_ARTIST: u32 = 1;
+    pub const EXPLORE_ARTIST: u32 = 1;
+    pub const EXPLORE_ALBUM: u32 = 0;
+}

+ 3 - 2
src/audio/mod.rs

@@ -2,11 +2,12 @@ pub mod audio;
 pub mod audio_state;
 
 mod song;
-mod loader;
+mod youtube_loader;
 mod song_queue;
 mod subprocess;
 mod work;
 mod spotify;
-mod query;
+mod song_searcher;
+mod config;
 
 pub use audio::*;

+ 0 - 48
src/audio/query.rs

@@ -1,48 +0,0 @@
-use super::{
-    spotify::get_playlist,
-    song::{
-        Song,
-        SongMetadata,
-    },
-    work::Work,
-};
-
-pub async fn process_query(query: &str) -> Result<Vec<(Song, Option<Work>)>, String>{
-    if query.contains("spotify") && query.contains("/playlist/"){
-        let split: Vec<&str> = query
-            .split("/playlist/")
-            .filter(|s| !s.is_empty())
-            .collect();
-        if split.len() != 2 {
-            return Err("invalid spotify playlist URL".to_string());
-        }
-        let playlist_id = split[1];
-        let playlist_id = playlist_id
-            .split('?')
-            .find(|s| !s.is_empty())
-            .expect("Logical error: process_query's playlist_id contains items?");
-        let songs = match get_playlist(playlist_id).await{
-            Ok(songs) => songs,
-            Err(why) => return Err(why),
-        };
-        return Ok(songs);
-    } else {
-        let data = if query.contains("watch?v=") {
-            (Some(query.to_string()), None)
-        } else {
-            (None, Some(query.to_string()))
-        };
-        let metadata = SongMetadata{
-            artist: None,
-            title: None,
-            duration: None,
-            search_query: data.1,
-            youtube_url: data.0,
-        };
-        let song = match Song::new_load(metadata){
-            Some(song) => song,
-            None => return Err("failed to get song from YouTube".to_string()),
-        };
-        return Ok(vec![song]);
-    };
-}

+ 7 - 5
src/audio/song_queue.rs

@@ -2,6 +2,7 @@ use std::{
     sync::Arc,
     collections::VecDeque,
     mem::drop,
+    cmp::min,
 };
 use tokio::sync::{Semaphore, Mutex};
 use rand::seq::SliceRandom;
@@ -10,11 +11,11 @@ use super::{
         Song,
         SongUrlState,
     },
-    loader::Loader,
+    youtube_loader::YoutubeLoader,
     work::Work,
 };
 pub struct SongQueue{
-    loader: Arc<Mutex<Loader>>,
+    loader: Arc<Mutex<YoutubeLoader>>,
     queue: Arc<Mutex<VecDeque<Song>>>,
     queue_sem: Semaphore,
 }
@@ -22,7 +23,7 @@ pub struct SongQueue{
 impl SongQueue{
     pub fn new() -> SongQueue {
         SongQueue{
-            loader: Arc::new(Mutex::new(Loader::new())),
+            loader: Arc::new(Mutex::new(YoutubeLoader::new())),
             queue: Arc::new(Mutex::new(VecDeque::new())),
             queue_sem: Semaphore::new(0),
         }
@@ -75,7 +76,7 @@ impl SongQueue{
     async fn reset_loader(&self) {
         let mut loader = self.loader.lock().await;
         loader.cleanup().await;
-        *loader = Loader::new();
+        *loader = YoutubeLoader::new();
     }
     pub async fn cleanup(&self) {
         let mut loader = self.loader.lock().await;
@@ -87,7 +88,8 @@ impl SongQueue{
             return "*empty*".to_string();
         };
         let mut s = String::new();
-        for song in queue.iter(){
+        s.push_str(&format!("*Showing {} of {} songs*\n", min(20, queue.len()), queue.len()));
+        for song in queue.iter().take(20){
             s += &song.get_string().await;
             s += "\n";
         }

+ 85 - 0
src/audio/song_searcher.rs

@@ -0,0 +1,85 @@
+use super::{
+    spotify::SpotifyClient,
+    song::{
+        Song,
+        SongMetadata,
+    },
+    work::Work,
+};
+
+use std::sync::Arc;
+
+pub async fn process_query(query: &str) -> Result<Vec<(Song, Option<Work>)>, String>{
+    if query.contains("spotify") && query.contains("/playlist/"){
+        let split: Vec<&str> = query
+            .split("/playlist/")
+            .filter(|s| !s.is_empty())
+            .collect();
+        if split.len() != 2 {
+            return Err("invalid spotify playlist URL".to_string());
+        }
+        let playlist_id = split[1];
+        let playlist_id = playlist_id
+            .split('?')
+            .find(|s| !s.is_empty())
+            .expect("Logical error: process_query's playlist_id contains items?");
+            
+        let client = SpotifyClient::new().await;
+        let client = match client {
+            Ok(client) => client,
+            Err(why) => return Err(why),
+        };
+        let tracks = match client.get_playlist(playlist_id).await{
+            Ok(tracks) => tracks,
+            Err(why) => return Err(why),
+        };
+        return Ok(SpotifyClient::process_track_objects(tracks));
+        /*
+        
+        */
+    } else {
+        let data = if query.contains("watch?v=") {
+            (Some(query.to_string()), None)
+        } else {
+            (None, Some(query.to_string()))
+        };
+        let metadata = SongMetadata{
+            artist: None,
+            title: None,
+            duration: None,
+            search_query: data.1,
+            youtube_url: data.0,
+        };
+        let song = match Song::new_load(metadata){
+            Some(song) => song,
+            None => return Err("failed to get song from YouTube".to_string()),
+        };
+        return Ok(vec![song]);
+    };
+}
+
+pub async fn song_recommender(query: &str, amount: usize) -> Result<Vec<(Song, Option<Work>)>, String>{
+    let split: Vec<&str> = query
+        .split("/playlist/")
+        .filter(|s| !s.is_empty())
+        .collect();
+    if split.len() != 2 {
+        return Err("invalid spotify playlist URL".to_string());
+    }
+    let playlist_id = split[1];
+    let playlist_id = playlist_id
+        .split('?')
+        .find(|s| !s.is_empty())
+        .expect("Logical error: process_query's playlist_id contains items?");
+    
+    let client = SpotifyClient::new().await;
+    let client = match client {
+        Ok(client) => Arc::new(client),
+        Err(why) => return Err(why),
+    };
+    let tracks = match SpotifyClient::recommend_playlist(client, amount, playlist_id).await{
+        Ok(tracks) => tracks,
+        Err(why) => return Err(why),
+    };
+    return Ok(SpotifyClient::process_track_objects(tracks));
+}

+ 246 - 50
src/audio/spotify.rs

@@ -1,74 +1,270 @@
 use rspotify::{
-    client::SpotifyBuilder,
+    client::{
+        Spotify,
+        SpotifyBuilder,
+    },
     model::{
         Id,
         PlayableItem,
+        Market,
+        Country,
+        FullTrack,
+        SimplifiedTrack,
     },
     oauth2::CredentialsBuilder,
 };
 
+use rand::{
+    seq::IteratorRandom,
+    distributions::{
+        WeightedIndex,
+        Distribution,
+    },
+    Rng,
+};
+
+use tokio::{
+    self,
+    time::{
+        sleep,
+        Duration,
+    },
+};
+
 use super::{
     song::{
         Song,
         SongMetadata,
     },
     work::Work,
+    config::spotify_recommend as sr,
 };
 
-pub async fn get_playlist(playlist_id: &str) -> Result<Vec<(Song, Option<Work>)>, String>{
-    let creds = CredentialsBuilder::default()
-        .id("5f573c9620494bae87890c0f08a60293")
-        .secret("212476d9b0f3472eaa762d90b19b0ba8")
-        .build()
-        .unwrap();
-    let mut spotify = SpotifyBuilder::default()
-        .credentials(creds)
-        //.oauth(oauth)
-        .build()
-        .unwrap();
-    if let Err(why) = spotify.request_client_token().await{
-        println!("error: {}", why);
-    };
-    let playlist_id = Id::from_id(playlist_id);
-    let playlist_id = match playlist_id{
-        Ok(playlist_id) => playlist_id,
-        Err(why) => {
-            return Err(format!("spotify::get_playlist: {:?}", why));
+use std::sync::Arc;
+
+pub enum TrackObject{
+    FullTrack(FullTrack),
+    SimplifiedTrack(SimplifiedTrack),
+}
+
+impl TrackObject{
+    fn artist(&self) -> &str{
+        match self{
+            TrackObject::FullTrack(track) => &track.artists[0].name,
+            TrackObject::SimplifiedTrack(track) => &track.artists[0].name,
         }
-    };
-    let tracks = spotify.playlist(playlist_id, None, None).await;
-    let tracks = match tracks{
-        Ok(tracks) => tracks,
-        Err(why)=>{
-            println!("Error in spotify.get_playlist: {:?}", why);
-            return Err(format!("spotify::get_playlist: {:?}", why));
+    }
+    fn title(&self) -> &str{
+        match self{
+            TrackObject::FullTrack(track) => &track.name,
+            TrackObject::SimplifiedTrack(track) => &track.name,
+        }
+    }
+    fn duration(&self) -> u64{
+        match self{
+            TrackObject::FullTrack(track) => track.duration.as_secs(),
+            TrackObject::SimplifiedTrack(track) => track.duration.as_secs(),
+        }
+    }
+    fn album_id(&self) -> Option<&str>{
+        match self{
+            TrackObject::FullTrack(track) => track.album.id.as_deref(),
+            TrackObject::SimplifiedTrack(_) => None,
+        }
+    }
+    fn artist_id(&self) -> Option<&str>{
+        match self{
+            TrackObject::FullTrack(track) => track.artists[0].id.as_deref(),
+            TrackObject::SimplifiedTrack(track) => track.artists[0].id.as_deref(),
+        }
+    }
+}
+
+pub struct SpotifyClient{
+    client: Spotify,
+}
+
+impl SpotifyClient{
+    pub async fn new() -> Result<SpotifyClient, String> {
+        let creds = CredentialsBuilder::default()
+            .id("5f573c9620494bae87890c0f08a60293")
+            .secret("212476d9b0f3472eaa762d90b19b0ba8")
+            .build();
+        let creds = match creds{
+            Ok(creds) => creds,
+            Err(why) => return Err(why.to_string()),
+        };
+        let mut spotify = SpotifyBuilder::default()
+            .credentials(creds)
+            //.oauth(oauth)
+            .build()
+            .unwrap();
+        if let Err(why) = spotify.request_client_token().await{
+            return Err(why.to_string());
+        };
+        Ok(SpotifyClient{
+            client: spotify,
+        })
+    }
+    pub fn process_track_objects(tracks: Vec<TrackObject>) -> Vec<(Song, Option<Work>)> {
+        let mut songs = vec![];
+        for track in tracks.into_iter(){
+            let artist = track.artist();
+            let title = track.title();
+            let metadata = SongMetadata{
+                search_query: Some(SpotifyClient::get_query_string(artist, title)),
+                artist: Some(artist.to_string()),
+                title: Some(title.to_string()),
+                youtube_url: None,
+                duration: Some(track.duration()),
+            };
+            match Song::new_load(metadata){
+                Some(data) => songs.push(data),
+                None => continue,
+            };
         }
-    };
-    let mut songs = Vec::new();
-    let tracks = tracks.tracks.items;
-    for data in tracks.iter() {
-        let track = match &data.track{
-            Some(PlayableItem::Track(track)) => track,
-            Some(_) => continue,
-            None => continue,
+        songs
+    }
+    pub async fn get_playlist(&self, playlist_id: &str) -> Result<Vec<TrackObject>, String>{
+        let playlist_id = Id::from_id(playlist_id);
+        let playlist_id = match playlist_id{
+            Ok(playlist_id) => playlist_id,
+            Err(why) => {
+                return Err(format!("spotify::get_playlist: {:?}", why));
+            }
+        };
+        let tracks = self.client.playlist(playlist_id, None, None).await;
+        let tracks = match tracks{
+            Ok(tracks) => tracks,
+            Err(why)=>{
+                println!("Error in spotify.get_playlist: {:?}", why);
+                return Err(format!("spotify::get_playlist: {:?}", why));
+            }
         };
-        let artist = &track.artists[0].name;
-        let title = &track.name;
-        let metadata = SongMetadata{
-            search_query: Some(get_query_string(artist, title)),
-            artist: Some(artist.clone()),
-            title: Some(title.clone()),
-            youtube_url: None,
-            duration: Some(track.duration.as_secs()),
+        let items = tracks.tracks.items;
+        let mut tracks = vec![];
+        for data in items.into_iter() {
+            let track = match data.track{
+                Some(PlayableItem::Track(track)) => track,
+                Some(_) => continue,
+                None => continue,
+            };
+            tracks.push(TrackObject::FullTrack(track));
+        }
+        Ok(tracks)
+    }
+    async fn random_from_artist(&self, id: &str) -> Option<TrackObject>{
+        let id = match Id::from_id(id){
+            Ok(id) => id,
+            Err(why) => {
+                println!("Error {:?}",why);
+                return None;
+            }
         };
-        match Song::new_load(metadata){
-            Some(data) => songs.push(data),
-            None => continue,
+        let tracks = self.client.artist_top_tracks(id, &Market::Country(Country::Japan)).await;
+        match tracks {
+            Ok(tracks) => Some(TrackObject::FullTrack(tracks.into_iter().choose(&mut rand::thread_rng())?)),
+            Err(why) => {
+                println!("Error SpotifyClient::random_from_artist: {:?}", why);
+                None
+            },
+        }
+    }
+    async fn random_from_album(&self, id: &str) -> Option<TrackObject>{
+        let id = match Id::from_id(id){
+            Ok(id) => id,
+            Err(why) => {
+                println!("Error {:?}",why);
+                return None;
+            }
         };
+        let album = self.client.album(id).await;
+        match album {
+            Ok(album) => {
+                let tracks = album.tracks.items;
+                Some(TrackObject::SimplifiedTrack(tracks.into_iter().choose(&mut rand::thread_rng())?))
+            },
+            Err(why) => {
+                println!("Error SpotifyClient::random_from_album: {:?}", why);
+                None
+            },
+        }
     }
-    Ok(songs)
-}
+    //  -> Result<Vec<(Song, Option<Work>)>, String>
+    pub async fn recommend_playlist(client: Arc<SpotifyClient>, amount: usize, playlist_id: &str) -> Result<Vec<TrackObject>, String>{
+        let mut tasks = vec![];
+        let tracks = match client.get_playlist(playlist_id).await{
+            Ok(tracks) => tracks,
+            Err(why) => return Err(why),
+        };
+        let tracks = Arc::new(tracks);
+        //let tracks = tracks.sample(&mut rand::thread_rng(), amount);
+        for _ in 0..amount{
+            let tracks = tracks.clone();
+            let ind = rand::thread_rng().gen::<u32>() as usize % tracks.len();
+            
+            let client = client.clone();
+            let task = tokio::spawn(
+                async move {
+                    let track = &tracks[ind];
 
-fn get_query_string(artist: &str, title: &str) -> String{
-    format!("{} {} lyrics", artist, title)
+                    let weights = [sr::SAME_ARTIST, sr::EXPLORE_ALBUM, sr::EXPLORE_ARTIST];
+                    let option = WeightedIndex::new(&weights).unwrap().sample(&mut rand::thread_rng());
+
+                    match option{
+                        // find random song from track artist
+                        0 => {
+                            let artist = match track.artist_id(){
+                                Some(artist) => artist,
+                                None => return None,
+                            };
+                            client.random_from_artist(artist).await
+                        },
+                        // find random song from track album
+                        1 => {
+                            let album = match track.album_id(){
+                                Some(album) => album,
+                                None => {
+                                    println!("album not found");
+                                    return None
+                                },
+                            };
+                            client.random_from_album(album).await
+                        },
+                        // find random song from a random similar artist
+                        _ => {
+                            let artist = match track.artist_id(){
+                                Some(artist) => artist,
+                                None => return None,
+                            };
+                            let id = Id::from_id(artist).unwrap();
+                            let artists = client.client.artist_related_artists(id).await;
+                            let artists = match artists{
+                                Ok(artists) => artists[..5].to_vec(),
+                                Err(why) => {
+                                    println!("Error related artists: {:?}", why);
+                                    return None
+                                }
+                            };
+                            let id = &artists.iter().choose(&mut rand::thread_rng()).unwrap().id;
+                            client.random_from_artist(id).await
+                        },
+                    }
+                }
+            );
+            tasks.push(task);
+            sleep(Duration::from_millis(100)).await;
+        };
+        let mut tracks = vec![];
+        for task in tasks.into_iter(){
+            match task.await.unwrap(){
+                Some(track) => tracks.push(track),
+                None => continue,
+            }
+        }
+        Ok(tracks)
+    }
+    fn get_query_string(artist: &str, title: &str) -> String{
+        format!("{} {} lyrics", artist, title)
+    }
 }

+ 6 - 6
src/audio/loader.rs → src/audio/youtube_loader.rs

@@ -6,12 +6,12 @@ use super::{
     subprocess::ytdl,
 };
 
-pub struct Loader {
+pub struct YoutubeLoader {
     work: mpsc::Sender<Work>,
     kill: mpsc::Sender<()>,
 }
 
-impl Loader{
+impl YoutubeLoader{
     pub async fn add_work(& self, work: Work){
         if let Err(err) = self.work.send(work).await{
             println!("Error in Loader::add_work: {}", err.to_string());
@@ -39,13 +39,13 @@ impl Loader{
         };
     }
     
-    pub fn new() -> Loader{
+    pub fn new() -> YoutubeLoader{
         let (work_tx, work_sx) = mpsc::channel(200);
         let (kill_tx, kill_sx) = mpsc::channel(1);
         tokio::spawn(async move{
-            Loader::start_loader_loop(work_sx, kill_sx).await
+            YoutubeLoader::start_loader_loop(work_sx, kill_sx).await
         });
-        Loader{
+        YoutubeLoader{
             work: work_tx,
             kill: kill_tx,
         }
@@ -53,7 +53,7 @@ impl Loader{
     
     async fn start_loader_loop(work: mpsc::Receiver<Work>, mut kill: mpsc::Receiver<()>){
         let f = tokio::spawn(async move {
-            Loader::loader_loop(work).await
+            YoutubeLoader::loader_loop(work).await
         });
         kill.recv().await;
         f.abort();