async_rusqlite/
database.rs1use crate::{
2 BoxedError,
3 DbThreadJoinHandle,
4 Error,
5 SyncWrapper,
6};
7use rusqlite::Connection;
8use std::{
9 panic::AssertUnwindSafe,
10 path::{
11 Path,
12 PathBuf,
13 },
14 sync::Arc,
15};
16
17const MESSAGE_CHANNEL_SIZE: usize = 128;
18
19enum Message {
20 Access {
21 func: Box<dyn FnOnce(&mut Connection) + Send + 'static>,
22 },
23 Close {
24 closed: tokio::sync::oneshot::Sender<()>,
25 },
26}
27
28impl std::fmt::Debug for Message {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 Self::Access { .. } => write!(f, "Access"),
32 Self::Close { .. } => write!(f, "Close"),
33 }
34 }
35}
36
37#[derive(Clone)]
39pub struct Database {
40 sender: tokio::sync::mpsc::Sender<Message>,
41
42 handle: Arc<std::sync::Mutex<Option<DbThreadJoinHandle>>>,
43}
44
45impl std::fmt::Debug for Database {
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 f.debug_struct("Database").finish()
49 }
50}
51
52impl Database {
53 pub async fn open<P, S>(path: P, create_if_missing: bool, setup_func: S) -> Result<Self, Error>
55 where
56 P: Into<PathBuf>,
57 S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
58 {
59 let path = path.into();
60 tokio::task::spawn_blocking(move || {
61 Self::blocking_open(path, create_if_missing, setup_func)
62 })
63 .await?
64 }
65
66 pub fn blocking_open<P, S>(
68 path: P,
69 create_if_missing: bool,
70 mut setup_func: S,
71 ) -> Result<Self, Error>
72 where
73 P: AsRef<Path>,
74 S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
75 {
76 let mut flags = rusqlite::OpenFlags::default();
78 if !create_if_missing {
79 flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE)
80 }
81
82 let mut db = Connection::open_with_flags(path, flags)?;
84
85 setup_func(&mut db).map_err(Error::SetupFunc)?;
87
88 let (sender, mut rx) = tokio::sync::mpsc::channel(MESSAGE_CHANNEL_SIZE);
90
91 let handle = std::thread::spawn(move || {
93 while let Some(msg) = rx.blocking_recv() {
94 match msg {
95 Message::Access { func } => {
96 func(&mut db);
97 }
98 Message::Close { closed } => {
99 rx.close();
100
101 let _ = closed.send(()).is_ok();
103 }
104 }
105 }
106
107 db.close()
109 });
110 let handle = Arc::new(std::sync::Mutex::new(Some(handle)));
111
112 Ok(Self { sender, handle })
113 }
114
115 pub async fn access_db<F, T>(&self, func: F) -> Result<T, Error>
117 where
118 F: FnOnce(&mut Connection) -> T + Send + 'static,
119 T: Send + 'static,
120 {
121 let (tx, rx) = tokio::sync::oneshot::channel();
122 self.sender
123 .send(Message::Access {
124 func: Box::new(move |db| {
125 let result = std::panic::catch_unwind(AssertUnwindSafe(|| func(db)));
126 let _ = tx.send(result).is_ok();
127 }),
128 })
129 .await
130 .map_err(|_| Error::SendMessage)?;
131
132 rx.await
133 .map_err(Error::MissingResponse)?
134 .map_err(|e| Error::AccessPanicked(SyncWrapper::new(e)))
135 }
136
137 pub async fn close(&self) -> Result<(), Error> {
142 let (closed, rx) = tokio::sync::oneshot::channel();
143 self.sender
144 .send(Message::Close { closed })
145 .await
146 .map_err(|_| Error::SendMessage)?;
147 rx.await.map_err(Error::MissingResponse)
148 }
149
150 pub async fn join(&self) -> Result<(), Error> {
156 let handle = self.handle.clone();
158 let result = tokio::task::spawn_blocking(move || {
159 handle
160 .lock()
161 .unwrap_or_else(|e| e.into_inner())
162 .take()
163 .ok_or(Error::AlreadyJoined)?
164 .join()
165 .map_err(|e| Error::ThreadJoin(SyncWrapper::new(e)))
166 })
167 .await??;
168 if let Err((_connection, error)) = result {
169 return Err(Error::from(error));
170 }
171 Ok(())
172 }
173}