1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::database::{
    model::TikTokEmbedFlags,
    Database,
};
use anyhow::Context;
use rusqlite::{
    named_params,
    OptionalExtension,
    TransactionBehavior,
};
use serenity::model::prelude::*;

// Tiktok Embed SQL
const GET_TIKTOK_EMBED_FLAGS_SQL: &str = include_str!("../../sql/get_tiktok_embed_flags.sql");
const SET_TIKTOK_EMBED_FLAGS_SQL: &str = include_str!("../../sql/set_tiktok_embed_flags.sql");

impl Database {
    /// Set the flags for tiktok embeds.
    ///
    /// # Returns
    /// Returns the old flags and new flags in a tuple in that order.
    pub async fn set_tiktok_embed_flags(
        &self,
        guild_id: GuildId,
        set_flags: TikTokEmbedFlags,
        unset_flags: TikTokEmbedFlags,
    ) -> anyhow::Result<(TikTokEmbedFlags, TikTokEmbedFlags)> {
        self.access_db(move |db| {
            let txn = db.transaction_with_behavior(TransactionBehavior::Immediate)?;
            let old_flags: TikTokEmbedFlags = txn
                .prepare_cached(GET_TIKTOK_EMBED_FLAGS_SQL)?
                .query_row(
                    named_params! {
                        ":guild_id": i64::from(guild_id),
                    },
                    |row| row.get(0),
                )
                .optional()?
                .unwrap_or_default();

            let mut new_flags = old_flags;
            new_flags.insert(set_flags);
            new_flags.remove(unset_flags);

            txn.prepare_cached(SET_TIKTOK_EMBED_FLAGS_SQL)?
                .execute(named_params! {
                    ":guild_id": i64::from(guild_id),
                    ":flags": new_flags,
                })?;

            txn.commit().context("failed to set tiktok embed")?;

            Ok((old_flags, new_flags))
        })
        .await?
    }

    /// Get the tiktok embed flags.
    pub async fn get_tiktok_embed_flags(
        &self,
        guild_id: GuildId,
    ) -> anyhow::Result<TikTokEmbedFlags> {
        self.access_db(move |db| {
            db.prepare_cached(GET_TIKTOK_EMBED_FLAGS_SQL)?
                .query_row(
                    named_params! {
                        ":guild_id": i64::from(guild_id),
                    },
                    |row| row.get(0),
                )
                .optional()
                .context("failed to read database")
                .map(|v| v.unwrap_or_default())
        })
        .await?
    }
}