From 3c112364040654494eb471414e13d8a8f2781d7d Mon Sep 17 00:00:00 2001 From: Victor Mignot Date: Sat, 19 Nov 2022 21:19:48 -0500 Subject: [PATCH] Add badic mongo models and functions --- src/client.rs | 8 +- src/database/client.rs | 248 ++++++++++++++++++++++++++++++++++++++--- src/database/models.rs | 20 +++- src/main.rs | 2 +- 4 files changed, 250 insertions(+), 28 deletions(-) diff --git a/src/client.rs b/src/client.rs index 76a3335..617de83 100644 --- a/src/client.rs +++ b/src/client.rs @@ -26,12 +26,12 @@ use serenity::{prelude::GatewayIntents, Client as DiscordClient}; /// /// # } /// ``` -pub struct Client<'a> { +pub struct Client { /// The Serenity Discord Client discord_client: DiscordClient, /// The database client - database_client: DatabaseClient<'a>, + database_client: DatabaseClient, } /// Yorokobot connection credentials @@ -40,10 +40,10 @@ pub struct ClientCredentials<'a> { pub discord_token: &'a String, /// MongoDB connection string. - pub db_credentials: &'a DatabaseCredentials, + pub db_credentials: DatabaseCredentials, } -impl<'a> Client<'a> { +impl<'a> Client { /// Create a Yorokobot client pub async fn new(credentials: ClientCredentials<'a>) -> Result { let discord_client = match DiscordClient::builder( diff --git a/src/database/client.rs b/src/database/client.rs index 55a6879..9107e4b 100644 --- a/src/database/client.rs +++ b/src/database/client.rs @@ -1,44 +1,258 @@ -use mongodb::Client as MongoClient; +use std::collections::HashSet; +use futures::TryStreamExt; +use mongodb::{ + bson::{doc, from_bson, Bson, Document}, + Client as MongoClient, Collection, Database, +}; +use serde::{Deserialize, Serialize}; + +use crate::environment::get_env_variable; use crate::errors::ClientError; use crate::DatabaseCredentials; +use super::models::{YorokobotModel, COLLECTIONS_NAMES}; + /// Database client -pub struct Client<'a> { +pub struct Client { mongo_client: Option, - // database: Option, - credentials: &'a DatabaseCredentials, + database: Option, + credentials: DatabaseCredentials, } -impl<'a> Client<'a> { +impl Client { /// Create a new database client - pub fn new(credentials: &'a DatabaseCredentials) -> Client { + pub fn new(credentials: DatabaseCredentials) -> Client { return Client { credentials, mongo_client: None, - // database: None, + database: None, }; } /// Connect the client pub async fn connect(&mut self) -> Result<(), ClientError> { - self.mongo_client = match MongoClient::with_options(self.credentials.clone()) { + self.mongo_client = match MongoClient::with_options(self.credentials.to_owned()) { Ok(c) => Some(c), Err(e) => return Err(ClientError::Database(e)), }; - if let None = self.mongo_client.as_ref().unwrap().default_database() { - // TODO: - // Implement an Environment Variable catcher to wrap std::env::var() - // As we often call it and always have to use a match control flow + self.database = Some( + self.mongo_client + .as_ref() + .unwrap() + .database(get_env_variable("MONGO_DEFAULT_DB").as_str()), + ); - // TODO: - // Complete error kind to be more specific. - // Ex: DatabaseConnection + // TODO: + // Complete error kind to be more specific. + // Ex: DatabaseConnection - todo!(); - } + self.check_init_error().await; Ok(()) } + + async fn check_init_error(&mut self) { + self.check_collections_presence().await; + } + + async fn check_collections_presence(&mut self) { + let mut missing_collections: Vec<&str> = vec![]; + let collections: HashSet = match self + .database + .as_ref() + .unwrap() + .list_collection_names(None) + .await + { + Ok(n) => n.into_iter().collect(), + Err(e) => panic!("Could not list collections: {e}"), + }; + + for col in COLLECTIONS_NAMES { + if !collections.contains(col) { + missing_collections.push(col); + } + } + + if missing_collections.len() != 0 { + panic!( + "Missing the following the following collections: {}", + missing_collections.join(", ") + ); + } + } + + fn get_collection(&self) -> Collection { + self.database + .as_ref() + .expect("Could not retrieve database") + .collection(&T::get_collection_name()) + } + + fn get_typed_collection(&self) -> Collection { + self.database + .as_ref() + .expect("Could not retrieve database") + .collection::(&T::get_collection_name()) + } + + #[allow(dead_code)] + pub async fn get_by_id Deserialize<'de>>(&self, id: &str) -> T { + self.get_one(doc! {"_id": id}).await + } + + #[allow(dead_code)] + pub async fn get_one Deserialize<'de>>( + &self, + filter: Document, + ) -> T { + let result = self + .get_collection::() + .find_one(filter, None) + .await + .expect("Could not issue request") + .expect("Could not find matching data"); + + return from_bson(Bson::Document(result)).expect("Could not deserialize data"); + } + + #[allow(dead_code)] + pub async fn get_all Deserialize<'de>>( + &self, + filter: Option, + ) { + let mut result: Vec = vec![]; + + let mut cursor = match filter { + Some(f) => self.get_collection::().find(f, None).await, + None => self.get_collection::().find(doc! {}, None).await, + } + .expect("Could not issue request"); + + while let Some(document) = cursor.try_next().await.expect("Could not fetch results") { + result + .push(from_bson(Bson::Document(document)).expect("Could not deserialize document")); + } + } + + #[allow(dead_code)] + // TODO: Set true error handling + pub async fn insert_one(&self, document: T) -> Result<(), ()> { + match self + .get_typed_collection::() + .insert_one(document, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + // TODO: Set true error handling + pub async fn insert_many( + &self, + documents: Vec, + ) -> Result<(), ()> { + match self + .get_typed_collection::() + .insert_many(documents, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + // TODO: Set true error handling + pub async fn delete_one(&self, document: Document) -> Result<(), ()> { + match self + .get_typed_collection::() + .delete_one(document, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + // TODO: Set true error handling + pub async fn delete_by_id(&self, id: &str) -> Result<(), ()> { + match self + .get_typed_collection::() + .delete_one(doc! {"_id": id}, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + // TODO: Set true error handling + pub async fn delete_many(&self, document: Document) -> Result<(), ()> { + match self + .get_typed_collection::() + .delete_many(document, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + //TODO: Set true error handling + pub async fn update_one( + &self, + document: Document, + update: Document, + ) -> Result<(), ()> { + match self + .get_typed_collection::() + .update_one(document, update, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + //TODO: Set true error handling + pub async fn update_by_id( + &self, + document_id: &str, + update: Document, + ) -> Result<(), ()> { + match self + .get_typed_collection::() + .update_one(doc! {"_id": document_id}, update, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } + + #[allow(dead_code)] + //TODO: Set true error handling + pub async fn update_many( + &self, + document: Document, + update: Document, + ) -> Result<(), ()> { + match self + .get_typed_collection::() + .update_many(document, update, None) + .await + { + Ok(_) => Ok(()), + Err(_) => Err(()), + } + } } diff --git a/src/database/models.rs b/src/database/models.rs index 011c40c..d5ff000 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -4,13 +4,10 @@ use serde::{Deserialize, Serialize}; -/// All the models within Mongo COllections -pub enum CollectionModels { - /// Discord Guild - Guild(Guild), +pub const COLLECTIONS_NAMES: [&str; 2] = ["guilds", "tags"]; - /// Yorokobot tags - Tag(Tag), +pub trait YorokobotModel { + fn get_collection_name() -> String; } /// Settings for a server @@ -35,3 +32,14 @@ pub struct Tag { is_nsfw: bool, subscribers: Vec, } + +impl YorokobotModel for Guild { + fn get_collection_name() -> String { + return "guilds".to_string(); + } +} +impl YorokobotModel for Tag { + fn get_collection_name() -> String { + return "traits".to_string(); + } +} diff --git a/src/main.rs b/src/main.rs index fa2e3c0..137e513 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,7 +34,7 @@ async fn main() -> std::process::ExitCode { let credentials = ClientCredentials { discord_token: &discord_token, - db_credentials: &db_credentials, + db_credentials: db_credentials, }; let mut client = match Client::new(credentials).await {