pikadick_slash_framework/
framework.rs

1use crate::{
2    BoxError,
3    BuilderError,
4    CheckFn,
5    Command,
6    HelpCommand,
7};
8use serenity::{
9    builder::{
10        CreateCommand,
11        CreateInteractionResponse,
12        CreateInteractionResponseMessage,
13    },
14    client::Context,
15    model::{
16        application::{
17            Command as ApplicationCommand,
18            CommandInteraction,
19            Interaction,
20        },
21        prelude::GuildId,
22    },
23};
24use std::{
25    collections::HashMap,
26    sync::Arc,
27};
28use tracing::{
29    info,
30    warn,
31};
32
33/// A wrapper for [`BoxError`] that impls error
34struct WrapBoxError(BoxError);
35
36impl WrapBoxError {
37    /// Make a new [`WrapBoxError`] from an error
38    fn new(e: BoxError) -> Self {
39        Self(e)
40    }
41}
42
43impl std::fmt::Debug for WrapBoxError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        self.0.fmt(f)
46    }
47}
48
49impl std::fmt::Display for WrapBoxError {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        self.0.fmt(f)
52    }
53}
54
55impl std::error::Error for WrapBoxError {}
56
57struct FmtOptionsHelper<'a>(&'a CommandInteraction);
58
59impl std::fmt::Display for FmtOptionsHelper<'_> {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "[")?;
62        let len = self.0.data.options.len();
63        for (i, option) in self.0.data.options.iter().enumerate() {
64            if i + 1 == len {
65                write!(f, "'{}'={:?}", option.name, option.value)?;
66            }
67        }
68        write!(f, "]")?;
69
70        Ok(())
71    }
72}
73
74/// A framework
75#[derive(Clone)]
76pub struct Framework {
77    commands: Arc<HashMap<Box<str>, Command>>,
78    help_command: Option<Arc<HelpCommand>>,
79    checks: Arc<[CheckFn]>,
80}
81
82impl Framework {
83    /// Register the framework.
84    ///
85    /// `test_guild_id` is an optional guild where the commands will be registered as guild commands,
86    /// so they update faster for testing purposes.
87    pub async fn register(
88        &self,
89        ctx: Context,
90        test_guild_id: Option<GuildId>,
91    ) -> Result<(), serenity::Error> {
92        for framework_command in self.commands.values() {
93            let mut command_builder = CreateCommand::new(framework_command.name());
94            command_builder = framework_command.register(command_builder);
95            ApplicationCommand::create_global_command(&ctx.http, command_builder).await?;
96        }
97
98        if let Some(framework_command) = self.help_command.as_deref() {
99            let mut command_builder = CreateCommand::new("help");
100            command_builder = framework_command.register(command_builder);
101            ApplicationCommand::create_global_command(&ctx.http, command_builder).await?;
102        }
103
104        if let Some(guild_id) = test_guild_id {
105            let mut create_commands = Vec::new();
106            for framework_command in self.commands.values() {
107                let mut command_builder = CreateCommand::new(framework_command.name());
108                command_builder = framework_command.register(command_builder);
109                create_commands.push(command_builder);
110            }
111            if let Some(framework_command) = self.help_command.as_deref() {
112                let mut command_builder = CreateCommand::new("help");
113                command_builder = framework_command.register(command_builder);
114                create_commands.push(command_builder);
115            }
116
117            GuildId::set_commands(guild_id, &ctx.http, create_commands).await?;
118        }
119
120        Ok(())
121    }
122
123    /// Process an interaction create event
124    pub async fn process_interaction_create(&self, ctx: Context, interaction: Interaction) {
125        if let Interaction::Command(command) = interaction {
126            self.process_interaction_create_application_command(ctx, command)
127                .await
128        }
129    }
130
131    #[tracing::instrument(skip(self, ctx, command), fields(id = %command.id, author = %command.user.id, guild = ?command.guild_id, channel_id = %command.channel_id))]
132    async fn process_interaction_create_application_command(
133        &self,
134        ctx: Context,
135        command: CommandInteraction,
136    ) {
137        if command.data.name.as_str() == "help" {
138            // Keep comments
139            #[allow(clippy::single_match)]
140            match self.help_command.as_ref() {
141                Some(framework_command) => {
142                    info!(
143                        "processing help command, options={}",
144                        FmtOptionsHelper(&command)
145                    );
146                    if let Err(error) = framework_command
147                        .fire_on_process(ctx, command, self.commands.clone())
148                        .await
149                        .map_err(WrapBoxError::new)
150                    {
151                        // TODO: handle error with handler
152                        warn!("{error}");
153                    }
154                }
155                None => {
156                    // Don't log, as we assume the user does not want to provide help.
157                    // Logging would be extra noise.
158                }
159            }
160
161            return;
162        }
163
164        let framework_command = match self.commands.get(command.data.name.as_str()) {
165            Some(command) => command,
166            None => {
167                // TODO: Unknown command handler
168                let command_name = command.data.name.as_str();
169                warn!("unknown command \"{command_name}\"");
170                return;
171            }
172        };
173
174        // TODO: Consider making parallel
175        let mut check_result = Ok(());
176        for check in self.checks.iter().chain(framework_command.checks().iter()) {
177            check_result = check_result.and(check(&ctx, &command, framework_command).await);
178        }
179
180        match check_result {
181            Ok(()) => {
182                let command_name = framework_command.name();
183                info!(
184                    "processing command \"{command_name}\", options={}",
185                    FmtOptionsHelper(&command)
186                );
187                if let Err(error) = framework_command
188                    .fire_on_process(ctx, command)
189                    .await
190                    .map_err(WrapBoxError::new)
191                {
192                    // TODO: handle error with handler
193                    warn!("{error}");
194                }
195            }
196            Err(error) => {
197                let content = error
198                    .user
199                    .as_deref()
200                    .unwrap_or("check failed for unknown reason");
201
202                if let Some(log) = error.log {
203                    warn!("{log}");
204                }
205
206                let response = CreateInteractionResponseMessage::new().content(content);
207                if let Err(error) = command
208                    .create_response(&ctx.http, CreateInteractionResponse::Message(response))
209                    .await
210                {
211                    warn!("{error}");
212                }
213            }
214        }
215    }
216}
217
218impl std::fmt::Debug for Framework {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        f.debug_struct("Framework")
221            .field("commands", &self.commands)
222            .finish()
223    }
224}
225
226/// A FrameworkBuilder for slash commands.
227pub struct FrameworkBuilder {
228    commands: HashMap<Box<str>, Command>,
229    help_command: Option<HelpCommand>,
230    checks: Vec<CheckFn>,
231
232    error: Option<BuilderError>,
233}
234
235impl FrameworkBuilder {
236    /// Make a new [`FrameworkBuilder`].
237    pub fn new() -> Self {
238        Self {
239            commands: HashMap::new(),
240            help_command: None,
241            checks: Vec::new(),
242
243            error: None,
244        }
245    }
246
247    /// Add a command
248    pub fn command(&mut self, command: Command) -> &mut Self {
249        if self.error.is_some() {
250            return self;
251        }
252
253        let command_name: Box<str> = command.name().into();
254
255        // A help command cannot be registered like this
256        if &*command_name == "help" {
257            self.error = Some(BuilderError::Duplicate(command_name));
258            return self;
259        }
260
261        // Don't overwrite commands
262        if self.commands.contains_key(&command_name) {
263            self.error = Some(BuilderError::Duplicate(command_name));
264            return self;
265        }
266
267        self.commands.insert(command_name, command);
268
269        self
270    }
271
272    /// Add a help command
273    pub fn help_command(&mut self, command: HelpCommand) -> &mut Self {
274        if self.error.is_some() {
275            return self;
276        }
277
278        // Don't overwrite commands
279        if self.help_command.is_some() {
280            self.error = Some(BuilderError::Duplicate("help".into()));
281            return self;
282        }
283
284        self.help_command = Some(command);
285
286        self
287    }
288
289    /// Add a check
290    pub fn check(&mut self, check: CheckFn) -> &mut Self {
291        if self.error.is_some() {
292            return self;
293        }
294
295        self.checks.push(check);
296        self
297    }
298
299    /// Build a framework
300    pub fn build(&mut self) -> Result<Framework, BuilderError> {
301        if let Some(error) = self.error.take() {
302            return Err(error);
303        }
304
305        Ok(Framework {
306            commands: Arc::new(std::mem::take(&mut self.commands)),
307            help_command: self.help_command.take().map(Arc::new),
308
309            checks: std::mem::take(&mut self.checks).into(),
310        })
311    }
312}
313
314impl std::fmt::Debug for FrameworkBuilder {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("FrameworkBuilder")
317            .field("commands", &self.commands)
318            .finish()
319    }
320}
321
322impl Default for FrameworkBuilder {
323    fn default() -> Self {
324        Self::new()
325    }
326}