use crate::{
BoxedError,
DbThreadJoinHandle,
Error,
SyncWrapper,
};
use rusqlite::Connection;
use std::{
panic::AssertUnwindSafe,
path::{
Path,
PathBuf,
},
sync::Arc,
};
const MESSAGE_CHANNEL_SIZE: usize = 128;
enum Message {
Access {
func: Box<dyn FnOnce(&mut Connection) + Send + 'static>,
},
Close {
closed: tokio::sync::oneshot::Sender<()>,
},
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Access { .. } => write!(f, "Access"),
Self::Close { .. } => write!(f, "Close"),
}
}
}
#[derive(Clone)]
pub struct Database {
sender: tokio::sync::mpsc::Sender<Message>,
handle: Arc<std::sync::Mutex<Option<DbThreadJoinHandle>>>,
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Database").finish()
}
}
impl Database {
pub async fn open<P, S>(path: P, create_if_missing: bool, setup_func: S) -> Result<Self, Error>
where
P: Into<PathBuf>,
S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
{
let path = path.into();
tokio::task::spawn_blocking(move || {
Self::blocking_open(path, create_if_missing, setup_func)
})
.await?
}
pub fn blocking_open<P, S>(
path: P,
create_if_missing: bool,
mut setup_func: S,
) -> Result<Self, Error>
where
P: AsRef<Path>,
S: FnMut(&mut rusqlite::Connection) -> Result<(), BoxedError> + Send + 'static,
{
let mut flags = rusqlite::OpenFlags::default();
if !create_if_missing {
flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE)
}
let mut db = Connection::open_with_flags(path, flags)?;
setup_func(&mut db).map_err(Error::SetupFunc)?;
let (sender, mut rx) = tokio::sync::mpsc::channel(MESSAGE_CHANNEL_SIZE);
let handle = std::thread::spawn(move || {
while let Some(msg) = rx.blocking_recv() {
match msg {
Message::Access { func } => {
func(&mut db);
}
Message::Close { closed } => {
rx.close();
let _ = closed.send(()).is_ok();
}
}
}
db.close()
});
let handle = Arc::new(std::sync::Mutex::new(Some(handle)));
Ok(Self { sender, handle })
}
pub async fn access_db<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> T + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel();
self.sender
.send(Message::Access {
func: Box::new(move |db| {
let result = std::panic::catch_unwind(AssertUnwindSafe(|| func(db)));
let _ = tx.send(result).is_ok();
}),
})
.await
.map_err(|_| Error::SendMessage)?;
rx.await
.map_err(Error::MissingResponse)?
.map_err(|e| Error::AccessPanicked(SyncWrapper::new(e)))
}
pub async fn close(&self) -> Result<(), Error> {
let (closed, rx) = tokio::sync::oneshot::channel();
self.sender
.send(Message::Close { closed })
.await
.map_err(|_| Error::SendMessage)?;
rx.await.map_err(Error::MissingResponse)
}
pub async fn join(&self) -> Result<(), Error> {
let handle = self.handle.clone();
let result = tokio::task::spawn_blocking(move || {
handle
.lock()
.unwrap_or_else(|e| e.into_inner())
.take()
.ok_or(Error::AlreadyJoined)?
.join()
.map_err(|e| Error::ThreadJoin(SyncWrapper::new(e)))
})
.await??;
if let Err((_connection, error)) = result {
return Err(Error::from(error));
}
Ok(())
}
}