pikadick/commands/
tiktok_embed.rs

1use crate::{
2    client_data::{
3        CacheStatsBuilder,
4        CacheStatsProvider,
5    },
6    util::{
7        EncoderTask,
8        TimedCache,
9        TimedCacheEntry,
10    },
11    ClientDataKey,
12    LoadingReaction,
13    TikTokEmbedFlags,
14};
15use anyhow::{
16    ensure,
17    Context as _,
18};
19use camino::{
20    Utf8Path,
21    Utf8PathBuf,
22};
23use nd_util::{
24    ArcAnyhowError,
25    DropRemovePath,
26};
27use pikadick_util::RequestMap;
28use serenity::{
29    builder::{
30        CreateAttachment,
31        CreateEmbed,
32        CreateInteractionResponse,
33        CreateInteractionResponseMessage,
34        CreateMessage,
35    },
36    model::prelude::*,
37    prelude::*,
38};
39use std::sync::Arc;
40use tokio_stream::StreamExt;
41use tracing::{
42    info,
43    warn,
44};
45use url::Url;
46
47const FILE_SIZE_LIMIT_BYTES: u64 = 8 * 1024 * 1024;
48const TARGET_FILE_SIZE_BYTES: u64 = 7 * 1024 * 1024;
49const ENCODER_PREFERENCE_LIST: &[&str] = &[
50    "h264_nvenc",
51    "h264_amf",
52    "h264_qsv",
53    "h264_mf",
54    "h264_v4l2m2m",
55    "h264_vaapi",
56    "h264_omx",
57    "libx264",
58    "libx264rgb",
59];
60
61type VideoDownloadRequestMap = Arc<RequestMap<String, Result<Arc<Utf8Path>, ArcAnyhowError>>>;
62
63/// Calculate the target bitrate.
64///
65/// target_size is in kilobits.
66/// target_duration is in seconds.
67/// the bitrate is in kilobits
68fn calc_target_bitrate(target_size: u64, duration: u64) -> u64 {
69    // https://stackoverflow.com/questions/29082422/ffmpeg-video-compression-specific-file-size
70
71    target_size / duration
72}
73
74/// TikTok Data
75#[derive(Debug, Clone)]
76pub struct TikTokData {
77    /// The inner client
78    client: tiktok::Client,
79
80    /// The encoder task
81    encoder_task: EncoderTask,
82
83    /// A cache of post urls => post pages
84    pub post_page_cache: TimedCache<String, tiktok::Post>,
85
86    /// The path to tiktok's cache dir
87    video_download_cache_path: Utf8PathBuf,
88
89    /// The request map for making requests for video downloads.
90    video_download_request_map: VideoDownloadRequestMap,
91
92    video_encoder: &'static str,
93}
94
95impl TikTokData {
96    /// Make a new [`TikTokData`].
97    pub async fn new<P>(cache_dir: P, encoder_task: EncoderTask) -> anyhow::Result<Self>
98    where
99        P: AsRef<Utf8Path>,
100    {
101        let cache_dir = cache_dir.as_ref();
102        let video_download_cache_path = cache_dir.join("tiktok");
103
104        // TODO: Expand into proper filecache manager
105        tokio::fs::create_dir_all(&video_download_cache_path)
106            .await
107            .context("failed to create tiktok cache dir")?;
108
109        let mut encoders = encoder_task
110            .get_encoders(true)
111            .await
112            .context("failed to get encoders")?;
113
114        // Keep only h264 encoders
115        encoders.retain(|encoder| encoder.description.ends_with("(codec h264)"));
116        info!("found h264 encoders: {encoders:#?}");
117
118        let mut best_encoder_index = None;
119        for encoder in encoders {
120            if let Some(index) = ENCODER_PREFERENCE_LIST
121                .iter()
122                .position(|name| **name == *encoder.name)
123            {
124                if best_encoder_index.is_none_or(|best_encoder_index| best_encoder_index > index) {
125                    best_encoder_index = Some(index);
126                }
127            }
128        }
129
130        let best_encoder_index = best_encoder_index.context("failed to select an encoder")?;
131        let best_encoder = ENCODER_PREFERENCE_LIST[best_encoder_index];
132
133        info!("selected encoder \"{best_encoder}\"");
134
135        Ok(Self {
136            client: tiktok::Client::new(),
137
138            encoder_task,
139
140            post_page_cache: TimedCache::new(),
141
142            video_download_cache_path,
143            video_download_request_map: Arc::new(RequestMap::new()),
144            video_encoder: best_encoder,
145        })
146    }
147
148    /// Get a post page, using the cache if needed
149    pub async fn get_post_cached(
150        &self,
151        url: &str,
152    ) -> anyhow::Result<Arc<TimedCacheEntry<tiktok::Post>>> {
153        if let Some(post_page) = self.post_page_cache.get_if_fresh(url) {
154            return Ok(post_page);
155        }
156
157        let video_id = Url::parse(url)?
158            .path_segments()
159            .context("missing path")?
160            .next_back()
161            .context("missing video id")?
162            .parse()
163            .context("invalid video id")?;
164
165        let mut feed = self
166            .client
167            .get_feed(Some(video_id))
168            .await
169            .context("failed to get feed")?;
170        ensure!(!feed.aweme_list.is_empty(), "missing post");
171
172        let post = feed.aweme_list.swap_remove(0);
173        ensure!(post.aweme_id == video_id);
174
175        Ok(self.post_page_cache.insert_and_get(url.to_string(), post))
176    }
177
178    /// Get video data, using the cache if needed
179    pub async fn get_video_data_cached(
180        &self,
181        id: u64,
182        format: &str,
183        url: &str,
184        video_duration: u64,
185    ) -> anyhow::Result<Arc<Utf8Path>> {
186        self.video_download_request_map
187            .get_or_fetch(id.to_string(), || {
188                let client = self.client.client.clone();
189
190                let encoder_task = self.encoder_task.clone();
191
192                let reencoded_file_name = format!("{id}-reencoded.mp4");
193                let reencoded_file_path = self.video_download_cache_path.join(reencoded_file_name);
194
195                let file_name = format!("{id}.{format}");
196                let file_path = self.video_download_cache_path.join(file_name);
197
198                let id = id.to_string();
199                let format = format.to_string();
200                let url = url.to_string();
201
202                let video_encoder = self.video_encoder;
203
204                async move {
205                    match tokio::fs::metadata(&reencoded_file_path).await {
206                        Ok(_metadata) => {
207                            // The reencoded file is present. Use it.
208                            return Ok(Arc::from(reencoded_file_path));
209                        }
210                        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
211                            // The transcoded file is not present.
212                            // Attempt to use the original file by passing through.
213                        }
214                        Err(e) => {
215                            return Err(e)
216                                .context("failed to get metadata of re-encoded file")
217                                .map_err(ArcAnyhowError::new);
218                        }
219                    };
220
221                    // Get the metadata of the raw file.
222                    // Download it if needed.
223                    let metadata = match tokio::fs::metadata(&file_path).await {
224                        Ok(metadata) => {
225                            // The reencoded file is present.
226                            // Return the metadata to validate its size.
227                            metadata
228                        }
229                        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
230                            // File not present. Download it.
231
232                            info!(
233                                "downloading tiktok video \
234                                with with id `{id}` \
235                                from url `{url}` \
236                                with format `{format}`"
237                            );
238
239                            let result = async {
240                                nd_util::download_to_path(&client, &url, &file_path).await?;
241                                tokio::fs::metadata(&file_path)
242                                    .await
243                                    .context("failed to get file metadata")
244                            }
245                            .await;
246
247                            result.map_err(ArcAnyhowError::new)?
248                        }
249                        Err(e) => {
250                            return Err(e)
251                                .context("failed to get metadata of file")
252                                .map_err(ArcAnyhowError::new);
253                        }
254                    };
255
256                    // If the file is greater than 8mb, we need to reencode it
257                    if metadata.len() > FILE_SIZE_LIMIT_BYTES {
258                        let result = async {
259                            // We target half of the maximum size to give ourselves some lee-way.
260                            // This merely sets the target bit-rate, and we don't take into account audio size.
261                            let target_bitrate = calc_target_bitrate(
262                                (TARGET_FILE_SIZE_BYTES / 1024) * 8 / 2,
263                                video_duration,
264                            );
265                            let reencoded_file_path_tmp_1 = DropRemovePath::new(
266                                nd_util::with_push_extension(&reencoded_file_path, "1.tmp"),
267                            );
268
269                            info!(
270                                "re-encoding tiktok video `{}` to `{}` \
271                                @ video bitrate {}",
272                                file_path,
273                                reencoded_file_path_tmp_1.display(),
274                                target_bitrate
275                            );
276
277                            {
278                                let mut stream = encoder_task
279                                    .encode()
280                                    .input(&file_path)
281                                    .output(&*reencoded_file_path_tmp_1)
282                                    .audio_codec("copy")
283                                    .video_codec(video_encoder)
284                                    .video_bitrate(format!("{target_bitrate}K"))
285                                    .output_format("mp4")
286                                    .try_send()
287                                    .await
288                                    .context("failed to start re-encoding")?;
289
290                                let mut maybe_exit_status = None;
291                                while let Some(msg) = stream.next().await {
292                                    match msg.context("ffmpeg stream error") {
293                                        Ok(tokio_ffmpeg_cli::Event::ExitStatus(exit_status)) => {
294                                            maybe_exit_status = Some(exit_status);
295                                        }
296                                        Ok(tokio_ffmpeg_cli::Event::Progress(_progress)) => {
297                                            // For now, we don't care about progress as there is no way to report it to the user on discord.
298                                        }
299                                        Ok(tokio_ffmpeg_cli::Event::Unknown(_line)) => {
300                                            // warn!("unknown ffmpeg line: `{}`", line);
301                                            // We don't care about unkown lines
302                                        }
303                                        Err(error) => {
304                                            warn!("{error:?}");
305                                        }
306                                    }
307                                }
308
309                                let exit_status = maybe_exit_status
310                                    .context("stream did not report an exit status")?;
311
312                                // Validate exit status
313                                ensure!(exit_status.success(), "invalid exit status");
314                            }
315
316                            // The RPI's ffmpeg produces invalid mp4 files.
317                            // Until we can investigate and fix, transcode the file to try to let ffmpeg fix it.
318                            let reencoded_file_path_tmp_2 = DropRemovePath::new(
319                                nd_util::with_push_extension(&reencoded_file_path, "2.tmp"),
320                            );
321
322                            {
323                                let mut stream = encoder_task
324                                    .encode()
325                                    .input(&*reencoded_file_path_tmp_1)
326                                    .output(&*reencoded_file_path_tmp_2)
327                                    .audio_codec("copy")
328                                    .video_codec("copy")
329                                    .output_format("mp4")
330                                    .try_send()
331                                    .await
332                                    .context("failed to start transcoding")?;
333
334                                let mut maybe_exit_status = None;
335                                while let Some(msg) = stream.next().await {
336                                    match msg.context("ffmpeg stream error") {
337                                        Ok(tokio_ffmpeg_cli::Event::ExitStatus(exit_status)) => {
338                                            maybe_exit_status = Some(exit_status);
339                                        }
340                                        Ok(tokio_ffmpeg_cli::Event::Progress(_progress)) => {
341                                            // For now, we don't care about progress as there is no way to report it to the user on discord.
342                                        }
343                                        Ok(tokio_ffmpeg_cli::Event::Unknown(_line)) => {
344                                            // warn!("unknown ffmpeg line: `{}`", line);
345                                            // We don't care about unkown lines
346                                        }
347                                        Err(error) => {
348                                            warn!("{error:?}");
349                                        }
350                                    }
351                                }
352
353                                let exit_status = maybe_exit_status
354                                    .context("stream did not report an exit status")?;
355
356                                // Validate exit status
357                                ensure!(exit_status.success(), "invalid exit status");
358                            }
359
360                            let mut reencoded_file_path_tmp = reencoded_file_path_tmp_2;
361
362                            // Validate file size
363                            let metadata = tokio::fs::metadata(&reencoded_file_path_tmp)
364                                .await
365                                .context("failed to get metadata of encoded file")?;
366                            let metadata_len = metadata.len();
367                            ensure!(
368                                metadata_len < FILE_SIZE_LIMIT_BYTES,
369                                "re-encoded file size ({metadata_len}) is larger than the limit {FILE_SIZE_LIMIT_BYTES}",
370                            );
371
372                            // Rename the tmp file to be the actual name.
373                            tokio::fs::rename(&*reencoded_file_path_tmp, &reencoded_file_path)
374                                .await
375                                .context("failed to rename temp file")?;
376
377                            // "Persist" the tmp file, as in don't try to remove it
378                            reencoded_file_path_tmp.persist();
379
380                            Ok(())
381                        }
382                        .await;
383
384                        result.map_err(ArcAnyhowError::new)?;
385
386                        Ok(Arc::from(reencoded_file_path))
387                    } else {
388                        Ok(Arc::from(file_path))
389                    }
390                }
391            })
392            .await
393            .map_err(From::from)
394    }
395
396    /// Try embedding a url
397    pub async fn try_embed_url(
398        &self,
399        ctx: &Context,
400        msg: &Message,
401        url: &Url,
402        loading_reaction: &mut Option<LoadingReaction>,
403        delete_link: bool,
404    ) -> anyhow::Result<()> {
405        let (video_url, video_id, video_format, video_duration) = {
406            let post = self.get_post_cached(url.as_str()).await?;
407            let post = post.data();
408
409            let video_url = post
410                .video
411                .download_addr
412                .url_list
413                .first()
414                .context("missing video url")?
415                .clone();
416            let video_id: u64 = post.aweme_id;
417            // let video_format = post.video.format.clone();
418            // TODO: Can this ever NOT be an mp4?
419            let video_format = String::from("mp4");
420            let video_duration = post.video.duration;
421
422            (video_url, video_id, video_format, video_duration)
423        };
424
425        let video_path = self
426            .get_video_data_cached(
427                video_id,
428                video_format.as_str(),
429                video_url.as_str(),
430                video_duration,
431            )
432            .await
433            .context("failed to download tiktok video")?;
434
435        let file = CreateAttachment::path(video_path.as_std_path()).await?;
436        let message_builder = CreateMessage::new().add_file(file);
437        msg.channel_id
438            .send_message(&ctx.http, message_builder)
439            .await?;
440
441        if let Some(mut loading_reaction) = loading_reaction.take() {
442            loading_reaction.send_ok();
443
444            if delete_link {
445                msg.delete(&ctx.http)
446                    .await
447                    .context("failed to delete original message")?;
448            }
449        }
450
451        Ok(())
452    }
453}
454
455impl CacheStatsProvider for TikTokData {
456    fn publish_cache_stats(&self, cache_stats_builder: &mut CacheStatsBuilder) {
457        cache_stats_builder.publish_stat(
458            "tiktok_data",
459            "post_page_cache",
460            self.post_page_cache.len() as f32,
461        );
462    }
463}
464
465/// Options for tiktok-embed
466#[derive(Debug, pikadick_slash_framework::FromOptions)]
467struct TikTokEmbedOptions {
468    /// Whether embeds should be enabled for this server
469    #[pikadick_slash_framework(description = "Whether embeds should be enabled for this server")]
470    enable: Option<bool>,
471
472    /// Whether source messages should be deleted
473    #[pikadick_slash_framework(
474        rename = "delete-link",
475        description = "Whether source messages should be deleted"
476    )]
477    delete_link: Option<bool>,
478}
479
480/// Create a slash command
481pub fn create_slash_command() -> anyhow::Result<pikadick_slash_framework::Command> {
482    use pikadick_slash_framework::FromOptions;
483
484    pikadick_slash_framework::CommandBuilder::new()
485        .name("tiktok-embed")
486        .description("Configure tiktok embeds for this server")
487        .check(crate::checks::admin::create_slash_check)
488        .arguments(TikTokEmbedOptions::get_argument_params()?.into_iter())
489        .on_process(|ctx, interaction, args: TikTokEmbedOptions| async move {
490            let data_lock = ctx.data.read().await;
491            let client_data = data_lock.get::<ClientDataKey>().unwrap();
492            let db = client_data.db.clone();
493            drop(data_lock);
494
495            let guild_id = match interaction.guild_id {
496                Some(id) => id,
497                None => {
498                    let message_builder = CreateInteractionResponseMessage::new()
499                        .content("Missing server id. Are you in a server right now?");
500                    let response = CreateInteractionResponse::Message(message_builder);
501                    interaction.create_response(&ctx.http, response).await?;
502                    return Ok(());
503                }
504            };
505
506            let mut set_flags = TikTokEmbedFlags::empty();
507            let mut unset_flags = TikTokEmbedFlags::empty();
508
509            if let Some(enable) = args.enable {
510                if enable {
511                    set_flags.insert(TikTokEmbedFlags::ENABLED);
512                } else {
513                    unset_flags.insert(TikTokEmbedFlags::ENABLED);
514                }
515            }
516
517            if let Some(enable) = args.delete_link {
518                if enable {
519                    set_flags.insert(TikTokEmbedFlags::DELETE_LINK);
520                } else {
521                    unset_flags.insert(TikTokEmbedFlags::DELETE_LINK);
522                }
523            }
524
525            let (_old_flags, new_flags) = db
526                .set_tiktok_embed_flags(guild_id, set_flags, unset_flags)
527                .await?;
528
529            let embed_builder = CreateEmbed::new()
530                .title("TikTok Embeds")
531                .field(
532                    "Enabled?",
533                    bool_to_str(new_flags.contains(TikTokEmbedFlags::ENABLED)),
534                    false,
535                )
536                .field(
537                    "Delete link?",
538                    bool_to_str(new_flags.contains(TikTokEmbedFlags::DELETE_LINK)),
539                    false,
540                );
541            let message_builder = CreateInteractionResponseMessage::new().embed(embed_builder);
542            let response = CreateInteractionResponse::Message(message_builder);
543            interaction.create_response(&ctx.http, response).await?;
544
545            Ok(())
546        })
547        .build()
548        .context("failed to build command")
549}
550
551/// Convert a bool to a str
552fn bool_to_str(value: bool) -> &'static str {
553    if value {
554        "True"
555    } else {
556        "False"
557    }
558}