1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
use crate::ClientDataKey;
use parking_lot::Mutex;
use pikadick_slash_framework::{
BoxFuture,
Command,
Reason as SlashReason,
};
use serenity::{
client::Context,
framework::standard::{
macros::check,
Args,
Check,
CommandGroup,
CommandOptions,
Reason,
},
model::{
application::CommandInteraction,
prelude::*,
},
};
use std::{
collections::HashMap,
sync::Arc,
};
use tracing::error;
type MutexGuard<'a, T> = parking_lot::lock_api::MutexGuard<'a, parking_lot::RawMutex, T>;
#[derive(Debug, Default, Clone)]
pub struct EnabledCheckData {
/// The set of all commands as strings.
command_name_cache: Arc<Mutex<Vec<String>>>,
/// A way to look up commands by CommandOptions and fn addr.
///
/// XXX MASSIVE HACK XXX
/// This uses the addresses of the `names` field of [`CommandOptions`] impls in order to compare commands.
/// This is necessary as this is all serenity gives to [`Check`] functions.
/// The only reason this works is because the serenity macro for making commands is used as the only way to make commands,
/// as it recreates each names array for each command uniquely.
command_lookup: Arc<Mutex<HashMap<usize, String>>>,
}
impl EnabledCheckData {
/// Make a new [`EnabledCheckData`].
pub fn new() -> Self {
EnabledCheckData {
command_name_cache: Arc::new(Mutex::new(Vec::new())),
command_lookup: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Add a group to have its commands enabled/disabled.
pub fn add_groups(&self, groups: &[&CommandGroup]) {
let mut names = Vec::with_capacity(4);
let mut queue = Vec::new();
let mut command_name_cache = self.command_name_cache.lock();
let mut command_lookup = self.command_lookup.lock();
for group in groups.iter() {
command_name_cache.reserve(group.options.commands.len());
queue.reserve(group.options.commands.len());
queue.extend(group.options.commands.iter().map(|command| (0, command)));
while let Some((depth, command)) = queue.pop() {
let has_enabled_check = command
.options
.checks
.iter()
.any(|&check| checks_are_same(check, &ENABLED_CHECK));
if !has_enabled_check {
continue;
}
names.truncate(depth);
let command_name = command
.options
.names
.first()
.expect("command does not have a name");
names.push(*command_name);
let command_name = names.join("::");
command_lookup.insert(
command.options.names.as_ptr() as usize,
command_name.clone(),
);
command_name_cache.push(command_name);
queue.extend(
command
.options
.sub_commands
.iter()
.map(|command| (depth + 1, command)),
);
}
}
}
pub fn get_command_name_from_options(&self, options: &CommandOptions) -> Option<String> {
self.command_lookup
.lock()
.get(&(options.names.as_ptr() as usize))
.cloned()
}
/// Returns a mutex guard to the list of command names.
pub fn get_command_names(&self) -> MutexGuard<'_, Vec<String>> {
self.command_name_cache.lock()
}
}
/// Check if 2 [`Check`]s are the same.
///
/// This includes their function pointers, though the argument references do not necessarily have to point to the same check.
/// This is necessary as `serenity`'s `PartialEq` for [`Check`] only checks the name.
fn checks_are_same(check1: &Check, check2: &Check) -> bool {
let is_same_partial_eq = check1 == check2;
// HACK:
// Use pointers as ids since checks have no unique identifiers
let function1_addr = check1.function as usize;
let function2_addr = check2.function as usize;
let is_same_function_ptr = function1_addr == function2_addr;
is_same_partial_eq && is_same_function_ptr
}
#[check]
#[name("Enabled")]
pub async fn enabled_check(
ctx: &Context,
msg: &Message,
_args: &mut Args,
opts: &CommandOptions,
) -> Result<(), Reason> {
let guild_id = match msg.guild_id {
Some(id) => id,
None => {
// Let's not care about dms for now.
// They'll probably need special handling anyways.
// This will also probably only be useful in Group DMs,
// which I don't think bots can participate in anyways.
return Ok(());
}
};
let data_lock = ctx.data.read().await;
let client_data = data_lock
.get::<ClientDataKey>()
.expect("missing client data");
let enabled_check_data = client_data.enabled_check_data.clone();
let db = client_data.db.clone();
drop(data_lock);
let command_name = match enabled_check_data.get_command_name_from_options(opts) {
Some(name) => name,
None => {
// The name is not present.
// This is fine, as that just means we haven't added it to the translation map
// aka it is not disable-able
return Ok(());
}
};
match db.is_command_disabled(guild_id, &command_name).await {
Ok(true) => Err(Reason::User("Command Disabled".to_string())),
Ok(false) => Ok(()),
Err(e) => {
error!("failed to read disabled commands: {}", e);
// DB failure, return false to be safe.
// Avoid being specific with error to prevent users from spamming knowingly.
Err(Reason::Unknown)
}
}
}
/// Check if a command is enabled via slash framework
pub fn create_slash_check<'a>(
ctx: &'a Context,
interaction: &'a CommandInteraction,
command: &'a Command,
) -> BoxFuture<'a, Result<(), SlashReason>> {
Box::pin(async move {
let guild_id = match interaction.guild_id {
Some(id) => id,
None => {
// Let's not care about dms for now.
// They'll probably need special handling anyways.
// This will also probably only be useful in Group DMs,
// which I don't think bots can participate in anyways.
return Ok(());
}
};
let data_lock = ctx.data.read().await;
let client_data = data_lock
.get::<ClientDataKey>()
.expect("missing client data");
let db = client_data.db.clone();
drop(data_lock);
let command_name = command.name();
match db.is_command_disabled(guild_id, command_name).await {
Ok(true) => Err(SlashReason::new_user("Command Disabled.".to_string())),
Ok(false) => Ok(()),
Err(e) => {
error!("failed to read disabled commands: {}", e);
// DB failure, return false to be safe.
// Avoid being specific with error to prevent users from spamming knowingly.
Err(SlashReason::new_unknown())
}
}
})
}