tower_batch_control/service.rs
1//! Wrapper service for batching items to an underlying service.
2
3use std::{
4 cmp::max,
5 fmt,
6 future::Future,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{Context, Poll},
10};
11
12use futures_core::ready;
13use tokio::{
14 pin,
15 sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
16 task::JoinHandle,
17};
18use tokio_util::sync::PollSemaphore;
19use tower::Service;
20use tracing::{info_span, Instrument};
21
22use crate::RequestWeight;
23
24use super::{
25 future::ResponseFuture,
26 message::Message,
27 worker::{ErrorHandle, Worker},
28 BatchControl,
29};
30
31/// The maximum number of batches in the queue.
32///
33/// This avoids having very large queues on machines with hundreds or thousands of cores.
34pub const QUEUE_BATCH_LIMIT: usize = 64;
35
36/// Allows batch processing of requests.
37///
38/// See the crate documentation for more details.
39pub struct Batch<T, Request: RequestWeight>
40where
41 T: Service<BatchControl<Request>>,
42{
43 // Batch management
44 //
45 /// A custom-bounded channel for sending requests to the batch worker.
46 ///
47 /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
48 /// channel, we use tokio's semaphore separately to implement the bound.
49 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
50
51 /// A semaphore used to bound the channel.
52 ///
53 /// When the buffer's channel is full, we want to exert backpressure in
54 /// `poll_ready`, so that callers such as load balancers could choose to call
55 /// another service rather than waiting for buffer capacity.
56 ///
57 /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
58 /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
59 /// own bounded MPSC on top of the unbounded channel, using a semaphore to
60 /// limit how many items are in the channel.
61 semaphore: PollSemaphore,
62
63 /// A semaphore permit that allows this service to send one message on `tx`.
64 permit: Option<OwnedSemaphorePermit>,
65
66 // Errors
67 //
68 /// An error handle shared between all service clones for the same worker.
69 error_handle: ErrorHandle,
70
71 /// A worker task handle shared between all service clones for the same worker.
72 ///
73 /// Only used when the worker is spawned on the tokio runtime.
74 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
75}
76
77impl<T, Request: RequestWeight> fmt::Debug for Batch<T, Request>
78where
79 T: Service<BatchControl<Request>>,
80{
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 let name = std::any::type_name::<Self>();
83 f.debug_struct(name)
84 .field("tx", &self.tx)
85 .field("semaphore", &self.semaphore)
86 .field("permit", &self.permit)
87 .field("error_handle", &self.error_handle)
88 .field("worker_handle", &self.worker_handle)
89 .finish()
90 }
91}
92
93impl<T, Request: RequestWeight> Batch<T, Request>
94where
95 T: Service<BatchControl<Request>>,
96 T::Future: Send + 'static,
97 T::Error: Into<crate::BoxError>,
98{
99 /// Creates a new `Batch` wrapping `service`.
100 ///
101 /// The wrapper is responsible for telling the inner service when to flush a
102 /// batch of requests. These parameters control this policy:
103 ///
104 /// * `max_items_weight_in_batch` gives the maximum item weight per batch.
105 /// * `max_batches` is an upper bound on the number of batches in the queue,
106 /// and the number of concurrently executing batches.
107 /// If this is `None`, we use the current number of [`rayon`] threads.
108 /// The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
109 /// * `max_latency` gives the maximum latency for a batch item to start verifying.
110 ///
111 /// The default Tokio executor is used to run the given service, which means
112 /// that this method must be called while on the Tokio runtime.
113 pub fn new(
114 service: T,
115 max_items_weight_in_batch: usize,
116 max_batches: impl Into<Option<usize>>,
117 max_latency: std::time::Duration,
118 ) -> Self
119 where
120 T: Send + 'static,
121 T::Future: Send,
122 T::Response: Send,
123 T::Error: Send + Sync,
124 Request: Send + 'static,
125 {
126 let (mut batch, worker) =
127 Self::pair(service, max_items_weight_in_batch, max_batches, max_latency);
128
129 let span = info_span!("batch worker", kind = std::any::type_name::<T>());
130
131 #[cfg(tokio_unstable)]
132 let worker_handle = {
133 let batch_kind = std::any::type_name::<T>();
134
135 // TODO: identify the unique part of the type name generically,
136 // or make it an argument to this method
137 let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
138 let batch_kind = batch_kind.trim_end_matches("::Verifier");
139
140 tokio::task::Builder::new()
141 .name(&format!("{} batch", batch_kind))
142 .spawn(worker.run().instrument(span))
143 .expect("panic on error to match tokio::spawn")
144 };
145 #[cfg(not(tokio_unstable))]
146 let worker_handle = tokio::spawn(worker.run().instrument(span));
147
148 batch.register_worker(worker_handle);
149
150 batch
151 }
152
153 /// Creates a new `Batch` wrapping `service`, but returns the background worker.
154 ///
155 /// This is useful if you do not want to spawn directly onto the `tokio`
156 /// runtime but instead want to use your own executor. This will return the
157 /// `Batch` and the background `Worker` that you can then spawn.
158 pub fn pair(
159 service: T,
160 max_items_weight_in_batch: usize,
161 max_batches: impl Into<Option<usize>>,
162 max_latency: std::time::Duration,
163 ) -> (Self, Worker<T, Request>)
164 where
165 T: Send + 'static,
166 T::Error: Send + Sync,
167 Request: Send + 'static,
168 {
169 let (tx, rx) = mpsc::unbounded_channel();
170
171 // Clamp config to sensible values.
172 let max_items_weight_in_batch = max(max_items_weight_in_batch, 1);
173 let max_batches = max_batches
174 .into()
175 .unwrap_or_else(rayon::current_num_threads);
176 let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
177
178 // The semaphore bound limits the maximum number of concurrent requests
179 // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
180 // used their semaphore reservation in a `call` yet).
181 //
182 // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
183 // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
184 // Often there is only one verifier running, when that happens we want it to take all the cores.
185 //
186 // Requests with a request weight greater than 1 won't typically exhaust the number of available
187 // permits, but will still be bounded to the maximum possible number of concurrent requests.
188 let semaphore = Semaphore::new(max_items_weight_in_batch * max_batches_in_queue);
189 let semaphore = PollSemaphore::new(Arc::new(semaphore));
190
191 let (error_handle, worker) = Worker::new(
192 service,
193 rx,
194 max_items_weight_in_batch,
195 max_batches,
196 max_latency,
197 semaphore.clone(),
198 );
199
200 let batch = Batch {
201 tx,
202 semaphore,
203 permit: None,
204 error_handle,
205 worker_handle: Arc::new(Mutex::new(None)),
206 };
207
208 (batch, worker)
209 }
210
211 /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
212 ///
213 /// Only used when the task is spawned on the tokio runtime.
214 pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
215 *self
216 .worker_handle
217 .lock()
218 .expect("previous task panicked while holding the worker handle mutex") =
219 Some(worker_handle);
220 }
221
222 /// Returns the error from the batch worker's `error_handle`.
223 fn get_worker_error(&self) -> crate::BoxError {
224 self.error_handle.get_error_on_closed()
225 }
226}
227
228impl<T, Request: RequestWeight> Service<Request> for Batch<T, Request>
229where
230 T: Service<BatchControl<Request>>,
231 T::Future: Send + 'static,
232 T::Error: Into<crate::BoxError>,
233{
234 type Response = T::Response;
235 type Error = crate::BoxError;
236 type Future = ResponseFuture<T::Future>;
237
238 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239 // Check to see if the worker has returned or panicked.
240 //
241 // Correctness: Registers this task for wakeup when the worker finishes.
242 if let Some(worker_handle) = self
243 .worker_handle
244 .lock()
245 .expect("previous task panicked while holding the worker handle mutex")
246 .as_mut()
247 {
248 match Pin::new(worker_handle).poll(cx) {
249 Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
250 Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
251 tracing::warn!(
252 "batch task cancelled: {task_cancelled}\n\
253 Is Zebra shutting down?"
254 );
255
256 return Poll::Ready(Err(task_cancelled.into()));
257 }
258 Poll::Ready(Err(task_panic)) => {
259 std::panic::resume_unwind(task_panic.into_panic());
260 }
261 Poll::Pending => {}
262 }
263 }
264
265 // Check if the worker has set an error and closed its channels.
266 //
267 // Correctness: Registers this task for wakeup when the channel is closed.
268 let tx = self.tx.clone();
269 let closed = tx.closed();
270 pin!(closed);
271 if closed.poll(cx).is_ready() {
272 return Poll::Ready(Err(self.get_worker_error()));
273 }
274
275 // Poll to acquire a semaphore permit.
276 //
277 // CORRECTNESS
278 //
279 // If we acquire a permit, then there's enough buffer capacity to send a new request.
280 // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
281 // this task for wakeup when the next permit is available, or when the semaphore is closed.
282 //
283 // When `poll_ready()` is called multiple times, and channel capacity is 1,
284 // avoid deadlocks by dropping any previous permit before acquiring another one.
285 // This also stops tasks holding a permit after an error.
286 //
287 // Calling `poll_ready()` multiple times can make tasks lose their previous permit
288 // to another concurrent task.
289 self.permit = None;
290
291 let permit = ready!(self.semaphore.poll_acquire(cx));
292 if let Some(permit) = permit {
293 // Calling poll_ready() more than once will drop any previous permit,
294 // releasing its capacity back to the semaphore.
295 self.permit = Some(permit);
296 } else {
297 // The semaphore has been closed.
298 return Poll::Ready(Err(self.get_worker_error()));
299 }
300
301 Poll::Ready(Ok(()))
302 }
303
304 fn call(&mut self, request: Request) -> Self::Future {
305 tracing::trace!("sending request to buffer worker");
306 let _permit = self
307 .permit
308 .take()
309 .expect("poll_ready must be called before a batch request");
310
311 // get the current Span so that we can explicitly propagate it to the worker
312 // if we didn't do this, events on the worker related to this span wouldn't be counted
313 // towards that span since the worker would have no way of entering it.
314 let span = tracing::Span::current();
315
316 // If we've made it here, then a semaphore permit has already been
317 // acquired, so we can freely allocate a oneshot.
318 let (tx, rx) = oneshot::channel();
319
320 match self.tx.send(Message {
321 request,
322 tx,
323 span,
324 _permit,
325 }) {
326 Err(_) => ResponseFuture::failed(self.get_worker_error()),
327 Ok(_) => ResponseFuture::new(rx),
328 }
329 }
330}
331
332impl<T, Request: RequestWeight> Clone for Batch<T, Request>
333where
334 T: Service<BatchControl<Request>>,
335{
336 fn clone(&self) -> Self {
337 Self {
338 tx: self.tx.clone(),
339 semaphore: self.semaphore.clone(),
340 permit: None,
341 error_handle: self.error_handle.clone(),
342 worker_handle: self.worker_handle.clone(),
343 }
344 }
345}