pikadick/
database.rs

1mod disabled_commands;
2mod kv_store;
3pub mod model;
4mod reddit_embed;
5mod tic_tac_toe;
6mod tiktok_embed;
7
8pub use self::tic_tac_toe::{
9    TicTacToeCreateGameError,
10    TicTacToeTryMoveError,
11    TicTacToeTryMoveResponse,
12};
13use anyhow::Context;
14use camino::{
15    Utf8Path,
16    Utf8PathBuf,
17};
18use once_cell::sync::Lazy;
19use std::{
20    os::raw::c_int,
21    sync::Arc,
22};
23use tracing::{
24    error,
25    warn,
26};
27
28// Setup
29const SETUP_TABLES_SQL: &str = include_str!("../sql/setup_tables.sql");
30
31static LOGGER_INIT: Lazy<Result<(), Arc<rusqlite::Error>>> = Lazy::new(|| {
32    // Safety:
33    // 1. `sqlite_logger_func` is threadsafe.
34    // 2. This is called only once.
35    // 3. This is called before any sqlite functions are used
36    // 4. sqlite functions cannot be used until the logger initializes.
37    unsafe { rusqlite::trace::config_log(Some(sqlite_logger_func)).map_err(Arc::new) }
38});
39
40fn sqlite_logger_func(error_code: c_int, msg: &str) {
41    warn!("sqlite error code ({}): {}", error_code, msg);
42}
43
44/// The database
45#[derive(Clone, Debug)]
46pub struct Database {
47    db: async_rusqlite::Database,
48}
49
50impl Database {
51    //// Make a new [`Database`].
52    ///
53    /// # Safety
54    /// This must be called before any other sqlite functions are called.
55    pub async unsafe fn new<P>(path: P, create_if_missing: bool) -> anyhow::Result<Self>
56    where
57        P: Into<Utf8PathBuf>,
58    {
59        let path = path.into();
60        tokio::task::spawn_blocking(move || Self::blocking_new(&path, create_if_missing))
61            .await
62            .context("failed to join tokio task")?
63    }
64
65    /// Make a new [`Database`] in a blocking manner.
66    ///
67    /// # Safety
68    /// This must be called before any other sqlite functions are called.
69    pub unsafe fn blocking_new<P>(path: P, create_if_missing: bool) -> anyhow::Result<Self>
70    where
71        P: AsRef<Utf8Path>,
72    {
73        LOGGER_INIT
74            .clone()
75            .context("failed to init sqlite logger")?;
76
77        let db = async_rusqlite::Database::blocking_open(path.as_ref(), create_if_missing, |db| {
78            db.execute_batch(SETUP_TABLES_SQL)
79                .context("failed to setup database")?;
80            Ok(())
81        })
82        .context("failed to open database")?;
83
84        Ok(Database { db })
85    }
86
87    /// Access the db
88    async fn access_db<F, R>(&self, func: F) -> anyhow::Result<R>
89    where
90        F: FnOnce(&mut rusqlite::Connection) -> R + Send + 'static,
91        R: Send + 'static,
92    {
93        Ok(self.db.access_db(move |db| func(db)).await?)
94    }
95
96    /// Close the db
97    pub async fn close(&self) -> anyhow::Result<()> {
98        // Failing to run shutdown commands is not critical and should not prevent shutdown.
99        if let Err(e) = self
100            .db
101            .access_db(|db| {
102                db.execute("PRAGMA OPTIMIZE;", [])?;
103                db.execute("VACUUM;", [])
104            })
105            .await
106            .context("failed to access db")
107            .and_then(|v| v.context("failed to execute shutdown commands"))
108        {
109            error!("{}", e);
110        }
111        self.db
112            .close()
113            .await
114            .context("failed to send close request to db")?;
115        self.db.join().await.context("failed to join db thread")?;
116
117        Ok(())
118    }
119}