pikadick/commands/
reddit_embed.rs

1use crate::{
2    checks::{
3        ADMIN_CHECK,
4        ENABLED_CHECK,
5    },
6    client_data::{
7        CacheStatsBuilder,
8        CacheStatsProvider,
9    },
10    util::{
11        LoadingReaction,
12        TimedCache,
13        TimedCacheEntry,
14    },
15    ClientDataKey,
16};
17use anyhow::{
18    bail,
19    Context as _,
20};
21use dashmap::DashMap;
22use rand::seq::SliceRandom;
23use reddit_tube::types::get_video_response::GetVideoResponseOk;
24use serenity::{
25    framework::standard::{
26        macros::command,
27        Args,
28        CommandResult,
29    },
30    model::prelude::*,
31    prelude::*,
32};
33use std::{
34    sync::Arc,
35    time::{
36        Duration,
37        Instant,
38    },
39};
40use tracing::{
41    error,
42    info,
43    warn,
44};
45use url::Url;
46
47type SubReddit = String;
48type PostId = String;
49
50type LinkVec = Vec<Arc<reddit::Link>>;
51
52// pub struct SubredditPostIdentifier {}
53
54#[derive(Clone)]
55pub struct RedditEmbedData {
56    reddit_client: reddit::Client,
57    reddit_tube_client: reddit_tube::Client,
58
59    pub cache: TimedCache<(SubReddit, PostId), String>,
60    pub video_data_cache: TimedCache<String, Box<GetVideoResponseOk>>,
61    random_post_cache: Arc<DashMap<String, Arc<(Instant, LinkVec)>>>,
62}
63
64impl RedditEmbedData {
65    /// Make a new [`RedditEmbedData`].
66    pub fn new() -> Self {
67        RedditEmbedData {
68            reddit_client: reddit::Client::new(),
69            reddit_tube_client: reddit_tube::Client::new(),
70
71            cache: Default::default(),
72            video_data_cache: TimedCache::new(),
73            random_post_cache: Arc::new(DashMap::new()),
74        }
75    }
76
77    /// Get the original post from a given subreddit and post id.
78    ///
79    /// This resolves crossposts. Currently only resolves 1 layer.
80    pub async fn get_original_post(
81        &self,
82        subreddit: &str,
83        post_id: &str,
84    ) -> anyhow::Result<Box<reddit::Link>> {
85        let mut post_data = self.reddit_client.get_post(subreddit, post_id).await?;
86
87        if post_data.is_empty() {
88            bail!("missing post");
89        }
90
91        let mut post_data = post_data
92            .swap_remove(0)
93            .data
94            .into_listing()
95            .context("missing post")?
96            .children;
97
98        if post_data.is_empty() {
99            bail!("missing post");
100        }
101
102        let mut post = post_data
103            .swap_remove(0)
104            .data
105            .into_link()
106            .context("missing post")?;
107
108        // If cross post, resolve one level. Is it possible to crosspost a crosspost?
109
110        // Remove crosspost list from response...
111        let crosspost_parent_list = std::mem::take(&mut post.crosspost_parent_list);
112        if let Some(post) = crosspost_parent_list.and_then(|mut l| {
113            if l.is_empty() {
114                None
115            } else {
116                Some(l.swap_remove(0))
117            }
118        }) {
119            // TODO: Crossposts are not stored in boxes, but in a vec. We need to unify the return types somehow.
120            // Should we choose to move out of a box, or move into a box? Which will be used more?
121            Ok(Box::new(post))
122        } else {
123            Ok(post)
124        }
125    }
126
127    /// Get video data from reddit.tube.
128    ///
129    /// Takes a reddit url.
130    pub async fn get_video_data(&self, url: &str) -> anyhow::Result<Box<GetVideoResponseOk>> {
131        let main_page = self
132            .reddit_tube_client
133            .get_main_page()
134            .await
135            .context("failed to get main page")?;
136        self.reddit_tube_client
137            .get_video(&main_page, url)
138            .await
139            .context("failed to get video data")?
140            .into_result()
141            .context("bad video response")
142    }
143
144    /// Get video data, but using a cache.
145    pub async fn get_video_data_cached(
146        &self,
147        url: &str,
148    ) -> anyhow::Result<Arc<TimedCacheEntry<Box<GetVideoResponseOk>>>> {
149        if let Some(response) = self.video_data_cache.get_if_fresh(url) {
150            return Ok(response);
151        }
152
153        let video_data = self.get_video_data(url).await?;
154
155        Ok(self
156            .video_data_cache
157            .insert_and_get(url.to_string(), video_data))
158    }
159
160    /// Create a video url for a url to a reddit video post.
161    pub async fn create_video_url(&self, url: &str) -> anyhow::Result<Url> {
162        let maybe_url = self
163            .get_video_data_cached(url)
164            .await
165            .with_context(|| format!("failed to get reddit video info for '{}'", url))
166            .map(|video_data| video_data.data().url.clone());
167
168        if let Err(e) = maybe_url.as_ref() {
169            warn!("{:?}", e);
170        }
171
172        maybe_url
173    }
174
175    /// Get a reddit embed url for a given subreddit and post id
176    pub async fn get_embed_url(&self, url: &Url) -> anyhow::Result<String> {
177        let (subreddit, post_id) = parse_post_url(url).context("failed to parse post")?;
178
179        let original_post = self
180            .get_original_post(subreddit, post_id)
181            .await
182            .context("failed to get reddit post")
183            .map_err(|e| {
184                warn!("{:?}", e);
185                e
186            })?;
187
188        if !original_post.is_video {
189            return Ok(original_post.url.into());
190        }
191
192        self.create_video_url(url.as_str())
193            .await
194            .map(|url| url.into())
195    }
196
197    /// Try to embed a url
198    pub async fn try_embed_url(
199        &self,
200        ctx: &Context,
201        msg: &Message,
202        url: &Url,
203        loading_reaction: &mut Option<LoadingReaction>,
204    ) -> anyhow::Result<()> {
205        // This is sometimes TOO smart and finds data for invalid urls...
206        // TODO: Consider making parsing stricter
207        if let Some((subreddit, post_id)) = parse_post_url(url) {
208            // Try cache
209            let maybe_url = self
210                .cache
211                .get_if_fresh(&(subreddit.into(), post_id.into()))
212                .map(|el| el.data().clone());
213
214            let data = if let Some(value) = maybe_url.clone() {
215                Some(value)
216            } else {
217                self.get_embed_url(url).await.ok()
218            };
219
220            if let Some(data) = data {
221                self.cache
222                    .insert((subreddit.into(), post_id.into()), data.clone());
223
224                // TODO: Consider downloading and reposting?
225                msg.channel_id.say(&ctx.http, data).await?;
226                if let Some(mut loading_reaction) = loading_reaction.take() {
227                    loading_reaction.send_ok();
228                }
229            }
230        } else {
231            error!("failed to parse reddit post url");
232            // TODO: Maybe expand this to an actual error to give better feedback
233        }
234        Ok(())
235    }
236
237    /// Get a random post url for a subreddit
238    pub async fn get_random_post(&self, subreddit: &str) -> anyhow::Result<Option<String>> {
239        {
240            let urls = self.random_post_cache.get(subreddit);
241
242            if let Some(link) = urls.and_then(|v| {
243                let entry = v.value().clone();
244                if entry.0.elapsed() > Duration::from_secs(10 * 60) {
245                    return None;
246                }
247                entry.1.choose(&mut rand::thread_rng()).cloned()
248            }) {
249                let url = self.reddit_link_to_embed_url(&link).await?;
250                return Ok(Some(url));
251            }
252        }
253
254        info!("fetching reddit posts for '{}'", subreddit);
255        let mut maybe_url = None;
256        let list = self.reddit_client.get_subreddit(subreddit, 100).await?;
257        if let Some(listing) = list.data.into_listing() {
258            let posts: Vec<Arc<reddit::Link>> = listing
259                .children
260                .into_iter()
261                .filter_map(|child| child.data.into_link())
262                .filter_map(|post| {
263                    if let Some(mut post) = post.crosspost_parent_list {
264                        if post.is_empty() {
265                            None
266                        } else {
267                            Some(post.swap_remove(0).into())
268                        }
269                    } else {
270                        Some(post)
271                    }
272                })
273                .map(|link| Arc::new(*link))
274                .collect();
275
276            let maybe_link = posts.choose(&mut rand::thread_rng()).cloned();
277            if let Some(link) = maybe_link {
278                maybe_url = Some(self.reddit_link_to_embed_url(&link).await?);
279            }
280
281            self.random_post_cache
282                .insert(subreddit.to_string(), Arc::new((Instant::now(), posts)));
283        }
284
285        Ok(maybe_url)
286    }
287
288    /// Convert a reddit link to an embed url
289    async fn reddit_link_to_embed_url(&self, link: &reddit::Link) -> anyhow::Result<String> {
290        let post_url = format!("https://www.reddit.com{}", link.permalink);
291
292        // Discord should be able to embed non-18 stuff
293        if !link.over_18 {
294            return Ok(post_url);
295        }
296
297        match link.post_hint {
298            Some(reddit::PostHint::HostedVideo) => {
299                let url = self.create_video_url(&post_url).await?;
300                Ok(url.into())
301            }
302            _ => Ok(link.url.clone().into()),
303        }
304    }
305}
306
307impl CacheStatsProvider for RedditEmbedData {
308    fn publish_cache_stats(&self, cache_stats_builder: &mut CacheStatsBuilder) {
309        cache_stats_builder.publish_stat("reddit_embed", "link_cache", self.cache.len() as f32);
310        cache_stats_builder.publish_stat(
311            "reddit_embed",
312            "video_data_cache",
313            self.video_data_cache.len() as f32,
314        );
315        cache_stats_builder.publish_stat(
316            "reddit_embed",
317            "random_post_cache",
318            self.random_post_cache
319                .iter()
320                .map(|v| v.value().1.len())
321                .sum::<usize>() as f32,
322        );
323    }
324}
325
326impl std::fmt::Debug for RedditEmbedData {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        // TODO: Replace with manual impl if/when reddit_client becomes debug
329        f.debug_struct("RedditEmbedData")
330            .field("reddit_tube_client", &self.reddit_tube_client)
331            .field("cache", &self.cache)
332            .finish()
333    }
334}
335
336impl Default for RedditEmbedData {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342// Broken in help:
343// #[required_permissions("ADMINISTRATOR")]
344
345#[command("reddit-embed")]
346#[description("Enable automaitc reddit embedding for this server")]
347#[usage("<enable/disable>")]
348#[example("enable")]
349#[min_args(1)]
350#[max_args(1)]
351#[checks(Admin, Enabled)]
352async fn reddit_embed(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
353    let data_lock = ctx.data.read().await;
354    let client_data = data_lock.get::<ClientDataKey>().unwrap();
355    let db = client_data.db.clone();
356    drop(data_lock);
357
358    let enable = match args.trimmed().current().expect("missing arg") {
359        "enable" => true,
360        "disable" => false,
361        arg => {
362            msg.channel_id
363                .say(
364                    &ctx.http,
365                    format!(
366                        "The argument '{}' is not recognized. Valid: enable, disable",
367                        arg
368                    ),
369                )
370                .await?;
371            return Ok(());
372        }
373    };
374
375    // TODO: Probably can unwrap if i add a check to the command
376    let guild_id = match msg.guild_id {
377        Some(id) => id,
378        None => {
379            msg.channel_id
380                .say(
381                    &ctx.http,
382                    "Missing server id. Are you in a server right now?",
383                )
384                .await?;
385            return Ok(());
386        }
387    };
388
389    let old_val = db.set_reddit_embed_enabled(guild_id, enable).await?;
390
391    let status_str = if enable { "enabled" } else { "disabled" };
392
393    if enable == old_val {
394        msg.channel_id
395            .say(
396                &ctx.http,
397                format!("Reddit embeds are already {} for this server", status_str),
398            )
399            .await?;
400    } else {
401        msg.channel_id
402            .say(
403                &ctx.http,
404                format!("Reddit embeds are now {} for this guild", status_str),
405            )
406            .await?;
407    }
408
409    Ok(())
410}
411
412/// Gets the subreddit and post id from a reddit url.
413///
414/// # Returns
415/// Returns a tuple or the the subreddit and post id in that order.
416pub fn parse_post_url(url: &Url) -> Option<(&str, &str)> {
417    // Reddit path:
418    // /r/dankmemes/comments/h966lq/davie_is_shookt/
419
420    // Template:
421    // /r/<subreddit>/comments/<post_id>/<post_title (irrelevant)>/
422
423    // Parts:
424    // r
425    // <subreddit>
426    // comments
427    // <post_id>
428    // <post_title>
429    // (Nothing, should be empty or not existent)
430
431    let mut iter = url.path_segments()?;
432
433    if iter.next()? != "r" {
434        return None;
435    }
436
437    let subreddit = iter.next()?;
438
439    if iter.next()? != "comments" {
440        return None;
441    }
442
443    let post_id = iter.next()?;
444
445    // TODO: Should we reject urls with the wrong ending?
446
447    Some((subreddit, post_id))
448}