pikadick_util/request_map.rs
1use futures::future::{
2 BoxFuture,
3 FutureExt,
4 Shared,
5};
6use std::{
7 collections::{
8 hash_map::Entry,
9 HashMap,
10 },
11 fmt::Debug,
12 future::Future,
13 hash::Hash,
14};
15
16/// A type to prevent two async requests from racing the same resource.
17#[derive(Debug)]
18pub struct RequestMap<K, V> {
19 map: std::sync::Mutex<HashMap<K, Shared<BoxFuture<'static, V>>>>,
20}
21
22impl<K, V> RequestMap<K, V> {
23 /// Make a new [`RequestMap`].
24 pub fn new() -> Self {
25 Self {
26 map: std::sync::Mutex::new(HashMap::new()),
27 }
28 }
29}
30
31impl<K, V> RequestMap<K, V>
32where
33 K: Eq + Hash + Clone + Debug,
34 V: Clone,
35{
36 /// Lock the key if it is missing, or run a future to fetch the resource.
37 pub async fn get_or_fetch<FN, F>(&self, key: K, fetch_future_func: FN) -> V
38 where
39 FN: FnOnce() -> F,
40 F: Future<Output = V> + Send + 'static,
41 {
42 let (_maybe_guard, shared_future) = {
43 // Lock the map
44 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
45
46 // Get the entry
47 match map.entry(key.clone()) {
48 Entry::Occupied(entry) => {
49 // A request is already in progress.
50
51 // Grab the response future and await it.
52 // Don't return a drop guard; only the task that started the request is allowed to clean it up.
53 (None, entry.get().clone())
54 }
55 Entry::Vacant(entry) => {
56 // A request is not in progress.
57
58 // First, make the future.
59 let fetch_future = fetch_future_func();
60
61 // Then, make that future sharable.
62 let shared_future = fetch_future.boxed().shared();
63
64 // Then, store a copy in the hashmap for others interested in this value.
65 entry.insert(shared_future.clone());
66
67 // Then, register a drop guard since we own this request,
68 // and are therefore responsible for cleaning it up.
69 let drop_guard = RequestMapDropGuard { key, map: self };
70
71 // Finally, return the future so we can await it in the next step.
72 (Some(drop_guard), shared_future)
73 }
74 }
75 };
76
77 // Await the future.
78 // It may actually be driven by another task,
79 // but we share the results.
80 // If we are driving the request,
81 // we clean up our entry in the hashmap with the drop guard.
82 shared_future.await
83 }
84}
85
86impl<K, V> Default for RequestMap<K, V> {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92/// This will remove an entry from the request map when it gets dropped.
93struct RequestMapDropGuard<'a, K, V>
94where
95 K: Eq + Hash + Debug,
96{
97 key: K,
98 map: &'a RequestMap<K, V>,
99}
100
101impl<K, V> Drop for RequestMapDropGuard<'_, K, V>
102where
103 K: Eq + Hash + Debug,
104{
105 fn drop(&mut self) {
106 // Remove the key from the request map as we are done downloading it.
107 if self
108 .map
109 .map
110 .lock()
111 .unwrap_or_else(|e| e.into_inner())
112 .remove(&self.key)
113 .is_none()
114 {
115 panic!("key `{:?}` was unexpectedly cleaned up", self.key);
116 }
117 }
118}