async_rusqlite/
database.rs

1use crate::{
2    BoxedError,
3    DbThreadJoinHandle,
4    Error,
5    SyncWrapper,
6};
7use rusqlite::Connection;
8use std::{
9    panic::AssertUnwindSafe,
10    path::{
11        Path,
12        PathBuf,
13    },
14    sync::Arc,
15};
16
17const MESSAGE_CHANNEL_SIZE: usize = 128;
18
19enum Message {
20    Access {
21        func: Box<dyn FnOnce(&mut Connection) + Send + 'static>,
22    },
23    Close {
24        closed: tokio::sync::oneshot::Sender<()>,
25    },
26}
27
28impl std::fmt::Debug for Message {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            Self::Access { .. } => write!(f, "Access"),
32            Self::Close { .. } => write!(f, "Close"),
33        }
34    }
35}
36
37/// A database connection
38#[derive(Clone)]
39pub struct Database {
40    sender: tokio::sync::mpsc::Sender<Message>,
41
42    handle: Arc<std::sync::Mutex<Option<DbThreadJoinHandle>>>,
43}
44
45impl std::fmt::Debug for Database {
46    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47        // TODO: Add more data
48        f.debug_struct("Database").finish()
49    }
50}
51
52impl Database {
53    /// Open a database at the given path with the setup func.
54    pub async fn open<P, S>(path: P, create_if_missing: bool, setup_func: S) -> Result<Self, Error>
55    where
56        P: Into<PathBuf>,
57        S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
58    {
59        let path = path.into();
60        tokio::task::spawn_blocking(move || {
61            Self::blocking_open(path, create_if_missing, setup_func)
62        })
63        .await?
64    }
65
66    /// Open a db in a blocking manner.
67    pub fn blocking_open<P, S>(
68        path: P,
69        create_if_missing: bool,
70        mut setup_func: S,
71    ) -> Result<Self, Error>
72    where
73        P: AsRef<Path>,
74        S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
75    {
76        // Setup flags
77        let mut flags = rusqlite::OpenFlags::default();
78        if !create_if_missing {
79            flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE)
80        }
81
82        // Open db
83        let mut db = Connection::open_with_flags(path, flags)?;
84
85        // Init connection
86        setup_func(&mut db).map_err(Error::SetupFunc)?;
87
88        // Setup channel
89        let (sender, mut rx) = tokio::sync::mpsc::channel(MESSAGE_CHANNEL_SIZE);
90
91        // Start background handling thread
92        let handle = std::thread::spawn(move || {
93            while let Some(msg) = rx.blocking_recv() {
94                match msg {
95                    Message::Access { func } => {
96                        func(&mut db);
97                    }
98                    Message::Close { closed } => {
99                        rx.close();
100
101                        // We don't care if a send failed.
102                        let _ = closed.send(()).is_ok();
103                    }
104                }
105            }
106
107            // Try close db
108            db.close()
109        });
110        let handle = Arc::new(std::sync::Mutex::new(Some(handle)));
111
112        Ok(Self { sender, handle })
113    }
114
115    /// Access the database.
116    pub async fn access_db<F, T>(&self, func: F) -> Result<T, Error>
117    where
118        F: FnOnce(&mut Connection) -> T + Send + 'static,
119        T: Send + 'static,
120    {
121        let (tx, rx) = tokio::sync::oneshot::channel();
122        self.sender
123            .send(Message::Access {
124                func: Box::new(move |db| {
125                    let result = std::panic::catch_unwind(AssertUnwindSafe(|| func(db)));
126                    let _ = tx.send(result).is_ok();
127                }),
128            })
129            .await
130            .map_err(|_| Error::SendMessage)?;
131
132        rx.await
133            .map_err(Error::MissingResponse)?
134            .map_err(|e| Error::AccessPanicked(SyncWrapper::new(e)))
135    }
136
137    /// Close the db.
138    ///
139    /// Commands will be able to be queued until this future completes.
140    /// Then, all commands that come after will process, though new commands cannot be queued.
141    pub async fn close(&self) -> Result<(), Error> {
142        let (closed, rx) = tokio::sync::oneshot::channel();
143        self.sender
144            .send(Message::Close { closed })
145            .await
146            .map_err(|_| Error::SendMessage)?;
147        rx.await.map_err(Error::MissingResponse)
148    }
149
150    /// Join background thread.
151    ///    
152    /// This can only be called once.
153    /// Future calls will fail.
154    /// You should generally close the db connection before joining.
155    pub async fn join(&self) -> Result<(), Error> {
156        // Clone to allow user to retry if failed to spawn tokio task.
157        let handle = self.handle.clone();
158        let result = tokio::task::spawn_blocking(move || {
159            handle
160                .lock()
161                .unwrap_or_else(|e| e.into_inner())
162                .take()
163                .ok_or(Error::AlreadyJoined)?
164                .join()
165                .map_err(|e| Error::ThreadJoin(SyncWrapper::new(e)))
166        })
167        .await??;
168        if let Err((_connection, error)) = result {
169            return Err(Error::from(error));
170        }
171        Ok(())
172    }
173}