tower_batch_control/
service.rs

1//! Wrapper service for batching items to an underlying service.
2
3use std::{
4    cmp::max,
5    fmt,
6    future::Future,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{Context, Poll},
10};
11
12use futures_core::ready;
13use tokio::{
14    pin,
15    sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
16    task::JoinHandle,
17};
18use tokio_util::sync::PollSemaphore;
19use tower::Service;
20use tracing::{info_span, Instrument};
21
22use super::{
23    future::ResponseFuture,
24    message::Message,
25    worker::{ErrorHandle, Worker},
26    BatchControl,
27};
28
29/// The maximum number of batches in the queue.
30///
31/// This avoids having very large queues on machines with hundreds or thousands of cores.
32pub const QUEUE_BATCH_LIMIT: usize = 64;
33
34/// Allows batch processing of requests.
35///
36/// See the crate documentation for more details.
37pub struct Batch<T, Request>
38where
39    T: Service<BatchControl<Request>>,
40{
41    // Batch management
42    //
43    /// A custom-bounded channel for sending requests to the batch worker.
44    ///
45    /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
46    /// channel, we use tokio's semaphore separately to implement the bound.
47    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
48
49    /// A semaphore used to bound the channel.
50    ///
51    /// When the buffer's channel is full, we want to exert backpressure in
52    /// `poll_ready`, so that callers such as load balancers could choose to call
53    /// another service rather than waiting for buffer capacity.
54    ///
55    /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
56    /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
57    /// own bounded MPSC on top of the unbounded channel, using a semaphore to
58    /// limit how many items are in the channel.
59    semaphore: PollSemaphore,
60
61    /// A semaphore permit that allows this service to send one message on `tx`.
62    permit: Option<OwnedSemaphorePermit>,
63
64    // Errors
65    //
66    /// An error handle shared between all service clones for the same worker.
67    error_handle: ErrorHandle,
68
69    /// A worker task handle shared between all service clones for the same worker.
70    ///
71    /// Only used when the worker is spawned on the tokio runtime.
72    worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
73}
74
75impl<T, Request> fmt::Debug for Batch<T, Request>
76where
77    T: Service<BatchControl<Request>>,
78{
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        let name = std::any::type_name::<Self>();
81        f.debug_struct(name)
82            .field("tx", &self.tx)
83            .field("semaphore", &self.semaphore)
84            .field("permit", &self.permit)
85            .field("error_handle", &self.error_handle)
86            .field("worker_handle", &self.worker_handle)
87            .finish()
88    }
89}
90
91impl<T, Request> Batch<T, Request>
92where
93    T: Service<BatchControl<Request>>,
94    T::Future: Send + 'static,
95    T::Error: Into<crate::BoxError>,
96{
97    /// Creates a new `Batch` wrapping `service`.
98    ///
99    /// The wrapper is responsible for telling the inner service when to flush a
100    /// batch of requests. These parameters control this policy:
101    ///
102    /// * `max_items_in_batch` gives the maximum number of items per batch.
103    /// * `max_batches` is an upper bound on the number of batches in the queue,
104    ///   and the number of concurrently executing batches.
105    ///   If this is `None`, we use the current number of [`rayon`] threads.
106    ///   The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
107    /// * `max_latency` gives the maximum latency for a batch item to start verifying.
108    ///
109    /// The default Tokio executor is used to run the given service, which means
110    /// that this method must be called while on the Tokio runtime.
111    pub fn new(
112        service: T,
113        max_items_in_batch: usize,
114        max_batches: impl Into<Option<usize>>,
115        max_latency: std::time::Duration,
116    ) -> Self
117    where
118        T: Send + 'static,
119        T::Future: Send,
120        T::Response: Send,
121        T::Error: Send + Sync,
122        Request: Send + 'static,
123    {
124        let (mut batch, worker) = Self::pair(service, max_items_in_batch, max_batches, max_latency);
125
126        let span = info_span!("batch worker", kind = std::any::type_name::<T>());
127
128        #[cfg(tokio_unstable)]
129        let worker_handle = {
130            let batch_kind = std::any::type_name::<T>();
131
132            // TODO: identify the unique part of the type name generically,
133            //       or make it an argument to this method
134            let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
135            let batch_kind = batch_kind.trim_end_matches("::Verifier");
136
137            tokio::task::Builder::new()
138                .name(&format!("{} batch", batch_kind))
139                .spawn(worker.run().instrument(span))
140                .expect("panic on error to match tokio::spawn")
141        };
142        #[cfg(not(tokio_unstable))]
143        let worker_handle = tokio::spawn(worker.run().instrument(span));
144
145        batch.register_worker(worker_handle);
146
147        batch
148    }
149
150    /// Creates a new `Batch` wrapping `service`, but returns the background worker.
151    ///
152    /// This is useful if you do not want to spawn directly onto the `tokio`
153    /// runtime but instead want to use your own executor. This will return the
154    /// `Batch` and the background `Worker` that you can then spawn.
155    pub fn pair(
156        service: T,
157        max_items_in_batch: usize,
158        max_batches: impl Into<Option<usize>>,
159        max_latency: std::time::Duration,
160    ) -> (Self, Worker<T, Request>)
161    where
162        T: Send + 'static,
163        T::Error: Send + Sync,
164        Request: Send + 'static,
165    {
166        let (tx, rx) = mpsc::unbounded_channel();
167
168        // Clamp config to sensible values.
169        let max_items_in_batch = max(max_items_in_batch, 1);
170        let max_batches = max_batches
171            .into()
172            .unwrap_or_else(rayon::current_num_threads);
173        let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
174
175        // The semaphore bound limits the maximum number of concurrent requests
176        // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
177        // used their semaphore reservation in a `call` yet).
178        //
179        // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
180        // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
181        // Often there is only one verifier running, when that happens we want it to take all the cores.
182        let semaphore = Semaphore::new(max_items_in_batch * max_batches_in_queue);
183        let semaphore = PollSemaphore::new(Arc::new(semaphore));
184
185        let (error_handle, worker) = Worker::new(
186            service,
187            rx,
188            max_items_in_batch,
189            max_batches,
190            max_latency,
191            semaphore.clone(),
192        );
193
194        let batch = Batch {
195            tx,
196            semaphore,
197            permit: None,
198            error_handle,
199            worker_handle: Arc::new(Mutex::new(None)),
200        };
201
202        (batch, worker)
203    }
204
205    /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
206    ///
207    /// Only used when the task is spawned on the tokio runtime.
208    pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
209        *self
210            .worker_handle
211            .lock()
212            .expect("previous task panicked while holding the worker handle mutex") =
213            Some(worker_handle);
214    }
215
216    /// Returns the error from the batch worker's `error_handle`.
217    fn get_worker_error(&self) -> crate::BoxError {
218        self.error_handle.get_error_on_closed()
219    }
220}
221
222impl<T, Request> Service<Request> for Batch<T, Request>
223where
224    T: Service<BatchControl<Request>>,
225    T::Future: Send + 'static,
226    T::Error: Into<crate::BoxError>,
227{
228    type Response = T::Response;
229    type Error = crate::BoxError;
230    type Future = ResponseFuture<T::Future>;
231
232    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233        // Check to see if the worker has returned or panicked.
234        //
235        // Correctness: Registers this task for wakeup when the worker finishes.
236        if let Some(worker_handle) = self
237            .worker_handle
238            .lock()
239            .expect("previous task panicked while holding the worker handle mutex")
240            .as_mut()
241        {
242            match Pin::new(worker_handle).poll(cx) {
243                Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
244                Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
245                    tracing::warn!(
246                        "batch task cancelled: {task_cancelled}\n\
247                         Is Zebra shutting down?"
248                    );
249
250                    return Poll::Ready(Err(task_cancelled.into()));
251                }
252                Poll::Ready(Err(task_panic)) => {
253                    std::panic::resume_unwind(task_panic.into_panic());
254                }
255                Poll::Pending => {}
256            }
257        }
258
259        // Check if the worker has set an error and closed its channels.
260        //
261        // Correctness: Registers this task for wakeup when the channel is closed.
262        let tx = self.tx.clone();
263        let closed = tx.closed();
264        pin!(closed);
265        if closed.poll(cx).is_ready() {
266            return Poll::Ready(Err(self.get_worker_error()));
267        }
268
269        // Poll to acquire a semaphore permit.
270        //
271        // CORRECTNESS
272        //
273        // If we acquire a permit, then there's enough buffer capacity to send a new request.
274        // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
275        // this task for wakeup when the next permit is available, or when the semaphore is closed.
276        //
277        // When `poll_ready()` is called multiple times, and channel capacity is 1,
278        // avoid deadlocks by dropping any previous permit before acquiring another one.
279        // This also stops tasks holding a permit after an error.
280        //
281        // Calling `poll_ready()` multiple times can make tasks lose their previous permit
282        // to another concurrent task.
283        self.permit = None;
284
285        let permit = ready!(self.semaphore.poll_acquire(cx));
286        if let Some(permit) = permit {
287            // Calling poll_ready() more than once will drop any previous permit,
288            // releasing its capacity back to the semaphore.
289            self.permit = Some(permit);
290        } else {
291            // The semaphore has been closed.
292            return Poll::Ready(Err(self.get_worker_error()));
293        }
294
295        Poll::Ready(Ok(()))
296    }
297
298    fn call(&mut self, request: Request) -> Self::Future {
299        tracing::trace!("sending request to buffer worker");
300        let _permit = self
301            .permit
302            .take()
303            .expect("poll_ready must be called before a batch request");
304
305        // get the current Span so that we can explicitly propagate it to the worker
306        // if we didn't do this, events on the worker related to this span wouldn't be counted
307        // towards that span since the worker would have no way of entering it.
308        let span = tracing::Span::current();
309
310        // If we've made it here, then a semaphore permit has already been
311        // acquired, so we can freely allocate a oneshot.
312        let (tx, rx) = oneshot::channel();
313
314        match self.tx.send(Message {
315            request,
316            tx,
317            span,
318            _permit,
319        }) {
320            Err(_) => ResponseFuture::failed(self.get_worker_error()),
321            Ok(_) => ResponseFuture::new(rx),
322        }
323    }
324}
325
326impl<T, Request> Clone for Batch<T, Request>
327where
328    T: Service<BatchControl<Request>>,
329{
330    fn clone(&self) -> Self {
331        Self {
332            tx: self.tx.clone(),
333            semaphore: self.semaphore.clone(),
334            permit: None,
335            error_handle: self.error_handle.clone(),
336            worker_handle: self.worker_handle.clone(),
337        }
338    }
339}