1#![deny(
2 unused_import_braces,
3 unused_lifetimes,
4 unreachable_pub,
5 trivial_numeric_casts,
6 missing_debug_implementations,
7 missing_copy_implementations,
8 deprecated_in_future,
9 meta_variable_misuse,
10 non_ascii_idents,
11 rust_2018_compatibility,
12 rust_2018_idioms,
13 future_incompatible,
14 nonstandard_style,
15 clippy::all
16)]
17#![warn(variant_size_differences, let_underscore_drop)]
18#![allow(deprecated)]
33
34pub mod checks;
37pub mod cli_options;
38pub mod client_data;
39pub mod commands;
40pub mod config;
41pub mod database;
42pub mod logger;
43pub mod setup;
44pub mod util;
45
46use crate::{
47 cli_options::CliOptions,
48 client_data::ClientData,
49 commands::*,
50 config::{
51 ActivityKind,
52 Config,
53 },
54 database::{
55 model::TikTokEmbedFlags,
56 Database,
57 },
58 util::LoadingReaction,
59};
60use anyhow::{
61 bail,
62 ensure,
63 Context as _,
64};
65use pikadick_util::AsyncLockFile;
66use serenity::{
67 framework::standard::{
68 buckets::BucketBuilder,
69 help_commands,
70 macros::{
71 group,
72 help,
73 },
74 Args,
75 CommandGroup,
76 CommandResult,
77 Configuration as StandardFrameworkConfiguration,
78 DispatchError,
79 HelpOptions,
80 Reason,
81 StandardFramework,
82 },
83 futures::future::BoxFuture,
84 gateway::{
85 ActivityData,
86 ShardManager,
87 },
88 model::{
89 application::Interaction,
90 prelude::*,
91 },
92 prelude::*,
93 FutureExt,
94};
95use songbird::SerenityInit;
96use std::{
97 collections::HashSet,
98 sync::Arc,
99 time::{
100 Duration,
101 Instant,
102 },
103};
104use tokio::runtime::Builder as RuntimeBuilder;
105use tracing::{
106 error,
107 info,
108 warn,
109};
110use tracing_appender::non_blocking::WorkerGuard;
111use url::Url;
112
113const TOKIO_RT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
114
115struct Handler;
116
117#[serenity::async_trait]
118impl EventHandler for Handler {
119 async fn ready(&self, ctx: Context, ready: Ready) {
120 let data_lock = ctx.data.read().await;
121 let client_data = data_lock
122 .get::<ClientDataKey>()
123 .expect("missing client data");
124 let slash_framework = data_lock
125 .get::<SlashFrameworkKey>()
126 .expect("missing slash framework")
127 .clone();
128 let config = client_data.config.clone();
129 drop(data_lock);
130
131 if let (Some(status), Some(kind)) = (config.status_name(), config.status_type()) {
132 match kind {
133 ActivityKind::Listening => {
134 ctx.set_activity(Some(ActivityData::listening(status)));
135 }
136 ActivityKind::Streaming => {
137 let result: Result<_, anyhow::Error> = async {
138 let activity = ActivityData::streaming(
139 status,
140 config.status_url().context("failed to get status url")?,
141 )?;
142
143 ctx.set_activity(Some(activity));
144
145 Ok(())
146 }
147 .await;
148
149 if let Err(error) = result.context("failed to set activity") {
150 error!("{error:?}");
151 }
152 }
153 ActivityKind::Playing => {
154 ctx.set_activity(Some(ActivityData::playing(status)));
155 }
156 }
157 }
158
159 info!("logged in as \"{}\"", ready.user.name);
160
161 if let Err(error) = slash_framework
163 .register(ctx.clone(), config.test_guild)
164 .await
165 .context("failed to register slash commands")
166 {
167 error!("{error:?}");
168 }
169
170 info!("registered slash commands");
171 }
172
173 async fn resume(&self, _ctx: Context, _resumed: ResumedEvent) {
174 warn!("resumed connection");
175 }
176
177 #[tracing::instrument(skip(self, ctx, msg), fields(author = %msg.author.id, guild = ?msg.guild_id, content = %msg.content))]
178 async fn message(&self, ctx: Context, msg: Message) {
179 let data_lock = ctx.data.read().await;
180 let client_data = data_lock
181 .get::<ClientDataKey>()
182 .expect("missing client data");
183 let reddit_embed_data = client_data.reddit_embed_data.clone();
184 let tiktok_data = client_data.tiktok_data.clone();
185 let db = client_data.db.clone();
186 drop(data_lock);
187
188 {
190 let guild_id = match msg.guild_id {
192 Some(id) => id,
193 None => {
194 return;
195 }
196 };
197
198 if msg.author.bot {
200 return;
201 }
202
203 let reddit_embed_is_enabled_for_guild = db
205 .get_reddit_embed_enabled(guild_id)
206 .await
207 .with_context(|| format!("failed to get reddit-embed server data for {guild_id}"))
208 .unwrap_or_else(|error| {
209 error!("{error:?}");
210 false
211 });
212 let tiktok_embed_flags = db
213 .get_tiktok_embed_flags(guild_id)
214 .await
215 .with_context(|| format!("failed to get tiktok-embed server data for {guild_id}"))
216 .unwrap_or_else(|error| {
217 error!("{error:?}");
218 TikTokEmbedFlags::empty()
219 });
220
221 let urls: Vec<Url> = util::extract_urls(&msg.content).collect();
224
225 let will_try_embedding = urls.iter().any(|url| {
227 let url_host = match url.host() {
228 Some(host) => host,
229 None => return false,
230 };
231
232 let reddit_url =
233 matches!(url_host, url::Host::Domain("www.reddit.com" | "reddit.com"));
234
235 let tiktok_url = matches!(
236 url_host,
237 url::Host::Domain("vm.tiktok.com" | "tiktok.com" | "www.tiktok.com")
238 );
239
240 (reddit_url && reddit_embed_is_enabled_for_guild)
241 || (tiktok_url && tiktok_embed_flags.contains(TikTokEmbedFlags::ENABLED))
242 });
243
244 if !will_try_embedding {
246 return;
247 }
248
249 let mut loading_reaction = Some(LoadingReaction::new(ctx.http.clone(), &msg));
250
251 for url in urls.iter() {
255 match url.host() {
256 Some(url::Host::Domain("www.reddit.com" | "reddit.com")) => {
257 if reddit_embed_is_enabled_for_guild {
259 if let Err(error) = reddit_embed_data
260 .try_embed_url(&ctx, &msg, url, &mut loading_reaction)
261 .await
262 .context("failed to generate reddit embed")
263 {
264 error!("{error:?}");
265 }
266 }
267 }
268 Some(url::Host::Domain("vm.tiktok.com" | "tiktok.com" | "www.tiktok.com")) => {
269 if tiktok_embed_flags.contains(TikTokEmbedFlags::ENABLED) {
270 if let Err(error) = tiktok_data
271 .try_embed_url(
272 &ctx,
273 &msg,
274 url,
275 &mut loading_reaction,
276 tiktok_embed_flags.contains(TikTokEmbedFlags::DELETE_LINK),
277 )
278 .await
279 .context("failed to generate tiktok embed")
280 {
281 error!("{error:?}");
282 }
283 }
284 }
285 _ => {}
286 }
287 }
288
289 reddit_embed_data.cache.trim();
291 reddit_embed_data.video_data_cache.trim();
292 tiktok_data.post_page_cache.trim();
293 }
294 }
295
296 async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
297 let data_lock = ctx.data.read().await;
298 let framework = data_lock
299 .get::<SlashFrameworkKey>()
300 .expect("missing slash framework")
301 .clone();
302 drop(data_lock);
303
304 framework.process_interaction_create(ctx, interaction).await;
305 }
306}
307
308#[derive(Debug, Clone, Copy)]
309pub struct ClientDataKey;
310
311impl TypeMapKey for ClientDataKey {
312 type Value = ClientData;
313}
314
315#[derive(Debug, Clone, Copy)]
316pub struct SlashFrameworkKey;
317
318impl TypeMapKey for SlashFrameworkKey {
319 type Value = pikadick_slash_framework::Framework;
320}
321
322#[help]
323async fn help(
324 ctx: &Context,
325 msg: &Message,
326 args: Args,
327 help_options: &'static HelpOptions,
328 groups: &[&'static CommandGroup],
329 owners: HashSet<UserId>,
330) -> CommandResult {
331 match help_commands::with_embeds(ctx, msg, args, help_options, groups, owners)
332 .await
333 .context("failed to send help")
334 {
335 Ok(_) => {}
336 Err(error) => {
337 error!("{error:?}");
338 }
339 }
340 Ok(())
341}
342
343#[group]
344#[commands(
345 system,
346 quizizz,
347 fml,
348 zalgo,
349 shift,
350 reddit_embed,
351 invite,
352 vaporwave,
353 cmd,
354 latency,
355 uwuify,
356 cache_stats,
357 insta_dl,
358 deviantart,
359 urban,
360 xkcd,
361 tic_tac_toe,
362 iqdb,
363 reddit,
364 leave,
365 stop,
366 sauce_nao
367)]
368struct General;
369
370async fn handle_ctrl_c(shard_manager: Arc<ShardManager>) {
371 match tokio::signal::ctrl_c()
372 .await
373 .context("failed to set ctrl-c handler")
374 {
375 Ok(_) => {
376 info!("shutting down...");
377 info!("stopping client...");
378 shard_manager.shutdown_all().await;
379 }
380 Err(error) => {
381 warn!("{error}");
382 }
384 };
385}
386
387#[tracing::instrument(skip(_ctx, msg), fields(author = %msg.author.id, guild = ?msg.guild_id, content = %msg.content))]
388fn before_handler<'fut>(
389 _ctx: &'fut Context,
390 msg: &'fut Message,
391 cmd_name: &'fut str,
392) -> BoxFuture<'fut, bool> {
393 info!("allowing command to process");
394 async move { true }.boxed()
395}
396
397fn after_handler<'fut>(
398 _ctx: &'fut Context,
399 _msg: &'fut Message,
400 command_name: &'fut str,
401 command_result: CommandResult,
402) -> BoxFuture<'fut, ()> {
403 async move {
404 if let Err(error) = command_result {
405 error!("failed to process command \"{command_name}\": {error}");
406 }
407 }
408 .boxed()
409}
410
411fn unrecognised_command_handler<'fut>(
412 ctx: &'fut Context,
413 msg: &'fut Message,
414 command_name: &'fut str,
415) -> BoxFuture<'fut, ()> {
416 async move {
417 error!("unrecognized command \"{command_name}\"");
418
419 let _ = msg
420 .channel_id
421 .say(
422 &ctx.http,
423 format!("Could not find command \"{command_name}\""),
424 )
425 .await
426 .is_ok();
427 }
428 .boxed()
429}
430
431fn process_dispatch_error<'fut>(
432 ctx: &'fut Context,
433 msg: &'fut Message,
434 error: DispatchError,
435 cmd_name: &'fut str,
436) -> BoxFuture<'fut, ()> {
437 process_dispatch_error_future(ctx, msg, error, cmd_name).boxed()
438}
439
440async fn process_dispatch_error_future<'fut>(
441 ctx: &'fut Context,
442 msg: &'fut Message,
443 error: DispatchError,
444 _cmd_name: &'fut str,
445) {
446 match error {
447 DispatchError::Ratelimited(duration) => {
448 let seconds = duration.as_secs();
449 let _ = msg
450 .channel_id
451 .say(
452 &ctx.http,
453 format!("Wait {seconds} seconds to use that command again"),
454 )
455 .await
456 .is_ok();
457 }
458 DispatchError::NotEnoughArguments { min, given } => {
459 let _ = msg
460 .channel_id
461 .say(
462 &ctx.http,
463 format!(
464 "Expected at least {min} argument(s) for this command, but only got {given}",
465 ),
466 )
467 .await
468 .is_ok();
469 }
470 DispatchError::TooManyArguments { max, given } => {
471 let response_str = format!("Expected no more than {max} argument(s) for this command, but got {given}. Try using quotation marks if your argument has spaces.");
472 let _ = msg.channel_id.say(&ctx.http, response_str).await.is_ok();
473 }
474 DispatchError::CheckFailed(check_name, reason) => match reason {
475 Reason::User(user_reason_str) => {
476 let _ = msg.channel_id.say(&ctx.http, user_reason_str).await.is_ok();
477 }
478 _ => {
479 let _ = msg
480 .channel_id
481 .say(
482 &ctx.http,
483 format!("\"{check_name}\" check failed: {reason:#?}"),
484 )
485 .await
486 .is_ok();
487 }
488 },
489 error => {
490 let _ = msg
491 .channel_id
492 .say(&ctx.http, format!("Unhandled Dispatch Error: {error:?}"))
493 .await
494 .is_ok();
495 }
496 };
497}
498
499async fn setup_client(config: Arc<Config>) -> anyhow::Result<Client> {
501 let slash_framework = pikadick_slash_framework::FrameworkBuilder::new()
503 .check(self::checks::enabled::create_slash_check)
504 .help_command(create_slash_help_command()?)
505 .command(nekos::create_slash_command()?)
506 .command(ping::create_slash_command()?)
507 .command(r6stats::create_slash_command()?)
508 .command(r6tracker::create_slash_command()?)
509 .command(rule34::create_slash_command()?)
510 .command(tiktok_embed::create_slash_command()?)
511 .command(chat::create_slash_command()?)
512 .command(yodaspeak::create_slash_command()?)
513 .build()?;
514
515 let config_prefix = config.prefix.clone();
517 let uppercase_prefix = config_prefix.to_uppercase();
518
519 info!("using prefix \"{config_prefix}\"");
521 let framework_config = StandardFrameworkConfiguration::new()
522 .prefixes([config_prefix, uppercase_prefix])
523 .case_insensitivity(true);
524 let framework = StandardFramework::new();
525 framework.configure(framework_config);
526 let framework = framework
527 .help(&HELP)
528 .group(&GENERAL_GROUP)
529 .bucket("r6stats", BucketBuilder::new_channel().delay(7))
530 .await
531 .bucket("r6tracker", BucketBuilder::new_channel().delay(7))
532 .await
533 .bucket("system", BucketBuilder::new_channel().delay(30))
534 .await
535 .bucket("quizizz", BucketBuilder::new_channel().delay(10))
536 .await
537 .bucket("insta-dl", BucketBuilder::new_channel().delay(10))
538 .await
539 .bucket("ttt-board", BucketBuilder::new_channel().delay(1))
540 .await
541 .bucket("default", BucketBuilder::new_channel().delay(1))
542 .await
543 .before(before_handler)
544 .after(after_handler)
545 .unrecognised_command(unrecognised_command_handler)
546 .on_dispatch_error(process_dispatch_error);
547
548 let config_token = config.token.clone();
550 let client = Client::builder(
551 config_token,
552 GatewayIntents::non_privileged() | GatewayIntents::MESSAGE_CONTENT,
553 )
554 .event_handler(Handler)
555 .application_id(ApplicationId::new(config.application_id))
556 .framework(framework)
557 .register_songbird()
558 .await
559 .context("failed to create client")?;
560
561 {
562 client
563 .data
564 .write()
565 .await
566 .insert::<SlashFrameworkKey>(slash_framework);
567 }
568
569 tokio::spawn(handle_ctrl_c(client.shard_manager.clone()));
572
573 Ok(client)
574}
575
576struct SetupData {
578 tokio_rt: tokio::runtime::Runtime,
579 config: Arc<Config>,
580 database: Database,
581 lock_file: AsyncLockFile,
582 worker_guard: WorkerGuard,
583}
584
585fn setup(cli_options: CliOptions) -> anyhow::Result<SetupData> {
587 eprintln!("starting tokio runtime...");
588 let tokio_rt = RuntimeBuilder::new_multi_thread()
589 .enable_all()
590 .thread_name("pikadick-tokio-worker")
591 .build()
592 .context("failed to start tokio runtime")?;
593
594 let config = setup::load_config(&cli_options.config)
595 .map(Arc::new)
596 .context("failed to load config")?;
597
598 eprintln!("opening data directory...");
599 let data_dir_metadata = match std::fs::metadata(&config.data_dir) {
600 Ok(metadata) => Some(metadata),
601 Err(e) if e.kind() == std::io::ErrorKind::NotFound => None,
602 Err(e) => {
603 return Err(e).context("failed to get metadata for the data dir");
604 }
605 };
606
607 let _missing_data_dir = data_dir_metadata.is_none();
608 match data_dir_metadata.as_ref() {
609 Some(metadata) => {
610 if metadata.is_dir() {
611 eprintln!("data directory already exists.");
612 } else if metadata.is_file() {
613 bail!("failed to create or open data directory, the path is a file");
614 }
615 }
616 None => {
617 eprintln!("data directory does not exist. creating...");
618 std::fs::create_dir_all(&config.data_dir).context("failed to create data directory")?;
619 }
620 }
621
622 eprintln!("creating lockfile...");
623 let lock_file_path = config.data_dir.join("pikadick.lock");
624 let lock_file = AsyncLockFile::blocking_open(lock_file_path.as_std_path())
625 .context("failed to open lockfile")?;
626 let lock_file_locked = lock_file
627 .try_lock_with_pid_blocking()
628 .context("failed to try to lock the lockfile")?;
629 ensure!(lock_file_locked, "another process has locked the lockfile");
630
631 std::fs::create_dir_all(config.log_file_dir()).context("failed to create log file dir")?;
632 std::fs::create_dir_all(config.cache_dir()).context("failed to create cache dir")?;
633
634 eprintln!("opening database...");
636 let database_path = config.data_dir.join("pikadick.sqlite");
637
638 let database = unsafe {
641 Database::blocking_new(database_path, true) .context("failed to open database")?
643 };
644
645 let _enter_guard = tokio_rt.handle().enter();
647
648 eprintln!("setting up logger...");
649 let worker_guard = logger::setup(&config).context("failed to initialize logger")?;
650
651 eprintln!();
652 Ok(SetupData {
653 tokio_rt,
654 config,
655 database,
656 lock_file,
657 worker_guard,
658 })
659}
660
661fn main() -> anyhow::Result<()> {
668 let cli_options = argh::from_env();
672
673 let setup_data = setup(cli_options)?;
674 real_main(setup_data)?;
675 Ok(())
676}
677
678fn real_main(setup_data: SetupData) -> anyhow::Result<()> {
680 let _enter_guard = setup_data.tokio_rt.enter();
682 let ret = setup_data.tokio_rt.block_on(tokio::spawn(async_main(
683 setup_data.config,
684 setup_data.database,
685 )));
686
687 let shutdown_start = Instant::now();
688 info!(
689 "shutting down tokio runtime (shutdown timeout is {:?})...",
690 TOKIO_RT_SHUTDOWN_TIMEOUT
691 );
692 setup_data
693 .tokio_rt
694 .shutdown_timeout(TOKIO_RT_SHUTDOWN_TIMEOUT);
695 info!("shutdown tokio runtime in {:?}", shutdown_start.elapsed());
696
697 info!("unlocking lockfile...");
698 setup_data
699 .lock_file
700 .blocking_unlock()
701 .context("failed to unlock lockfile")?;
702
703 info!("successful shutdown");
704
705 drop(setup_data.worker_guard);
707
708 ret?
709}
710
711async fn async_main(config: Arc<Config>, database: Database) -> anyhow::Result<()> {
713 info!("setting up client...");
715 let mut client = setup_client(config.clone())
716 .await
717 .context("failed to set up client")?;
718
719 let client_data = ClientData::init(client.shard_manager.clone(), config, database.clone())
720 .await
721 .context("client data initialization failed")?;
722
723 {
725 client_data.enabled_check_data.add_groups(&[&GENERAL_GROUP]);
726 }
727
728 {
729 let mut data = client.data.write().await;
730 data.insert::<ClientDataKey>(client_data);
731 }
732
733 info!("logging in...");
734 client.start().await.context("failed to run client")?;
735 let client_data = {
736 let mut data = client.data.write().await;
737 data.remove::<ClientDataKey>().expect("missing client data")
738 };
739 drop(client);
740
741 info!("running shutdown routine for client data");
742 client_data.shutdown().await;
743 drop(client_data);
744
745 info!("closing database...");
746 database.close().await.context("failed to close database")?;
747
748 Ok(())
749}