tower_batch_control/service.rs
1//! Wrapper service for batching items to an underlying service.
2
3use std::{
4 cmp::max,
5 fmt,
6 future::Future,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{Context, Poll},
10};
11
12use futures_core::ready;
13use tokio::{
14 pin,
15 sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
16 task::JoinHandle,
17};
18use tokio_util::sync::PollSemaphore;
19use tower::Service;
20use tracing::{info_span, Instrument};
21
22use super::{
23 future::ResponseFuture,
24 message::Message,
25 worker::{ErrorHandle, Worker},
26 BatchControl,
27};
28
29/// The maximum number of batches in the queue.
30///
31/// This avoids having very large queues on machines with hundreds or thousands of cores.
32pub const QUEUE_BATCH_LIMIT: usize = 64;
33
34/// Allows batch processing of requests.
35///
36/// See the crate documentation for more details.
37pub struct Batch<T, Request>
38where
39 T: Service<BatchControl<Request>>,
40{
41 // Batch management
42 //
43 /// A custom-bounded channel for sending requests to the batch worker.
44 ///
45 /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
46 /// channel, we use tokio's semaphore separately to implement the bound.
47 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
48
49 /// A semaphore used to bound the channel.
50 ///
51 /// When the buffer's channel is full, we want to exert backpressure in
52 /// `poll_ready`, so that callers such as load balancers could choose to call
53 /// another service rather than waiting for buffer capacity.
54 ///
55 /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
56 /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
57 /// own bounded MPSC on top of the unbounded channel, using a semaphore to
58 /// limit how many items are in the channel.
59 semaphore: PollSemaphore,
60
61 /// A semaphore permit that allows this service to send one message on `tx`.
62 permit: Option<OwnedSemaphorePermit>,
63
64 // Errors
65 //
66 /// An error handle shared between all service clones for the same worker.
67 error_handle: ErrorHandle,
68
69 /// A worker task handle shared between all service clones for the same worker.
70 ///
71 /// Only used when the worker is spawned on the tokio runtime.
72 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
73}
74
75impl<T, Request> fmt::Debug for Batch<T, Request>
76where
77 T: Service<BatchControl<Request>>,
78{
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 let name = std::any::type_name::<Self>();
81 f.debug_struct(name)
82 .field("tx", &self.tx)
83 .field("semaphore", &self.semaphore)
84 .field("permit", &self.permit)
85 .field("error_handle", &self.error_handle)
86 .field("worker_handle", &self.worker_handle)
87 .finish()
88 }
89}
90
91impl<T, Request> Batch<T, Request>
92where
93 T: Service<BatchControl<Request>>,
94 T::Future: Send + 'static,
95 T::Error: Into<crate::BoxError>,
96{
97 /// Creates a new `Batch` wrapping `service`.
98 ///
99 /// The wrapper is responsible for telling the inner service when to flush a
100 /// batch of requests. These parameters control this policy:
101 ///
102 /// * `max_items_in_batch` gives the maximum number of items per batch.
103 /// * `max_batches` is an upper bound on the number of batches in the queue,
104 /// and the number of concurrently executing batches.
105 /// If this is `None`, we use the current number of [`rayon`] threads.
106 /// The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
107 /// * `max_latency` gives the maximum latency for a batch item to start verifying.
108 ///
109 /// The default Tokio executor is used to run the given service, which means
110 /// that this method must be called while on the Tokio runtime.
111 pub fn new(
112 service: T,
113 max_items_in_batch: usize,
114 max_batches: impl Into<Option<usize>>,
115 max_latency: std::time::Duration,
116 ) -> Self
117 where
118 T: Send + 'static,
119 T::Future: Send,
120 T::Response: Send,
121 T::Error: Send + Sync,
122 Request: Send + 'static,
123 {
124 let (mut batch, worker) = Self::pair(service, max_items_in_batch, max_batches, max_latency);
125
126 let span = info_span!("batch worker", kind = std::any::type_name::<T>());
127
128 #[cfg(tokio_unstable)]
129 let worker_handle = {
130 let batch_kind = std::any::type_name::<T>();
131
132 // TODO: identify the unique part of the type name generically,
133 // or make it an argument to this method
134 let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
135 let batch_kind = batch_kind.trim_end_matches("::Verifier");
136
137 tokio::task::Builder::new()
138 .name(&format!("{} batch", batch_kind))
139 .spawn(worker.run().instrument(span))
140 .expect("panic on error to match tokio::spawn")
141 };
142 #[cfg(not(tokio_unstable))]
143 let worker_handle = tokio::spawn(worker.run().instrument(span));
144
145 batch.register_worker(worker_handle);
146
147 batch
148 }
149
150 /// Creates a new `Batch` wrapping `service`, but returns the background worker.
151 ///
152 /// This is useful if you do not want to spawn directly onto the `tokio`
153 /// runtime but instead want to use your own executor. This will return the
154 /// `Batch` and the background `Worker` that you can then spawn.
155 pub fn pair(
156 service: T,
157 max_items_in_batch: usize,
158 max_batches: impl Into<Option<usize>>,
159 max_latency: std::time::Duration,
160 ) -> (Self, Worker<T, Request>)
161 where
162 T: Send + 'static,
163 T::Error: Send + Sync,
164 Request: Send + 'static,
165 {
166 let (tx, rx) = mpsc::unbounded_channel();
167
168 // Clamp config to sensible values.
169 let max_items_in_batch = max(max_items_in_batch, 1);
170 let max_batches = max_batches
171 .into()
172 .unwrap_or_else(rayon::current_num_threads);
173 let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
174
175 // The semaphore bound limits the maximum number of concurrent requests
176 // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
177 // used their semaphore reservation in a `call` yet).
178 //
179 // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
180 // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
181 // Often there is only one verifier running, when that happens we want it to take all the cores.
182 let semaphore = Semaphore::new(max_items_in_batch * max_batches_in_queue);
183 let semaphore = PollSemaphore::new(Arc::new(semaphore));
184
185 let (error_handle, worker) = Worker::new(
186 service,
187 rx,
188 max_items_in_batch,
189 max_batches,
190 max_latency,
191 semaphore.clone(),
192 );
193
194 let batch = Batch {
195 tx,
196 semaphore,
197 permit: None,
198 error_handle,
199 worker_handle: Arc::new(Mutex::new(None)),
200 };
201
202 (batch, worker)
203 }
204
205 /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
206 ///
207 /// Only used when the task is spawned on the tokio runtime.
208 pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
209 *self
210 .worker_handle
211 .lock()
212 .expect("previous task panicked while holding the worker handle mutex") =
213 Some(worker_handle);
214 }
215
216 /// Returns the error from the batch worker's `error_handle`.
217 fn get_worker_error(&self) -> crate::BoxError {
218 self.error_handle.get_error_on_closed()
219 }
220}
221
222impl<T, Request> Service<Request> for Batch<T, Request>
223where
224 T: Service<BatchControl<Request>>,
225 T::Future: Send + 'static,
226 T::Error: Into<crate::BoxError>,
227{
228 type Response = T::Response;
229 type Error = crate::BoxError;
230 type Future = ResponseFuture<T::Future>;
231
232 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233 // Check to see if the worker has returned or panicked.
234 //
235 // Correctness: Registers this task for wakeup when the worker finishes.
236 if let Some(worker_handle) = self
237 .worker_handle
238 .lock()
239 .expect("previous task panicked while holding the worker handle mutex")
240 .as_mut()
241 {
242 match Pin::new(worker_handle).poll(cx) {
243 Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
244 Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
245 tracing::warn!(
246 "batch task cancelled: {task_cancelled}\n\
247 Is Zebra shutting down?"
248 );
249
250 return Poll::Ready(Err(task_cancelled.into()));
251 }
252 Poll::Ready(Err(task_panic)) => {
253 std::panic::resume_unwind(task_panic.into_panic());
254 }
255 Poll::Pending => {}
256 }
257 }
258
259 // Check if the worker has set an error and closed its channels.
260 //
261 // Correctness: Registers this task for wakeup when the channel is closed.
262 let tx = self.tx.clone();
263 let closed = tx.closed();
264 pin!(closed);
265 if closed.poll(cx).is_ready() {
266 return Poll::Ready(Err(self.get_worker_error()));
267 }
268
269 // Poll to acquire a semaphore permit.
270 //
271 // CORRECTNESS
272 //
273 // If we acquire a permit, then there's enough buffer capacity to send a new request.
274 // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
275 // this task for wakeup when the next permit is available, or when the semaphore is closed.
276 //
277 // When `poll_ready()` is called multiple times, and channel capacity is 1,
278 // avoid deadlocks by dropping any previous permit before acquiring another one.
279 // This also stops tasks holding a permit after an error.
280 //
281 // Calling `poll_ready()` multiple times can make tasks lose their previous permit
282 // to another concurrent task.
283 self.permit = None;
284
285 let permit = ready!(self.semaphore.poll_acquire(cx));
286 if let Some(permit) = permit {
287 // Calling poll_ready() more than once will drop any previous permit,
288 // releasing its capacity back to the semaphore.
289 self.permit = Some(permit);
290 } else {
291 // The semaphore has been closed.
292 return Poll::Ready(Err(self.get_worker_error()));
293 }
294
295 Poll::Ready(Ok(()))
296 }
297
298 fn call(&mut self, request: Request) -> Self::Future {
299 tracing::trace!("sending request to buffer worker");
300 let _permit = self
301 .permit
302 .take()
303 .expect("poll_ready must be called before a batch request");
304
305 // get the current Span so that we can explicitly propagate it to the worker
306 // if we didn't do this, events on the worker related to this span wouldn't be counted
307 // towards that span since the worker would have no way of entering it.
308 let span = tracing::Span::current();
309
310 // If we've made it here, then a semaphore permit has already been
311 // acquired, so we can freely allocate a oneshot.
312 let (tx, rx) = oneshot::channel();
313
314 match self.tx.send(Message {
315 request,
316 tx,
317 span,
318 _permit,
319 }) {
320 Err(_) => ResponseFuture::failed(self.get_worker_error()),
321 Ok(_) => ResponseFuture::new(rx),
322 }
323 }
324}
325
326impl<T, Request> Clone for Batch<T, Request>
327where
328 T: Service<BatchControl<Request>>,
329{
330 fn clone(&self) -> Self {
331 Self {
332 tx: self.tx.clone(),
333 semaphore: self.semaphore.clone(),
334 permit: None,
335 error_handle: self.error_handle.clone(),
336 worker_handle: self.worker_handle.clone(),
337 }
338 }
339}