pikadick/checks/
enabled.rs

1use crate::ClientDataKey;
2use parking_lot::Mutex;
3use pikadick_slash_framework::{
4    BoxFuture,
5    Command,
6    Reason as SlashReason,
7};
8use serenity::{
9    client::Context,
10    framework::standard::{
11        macros::check,
12        Args,
13        Check,
14        CommandGroup,
15        CommandOptions,
16        Reason,
17    },
18    model::{
19        application::CommandInteraction,
20        prelude::*,
21    },
22};
23use std::{
24    collections::HashMap,
25    sync::Arc,
26};
27use tracing::error;
28
29type MutexGuard<'a, T> = parking_lot::lock_api::MutexGuard<'a, parking_lot::RawMutex, T>;
30
31#[derive(Debug, Default, Clone)]
32pub struct EnabledCheckData {
33    /// The set of all commands as strings.
34    command_name_cache: Arc<Mutex<Vec<String>>>,
35
36    /// A way to look up commands by CommandOptions and fn addr.
37    ///
38    /// XXX MASSIVE HACK XXX
39    /// This uses the addresses of the `names` field of [`CommandOptions`] impls in order to compare commands.
40    /// This is necessary as this is all serenity gives to [`Check`] functions.
41    /// The only reason this works is because the serenity macro for making commands is used as the only way to make commands,
42    /// as it recreates each names array for each command uniquely.
43    command_lookup: Arc<Mutex<HashMap<usize, String>>>,
44}
45
46impl EnabledCheckData {
47    /// Make a new [`EnabledCheckData`].
48    pub fn new() -> Self {
49        EnabledCheckData {
50            command_name_cache: Arc::new(Mutex::new(Vec::new())),
51            command_lookup: Arc::new(Mutex::new(HashMap::new())),
52        }
53    }
54
55    /// Add a group to have its commands enabled/disabled.
56    pub fn add_groups(&self, groups: &[&CommandGroup]) {
57        let mut names = Vec::with_capacity(4);
58        let mut queue = Vec::new();
59        let mut command_name_cache = self.command_name_cache.lock();
60        let mut command_lookup = self.command_lookup.lock();
61
62        for group in groups.iter() {
63            command_name_cache.reserve(group.options.commands.len());
64            queue.reserve(group.options.commands.len());
65
66            queue.extend(group.options.commands.iter().map(|command| (0, command)));
67
68            while let Some((depth, command)) = queue.pop() {
69                let has_enabled_check = command
70                    .options
71                    .checks
72                    .iter()
73                    .any(|&check| checks_are_same(check, &ENABLED_CHECK));
74
75                if !has_enabled_check {
76                    continue;
77                }
78
79                names.truncate(depth);
80
81                let command_name = command
82                    .options
83                    .names
84                    .first()
85                    .expect("command does not have a name");
86
87                names.push(*command_name);
88                let command_name = names.join("::");
89                command_lookup.insert(
90                    command.options.names.as_ptr() as usize,
91                    command_name.clone(),
92                );
93                command_name_cache.push(command_name);
94
95                queue.extend(
96                    command
97                        .options
98                        .sub_commands
99                        .iter()
100                        .map(|command| (depth + 1, command)),
101                );
102            }
103        }
104    }
105
106    pub fn get_command_name_from_options(&self, options: &CommandOptions) -> Option<String> {
107        self.command_lookup
108            .lock()
109            .get(&(options.names.as_ptr() as usize))
110            .cloned()
111    }
112
113    /// Returns a mutex guard to the list of command names.
114    pub fn get_command_names(&self) -> MutexGuard<'_, Vec<String>> {
115        self.command_name_cache.lock()
116    }
117}
118
119/// Check if 2 [`Check`]s are the same.
120///
121/// This includes their function pointers, though the argument references do not necessarily have to point to the same check.
122/// This is necessary as `serenity`'s `PartialEq` for [`Check`] only checks the name.
123fn checks_are_same(check1: &Check, check2: &Check) -> bool {
124    let is_same_partial_eq = check1 == check2;
125
126    // HACK:
127    // Use pointers as ids since checks have no unique identifiers
128    let function1_addr = check1.function as usize;
129    let function2_addr = check2.function as usize;
130    let is_same_function_ptr = function1_addr == function2_addr;
131
132    is_same_partial_eq && is_same_function_ptr
133}
134
135#[check]
136#[name("Enabled")]
137pub async fn enabled_check(
138    ctx: &Context,
139    msg: &Message,
140    _args: &mut Args,
141    opts: &CommandOptions,
142) -> Result<(), Reason> {
143    let guild_id = match msg.guild_id {
144        Some(id) => id,
145        None => {
146            // Let's not care about dms for now.
147            // They'll probably need special handling anyways.
148            // This will also probably only be useful in Group DMs,
149            // which I don't think bots can participate in anyways.
150            return Ok(());
151        }
152    };
153
154    let data_lock = ctx.data.read().await;
155    let client_data = data_lock
156        .get::<ClientDataKey>()
157        .expect("missing client data");
158    let enabled_check_data = client_data.enabled_check_data.clone();
159    let db = client_data.db.clone();
160    drop(data_lock);
161
162    let command_name = match enabled_check_data.get_command_name_from_options(opts) {
163        Some(name) => name,
164        None => {
165            // The name is not present.
166            // This is fine, as that just means we haven't added it to the translation map
167            // aka it is not disable-able
168            return Ok(());
169        }
170    };
171
172    match db.is_command_disabled(guild_id, &command_name).await {
173        Ok(true) => Err(Reason::User("Command Disabled".to_string())),
174        Ok(false) => Ok(()),
175        Err(e) => {
176            error!("failed to read disabled commands: {}", e);
177            // DB failure, return false to be safe.
178            // Avoid being specific with error to prevent users from spamming knowingly.
179            Err(Reason::Unknown)
180        }
181    }
182}
183
184/// Check if a command is enabled via slash framework
185pub fn create_slash_check<'a>(
186    ctx: &'a Context,
187    interaction: &'a CommandInteraction,
188    command: &'a Command,
189) -> BoxFuture<'a, Result<(), SlashReason>> {
190    Box::pin(async move {
191        let guild_id = match interaction.guild_id {
192            Some(id) => id,
193            None => {
194                // Let's not care about dms for now.
195                // They'll probably need special handling anyways.
196                // This will also probably only be useful in Group DMs,
197                // which I don't think bots can participate in anyways.
198                return Ok(());
199            }
200        };
201
202        let data_lock = ctx.data.read().await;
203        let client_data = data_lock
204            .get::<ClientDataKey>()
205            .expect("missing client data");
206        let db = client_data.db.clone();
207        drop(data_lock);
208
209        let command_name = command.name();
210
211        match db.is_command_disabled(guild_id, command_name).await {
212            Ok(true) => Err(SlashReason::new_user("Command Disabled.".to_string())),
213            Ok(false) => Ok(()),
214            Err(e) => {
215                error!("failed to read disabled commands: {}", e);
216                // DB failure, return false to be safe.
217                // Avoid being specific with error to prevent users from spamming knowingly.
218                Err(SlashReason::new_unknown())
219            }
220        }
221    })
222}