pikadick/checks/
enabled.rs1use crate::ClientDataKey;
2use parking_lot::Mutex;
3use pikadick_slash_framework::{
4 BoxFuture,
5 Command,
6 Reason as SlashReason,
7};
8use serenity::{
9 client::Context,
10 framework::standard::{
11 macros::check,
12 Args,
13 Check,
14 CommandGroup,
15 CommandOptions,
16 Reason,
17 },
18 model::{
19 application::CommandInteraction,
20 prelude::*,
21 },
22};
23use std::{
24 collections::HashMap,
25 sync::Arc,
26};
27use tracing::error;
28
29type MutexGuard<'a, T> = parking_lot::lock_api::MutexGuard<'a, parking_lot::RawMutex, T>;
30
31#[derive(Debug, Default, Clone)]
32pub struct EnabledCheckData {
33 command_name_cache: Arc<Mutex<Vec<String>>>,
35
36 command_lookup: Arc<Mutex<HashMap<usize, String>>>,
44}
45
46impl EnabledCheckData {
47 pub fn new() -> Self {
49 EnabledCheckData {
50 command_name_cache: Arc::new(Mutex::new(Vec::new())),
51 command_lookup: Arc::new(Mutex::new(HashMap::new())),
52 }
53 }
54
55 pub fn add_groups(&self, groups: &[&CommandGroup]) {
57 let mut names = Vec::with_capacity(4);
58 let mut queue = Vec::new();
59 let mut command_name_cache = self.command_name_cache.lock();
60 let mut command_lookup = self.command_lookup.lock();
61
62 for group in groups.iter() {
63 command_name_cache.reserve(group.options.commands.len());
64 queue.reserve(group.options.commands.len());
65
66 queue.extend(group.options.commands.iter().map(|command| (0, command)));
67
68 while let Some((depth, command)) = queue.pop() {
69 let has_enabled_check = command
70 .options
71 .checks
72 .iter()
73 .any(|&check| checks_are_same(check, &ENABLED_CHECK));
74
75 if !has_enabled_check {
76 continue;
77 }
78
79 names.truncate(depth);
80
81 let command_name = command
82 .options
83 .names
84 .first()
85 .expect("command does not have a name");
86
87 names.push(*command_name);
88 let command_name = names.join("::");
89 command_lookup.insert(
90 command.options.names.as_ptr() as usize,
91 command_name.clone(),
92 );
93 command_name_cache.push(command_name);
94
95 queue.extend(
96 command
97 .options
98 .sub_commands
99 .iter()
100 .map(|command| (depth + 1, command)),
101 );
102 }
103 }
104 }
105
106 pub fn get_command_name_from_options(&self, options: &CommandOptions) -> Option<String> {
107 self.command_lookup
108 .lock()
109 .get(&(options.names.as_ptr() as usize))
110 .cloned()
111 }
112
113 pub fn get_command_names(&self) -> MutexGuard<'_, Vec<String>> {
115 self.command_name_cache.lock()
116 }
117}
118
119fn checks_are_same(check1: &Check, check2: &Check) -> bool {
124 let is_same_partial_eq = check1 == check2;
125
126 let function1_addr = check1.function as usize;
129 let function2_addr = check2.function as usize;
130 let is_same_function_ptr = function1_addr == function2_addr;
131
132 is_same_partial_eq && is_same_function_ptr
133}
134
135#[check]
136#[name("Enabled")]
137pub async fn enabled_check(
138 ctx: &Context,
139 msg: &Message,
140 _args: &mut Args,
141 opts: &CommandOptions,
142) -> Result<(), Reason> {
143 let guild_id = match msg.guild_id {
144 Some(id) => id,
145 None => {
146 return Ok(());
151 }
152 };
153
154 let data_lock = ctx.data.read().await;
155 let client_data = data_lock
156 .get::<ClientDataKey>()
157 .expect("missing client data");
158 let enabled_check_data = client_data.enabled_check_data.clone();
159 let db = client_data.db.clone();
160 drop(data_lock);
161
162 let command_name = match enabled_check_data.get_command_name_from_options(opts) {
163 Some(name) => name,
164 None => {
165 return Ok(());
169 }
170 };
171
172 match db.is_command_disabled(guild_id, &command_name).await {
173 Ok(true) => Err(Reason::User("Command Disabled".to_string())),
174 Ok(false) => Ok(()),
175 Err(e) => {
176 error!("failed to read disabled commands: {}", e);
177 Err(Reason::Unknown)
180 }
181 }
182}
183
184pub fn create_slash_check<'a>(
186 ctx: &'a Context,
187 interaction: &'a CommandInteraction,
188 command: &'a Command,
189) -> BoxFuture<'a, Result<(), SlashReason>> {
190 Box::pin(async move {
191 let guild_id = match interaction.guild_id {
192 Some(id) => id,
193 None => {
194 return Ok(());
199 }
200 };
201
202 let data_lock = ctx.data.read().await;
203 let client_data = data_lock
204 .get::<ClientDataKey>()
205 .expect("missing client data");
206 let db = client_data.db.clone();
207 drop(data_lock);
208
209 let command_name = command.name();
210
211 match db.is_command_disabled(guild_id, command_name).await {
212 Ok(true) => Err(SlashReason::new_user("Command Disabled.".to_string())),
213 Ok(false) => Ok(()),
214 Err(e) => {
215 error!("failed to read disabled commands: {}", e);
216 Err(SlashReason::new_unknown())
219 }
220 }
221 })
222}