pikadick/
main.rs

1#![deny(
2    unused_import_braces,
3    unused_lifetimes,
4    unreachable_pub,
5    trivial_numeric_casts,
6    missing_debug_implementations,
7    missing_copy_implementations,
8    deprecated_in_future,
9    meta_variable_misuse,
10    non_ascii_idents,
11    rust_2018_compatibility,
12    rust_2018_idioms,
13    future_incompatible,
14    nonstandard_style,
15    clippy::all
16)]
17#![warn(variant_size_differences, let_underscore_drop)]
18// TODO: Document everything properly
19// clippy::default_trait_access
20// clippy::use_self
21// clippy::undocumented_unsafe_blocks
22// clippy::allow_attributes_without_reason
23// clippy::as_underscore
24// clippy::cast_possible_truncation
25// clippy::cast_possible_wrap
26// clippy::cast_sign_loss
27// clippy::fn_to_numeric_cast_any
28// clippy::redundant_closure_for_method_calls
29// clippy::too_many_lines
30
31// TODO: Switch to poise
32#![allow(deprecated)]
33
34//! # Pikadick
35
36pub mod checks;
37pub mod cli_options;
38pub mod client_data;
39pub mod commands;
40pub mod config;
41pub mod database;
42pub mod logger;
43pub mod setup;
44pub mod util;
45
46use crate::{
47    cli_options::CliOptions,
48    client_data::ClientData,
49    commands::*,
50    config::{
51        ActivityKind,
52        Config,
53    },
54    database::{
55        model::TikTokEmbedFlags,
56        Database,
57    },
58    util::LoadingReaction,
59};
60use anyhow::{
61    bail,
62    ensure,
63    Context as _,
64};
65use pikadick_util::AsyncLockFile;
66use serenity::{
67    framework::standard::{
68        buckets::BucketBuilder,
69        help_commands,
70        macros::{
71            group,
72            help,
73        },
74        Args,
75        CommandGroup,
76        CommandResult,
77        Configuration as StandardFrameworkConfiguration,
78        DispatchError,
79        HelpOptions,
80        Reason,
81        StandardFramework,
82    },
83    futures::future::BoxFuture,
84    gateway::{
85        ActivityData,
86        ShardManager,
87    },
88    model::{
89        application::Interaction,
90        prelude::*,
91    },
92    prelude::*,
93    FutureExt,
94};
95use songbird::SerenityInit;
96use std::{
97    collections::HashSet,
98    sync::Arc,
99    time::{
100        Duration,
101        Instant,
102    },
103};
104use tokio::runtime::Builder as RuntimeBuilder;
105use tracing::{
106    error,
107    info,
108    warn,
109};
110use tracing_appender::non_blocking::WorkerGuard;
111use url::Url;
112
113const TOKIO_RT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
114
115struct Handler;
116
117#[serenity::async_trait]
118impl EventHandler for Handler {
119    async fn ready(&self, ctx: Context, ready: Ready) {
120        let data_lock = ctx.data.read().await;
121        let client_data = data_lock
122            .get::<ClientDataKey>()
123            .expect("missing client data");
124        let slash_framework = data_lock
125            .get::<SlashFrameworkKey>()
126            .expect("missing slash framework")
127            .clone();
128        let config = client_data.config.clone();
129        drop(data_lock);
130
131        if let (Some(status), Some(kind)) = (config.status_name(), config.status_type()) {
132            match kind {
133                ActivityKind::Listening => {
134                    ctx.set_activity(Some(ActivityData::listening(status)));
135                }
136                ActivityKind::Streaming => {
137                    let result: Result<_, anyhow::Error> = async {
138                        let activity = ActivityData::streaming(
139                            status,
140                            config.status_url().context("failed to get status url")?,
141                        )?;
142
143                        ctx.set_activity(Some(activity));
144
145                        Ok(())
146                    }
147                    .await;
148
149                    if let Err(error) = result.context("failed to set activity") {
150                        error!("{error:?}");
151                    }
152                }
153                ActivityKind::Playing => {
154                    ctx.set_activity(Some(ActivityData::playing(status)));
155                }
156            }
157        }
158
159        info!("logged in as \"{}\"", ready.user.name);
160
161        // TODO: Consider shutting down the bot. It might be possible to use old data though.
162        if let Err(error) = slash_framework
163            .register(ctx.clone(), config.test_guild)
164            .await
165            .context("failed to register slash commands")
166        {
167            error!("{error:?}");
168        }
169
170        info!("registered slash commands");
171    }
172
173    async fn resume(&self, _ctx: Context, _resumed: ResumedEvent) {
174        warn!("resumed connection");
175    }
176
177    #[tracing::instrument(skip(self, ctx, msg), fields(author = %msg.author.id, guild = ?msg.guild_id, content = %msg.content))]
178    async fn message(&self, ctx: Context, msg: Message) {
179        let data_lock = ctx.data.read().await;
180        let client_data = data_lock
181            .get::<ClientDataKey>()
182            .expect("missing client data");
183        let reddit_embed_data = client_data.reddit_embed_data.clone();
184        let tiktok_data = client_data.tiktok_data.clone();
185        let db = client_data.db.clone();
186        drop(data_lock);
187
188        // Process URL Embeds
189        {
190            // Only embed guild links
191            let guild_id = match msg.guild_id {
192                Some(id) => id,
193                None => {
194                    return;
195                }
196            };
197
198            // No Bots
199            if msg.author.bot {
200                return;
201            }
202
203            // Get enabled data for embeds
204            let reddit_embed_is_enabled_for_guild = db
205                .get_reddit_embed_enabled(guild_id)
206                .await
207                .with_context(|| format!("failed to get reddit-embed server data for {guild_id}"))
208                .unwrap_or_else(|error| {
209                    error!("{error:?}");
210                    false
211                });
212            let tiktok_embed_flags = db
213                .get_tiktok_embed_flags(guild_id)
214                .await
215                .with_context(|| format!("failed to get tiktok-embed server data for {guild_id}"))
216                .unwrap_or_else(|error| {
217                    error!("{error:?}");
218                    TikTokEmbedFlags::empty()
219                });
220
221            // Extract urls.
222            // We collect into a `Vec` as the regex iterator is not Sync and cannot be held across await points.
223            let urls: Vec<Url> = util::extract_urls(&msg.content).collect();
224
225            // Check to see if it we will even try to embed
226            let will_try_embedding = urls.iter().any(|url| {
227                let url_host = match url.host() {
228                    Some(host) => host,
229                    None => return false,
230                };
231
232                let reddit_url =
233                    matches!(url_host, url::Host::Domain("www.reddit.com" | "reddit.com"));
234
235                let tiktok_url = matches!(
236                    url_host,
237                    url::Host::Domain("vm.tiktok.com" | "tiktok.com" | "www.tiktok.com")
238                );
239
240                (reddit_url && reddit_embed_is_enabled_for_guild)
241                    || (tiktok_url && tiktok_embed_flags.contains(TikTokEmbedFlags::ENABLED))
242            });
243
244            // Return if we won't try embedding
245            if !will_try_embedding {
246                return;
247            }
248
249            let mut loading_reaction = Some(LoadingReaction::new(ctx.http.clone(), &msg));
250
251            // Embed for each url
252            // NOTE: we short circuit on failure since sending a msg to a channel and failing is most likely a permissions problem,
253            // especially since serenity retries each req once
254            for url in urls.iter() {
255                match url.host() {
256                    Some(url::Host::Domain("www.reddit.com" | "reddit.com")) => {
257                        // Don't process if it isn't enabled
258                        if reddit_embed_is_enabled_for_guild {
259                            if let Err(error) = reddit_embed_data
260                                .try_embed_url(&ctx, &msg, url, &mut loading_reaction)
261                                .await
262                                .context("failed to generate reddit embed")
263                            {
264                                error!("{error:?}");
265                            }
266                        }
267                    }
268                    Some(url::Host::Domain("vm.tiktok.com" | "tiktok.com" | "www.tiktok.com")) => {
269                        if tiktok_embed_flags.contains(TikTokEmbedFlags::ENABLED) {
270                            if let Err(error) = tiktok_data
271                                .try_embed_url(
272                                    &ctx,
273                                    &msg,
274                                    url,
275                                    &mut loading_reaction,
276                                    tiktok_embed_flags.contains(TikTokEmbedFlags::DELETE_LINK),
277                                )
278                                .await
279                                .context("failed to generate tiktok embed")
280                            {
281                                error!("{error:?}");
282                            }
283                        }
284                    }
285                    _ => {}
286                }
287            }
288
289            // Trim caches
290            reddit_embed_data.cache.trim();
291            reddit_embed_data.video_data_cache.trim();
292            tiktok_data.post_page_cache.trim();
293        }
294    }
295
296    async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
297        let data_lock = ctx.data.read().await;
298        let framework = data_lock
299            .get::<SlashFrameworkKey>()
300            .expect("missing slash framework")
301            .clone();
302        drop(data_lock);
303
304        framework.process_interaction_create(ctx, interaction).await;
305    }
306}
307
308#[derive(Debug, Clone, Copy)]
309pub struct ClientDataKey;
310
311impl TypeMapKey for ClientDataKey {
312    type Value = ClientData;
313}
314
315#[derive(Debug, Clone, Copy)]
316pub struct SlashFrameworkKey;
317
318impl TypeMapKey for SlashFrameworkKey {
319    type Value = pikadick_slash_framework::Framework;
320}
321
322#[help]
323async fn help(
324    ctx: &Context,
325    msg: &Message,
326    args: Args,
327    help_options: &'static HelpOptions,
328    groups: &[&'static CommandGroup],
329    owners: HashSet<UserId>,
330) -> CommandResult {
331    match help_commands::with_embeds(ctx, msg, args, help_options, groups, owners)
332        .await
333        .context("failed to send help")
334    {
335        Ok(_) => {}
336        Err(error) => {
337            error!("{error:?}");
338        }
339    }
340    Ok(())
341}
342
343#[group]
344#[commands(
345    system,
346    quizizz,
347    fml,
348    zalgo,
349    shift,
350    reddit_embed,
351    invite,
352    vaporwave,
353    cmd,
354    latency,
355    uwuify,
356    cache_stats,
357    insta_dl,
358    deviantart,
359    urban,
360    xkcd,
361    tic_tac_toe,
362    iqdb,
363    reddit,
364    leave,
365    stop,
366    sauce_nao
367)]
368struct General;
369
370async fn handle_ctrl_c(shard_manager: Arc<ShardManager>) {
371    match tokio::signal::ctrl_c()
372        .await
373        .context("failed to set ctrl-c handler")
374    {
375        Ok(_) => {
376            info!("shutting down...");
377            info!("stopping client...");
378            shard_manager.shutdown_all().await;
379        }
380        Err(error) => {
381            warn!("{error}");
382            // The default "kill everything" handler is probably still installed, so this isn't a problem?
383        }
384    };
385}
386
387#[tracing::instrument(skip(_ctx, msg), fields(author = %msg.author.id, guild = ?msg.guild_id, content = %msg.content))]
388fn before_handler<'fut>(
389    _ctx: &'fut Context,
390    msg: &'fut Message,
391    cmd_name: &'fut str,
392) -> BoxFuture<'fut, bool> {
393    info!("allowing command to process");
394    async move { true }.boxed()
395}
396
397fn after_handler<'fut>(
398    _ctx: &'fut Context,
399    _msg: &'fut Message,
400    command_name: &'fut str,
401    command_result: CommandResult,
402) -> BoxFuture<'fut, ()> {
403    async move {
404        if let Err(error) = command_result {
405            error!("failed to process command \"{command_name}\": {error}");
406        }
407    }
408    .boxed()
409}
410
411fn unrecognised_command_handler<'fut>(
412    ctx: &'fut Context,
413    msg: &'fut Message,
414    command_name: &'fut str,
415) -> BoxFuture<'fut, ()> {
416    async move {
417        error!("unrecognized command \"{command_name}\"");
418
419        let _ = msg
420            .channel_id
421            .say(
422                &ctx.http,
423                format!("Could not find command \"{command_name}\""),
424            )
425            .await
426            .is_ok();
427    }
428    .boxed()
429}
430
431fn process_dispatch_error<'fut>(
432    ctx: &'fut Context,
433    msg: &'fut Message,
434    error: DispatchError,
435    cmd_name: &'fut str,
436) -> BoxFuture<'fut, ()> {
437    process_dispatch_error_future(ctx, msg, error, cmd_name).boxed()
438}
439
440async fn process_dispatch_error_future<'fut>(
441    ctx: &'fut Context,
442    msg: &'fut Message,
443    error: DispatchError,
444    _cmd_name: &'fut str,
445) {
446    match error {
447        DispatchError::Ratelimited(duration) => {
448            let seconds = duration.as_secs();
449            let _ = msg
450                .channel_id
451                .say(
452                    &ctx.http,
453                    format!("Wait {seconds} seconds to use that command again"),
454                )
455                .await
456                .is_ok();
457        }
458        DispatchError::NotEnoughArguments { min, given } => {
459            let _ = msg
460                .channel_id
461                .say(
462                    &ctx.http,
463                    format!(
464                        "Expected at least {min} argument(s) for this command, but only got {given}",
465                    ),
466                )
467                .await
468                .is_ok();
469        }
470        DispatchError::TooManyArguments { max, given } => {
471            let response_str = format!("Expected no more than {max} argument(s) for this command, but got {given}. Try using quotation marks if your argument has spaces.");
472            let _ = msg.channel_id.say(&ctx.http, response_str).await.is_ok();
473        }
474        DispatchError::CheckFailed(check_name, reason) => match reason {
475            Reason::User(user_reason_str) => {
476                let _ = msg.channel_id.say(&ctx.http, user_reason_str).await.is_ok();
477            }
478            _ => {
479                let _ = msg
480                    .channel_id
481                    .say(
482                        &ctx.http,
483                        format!("\"{check_name}\" check failed: {reason:#?}"),
484                    )
485                    .await
486                    .is_ok();
487            }
488        },
489        error => {
490            let _ = msg
491                .channel_id
492                .say(&ctx.http, format!("Unhandled Dispatch Error: {error:?}"))
493                .await
494                .is_ok();
495        }
496    };
497}
498
499/// Set up a serenity client
500async fn setup_client(config: Arc<Config>) -> anyhow::Result<Client> {
501    // Setup slash framework
502    let slash_framework = pikadick_slash_framework::FrameworkBuilder::new()
503        .check(self::checks::enabled::create_slash_check)
504        .help_command(create_slash_help_command()?)
505        .command(nekos::create_slash_command()?)
506        .command(ping::create_slash_command()?)
507        .command(r6stats::create_slash_command()?)
508        .command(r6tracker::create_slash_command()?)
509        .command(rule34::create_slash_command()?)
510        .command(tiktok_embed::create_slash_command()?)
511        .command(chat::create_slash_command()?)
512        .command(yodaspeak::create_slash_command()?)
513        .build()?;
514
515    // Create second prefix that is uppercase so we are case-insensitive
516    let config_prefix = config.prefix.clone();
517    let uppercase_prefix = config_prefix.to_uppercase();
518
519    // Build the standard framework
520    info!("using prefix \"{config_prefix}\"");
521    let framework_config = StandardFrameworkConfiguration::new()
522        .prefixes([config_prefix, uppercase_prefix])
523        .case_insensitivity(true);
524    let framework = StandardFramework::new();
525    framework.configure(framework_config);
526    let framework = framework
527        .help(&HELP)
528        .group(&GENERAL_GROUP)
529        .bucket("r6stats", BucketBuilder::new_channel().delay(7))
530        .await
531        .bucket("r6tracker", BucketBuilder::new_channel().delay(7))
532        .await
533        .bucket("system", BucketBuilder::new_channel().delay(30))
534        .await
535        .bucket("quizizz", BucketBuilder::new_channel().delay(10))
536        .await
537        .bucket("insta-dl", BucketBuilder::new_channel().delay(10))
538        .await
539        .bucket("ttt-board", BucketBuilder::new_channel().delay(1))
540        .await
541        .bucket("default", BucketBuilder::new_channel().delay(1))
542        .await
543        .before(before_handler)
544        .after(after_handler)
545        .unrecognised_command(unrecognised_command_handler)
546        .on_dispatch_error(process_dispatch_error);
547
548    // Build the client
549    let config_token = config.token.clone();
550    let client = Client::builder(
551        config_token,
552        GatewayIntents::non_privileged() | GatewayIntents::MESSAGE_CONTENT,
553    )
554    .event_handler(Handler)
555    .application_id(ApplicationId::new(config.application_id))
556    .framework(framework)
557    .register_songbird()
558    .await
559    .context("failed to create client")?;
560
561    {
562        client
563            .data
564            .write()
565            .await
566            .insert::<SlashFrameworkKey>(slash_framework);
567    }
568
569    // TODO: Spawn a task for this earlier?
570    // Spawn the ctrl-c handler
571    tokio::spawn(handle_ctrl_c(client.shard_manager.clone()));
572
573    Ok(client)
574}
575
576/// Data from the setup function
577struct SetupData {
578    tokio_rt: tokio::runtime::Runtime,
579    config: Arc<Config>,
580    database: Database,
581    lock_file: AsyncLockFile,
582    worker_guard: WorkerGuard,
583}
584
585/// Pre-main setup
586fn setup(cli_options: CliOptions) -> anyhow::Result<SetupData> {
587    eprintln!("starting tokio runtime...");
588    let tokio_rt = RuntimeBuilder::new_multi_thread()
589        .enable_all()
590        .thread_name("pikadick-tokio-worker")
591        .build()
592        .context("failed to start tokio runtime")?;
593
594    let config = setup::load_config(&cli_options.config)
595        .map(Arc::new)
596        .context("failed to load config")?;
597
598    eprintln!("opening data directory...");
599    let data_dir_metadata = match std::fs::metadata(&config.data_dir) {
600        Ok(metadata) => Some(metadata),
601        Err(e) if e.kind() == std::io::ErrorKind::NotFound => None,
602        Err(e) => {
603            return Err(e).context("failed to get metadata for the data dir");
604        }
605    };
606
607    let _missing_data_dir = data_dir_metadata.is_none();
608    match data_dir_metadata.as_ref() {
609        Some(metadata) => {
610            if metadata.is_dir() {
611                eprintln!("data directory already exists.");
612            } else if metadata.is_file() {
613                bail!("failed to create or open data directory, the path is a file");
614            }
615        }
616        None => {
617            eprintln!("data directory does not exist. creating...");
618            std::fs::create_dir_all(&config.data_dir).context("failed to create data directory")?;
619        }
620    }
621
622    eprintln!("creating lockfile...");
623    let lock_file_path = config.data_dir.join("pikadick.lock");
624    let lock_file = AsyncLockFile::blocking_open(lock_file_path.as_std_path())
625        .context("failed to open lockfile")?;
626    let lock_file_locked = lock_file
627        .try_lock_with_pid_blocking()
628        .context("failed to try to lock the lockfile")?;
629    ensure!(lock_file_locked, "another process has locked the lockfile");
630
631    std::fs::create_dir_all(config.log_file_dir()).context("failed to create log file dir")?;
632    std::fs::create_dir_all(config.cache_dir()).context("failed to create cache dir")?;
633
634    // TODO: Init db
635    eprintln!("opening database...");
636    let database_path = config.data_dir.join("pikadick.sqlite");
637
638    // Safety: This is called before any other sqlite functions.
639    // TODO: Is there a good reason to not remake the db if it is missing?
640    let database = unsafe {
641        Database::blocking_new(database_path, true) // missing_data_dir
642            .context("failed to open database")?
643    };
644
645    // Everything past here is assumed to need tokio
646    let _enter_guard = tokio_rt.handle().enter();
647
648    eprintln!("setting up logger...");
649    let worker_guard = logger::setup(&config).context("failed to initialize logger")?;
650
651    eprintln!();
652    Ok(SetupData {
653        tokio_rt,
654        config,
655        database,
656        lock_file,
657        worker_guard,
658    })
659}
660
661/// The main entry.
662///
663/// Sets up the program and calls `real_main`.
664/// This allows more things to drop correctly.
665/// This also calls setup operations like loading config and setting up the tokio runtime,
666/// logging errors to the stderr instead of the loggers, which are not initialized yet.
667fn main() -> anyhow::Result<()> {
668    // This line MUST run first.
669    // It is needed to exit early if the options are invalid,
670    // and this will NOT run destructors if it does so.
671    let cli_options = argh::from_env();
672
673    let setup_data = setup(cli_options)?;
674    real_main(setup_data)?;
675    Ok(())
676}
677
678/// The actual entry point
679fn real_main(setup_data: SetupData) -> anyhow::Result<()> {
680    // We spawn this is a seperate thread/task as the main thread does not have enough stack space
681    let _enter_guard = setup_data.tokio_rt.enter();
682    let ret = setup_data.tokio_rt.block_on(tokio::spawn(async_main(
683        setup_data.config,
684        setup_data.database,
685    )));
686
687    let shutdown_start = Instant::now();
688    info!(
689        "shutting down tokio runtime (shutdown timeout is {:?})...",
690        TOKIO_RT_SHUTDOWN_TIMEOUT
691    );
692    setup_data
693        .tokio_rt
694        .shutdown_timeout(TOKIO_RT_SHUTDOWN_TIMEOUT);
695    info!("shutdown tokio runtime in {:?}", shutdown_start.elapsed());
696
697    info!("unlocking lockfile...");
698    setup_data
699        .lock_file
700        .blocking_unlock()
701        .context("failed to unlock lockfile")?;
702
703    info!("successful shutdown");
704
705    // Logging no longer reliable past this point
706    drop(setup_data.worker_guard);
707
708    ret?
709}
710
711/// The async entry
712async fn async_main(config: Arc<Config>, database: Database) -> anyhow::Result<()> {
713    // TODO: See if it is possible to start serenity without a network
714    info!("setting up client...");
715    let mut client = setup_client(config.clone())
716        .await
717        .context("failed to set up client")?;
718
719    let client_data = ClientData::init(client.shard_manager.clone(), config, database.clone())
720        .await
721        .context("client data initialization failed")?;
722
723    // Add all post-init client data changes here
724    {
725        client_data.enabled_check_data.add_groups(&[&GENERAL_GROUP]);
726    }
727
728    {
729        let mut data = client.data.write().await;
730        data.insert::<ClientDataKey>(client_data);
731    }
732
733    info!("logging in...");
734    client.start().await.context("failed to run client")?;
735    let client_data = {
736        let mut data = client.data.write().await;
737        data.remove::<ClientDataKey>().expect("missing client data")
738    };
739    drop(client);
740
741    info!("running shutdown routine for client data");
742    client_data.shutdown().await;
743    drop(client_data);
744
745    info!("closing database...");
746    database.close().await.context("failed to close database")?;
747
748    Ok(())
749}