From 14be0e93b49a42d2a3e63258a0aeff211f304fab Mon Sep 17 00:00:00 2001 From: Victor Mignot Date: Mon, 13 Feb 2023 00:05:50 +0100 Subject: [PATCH] Implementing core functionnalities of the bot --- Cargo.toml | 2 +- log4rs_config.yaml | 15 +- src/client.rs | 32 +- src/database/client.rs | 178 +++++----- src/database/models.rs | 24 +- src/discord.rs | 1 + src/discord/commands.rs | 14 +- src/discord/commands/bulk_create_tag.rs | 22 -- src/discord/commands/commands.rs | 49 --- src/discord/commands/commons.rs | 77 +++++ src/discord/commands/create_tag.rs | 141 +++++--- src/discord/commands/delete_tag.rs | 139 +++++++- src/discord/commands/list_tags.rs | 98 +++++- src/discord/commands/source_code.rs | 74 ++++- src/discord/commands/subscribe.rs | 150 +++++++++ src/discord/commands/tag_notify.rs | 164 ++++++++++ src/discord/event_handler.rs | 178 +++++++++- src/discord/message_builders.rs | 2 + src/discord/message_builders/embed_builder.rs | 113 +++++++ .../message_builders/selector_builder.rs | 305 ++++++++++++++++++ src/environment.rs | 2 +- src/lib.rs | 2 +- 22 files changed, 1516 insertions(+), 266 deletions(-) delete mode 100644 src/discord/commands/bulk_create_tag.rs delete mode 100644 src/discord/commands/commands.rs create mode 100644 src/discord/commands/commons.rs create mode 100644 src/discord/commands/subscribe.rs create mode 100644 src/discord/commands/tag_notify.rs create mode 100644 src/discord/message_builders.rs create mode 100644 src/discord/message_builders/embed_builder.rs create mode 100644 src/discord/message_builders/selector_builder.rs diff --git a/Cargo.toml b/Cargo.toml index ba41558..0495a73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -serenity = { version="0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "collector" ] } +serenity = { version="0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "model", "collector", "absolute_ratelimits" ] } tokio = { version = "1", features = ["macros", "rt-multi-thread"] } mongodb = { version = "2.3.0", default-features = false, features = ["tokio-runtime"] } serde = { version = "1.0", features = [ "derive" ] } diff --git a/log4rs_config.yaml b/log4rs_config.yaml index 5b9ff44..52b5452 100644 --- a/log4rs_config.yaml +++ b/log4rs_config.yaml @@ -38,18 +38,19 @@ appenders: count: 20 root: - level: warn + level: info appenders: - stdout loggers: - bot_infos: - level: info - appenders: - - rolling_debug + serenity: + level: error - bot_warn_errors: - level: warn + tracing: + level: error + + logs: + level: info appenders: - rolling_logs diff --git a/src/client.rs b/src/client.rs index 36cfe91..26ed17b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::Mutex; use serenity::{prelude::GatewayIntents, Client as SerenityClient}; @@ -6,21 +8,33 @@ use crate::{database::Client as DatabaseClient, environment::get_env_variable}; use crate::discord::event_handler::Handler; +/// The Yorokobot client. +/// +/// To launch Yorokobot, you have to set the following environment variables: +/// - DISCORD_TOKEN: The secret Discord provide you when creating a new bot in the +/// Discord Developper websites. +/// - MONGODB_URI: The connection string to your Mongo database. +/// - MONGODB_DATABASE: The database to use in your Mongo instance (falcultative if given in the +/// MONGODB_URI connection string). pub struct Client { serenity_client: SerenityClient, - database_client: Arc>, } impl Client { + /// Create a new Yorokobot instance pub async fn new() -> Self { - let database_client = Arc::new(Mutex::new(DatabaseClient::new())); - database_client.clone().lock().unwrap().connect(); + let mut database_client = DatabaseClient::new(); + + database_client.connect().await; let discord_token = get_env_variable("DISCORD_TOKEN"); - let intents = GatewayIntents::GUILD_MESSAGES | GatewayIntents::MESSAGE_CONTENT; + let intents = GatewayIntents::GUILD_MESSAGES + | GatewayIntents::MESSAGE_CONTENT + | GatewayIntents::GUILD_MESSAGE_REACTIONS; let event_handler = Handler { - database: database_client.clone(), + database: Arc::new(database_client), + users_with_running_selector: Arc::new(Mutex::new(HashSet::new())), }; let serenity_client = match SerenityClient::builder(discord_token, intents) @@ -31,12 +45,10 @@ impl Client { Err(e) => panic!("Failed to instantiate Discord Client: {e}"), }; - Client { - serenity_client, - database_client, - } + Client { serenity_client } } + /// Start the bot, connecting it to the database and the Discord API. pub async fn start(&mut self) { if let Err(e) = self.serenity_client.start().await { panic!("Could not connect the bot: {e}"); diff --git a/src/database/client.rs b/src/database/client.rs index 1764ea0..c9906c2 100644 --- a/src/database/client.rs +++ b/src/database/client.rs @@ -1,15 +1,15 @@ use std::collections::HashSet; use futures::TryStreamExt; +use log::{error, info, trace}; use mongodb::{ - bson::{doc, from_bson, Bson, Document}, + bson::{doc, from_bson, to_document, Bson, Document}, Client as MongoClient, Collection, Database, }; -use serde::{Deserialize, Serialize}; use crate::environment::get_env_variable; -use super::models::{YorokobotModel, COLLECTIONS_NAMES}; +use super::models::{YorokobotCollection, COLLECTIONS_NAMES}; /// Database client pub struct Client { @@ -26,42 +26,58 @@ impl Client { } } + pub fn get_database(&self) -> &Database { + match &self.database { + Some(db) => db, + None => { + error!("Tried to access to the database before instantiating it"); + panic!(); + } + } + } + /// Connect the client pub async fn connect(&mut self) { self.mongo_client = match MongoClient::with_uri_str(get_env_variable("MONGODB_URI")).await { - Ok(c) => Some(c), - Err(e) => panic!("Failed to connect to Mongo database: {e}"), + Ok(c) => { + info!("Successfully connected to the Mongo database"); + Some(c) + } + Err(e) => { + error!("Failed to connect to the Mongo database: {e:#?}"); + panic!(); + } }; - self.database = Some( - self.mongo_client - .as_ref() - .unwrap() - .database(get_env_variable("MONGODB_DATABASE").as_str()), - ); - - // TODO: - // Complete error kind to be more specific. - // Ex: DatabaseConnection + self.database = match &self.mongo_client { + Some(c) => Some(c.database(get_env_variable("MONGODB_DATABASE").as_str())), + None => { + error!("Got an unexpected None from self.database"); + panic!(); + } + }; self.check_init_error().await; } async fn check_init_error(&mut self) { - self.check_collections_presence().await; + info!("Launching initial database checks"); + + let database = self.get_database(); + + Self::check_collections_presence(database).await; } - async fn check_collections_presence(&mut self) { + async fn check_collections_presence(db: &Database) { + trace!("Starting the collections presence check for the database"); + let mut missing_collections: Vec<&str> = vec![]; - let collections: HashSet = match self - .database - .as_ref() - .unwrap() - .list_collection_names(None) - .await - { + let collections: HashSet = match db.list_collection_names(None).await { Ok(n) => n.into_iter().collect(), - Err(e) => panic!("Could not list collections: {e}"), + Err(e) => { + error!("Failed to get the collections for the database: {e:#?}"); + panic!(); + } }; for col in COLLECTIONS_NAMES { @@ -71,71 +87,71 @@ impl Client { } if !missing_collections.is_empty() { - panic!( - "Missing the following the following collections: {}", + error!( + "Missing the following collections in the Database: {}", missing_collections.join(", ") ); + panic!(); } } - fn get_collection(&self) -> Collection { - self.database - .as_ref() - .expect("Could not retrieve database") - .collection(&T::get_collection_name()) + #[allow(dead_code)] + fn get_collection(&self) -> Collection { + self.get_database().collection(&T::get_collection_name()) } - fn get_typed_collection(&self) -> Collection { - self.database - .as_ref() - .expect("Could not retrieve database") + fn get_typed_collection(&self) -> Collection { + self.get_database() .collection::(&T::get_collection_name()) } - #[allow(dead_code)] - pub async fn get_by_id Deserialize<'de>>(&self, id: &str) -> T { - self.get_one(doc! {"_id": id}).await + pub async fn get_by_id(&self, id: &str) -> Result, ()> { + self.get_one::(doc! {"_id": id}).await } - #[allow(dead_code)] - pub async fn get_one Deserialize<'de>>( - &self, - filter: Document, - ) -> T { - let result = self - .get_collection::() + pub async fn get_one(&self, filter: Document) -> Result, ()> { + match self + .get_typed_collection::() .find_one(filter, None) .await - .expect("Could not issue request") - .expect("Could not find matching data"); - - from_bson(Bson::Document(result)).expect("Could not deserialize data") + { + Ok(e) => Ok(e), + Err(_) => Err(()), + } } #[allow(dead_code)] - pub async fn get_all Deserialize<'de>>( + pub async fn get_all( &self, filter: Option, - ) -> Vec { - let mut result: Vec = vec![]; + ) -> Result, ()> { + let mut matching_docs: Vec = vec![]; - let mut cursor = match filter { + let result = match filter { Some(f) => self.get_collection::().find(f, None).await, None => self.get_collection::().find(doc! {}, None).await, - } - .expect("Could not issue request"); + }; - while let Some(document) = cursor.try_next().await.expect("Could not fetch results") { - result - .push(from_bson(Bson::Document(document)).expect("Could not deserialize document")); + let mut cursor = match result { + Ok(c) => c, + Err(_) => return Err(()), + }; + + while let Some(document) = match cursor.try_next().await { + Ok(e) => e, + Err(_) => return Err(()), + } { + match from_bson(Bson::Document(document)) { + Ok(d) => matching_docs.push(d), + Err(_) => return Err(()), + } } - return result; + Ok(matching_docs) } - #[allow(dead_code)] // TODO: Set true error handling - pub async fn insert_one(&self, document: T) -> Result<(), ()> { + pub async fn insert_one(&self, document: T) -> Result<(), ()> { match self .get_typed_collection::() .insert_one(document, None) @@ -146,12 +162,9 @@ impl Client { } } - #[allow(dead_code)] // TODO: Set true error handling - pub async fn insert_many( - &self, - documents: Vec, - ) -> Result<(), ()> { + #[allow(dead_code)] + pub async fn insert_many(&self, documents: Vec) -> Result<(), ()> { match self .get_typed_collection::() .insert_many(documents, None) @@ -163,7 +176,8 @@ impl Client { } // TODO: Set true error handling - pub async fn delete_one(&self, document: Document) -> Result { + #[allow(dead_code)] + pub async fn delete_one(&self, document: Document) -> Result { match self .get_typed_collection::() .delete_one(document, None) @@ -174,9 +188,9 @@ impl Client { } } - #[allow(dead_code)] // TODO: Set true error handling - pub async fn delete_by_id(&self, id: &str) -> Result<(), ()> { + #[allow(dead_code)] + pub async fn delete_by_id(&self, id: &str) -> Result<(), ()> { match self .get_typed_collection::() .delete_one(doc! {"_id": id}, None) @@ -187,9 +201,9 @@ impl Client { } } - #[allow(dead_code)] // TODO: Set true error handling - pub async fn delete_many(&self, document: Document) -> Result<(), ()> { + #[allow(dead_code)] + pub async fn delete_many(&self, document: Document) -> Result<(), ()> { match self .get_typed_collection::() .delete_many(document, None) @@ -200,16 +214,20 @@ impl Client { } } - #[allow(dead_code)] //TODO: Set true error handling - pub async fn update_one( + pub async fn update_one( &self, - document: Document, + object: T, update: Document, ) -> Result<(), ()> { + let serialized_doc = match to_document(&object) { + Ok(d) => d, + Err(_) => return Err(()), + }; + match self .get_typed_collection::() - .update_one(document, update, None) + .update_one(serialized_doc, update, None) .await { Ok(_) => Ok(()), @@ -217,9 +235,9 @@ impl Client { } } - #[allow(dead_code)] //TODO: Set true error handling - pub async fn update_by_id( + #[allow(dead_code)] + pub async fn update_by_id( &self, document_id: &str, update: Document, @@ -234,9 +252,9 @@ impl Client { } } - #[allow(dead_code)] //TODO: Set true error handling - pub async fn update_many( + #[allow(dead_code)] + pub async fn update_many( &self, document: Document, update: Document, diff --git a/src/database/models.rs b/src/database/models.rs index 9e5237c..987af91 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -1,29 +1,31 @@ //! Bot data models -#![allow(dead_code)] - -use mongodb::bson::oid::ObjectId; use serde::{Deserialize, Serialize}; -pub const COLLECTIONS_NAMES: [&str; 1] = ["tags"]; +pub const COLLECTIONS_NAMES: [&str; 1] = ["guilds"]; -pub trait YorokobotModel { +pub trait YorokobotCollection: for<'de> Deserialize<'de> + Serialize + Unpin + Send + Sync { fn get_collection_name() -> String; } /// Tags -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Tag { - #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] - pub id: Option, pub name: String, - pub guild_id: String, pub is_nsfw: bool, pub subscribers: Vec, } -impl YorokobotModel for Tag { +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Guild { + #[serde(rename = "_id")] + pub id: String, + pub ban_list: Vec, + pub tags: Vec, +} + +impl YorokobotCollection for Guild { fn get_collection_name() -> String { - "tags".to_string() + COLLECTIONS_NAMES[0].to_string() } } diff --git a/src/discord.rs b/src/discord.rs index c1f06ed..fd7c559 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -1,3 +1,4 @@ pub mod event_handler; +mod message_builders; mod commands; diff --git a/src/discord/commands.rs b/src/discord/commands.rs index 6791455..ef1e33c 100644 --- a/src/discord/commands.rs +++ b/src/discord/commands.rs @@ -1,6 +1,16 @@ -mod bulk_create_tag; -pub mod commands; +// mod bulk_create_tag; +pub mod commons; mod create_tag; mod delete_tag; mod list_tags; mod source_code; +mod subscribe; +mod tag_notify; + +pub use commons::BotCommand; +pub use create_tag::CreateTagCommand; +pub use delete_tag::DeleteTagCommand; +pub use list_tags::ListTagCommand; +pub use source_code::SourceCodeCommand; +pub use subscribe::SubscribeCommand; +pub use tag_notify::TagNotifyCommand; diff --git a/src/discord/commands/bulk_create_tag.rs b/src/discord/commands/bulk_create_tag.rs deleted file mode 100644 index 07ce50c..0000000 --- a/src/discord/commands/bulk_create_tag.rs +++ /dev/null @@ -1,22 +0,0 @@ -use serenity::{builder::CreateApplicationCommand, model::prelude::command::CommandOptionType}; - -pub fn register( - command: &mut CreateApplicationCommand, - max_args_number: u32, -) -> &mut CreateApplicationCommand { - command - .name("bulk_create_tag") - .description("Add multiples tags"); - - for i in 0..max_args_number { - command.create_option(|option| { - option - .name(format!("tag{}", i + 1)) - .description("A new tag to add") - .kind(CommandOptionType::String) - .required(i == 0) - }); - } - - command -} diff --git a/src/discord/commands/commands.rs b/src/discord/commands/commands.rs deleted file mode 100644 index f88846b..0000000 --- a/src/discord/commands/commands.rs +++ /dev/null @@ -1,49 +0,0 @@ -use serenity::{ - async_trait, - builder::CreateInteractionResponseData, - model::prelude::{ - command::{Command, CommandOptionType}, - interaction::application_command::ApplicationCommandInteraction, - }, - prelude::Context, -}; - -use crate::database::Client as DatabaseClient; - -pub struct BotCommandOption { - pub name: String, - pub description: String, - pub kind: CommandOptionType, - pub required: bool, -} - -#[async_trait] -pub trait BotCommand { - fn new(context: ApplicationCommandInteraction) -> Self; - fn name() -> String; - fn description() -> String; - fn options_list() -> Vec; - async fn run(&self, response: &mut CreateInteractionResponseData, database: &DatabaseClient); - - async fn register(context: &Context) { - match Command::create_global_application_command(context, |command| { - let mut new_command = command.name(Self::name()).description(Self::description()); - - for opt in Self::options_list() { - new_command = new_command.create_option(|option| { - option - .name(opt.name) - .description(opt.description) - .kind(opt.kind) - .required(opt.required) - }); - } - new_command - }) - .await - { - Ok(_) => println!("Successfully registered the {} command", Self::name()), - Err(_) => panic!("Failed to register the {} command", Self::name()), - }; - } -} diff --git a/src/discord/commands/commons.rs b/src/discord/commands/commons.rs new file mode 100644 index 0000000..c294130 --- /dev/null +++ b/src/discord/commands/commons.rs @@ -0,0 +1,77 @@ +use log::{info, warn}; +use std::sync::Arc; + +use serenity::{ + async_trait, + model::prelude::{ + command::{Command, CommandOptionType}, + interaction::application_command::{ + ApplicationCommandInteraction, CommandDataOption, CommandDataOptionValue, + }, + }, + prelude::Context, +}; + +use crate::database::Client as DatabaseClient; + +#[derive(Debug)] +pub enum CommandExecutionError { + ArgumentExtractionError(String), + ArgumentDeserializationError(String), + DatabaseQueryError(String), + ContextRetrievalError(String), + DiscordAPICallError(String), + SelectorError(String), + UnknownCommand(String), +} + +pub struct BotCommandOption { + pub name: String, + pub description: String, + pub kind: CommandOptionType, + pub required: bool, +} + +#[async_trait] +pub trait BotCommand { + fn new(interaction: ApplicationCommandInteraction) -> Self; + fn name() -> String; + fn description() -> String; + fn options_list() -> Vec; + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError>; + + fn extract_option_value( + options: &[CommandDataOption], + index: usize, + ) -> Option { + let serialized_opt = options.get(index)?; + serialized_opt.resolved.clone() + } + + async fn register(context: &Context) { + info!("Starting the {} command registration", Self::name()); + match Command::create_global_application_command(context, |command| { + let mut new_command = command.name(Self::name()).description(Self::description()); + + for opt in Self::options_list() { + new_command = new_command.create_option(|option| { + option + .name(opt.name) + .description(opt.description) + .kind(opt.kind) + .required(opt.required) + }); + } + new_command + }) + .await + { + Ok(_) => info!("Successfully registered the {} command", Self::name()), + Err(_) => warn!("Failed to register the {} command", Self::name()), + }; + } +} diff --git a/src/discord/commands/create_tag.rs b/src/discord/commands/create_tag.rs index 231100e..3e1d15c 100644 --- a/src/discord/commands/create_tag.rs +++ b/src/discord/commands/create_tag.rs @@ -1,22 +1,29 @@ -use mongodb::bson::doc; +use log::debug; +use std::sync::Arc; + +use mongodb::bson::{doc, to_bson}; use serenity::{ async_trait, - builder::CreateInteractionResponseData, model::{ application::interaction::application_command::ApplicationCommandInteraction, prelude::{ - command::CommandOptionType, interaction::application_command::CommandDataOptionValue, + command::CommandOptionType, + interaction::{application_command::CommandDataOptionValue, InteractionResponseType}, }, }, + prelude::Context, }; -use crate::database::{models::Tag, Client as DatabaseClient}; +use crate::database::{ + models::{Guild, Tag}, + Client as DatabaseClient, +}; -use super::commands::{BotCommand, BotCommandOption}; +use super::commons::{BotCommand, BotCommandOption, CommandExecutionError}; -struct CreateTagCommand { - context: ApplicationCommandInteraction, +pub struct CreateTagCommand { + interaction: ApplicationCommandInteraction, } #[async_trait] @@ -38,49 +45,93 @@ impl BotCommand for CreateTagCommand { }] } - fn new(context: ApplicationCommandInteraction) -> Self { - CreateTagCommand { context } + fn new(interaction: ApplicationCommandInteraction) -> Self { + debug!("Creating a new CreateTagCommand object"); + CreateTagCommand { interaction } } - async fn run(&self, response: &mut CreateInteractionResponseData, database: &DatabaseClient) { - let arg = self - .context - .data - .options - .get(0) - .expect("Missing option") - .resolved - .as_ref() - .expect("Could not deserialize option"); + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError> { + // Extract tag_name parameter + let tag_name = match self.interaction.data.options.get(0) { + Some(a) => match &a.resolved { + Some(r) => match r { + CommandDataOptionValue::String(r_str) => Ok(r_str), + _ => Err(CommandExecutionError::ArgumentDeserializationError( + "Received non String argument for the CreateTagCommand".to_string(), + )), + }, + None => Err(CommandExecutionError::ArgumentDeserializationError( + "Could not deserialize the argument for the CreateTagCommand".to_string(), + )), + }, + None => Err(CommandExecutionError::ArgumentExtractionError( + "Failed to get the CreateTagCommand argument".to_string(), + )), + }?; - let guild_id = self - .context - .guild_id - .expect("Could not fetch guild id") - .to_string(); + // Extract guild id from Serenity context + let guild_id = match self.interaction.guild_id { + Some(a) => Ok(a.to_string()), + None => Err(CommandExecutionError::ContextRetrievalError( + "Could not fetch guild id from issued command".to_string(), + )), + }?; - if let CommandDataOptionValue::String(tag_name) = arg { - let matching_tags = database - .get_all::(Some(doc! {"name": tag_name, "guild_id": guild_id.as_str()})) - .await; + let guild = match database.get_by_id::(&guild_id).await { + Ok(query) => match query { + Some(r) => Ok(r), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to retrieve the guild where the command was issued".to_string(), + )), + }, + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Could not access to the database".to_string(), + )), + }?; - if !matching_tags.is_empty() { - response.content("This tag already exist for this server"); - } else { - match database - .insert_one(Tag { - id: None, - name: tag_name.to_string(), - guild_id: guild_id.clone(), - is_nsfw: false, - subscribers: vec![], - }) - .await - { - Ok(_) => response.content("Tag successfully created."), - Err(_) => response.content("Error creating the tag"), - }; - } + let matching_tag = guild.tags.iter().find(|t| t.name == *tag_name); + + let response_content = if matching_tag.is_some() { + String::from("This tag already exist for this server.") + } else { + let mut new_tags = guild.tags.clone(); + new_tags.push(Tag { + name: tag_name.to_string(), + is_nsfw: false, + subscribers: vec![], + }); + + match database + .update_one( + guild, + doc! {"$set": { "tags": to_bson(&new_tags).unwrap() }}, + ) + .await + { + Ok(_) => Ok(String::from("Tag successfully created.")), + Err(_) => Err(CommandExecutionError::DatabaseQueryError( + "Could not add new tag to the database".to_string(), + )), + }? + }; + + match self + .interaction + .create_interaction_response(context.http, |response| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message| message.content(response_content)) + }) + .await + { + Ok(_) => Ok(()), + Err(_e) => Err(CommandExecutionError::DiscordAPICallError( + "Failed to answer to the initial command".to_string(), + )), } } } diff --git a/src/discord/commands/delete_tag.rs b/src/discord/commands/delete_tag.rs index 63e8c6b..be74257 100644 --- a/src/discord/commands/delete_tag.rs +++ b/src/discord/commands/delete_tag.rs @@ -1,14 +1,129 @@ -use serenity::{builder::CreateApplicationCommand, model::prelude::command::CommandOptionType}; +use std::sync::Arc; -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("delete_tag") - .description("Delete a tag") - .create_option(|option| { - option - .name("tag") - .description("The tag to delete") - .kind(CommandOptionType::String) - .required(true) - }) +use log::debug; +use mongodb::bson::{doc, to_bson}; + +use crate::database::{models::Guild, Client as DatabaseClient}; + +use serenity::{ + async_trait, + model::prelude::{ + command::CommandOptionType, + interaction::{ + application_command::{ApplicationCommandInteraction, CommandDataOptionValue}, + InteractionResponseType, + }, + }, + prelude::Context, +}; + +use super::{ + commons::{BotCommandOption, CommandExecutionError}, + BotCommand, +}; + +pub struct DeleteTagCommand { + interaction: ApplicationCommandInteraction, +} + +#[async_trait] +impl BotCommand for DeleteTagCommand { + fn new(interaction: ApplicationCommandInteraction) -> Self { + debug!("Creating a new DeleteTagCommand object"); + DeleteTagCommand { interaction } + } + + fn name() -> String { + String::from("delete_tag") + } + + fn description() -> String { + String::from("Delete a tag from the server") + } + + fn options_list() -> Vec { + vec![BotCommandOption { + name: String::from("tag"), + description: String::from("The tag to delete"), + kind: CommandOptionType::String, + required: true, + }] + } + + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError> { + let tag_name = match self.interaction.data.options.get(0) { + Some(a) => match &a.resolved { + Some(r) => match r { + CommandDataOptionValue::String(r_str) => Ok(r_str), + _ => Err(CommandExecutionError::ArgumentDeserializationError( + "Received non String argument for DeleteTagCommand".to_string(), + )), + }, + None => Err(CommandExecutionError::ArgumentDeserializationError( + "Failed to deserialize argument for DeleteTagCommand".to_string(), + )), + }, + None => Err(CommandExecutionError::ArgumentExtractionError( + "Failed to find argument in DeleteTagCommand".to_string(), + )), + }?; + + let guild_id = match self.interaction.guild_id { + Some(r) => Ok(r.to_string()), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to extract guild id from current context".to_string(), + )), + }?; + + let guild = match database.get_by_id::(&guild_id).await { + Ok(query) => match query { + Some(r) => Ok(r), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to retrieve the guild where the command was issued".to_string(), + )), + }, + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Failed to access to the database".to_string(), + )), + }?; + + let response: String = + if let Some(tag_index) = guild.tags.iter().position(|t| t.name == *tag_name) { + let mut clone_tags = guild.tags.clone(); + clone_tags.remove(tag_index); + + match database + .update_one::( + guild, + doc! {"$set": { "tags": to_bson(&clone_tags).unwrap() }}, + ) + .await + { + Ok(_) => Ok(String::from("Successfully remove the tag")), + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Failed to remove tag from the database".to_string(), + )), + }? + } else { + String::from("No matching tag for this server.") + }; + + match self + .interaction + .create_interaction_response(context.http, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message| message.content(response)) + }) + .await + { + Ok(()) => Ok(()), + Err(_e) => Err(CommandExecutionError::DiscordAPICallError( + "Failed to answer the initial command".to_string(), + )), + } + } } diff --git a/src/discord/commands/list_tags.rs b/src/discord/commands/list_tags.rs index 67b1101..faea1fc 100644 --- a/src/discord/commands/list_tags.rs +++ b/src/discord/commands/list_tags.rs @@ -1,5 +1,97 @@ -use serenity::builder::CreateApplicationCommand; +use std::sync::Arc; -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command.name("list_tags").description("List your own tags") +use log::debug; + +use serenity::{ + async_trait, + model::prelude::interaction::{ + application_command::ApplicationCommandInteraction, InteractionResponseType, + }, + prelude::Context, +}; + +use crate::database::{models::Guild, Client as DatabaseClient}; + +use super::{ + commons::{BotCommandOption, CommandExecutionError}, + BotCommand, +}; + +pub struct ListTagCommand { + interaction: ApplicationCommandInteraction, +} + +#[async_trait] +impl BotCommand for ListTagCommand { + fn name() -> String { + String::from("list_tags") + } + + fn description() -> String { + String::from("List available tags") + } + + fn options_list() -> Vec { + vec![] + } + + fn new(interaction: ApplicationCommandInteraction) -> Self { + debug!("Creating a new ListTagCommand object"); + ListTagCommand { interaction } + } + + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError> { + let guild_id = match self.interaction.guild_id { + Some(id) => Ok(id.to_string()), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to extract guild id from current context".to_string(), + )), + }?; + + let guild = match database.get_by_id::(&guild_id).await { + Ok(query) => match query { + Some(r) => Ok(r), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to retrieve the guild where the command was issued".to_string(), + )), + }, + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Failed to access to the database".to_string(), + )), + }?; + + let mut response_content: String; + + if guild.tags.is_empty() { + response_content = String::from("No tag available on this server."); + } else { + response_content = String::from("Available tags on this server:\n"); + for tag in guild.tags { + if response_content.len() + tag.name.len() < 1995 { + response_content.push_str(format!("`{}` ", tag.name.as_str()).as_str()); + } else { + response_content.push_str("..."); + } + } + } + + match self + .interaction + .create_interaction_response(context.http, |response| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message| message.content(response_content)) + }) + .await + { + Ok(()) => Ok(()), + Err(_) => Err(CommandExecutionError::DiscordAPICallError( + "Failed to answer to the initial command".to_string(), + )), + } + } } diff --git a/src/discord/commands/source_code.rs b/src/discord/commands/source_code.rs index ce2e919..daef252 100644 --- a/src/discord/commands/source_code.rs +++ b/src/discord/commands/source_code.rs @@ -1,17 +1,65 @@ -use serenity::builder::{CreateApplicationCommand, CreateInteractionResponseData}; +use log::debug; +use std::sync::Arc; -pub fn register(command: &mut CreateApplicationCommand) -> &mut CreateApplicationCommand { - command - .name("source_code") - .description("Access to the bot source code") +use serenity::{ + async_trait, + model::prelude::interaction::{ + application_command::ApplicationCommandInteraction, InteractionResponseType, + }, + prelude::Context, +}; + +use crate::{ + database::Client as DatabaseClient, + discord::message_builders::embed_builder::EmbedMessageBuilder, +}; + +use super::commons::{BotCommand, BotCommandOption, CommandExecutionError}; + +pub struct SourceCodeCommand { + interaction: ApplicationCommandInteraction, } -pub fn run<'a, 'b>( - response: &'a mut CreateInteractionResponseData<'b>, -) -> &'a mut CreateInteractionResponseData<'b> { - response.embed(|embed| { - embed - .title("Yorokobot repository") - .description("https://sr.ht/~victormignot/yorokobot/") - }) +#[async_trait] +impl BotCommand for SourceCodeCommand { + fn name() -> String { + String::from("about") + } + + fn description() -> String { + String::from("Display the bot credentials") + } + + fn options_list() -> Vec { + vec![] + } + + fn new(interaction: ApplicationCommandInteraction) -> Self { + debug!("Creating a new SourceCodeCommand object"); + SourceCodeCommand { interaction } + } + + async fn run( + &self, + context: Context, + _: Arc, + ) -> Result<(), CommandExecutionError> { + let embed_builder = EmbedMessageBuilder::new(&context).await?; + match self + .interaction + .create_interaction_response(context.http, |response| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|r| { + r.set_embed(embed_builder.create_bot_credentials_embed()) + }) + }) + .await + { + Ok(()) => Ok(()), + Err(_e) => Err(CommandExecutionError::DiscordAPICallError( + "Failed to answer to the issued command".to_string(), + )), + } + } } diff --git a/src/discord/commands/subscribe.rs b/src/discord/commands/subscribe.rs new file mode 100644 index 0000000..cd7ae85 --- /dev/null +++ b/src/discord/commands/subscribe.rs @@ -0,0 +1,150 @@ +use std::{collections::HashSet, sync::Arc}; + +use log::debug; +use mongodb::bson::{doc, to_bson}; + +use serenity::{ + async_trait, + model::prelude::interaction::{ + application_command::ApplicationCommandInteraction, InteractionResponseType, + }, + prelude::Context, +}; + +use super::{ + commons::{BotCommandOption, CommandExecutionError}, + BotCommand, +}; +use crate::{ + database::{client::Client as DatabaseClient, models::Guild}, + discord::message_builders::selector_builder::EmbedSelector, +}; + +pub struct SubscribeCommand { + interaction: ApplicationCommandInteraction, +} + +#[async_trait] +impl BotCommand for SubscribeCommand { + fn new(interaction: ApplicationCommandInteraction) -> Self { + debug!("Creating a new SubscribeCommand object"); + SubscribeCommand { interaction } + } + + fn name() -> String { + String::from("subscribe") + } + + fn description() -> String { + String::from("Subscribe to a selection of tags") + } + + fn options_list() -> Vec { + vec![] + } + + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError> { + self.interaction + .create_interaction_response(&context.http, |response| { + response.kind(InteractionResponseType::DeferredChannelMessageWithSource) + }) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to answer with a temporary response".to_string(), + ) + })?; + + let user_id = self.interaction.user.id.to_string(); + let guild_id = match self.interaction.guild_id { + Some(id) => Ok(id.to_string()), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to extract guild id from current context".to_string(), + )), + }?; + + let guild = match database.get_by_id::(&guild_id).await { + Ok(query) => match query { + Some(r) => Ok(r), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to retrieve the guild where the command was issued".to_string(), + )), + }, + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Failed to access to the database".to_string(), + )), + }?; + + let mut available_tags_str: Vec = Vec::new(); + let mut user_subscriptions: HashSet = HashSet::new(); + + for tag in guild.tags.as_slice() { + available_tags_str.push(tag.name.clone()); + if tag.subscribers.contains(&user_id) { + user_subscriptions.insert(tag.name.clone()); + } + } + + let mut selector = EmbedSelector::new( + "Select the tags you want to subscribe to".to_string(), + Self::description(), + &self.interaction, + &context, + available_tags_str, + Some(user_subscriptions), + ); + + let user_selection = selector.get_user_selection().await?; + + if let Some(selection) = user_selection { + let mut cloned_tags = guild.tags.clone(); + for t in cloned_tags.as_mut_slice() { + if selection.contains(&t.name) && !t.subscribers.contains(&user_id) { + t.subscribers.push(user_id.clone()); + } else if !selection.contains(&t.name) && t.subscribers.contains(&user_id) { + t.subscribers.retain(|x| *x != user_id); + } + } + + database + .update_one::( + guild, + doc! {"$set": {"tags": to_bson(&cloned_tags.to_vec()).unwrap()}}, + ) + .await + .map_err(|_| { + CommandExecutionError::DatabaseQueryError( + "Failed to update user subscriptions in database".to_string(), + ) + })?; + } + + let mut response = match self.interaction.get_interaction_response(&context).await { + Ok(r) => Ok(r), + Err(_e) => Err(CommandExecutionError::ContextRetrievalError( + "Failed to fetch initial interaction response".to_string(), + )), + }?; + + response.delete_reactions(&context).await.map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to remove reactions from initial response".to_string(), + ) + })?; + + response + .edit(&context, |msg| msg.suppress_embeds(true).content("Done !")) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to edit content of the original response message".to_string(), + ) + })?; + + Ok(()) + } +} diff --git a/src/discord/commands/tag_notify.rs b/src/discord/commands/tag_notify.rs new file mode 100644 index 0000000..f09c2d0 --- /dev/null +++ b/src/discord/commands/tag_notify.rs @@ -0,0 +1,164 @@ +use std::sync::Arc; + +use log::debug; + +use serenity::{ + async_trait, + model::prelude::{ + interaction::{ + application_command::ApplicationCommandInteraction, InteractionResponseType, + }, + UserId, + }, + prelude::{Context, Mentionable}, +}; + +use super::{ + commons::{BotCommandOption, CommandExecutionError}, + BotCommand, +}; + +use crate::{ + database::{models::Guild, Client as DatabaseClient}, + discord::message_builders::selector_builder::EmbedSelector, +}; + +pub struct TagNotifyCommand { + interaction: ApplicationCommandInteraction, +} + +#[async_trait] +impl BotCommand for TagNotifyCommand { + fn name() -> String { + debug!("Creating a new TagNotifyCommand object"); + String::from("notify") + } + + fn new(interaction: ApplicationCommandInteraction) -> Self { + TagNotifyCommand { interaction } + } + + fn description() -> String { + String::from("Ping users according to a list of tag") + } + + fn options_list() -> Vec { + vec![] + } + + async fn run( + &self, + context: Context, + database: Arc, + ) -> Result<(), CommandExecutionError> { + self.interaction + .create_interaction_response(&context.http, |response| { + response.kind(InteractionResponseType::DeferredChannelMessageWithSource) + }) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to answer with a temporary response".to_string(), + ) + })?; + + let guild_id = match self.interaction.guild_id { + Some(id) => Ok(id.to_string()), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to extract guild id from current context".to_string(), + )), + }?; + + let guild = match database.get_by_id::(&guild_id).await { + Ok(query) => match query { + Some(r) => Ok(r), + None => Err(CommandExecutionError::ContextRetrievalError( + "Failed to retrieve the guild where the command was issued".to_string(), + )), + }, + Err(()) => Err(CommandExecutionError::DatabaseQueryError( + "Failed to access to the database".to_string(), + )), + }?; + + let mut available_tags_str: Vec = Vec::new(); + + for tag in guild.tags.as_slice() { + available_tags_str.push(tag.name.clone()); + } + + let mut selector = EmbedSelector::new( + "Select the tags to notify".to_string(), + Self::description(), + &self.interaction, + &context, + available_tags_str, + None, + ); + + let user_selection = selector.get_user_selection().await?; + + if let Some(selection) = user_selection { + let mut answer = String::new(); + for selected_tag in selection { + let t = match guild.tags.iter().find(|s| s.name == selected_tag) { + Some(t) => Ok(t), + None => Err(CommandExecutionError::ArgumentExtractionError( + "No matching tag found for selection".to_string(), + )), + }?; + + for user_id in t.subscribers.as_slice() { + match user_id.parse::() { + Ok(id) => match &UserId(id).to_user(&context).await { + Ok(e) => { + answer += &e.mention().to_string(); + } + Err(_e) => {} + }, + Err(_e) => {} + } + } + } + let mut response = match self.interaction.get_interaction_response(&context).await { + Ok(r) => Ok(r), + Err(_e) => Err(CommandExecutionError::ContextRetrievalError( + "Failed to fetch initial interaction response".to_string(), + )), + }?; + + response.delete_reactions(&context).await.map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to remove reactions from initial response".to_string(), + ) + })?; + + response + .edit(&context, |msg| { + msg.suppress_embeds(true); + msg.content(if answer.is_empty() { + "Nobody to ping for your selection".to_string() + } else { + answer + }) + }) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to edit content of the original response message".to_string(), + ) + })?; + } else { + self.interaction + .delete_original_interaction_response(&context) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to delete the original interaction message".to_string(), + ) + })?; + } + + Ok(()) + } +} diff --git a/src/discord/event_handler.rs b/src/discord/event_handler.rs index 7bf826c..7d9a11f 100644 --- a/src/discord/event_handler.rs +++ b/src/discord/event_handler.rs @@ -1,34 +1,194 @@ -use std::sync::{Arc, Mutex}; +use log::{debug, info, warn}; +use std::{collections::HashSet, sync::Arc}; +use tokio::sync::Mutex; -use crate::database::Client as DatabaseClient; +use super::commands::{ + BotCommand, CreateTagCommand, ListTagCommand, SourceCodeCommand, TagNotifyCommand, +}; + +use crate::{ + database::{models::Guild, Client as DatabaseClient}, + discord::commands::{commons::CommandExecutionError, DeleteTagCommand, SubscribeCommand}, +}; use serenity::{ async_trait, model::gateway::Ready, - model::prelude::{interaction::Interaction, ResumedEvent}, + model::prelude::{ + interaction::{ + application_command::ApplicationCommandInteraction, Interaction, + InteractionResponseType, + }, + ResumedEvent, UserId, + }, prelude::{Context, EventHandler}, }; -const MAX_ARGS_NUMBER: u32 = 25; +async fn answer_with_error(ctx: &Context, interaction: &ApplicationCommandInteraction) { + const ERROR_MSG: &str = "Internal error while executing your command."; + let result = match interaction.get_interaction_response(&ctx).await { + Ok(mut m) => { + m.edit(&ctx, |msg| msg.suppress_embeds(true).content(ERROR_MSG)) + .await + } + Err(_) => { + interaction + .create_interaction_response(&ctx, |msg| { + msg.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|c| c.content(ERROR_MSG)) + }) + .await + } + }; + + if let Err(e) = result { + warn!("Could not reply to user with error message: {e:#?}"); + } +} pub struct Handler { - pub database: Arc>, + pub database: Arc, + pub users_with_running_selector: Arc>>, +} + +impl Handler { + async fn reset_selector_list(&self) { + self.users_with_running_selector.lock().await.clear(); + debug!("List of user with running selector reset"); + } } #[async_trait] impl EventHandler for Handler { async fn ready(&self, ctx: Context, ready: Ready) { - println!("Successfully connected as {}", ready.user.name); + // Unregister all application commands before registering them again + if let Ok(commands) = ctx.http.get_global_application_commands().await { + for command in commands { + match ctx + .http + .delete_global_application_command(*command.id.as_u64()) + .await + { + Ok(_) => debug!("Successfully unregistered {} command.", command.name), + Err(_) => debug!("Failed to unregister {} command", command.name), + } + } + } - // TODO: Register commands + info!("Successfully connected as {}", ready.user.name); + CreateTagCommand::register(&ctx).await; + SourceCodeCommand::register(&ctx).await; + ListTagCommand::register(&ctx).await; + DeleteTagCommand::register(&ctx).await; + TagNotifyCommand::register(&ctx).await; + SubscribeCommand::register(&ctx).await; } async fn resume(&self, _: Context, _: ResumedEvent) { - println!("Successfully reconnected.") + self.reset_selector_list().await; + info!("Successfully reconnected.") } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + let mut failed_command = false; + if let Interaction::ApplicationCommand(command) = interaction { - println!("Received command {}", command.data.name); + info!("Received command {}", command.data.name); + + if let Some(guild_id) = command.guild_id { + if let Ok(None) = self + .database + .get_by_id::(&guild_id.to_string()) + .await + { + let new_guild = Guild { + id: guild_id.to_string(), + ban_list: vec![], + tags: vec![], + }; + + match self.database.insert_one(new_guild).await { + Ok(()) => info!("Unregistered guild: Adding it to the database"), + Err(()) => { + warn!("Error adding a new guild in the database"); + failed_command = true; + } + }; + } + } + + if failed_command { + answer_with_error(&ctx, &command).await; + return; + } + + let command_result = match command.data.name.as_str() { + "create_tag" => { + CreateTagCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } + "about" => { + SourceCodeCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } + "list_tags" => { + ListTagCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } + "delete_tag" => { + DeleteTagCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } + "notify" => { + let mut users_selector = self.users_with_running_selector.lock().await; + + if !users_selector.contains(&command.user.id) { + users_selector.insert(command.user.id); + drop(users_selector); + + TagNotifyCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } else { + Err(CommandExecutionError::SelectorError( + "User has already a selector running".to_string(), + )) + } + } + "subscribe" => { + let mut users_selector = self.users_with_running_selector.lock().await; + + if !users_selector.contains(&command.user.id) { + users_selector.insert(command.user.id); + drop(users_selector); + + SubscribeCommand::new(command.clone()) + .run(ctx.clone(), self.database.clone()) + .await + } else { + Err(CommandExecutionError::SelectorError( + "User has already a selector running".to_string(), + )) + } + } + + _ => Err(CommandExecutionError::UnknownCommand( + "Received an unknon command from Discord".to_string(), + )), + }; + + if let Err(e) = command_result { + warn!("Error while executing command: {e:#?}"); + answer_with_error(&ctx, &command).await; + } else { + let mut users_lock = self.users_with_running_selector.lock().await; + if users_lock.contains(&command.user.id) { + users_lock.remove(&command.user.id); + } + } } } } diff --git a/src/discord/message_builders.rs b/src/discord/message_builders.rs new file mode 100644 index 0000000..f644051 --- /dev/null +++ b/src/discord/message_builders.rs @@ -0,0 +1,2 @@ +pub mod embed_builder; +pub mod selector_builder; diff --git a/src/discord/message_builders/embed_builder.rs b/src/discord/message_builders/embed_builder.rs new file mode 100644 index 0000000..0f2e6b7 --- /dev/null +++ b/src/discord/message_builders/embed_builder.rs @@ -0,0 +1,113 @@ +use std::collections::{HashMap, HashSet}; + +use serenity::{ + builder::{CreateEmbed, CreateEmbedAuthor, CreateEmbedFooter}, + prelude::Context, +}; + +use crate::discord::commands::commons::CommandExecutionError; + +const HTML_COLOR_CODE: u32 = 0xffffff; + +pub struct EmbedMessageBuilder { + color_code: u32, + embed_author: String, + embed_avatar_url: String, +} + +impl EmbedMessageBuilder { + pub async fn new(context: &Context) -> Result { + let bot_user = context.http.get_current_user().await.map_err(|_e| { + CommandExecutionError::ContextRetrievalError( + "Failed to get current bot user".to_string(), + ) + })?; + + let embed_author = bot_user.name.clone(); + + let embed_avatar_url = bot_user + .avatar_url() + .unwrap_or_else(|| bot_user.default_avatar_url()); + + Ok(EmbedMessageBuilder { + color_code: HTML_COLOR_CODE, + embed_author, + embed_avatar_url, + }) + } + + fn create_embed_author(&self) -> CreateEmbedAuthor { + CreateEmbedAuthor(HashMap::new()) + .name(self.embed_author.clone()) + .icon_url(self.embed_avatar_url.clone()) + .to_owned() + } + + fn create_embed_base(&self) -> CreateEmbed { + CreateEmbed(HashMap::new()) + .set_author(self.create_embed_author()) + .colour(self.color_code) + .to_owned() + } + + fn create_embed_pages_footer( + &self, + current_page: usize, + total_pages: usize, + ) -> CreateEmbedFooter { + CreateEmbedFooter(HashMap::new()) + .text(format!("Page {current_page}/{total_pages}")) + .to_owned() + } + + pub fn create_bot_credentials_embed(&self) -> CreateEmbed { + self.create_embed_base() + .title("Credentials") + .fields(vec![ + ("Creator", "This bot was created by Victor Mignot (aka Dala).\nMastodon link: https://fosstodon.org/@Dala", false), + ("License", "The source code is under the GNU Affero General Public License v3.0", false), + ("Source code", "https://sr.ht/~victormignot/yorokobot/", false), + ("Illustrator's Twitter", "https://twitter.com/MaewenMitzuki", false), + ("Developer's Discord Server", "https://discord.gg/e8Q4zQbJb3", false), + ]) + .to_owned() + } + + pub fn create_selection_embed( + &self, + title: &str, + description: &str, + selectable: &[String], + selected: &HashSet, + pages: usize, + current_page: usize, + ) -> CreateEmbed { + let mut content = "".to_string(); + + const SELECTION_EMOTES: [&str; 10] = [ + ":zero:", ":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", + ":eight:", ":nine:", + ]; + + for (value, emote) in selectable.iter().zip(SELECTION_EMOTES.iter()) { + content.push_str( + format!( + "{emote} - *{value}* {}\n", + if selected.contains(value) { + ":white_check_mark:" + } else { + "" + } + ) + .as_str(), + ); + } + + self.create_embed_base() + .title(title) + .description(description) + .field("Selection", content, false) + .set_footer(self.create_embed_pages_footer(current_page, pages)) + .to_owned() + } +} diff --git a/src/discord/message_builders/selector_builder.rs b/src/discord/message_builders/selector_builder.rs new file mode 100644 index 0000000..6c4d51c --- /dev/null +++ b/src/discord/message_builders/selector_builder.rs @@ -0,0 +1,305 @@ +use std::collections::HashSet; + +use futures::StreamExt; +use log::warn; +use serenity::{ + collector::EventCollectorBuilder, + model::prelude::{ + interaction::application_command::ApplicationCommandInteraction, Event, EventType, Message, + ReactionType, + }, + prelude::Context, +}; + +use crate::discord::commands::commons::CommandExecutionError; + +use super::embed_builder::EmbedMessageBuilder; + +const MAX_SELECTABLE_PER_PAGE: usize = 10; + +/* + * Match the Discord 0 to 9 icon (that are encoded with three utf-8 character) + * We can notice here that only the first character change between the emote. + * This character is in fact the encoded number related to the emote. + * Ex: '\u{0030}' = '0', \u{0031} = '1' ... + */ +const SELECTION_EMOTES: [&str; 10] = [ + "\u{0030}\u{FE0F}\u{20E3}", + "\u{0031}\u{FE0F}\u{20E3}", + "\u{0032}\u{FE0F}\u{20E3}", + "\u{0033}\u{FE0F}\u{20E3}", + "\u{0034}\u{FE0F}\u{20E3}", + "\u{0035}\u{FE0F}\u{20E3}", + "\u{0036}\u{FE0F}\u{20E3}", + "\u{0037}\u{FE0F}\u{20E3}", + "\u{0038}\u{FE0F}\u{20E3}", + "\u{0039}\u{FE0F}\u{20E3}", +]; + +const PREVIOUS_PAGE_EMOTE: &str = "\u{2B05}"; +const NEXT_PAGE_EMOTE: &str = "\u{27A1}"; +const CONFIRM_EMOTE: &str = "\u{2705}"; +const CANCEL_EMOTE: &str = "\u{274C}"; + +pub struct EmbedSelector<'a> { + interaction: &'a ApplicationCommandInteraction, + context: &'a Context, + embed_answer: Option, + title: String, + description: String, + selection: HashSet, + selectable: Vec, + page_number: usize, + current_page: usize, + aborted: bool, +} + +impl<'a> EmbedSelector<'a> { + pub fn new( + title: String, + description: String, + interaction: &'a ApplicationCommandInteraction, + context: &'a Context, + selectable: Vec, + initial_selection: Option>, + ) -> Self { + let selection = match initial_selection { + Some(r) => r, + None => HashSet::new(), + }; + + let mut selector = EmbedSelector { + interaction, + context, + embed_answer: None, + title, + description, + selection, + selectable: selectable.clone(), + page_number: (selectable.len() as f32 / MAX_SELECTABLE_PER_PAGE as f32).ceil() as usize, + current_page: 1, + aborted: false, + }; + + selector.selectable.sort(); + selector + } + + pub async fn get_user_selection( + &mut self, + ) -> Result>, CommandExecutionError> { + let embed_builder = EmbedMessageBuilder::new(self.context).await?; + + match self + .interaction + .edit_original_interaction_response(self.context, |response| { + response.set_embed(embed_builder.create_selection_embed( + &self.title, + &self.description, + &self.selectable[0..self.get_current_page_choice_number()], + &self.selection, + self.page_number, + 1, + )) + }) + .await + { + Ok(m) => { + self.embed_answer = Some(m); + self.refresh_reactions().await?; + Ok(self.wait_selector_end().await?) + } + Err(_e) => Err(CommandExecutionError::DiscordAPICallError( + "Failed to edit original interaction responnse".to_string(), + )), + } + } + + async fn refresh_reactions(&self) -> Result<(), CommandExecutionError> { + if let Some(answer) = &self.embed_answer { + answer.delete_reactions(self.context).await.map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to delete reaction on the current selector".to_string(), + ) + })?; + + for emote in SELECTION_EMOTES[0..self.get_current_page_choice_number()] + .iter() + .chain( + [ + PREVIOUS_PAGE_EMOTE, + NEXT_PAGE_EMOTE, + CANCEL_EMOTE, + CONFIRM_EMOTE, + ] + .iter(), + ) + { + match &self.embed_answer { + Some(a) => a + .react(self.context, ReactionType::Unicode(emote.to_string())) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to add reactions on the current selector".to_string(), + ) + }), + None => Err(CommandExecutionError::SelectorError( + "Failed to refresh the reactions of the current selector".to_string(), + )), + }?; + } + Ok(()) + } else { + Err(CommandExecutionError::SelectorError( + "Tried to delete reaction from a non existent message".to_string(), + )) + } + } + + async fn wait_selector_end( + &mut self, + ) -> Result>, CommandExecutionError> { + let answer = match &self.embed_answer { + Some(a) => Ok(a), + None => Err(CommandExecutionError::SelectorError( + "Tried to start collector before sending it".to_string(), + )), + }?; + + let mut collector = EventCollectorBuilder::new(self.context) + .add_event_type(EventType::ReactionAdd) + .add_user_id(self.interaction.user.id) + .add_message_id(*answer.id.as_u64()) + .build() + .map_err(|_e| { + CommandExecutionError::SelectorError( + "Failed to build the EventCollector".to_string(), + ) + })?; + + while let Some(reaction_event) = collector.next().await { + let reaction = match *reaction_event { + Event::ReactionAdd(ref r) => &r.reaction, + ref e => { + warn!("Received unexpected event in the selector EventCollector: {e:#?}"); + continue; + } + }; + + if let ReactionType::Unicode(reaction_str) = &reaction.emoji { + match reaction_str.as_str() { + PREVIOUS_PAGE_EMOTE => self.previous_page().await?, + NEXT_PAGE_EMOTE => self.next_page().await?, + CANCEL_EMOTE => { + self.aborted = true; + break; + } + CONFIRM_EMOTE => { + break; + } + r => { + if SELECTION_EMOTES.contains(&reaction_str.as_str()) { + // Extract the number part of the emote unicode + let selected_nb = match r.chars().next() { + Some(c) => match c.to_digit(10) { + Some(nb) => nb as usize, + None => { + warn!("Failed to cast emote code into number"); + continue; + } + }, + None => { + warn!("Received empty emote code in ReactionAdd event"); + continue; + } + }; + + let tag_index: usize = + (self.current_page - 1) * MAX_SELECTABLE_PER_PAGE + selected_nb; + + let tag = &self.selectable[tag_index]; + + if self.selection.contains(tag) { + self.selection.remove(tag); + } else { + self.selection.insert(tag.clone()); + } + } + + reaction.delete(self.context).await.map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to delete reaction from selector".to_string(), + ) + })?; + + self.refresh_embed_selection().await?; + } + } + } + } + collector.stop(); + + if self.aborted { + Ok(None) + } else { + Ok(Some(self.selection.clone())) + } + } + + async fn next_page(&mut self) -> Result<(), CommandExecutionError> { + if self.current_page != self.page_number { + self.current_page += 1; + self.refresh_embed_selection().await?; + self.refresh_reactions().await?; + } + Ok(()) + } + + async fn previous_page(&mut self) -> Result<(), CommandExecutionError> { + if self.current_page != 1 { + self.current_page -= 1; + self.refresh_embed_selection().await?; + self.refresh_reactions().await?; + } + Ok(()) + } + + fn get_current_page_choice_number(&self) -> usize { + if self.page_number == self.current_page + && self.selectable.len() % MAX_SELECTABLE_PER_PAGE != 0 + { + self.selectable.len() % MAX_SELECTABLE_PER_PAGE + } else { + MAX_SELECTABLE_PER_PAGE + } + } + + async fn refresh_embed_selection(&mut self) -> Result<(), CommandExecutionError> { + let embed_builder = EmbedMessageBuilder::new(self.context).await?; + + let curr_choices = self.selectable + [(self.current_page - 1) * 10..(self.current_page * 10).min(self.selectable.len())] + .to_vec(); + + self.embed_answer + .as_mut() + .unwrap() + .edit(self.context, |msg| { + msg.set_embed(embed_builder.create_selection_embed( + &self.title, + &self.description, + &curr_choices, + &self.selection, + self.page_number, + self.current_page, + )) + }) + .await + .map_err(|_e| { + CommandExecutionError::DiscordAPICallError( + "Failed to edit selector content".to_string(), + ) + }) + } +} diff --git a/src/environment.rs b/src/environment.rs index 0cbfacd..1dd52af 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -3,7 +3,7 @@ use log::error; use std::env; -/// Get the environment variiable var_name or panic +/// Fetch the `var_name` environment variable. pub fn get_env_variable(var_name: &str) -> String { match env::var(var_name) { Ok(v) => v, diff --git a/src/lib.rs b/src/lib.rs index f2183ff..ccb1e9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ //! //! [`Serenity`]: https://github.com/serenity-rs/serenity -//#![deny(missing_docs)] +#![deny(missing_docs)] #![deny(warnings)] mod client;