1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
use futures::future::{
BoxFuture,
FutureExt,
Shared,
};
use std::{
collections::{
hash_map::Entry,
HashMap,
},
fmt::Debug,
future::Future,
hash::Hash,
};
/// A type to prevent two async requests from racing the same resource.
#[derive(Debug)]
pub struct RequestMap<K, V> {
map: std::sync::Mutex<HashMap<K, Shared<BoxFuture<'static, V>>>>,
}
impl<K, V> RequestMap<K, V> {
/// Make a new [`RequestMap`].
pub fn new() -> Self {
Self {
map: std::sync::Mutex::new(HashMap::new()),
}
}
}
impl<K, V> RequestMap<K, V>
where
K: Eq + Hash + Clone + Debug,
V: Clone,
{
/// Lock the key if it is missing, or run a future to fetch the resource.
pub async fn get_or_fetch<FN, F>(&self, key: K, fetch_future_func: FN) -> V
where
FN: FnOnce() -> F,
F: Future<Output = V> + Send + 'static,
{
let (_maybe_guard, shared_future) = {
// Lock the map
let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
// Get the entry
match map.entry(key.clone()) {
Entry::Occupied(entry) => {
// A request is already in progress.
// Grab the response future and await it.
// Don't return a drop guard; only the task that started the request is allowed to clean it up.
(None, entry.get().clone())
}
Entry::Vacant(entry) => {
// A request is not in progress.
// First, make the future.
let fetch_future = fetch_future_func();
// Then, make that future sharable.
let shared_future = fetch_future.boxed().shared();
// Then, store a copy in the hashmap for others interested in this value.
entry.insert(shared_future.clone());
// Then, register a drop guard since we own this request,
// and are therefore responsible for cleaning it up.
let drop_guard = RequestMapDropGuard { key, map: self };
// Finally, return the future so we can await it in the next step.
(Some(drop_guard), shared_future)
}
}
};
// Await the future.
// It may actually be driven by another task,
// but we share the results.
// If we are driving the request,
// we clean up our entry in the hashmap with the drop guard.
shared_future.await
}
}
impl<K, V> Default for RequestMap<K, V> {
fn default() -> Self {
Self::new()
}
}
/// This will remove an entry from the request map when it gets dropped.
struct RequestMapDropGuard<'a, K, V>
where
K: Eq + Hash + Debug,
{
key: K,
map: &'a RequestMap<K, V>,
}
impl<K, V> Drop for RequestMapDropGuard<'_, K, V>
where
K: Eq + Hash + Debug,
{
fn drop(&mut self) {
// Remove the key from the request map as we are done downloading it.
if self
.map
.map
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&self.key)
.is_none()
{
panic!("key `{:?}` was unexpectedly cleaned up", self.key);
}
}
}