1use crate::{
2 checks::{
3 ADMIN_CHECK,
4 ENABLED_CHECK,
5 },
6 client_data::{
7 CacheStatsBuilder,
8 CacheStatsProvider,
9 },
10 util::{
11 LoadingReaction,
12 TimedCache,
13 TimedCacheEntry,
14 },
15 ClientDataKey,
16};
17use anyhow::{
18 bail,
19 Context as _,
20};
21use dashmap::DashMap;
22use rand::seq::SliceRandom;
23use reddit_tube::types::get_video_response::GetVideoResponseOk;
24use serenity::{
25 framework::standard::{
26 macros::command,
27 Args,
28 CommandResult,
29 },
30 model::prelude::*,
31 prelude::*,
32};
33use std::{
34 sync::Arc,
35 time::{
36 Duration,
37 Instant,
38 },
39};
40use tracing::{
41 error,
42 info,
43 warn,
44};
45use url::Url;
46
47type SubReddit = String;
48type PostId = String;
49
50type LinkVec = Vec<Arc<reddit::Link>>;
51
52#[derive(Clone)]
55pub struct RedditEmbedData {
56 reddit_client: reddit::Client,
57 reddit_tube_client: reddit_tube::Client,
58
59 pub cache: TimedCache<(SubReddit, PostId), String>,
60 pub video_data_cache: TimedCache<String, Box<GetVideoResponseOk>>,
61 random_post_cache: Arc<DashMap<String, Arc<(Instant, LinkVec)>>>,
62}
63
64impl RedditEmbedData {
65 pub fn new() -> Self {
67 RedditEmbedData {
68 reddit_client: reddit::Client::new(),
69 reddit_tube_client: reddit_tube::Client::new(),
70
71 cache: Default::default(),
72 video_data_cache: TimedCache::new(),
73 random_post_cache: Arc::new(DashMap::new()),
74 }
75 }
76
77 pub async fn get_original_post(
81 &self,
82 subreddit: &str,
83 post_id: &str,
84 ) -> anyhow::Result<Box<reddit::Link>> {
85 let mut post_data = self.reddit_client.get_post(subreddit, post_id).await?;
86
87 if post_data.is_empty() {
88 bail!("missing post");
89 }
90
91 let mut post_data = post_data
92 .swap_remove(0)
93 .data
94 .into_listing()
95 .context("missing post")?
96 .children;
97
98 if post_data.is_empty() {
99 bail!("missing post");
100 }
101
102 let mut post = post_data
103 .swap_remove(0)
104 .data
105 .into_link()
106 .context("missing post")?;
107
108 let crosspost_parent_list = std::mem::take(&mut post.crosspost_parent_list);
112 if let Some(post) = crosspost_parent_list.and_then(|mut l| {
113 if l.is_empty() {
114 None
115 } else {
116 Some(l.swap_remove(0))
117 }
118 }) {
119 Ok(Box::new(post))
122 } else {
123 Ok(post)
124 }
125 }
126
127 pub async fn get_video_data(&self, url: &str) -> anyhow::Result<Box<GetVideoResponseOk>> {
131 let main_page = self
132 .reddit_tube_client
133 .get_main_page()
134 .await
135 .context("failed to get main page")?;
136 self.reddit_tube_client
137 .get_video(&main_page, url)
138 .await
139 .context("failed to get video data")?
140 .into_result()
141 .context("bad video response")
142 }
143
144 pub async fn get_video_data_cached(
146 &self,
147 url: &str,
148 ) -> anyhow::Result<Arc<TimedCacheEntry<Box<GetVideoResponseOk>>>> {
149 if let Some(response) = self.video_data_cache.get_if_fresh(url) {
150 return Ok(response);
151 }
152
153 let video_data = self.get_video_data(url).await?;
154
155 Ok(self
156 .video_data_cache
157 .insert_and_get(url.to_string(), video_data))
158 }
159
160 pub async fn create_video_url(&self, url: &str) -> anyhow::Result<Url> {
162 let maybe_url = self
163 .get_video_data_cached(url)
164 .await
165 .with_context(|| format!("failed to get reddit video info for '{}'", url))
166 .map(|video_data| video_data.data().url.clone());
167
168 if let Err(e) = maybe_url.as_ref() {
169 warn!("{:?}", e);
170 }
171
172 maybe_url
173 }
174
175 pub async fn get_embed_url(&self, url: &Url) -> anyhow::Result<String> {
177 let (subreddit, post_id) = parse_post_url(url).context("failed to parse post")?;
178
179 let original_post = self
180 .get_original_post(subreddit, post_id)
181 .await
182 .context("failed to get reddit post")
183 .map_err(|e| {
184 warn!("{:?}", e);
185 e
186 })?;
187
188 if !original_post.is_video {
189 return Ok(original_post.url.into());
190 }
191
192 self.create_video_url(url.as_str())
193 .await
194 .map(|url| url.into())
195 }
196
197 pub async fn try_embed_url(
199 &self,
200 ctx: &Context,
201 msg: &Message,
202 url: &Url,
203 loading_reaction: &mut Option<LoadingReaction>,
204 ) -> anyhow::Result<()> {
205 if let Some((subreddit, post_id)) = parse_post_url(url) {
208 let maybe_url = self
210 .cache
211 .get_if_fresh(&(subreddit.into(), post_id.into()))
212 .map(|el| el.data().clone());
213
214 let data = if let Some(value) = maybe_url.clone() {
215 Some(value)
216 } else {
217 self.get_embed_url(url).await.ok()
218 };
219
220 if let Some(data) = data {
221 self.cache
222 .insert((subreddit.into(), post_id.into()), data.clone());
223
224 msg.channel_id.say(&ctx.http, data).await?;
226 if let Some(mut loading_reaction) = loading_reaction.take() {
227 loading_reaction.send_ok();
228 }
229 }
230 } else {
231 error!("failed to parse reddit post url");
232 }
234 Ok(())
235 }
236
237 pub async fn get_random_post(&self, subreddit: &str) -> anyhow::Result<Option<String>> {
239 {
240 let urls = self.random_post_cache.get(subreddit);
241
242 if let Some(link) = urls.and_then(|v| {
243 let entry = v.value().clone();
244 if entry.0.elapsed() > Duration::from_secs(10 * 60) {
245 return None;
246 }
247 entry.1.choose(&mut rand::thread_rng()).cloned()
248 }) {
249 let url = self.reddit_link_to_embed_url(&link).await?;
250 return Ok(Some(url));
251 }
252 }
253
254 info!("fetching reddit posts for '{}'", subreddit);
255 let mut maybe_url = None;
256 let list = self.reddit_client.get_subreddit(subreddit, 100).await?;
257 if let Some(listing) = list.data.into_listing() {
258 let posts: Vec<Arc<reddit::Link>> = listing
259 .children
260 .into_iter()
261 .filter_map(|child| child.data.into_link())
262 .filter_map(|post| {
263 if let Some(mut post) = post.crosspost_parent_list {
264 if post.is_empty() {
265 None
266 } else {
267 Some(post.swap_remove(0).into())
268 }
269 } else {
270 Some(post)
271 }
272 })
273 .map(|link| Arc::new(*link))
274 .collect();
275
276 let maybe_link = posts.choose(&mut rand::thread_rng()).cloned();
277 if let Some(link) = maybe_link {
278 maybe_url = Some(self.reddit_link_to_embed_url(&link).await?);
279 }
280
281 self.random_post_cache
282 .insert(subreddit.to_string(), Arc::new((Instant::now(), posts)));
283 }
284
285 Ok(maybe_url)
286 }
287
288 async fn reddit_link_to_embed_url(&self, link: &reddit::Link) -> anyhow::Result<String> {
290 let post_url = format!("https://www.reddit.com{}", link.permalink);
291
292 if !link.over_18 {
294 return Ok(post_url);
295 }
296
297 match link.post_hint {
298 Some(reddit::PostHint::HostedVideo) => {
299 let url = self.create_video_url(&post_url).await?;
300 Ok(url.into())
301 }
302 _ => Ok(link.url.clone().into()),
303 }
304 }
305}
306
307impl CacheStatsProvider for RedditEmbedData {
308 fn publish_cache_stats(&self, cache_stats_builder: &mut CacheStatsBuilder) {
309 cache_stats_builder.publish_stat("reddit_embed", "link_cache", self.cache.len() as f32);
310 cache_stats_builder.publish_stat(
311 "reddit_embed",
312 "video_data_cache",
313 self.video_data_cache.len() as f32,
314 );
315 cache_stats_builder.publish_stat(
316 "reddit_embed",
317 "random_post_cache",
318 self.random_post_cache
319 .iter()
320 .map(|v| v.value().1.len())
321 .sum::<usize>() as f32,
322 );
323 }
324}
325
326impl std::fmt::Debug for RedditEmbedData {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 f.debug_struct("RedditEmbedData")
330 .field("reddit_tube_client", &self.reddit_tube_client)
331 .field("cache", &self.cache)
332 .finish()
333 }
334}
335
336impl Default for RedditEmbedData {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[command("reddit-embed")]
346#[description("Enable automaitc reddit embedding for this server")]
347#[usage("<enable/disable>")]
348#[example("enable")]
349#[min_args(1)]
350#[max_args(1)]
351#[checks(Admin, Enabled)]
352async fn reddit_embed(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
353 let data_lock = ctx.data.read().await;
354 let client_data = data_lock.get::<ClientDataKey>().unwrap();
355 let db = client_data.db.clone();
356 drop(data_lock);
357
358 let enable = match args.trimmed().current().expect("missing arg") {
359 "enable" => true,
360 "disable" => false,
361 arg => {
362 msg.channel_id
363 .say(
364 &ctx.http,
365 format!(
366 "The argument '{}' is not recognized. Valid: enable, disable",
367 arg
368 ),
369 )
370 .await?;
371 return Ok(());
372 }
373 };
374
375 let guild_id = match msg.guild_id {
377 Some(id) => id,
378 None => {
379 msg.channel_id
380 .say(
381 &ctx.http,
382 "Missing server id. Are you in a server right now?",
383 )
384 .await?;
385 return Ok(());
386 }
387 };
388
389 let old_val = db.set_reddit_embed_enabled(guild_id, enable).await?;
390
391 let status_str = if enable { "enabled" } else { "disabled" };
392
393 if enable == old_val {
394 msg.channel_id
395 .say(
396 &ctx.http,
397 format!("Reddit embeds are already {} for this server", status_str),
398 )
399 .await?;
400 } else {
401 msg.channel_id
402 .say(
403 &ctx.http,
404 format!("Reddit embeds are now {} for this guild", status_str),
405 )
406 .await?;
407 }
408
409 Ok(())
410}
411
412pub fn parse_post_url(url: &Url) -> Option<(&str, &str)> {
417 let mut iter = url.path_segments()?;
432
433 if iter.next()? != "r" {
434 return None;
435 }
436
437 let subreddit = iter.next()?;
438
439 if iter.next()? != "comments" {
440 return None;
441 }
442
443 let post_id = iter.next()?;
444
445 Some((subreddit, post_id))
448}