pikadick/database/
tiktok_embed.rs

1use crate::database::{
2    model::TikTokEmbedFlags,
3    Database,
4};
5use anyhow::Context;
6use rusqlite::{
7    named_params,
8    OptionalExtension,
9    TransactionBehavior,
10};
11use serenity::model::prelude::*;
12
13// Tiktok Embed SQL
14const GET_TIKTOK_EMBED_FLAGS_SQL: &str = include_str!("../../sql/get_tiktok_embed_flags.sql");
15const SET_TIKTOK_EMBED_FLAGS_SQL: &str = include_str!("../../sql/set_tiktok_embed_flags.sql");
16
17impl Database {
18    /// Set the flags for tiktok embeds.
19    ///
20    /// # Returns
21    /// Returns the old flags and new flags in a tuple in that order.
22    pub async fn set_tiktok_embed_flags(
23        &self,
24        guild_id: GuildId,
25        set_flags: TikTokEmbedFlags,
26        unset_flags: TikTokEmbedFlags,
27    ) -> anyhow::Result<(TikTokEmbedFlags, TikTokEmbedFlags)> {
28        self.access_db(move |db| {
29            let txn = db.transaction_with_behavior(TransactionBehavior::Immediate)?;
30            let old_flags: TikTokEmbedFlags = txn
31                .prepare_cached(GET_TIKTOK_EMBED_FLAGS_SQL)?
32                .query_row(
33                    named_params! {
34                        ":guild_id": i64::from(guild_id),
35                    },
36                    |row| row.get(0),
37                )
38                .optional()?
39                .unwrap_or_default();
40
41            let mut new_flags = old_flags;
42            new_flags.insert(set_flags);
43            new_flags.remove(unset_flags);
44
45            txn.prepare_cached(SET_TIKTOK_EMBED_FLAGS_SQL)?
46                .execute(named_params! {
47                    ":guild_id": i64::from(guild_id),
48                    ":flags": new_flags,
49                })?;
50
51            txn.commit().context("failed to set tiktok embed")?;
52
53            Ok((old_flags, new_flags))
54        })
55        .await?
56    }
57
58    /// Get the tiktok embed flags.
59    pub async fn get_tiktok_embed_flags(
60        &self,
61        guild_id: GuildId,
62    ) -> anyhow::Result<TikTokEmbedFlags> {
63        self.access_db(move |db| {
64            db.prepare_cached(GET_TIKTOK_EMBED_FLAGS_SQL)?
65                .query_row(
66                    named_params! {
67                        ":guild_id": i64::from(guild_id),
68                    },
69                    |row| row.get(0),
70                )
71                .optional()
72                .context("failed to read database")
73                .map(|v| v.unwrap_or_default())
74        })
75        .await?
76    }
77}