pikadick_util/
request_map.rs

1use futures::future::{
2    BoxFuture,
3    FutureExt,
4    Shared,
5};
6use std::{
7    collections::{
8        hash_map::Entry,
9        HashMap,
10    },
11    fmt::Debug,
12    future::Future,
13    hash::Hash,
14};
15
16/// A type to prevent two async requests from racing the same resource.
17#[derive(Debug)]
18pub struct RequestMap<K, V> {
19    map: std::sync::Mutex<HashMap<K, Shared<BoxFuture<'static, V>>>>,
20}
21
22impl<K, V> RequestMap<K, V> {
23    /// Make a new [`RequestMap`].
24    pub fn new() -> Self {
25        Self {
26            map: std::sync::Mutex::new(HashMap::new()),
27        }
28    }
29}
30
31impl<K, V> RequestMap<K, V>
32where
33    K: Eq + Hash + Clone + Debug,
34    V: Clone,
35{
36    /// Lock the key if it is missing, or run a future to fetch the resource.
37    pub async fn get_or_fetch<FN, F>(&self, key: K, fetch_future_func: FN) -> V
38    where
39        FN: FnOnce() -> F,
40        F: Future<Output = V> + Send + 'static,
41    {
42        let (_maybe_guard, shared_future) = {
43            // Lock the map
44            let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
45
46            // Get the entry
47            match map.entry(key.clone()) {
48                Entry::Occupied(entry) => {
49                    // A request is already in progress.
50
51                    // Grab the response future and await it.
52                    // Don't return a drop guard; only the task that started the request is allowed to clean it up.
53                    (None, entry.get().clone())
54                }
55                Entry::Vacant(entry) => {
56                    // A request is not in progress.
57
58                    // First, make the future.
59                    let fetch_future = fetch_future_func();
60
61                    // Then, make that future sharable.
62                    let shared_future = fetch_future.boxed().shared();
63
64                    // Then, store a copy in the hashmap for others interested in this value.
65                    entry.insert(shared_future.clone());
66
67                    // Then, register a drop guard since we own this request,
68                    // and are therefore responsible for cleaning it up.
69                    let drop_guard = RequestMapDropGuard { key, map: self };
70
71                    // Finally, return the future so we can await it in the next step.
72                    (Some(drop_guard), shared_future)
73                }
74            }
75        };
76
77        // Await the future.
78        // It may actually be driven by another task,
79        // but we share the results.
80        // If we are driving the request,
81        // we clean up our entry in the hashmap with the drop guard.
82        shared_future.await
83    }
84}
85
86impl<K, V> Default for RequestMap<K, V> {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92/// This will remove an entry from the request map when it gets dropped.
93struct RequestMapDropGuard<'a, K, V>
94where
95    K: Eq + Hash + Debug,
96{
97    key: K,
98    map: &'a RequestMap<K, V>,
99}
100
101impl<K, V> Drop for RequestMapDropGuard<'_, K, V>
102where
103    K: Eq + Hash + Debug,
104{
105    fn drop(&mut self) {
106        // Remove the key from the request map as we are done downloading it.
107        if self
108            .map
109            .map
110            .lock()
111            .unwrap_or_else(|e| e.into_inner())
112            .remove(&self.key)
113            .is_none()
114        {
115            panic!("key `{:?}` was unexpectedly cleaned up", self.key);
116        }
117    }
118}