tower_batch_control/
service.rs

1//! Wrapper service for batching items to an underlying service.
2
3use std::{
4    cmp::max,
5    fmt,
6    future::Future,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{Context, Poll},
10};
11
12use futures_core::ready;
13use tokio::{
14    pin,
15    sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
16    task::JoinHandle,
17};
18use tokio_util::sync::PollSemaphore;
19use tower::Service;
20use tracing::{info_span, Instrument};
21
22use crate::RequestWeight;
23
24use super::{
25    future::ResponseFuture,
26    message::Message,
27    worker::{ErrorHandle, Worker},
28    BatchControl,
29};
30
31/// The maximum number of batches in the queue.
32///
33/// This avoids having very large queues on machines with hundreds or thousands of cores.
34pub const QUEUE_BATCH_LIMIT: usize = 64;
35
36/// Allows batch processing of requests.
37///
38/// See the crate documentation for more details.
39pub struct Batch<T, Request: RequestWeight>
40where
41    T: Service<BatchControl<Request>>,
42{
43    // Batch management
44    //
45    /// A custom-bounded channel for sending requests to the batch worker.
46    ///
47    /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
48    /// channel, we use tokio's semaphore separately to implement the bound.
49    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
50
51    /// A semaphore used to bound the channel.
52    ///
53    /// When the buffer's channel is full, we want to exert backpressure in
54    /// `poll_ready`, so that callers such as load balancers could choose to call
55    /// another service rather than waiting for buffer capacity.
56    ///
57    /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
58    /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
59    /// own bounded MPSC on top of the unbounded channel, using a semaphore to
60    /// limit how many items are in the channel.
61    semaphore: PollSemaphore,
62
63    /// A semaphore permit that allows this service to send one message on `tx`.
64    permit: Option<OwnedSemaphorePermit>,
65
66    // Errors
67    //
68    /// An error handle shared between all service clones for the same worker.
69    error_handle: ErrorHandle,
70
71    /// A worker task handle shared between all service clones for the same worker.
72    ///
73    /// Only used when the worker is spawned on the tokio runtime.
74    worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
75}
76
77impl<T, Request: RequestWeight> fmt::Debug for Batch<T, Request>
78where
79    T: Service<BatchControl<Request>>,
80{
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        let name = std::any::type_name::<Self>();
83        f.debug_struct(name)
84            .field("tx", &self.tx)
85            .field("semaphore", &self.semaphore)
86            .field("permit", &self.permit)
87            .field("error_handle", &self.error_handle)
88            .field("worker_handle", &self.worker_handle)
89            .finish()
90    }
91}
92
93impl<T, Request: RequestWeight> Batch<T, Request>
94where
95    T: Service<BatchControl<Request>>,
96    T::Future: Send + 'static,
97    T::Error: Into<crate::BoxError>,
98{
99    /// Creates a new `Batch` wrapping `service`.
100    ///
101    /// The wrapper is responsible for telling the inner service when to flush a
102    /// batch of requests. These parameters control this policy:
103    ///
104    /// * `max_items_weight_in_batch` gives the maximum item weight per batch.
105    /// * `max_batches` is an upper bound on the number of batches in the queue,
106    ///   and the number of concurrently executing batches.
107    ///   If this is `None`, we use the current number of [`rayon`] threads.
108    ///   The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
109    /// * `max_latency` gives the maximum latency for a batch item to start verifying.
110    ///
111    /// The default Tokio executor is used to run the given service, which means
112    /// that this method must be called while on the Tokio runtime.
113    pub fn new(
114        service: T,
115        max_items_weight_in_batch: usize,
116        max_batches: impl Into<Option<usize>>,
117        max_latency: std::time::Duration,
118    ) -> Self
119    where
120        T: Send + 'static,
121        T::Future: Send,
122        T::Response: Send,
123        T::Error: Send + Sync,
124        Request: Send + 'static,
125    {
126        let (mut batch, worker) =
127            Self::pair(service, max_items_weight_in_batch, max_batches, max_latency);
128
129        let span = info_span!("batch worker", kind = std::any::type_name::<T>());
130
131        #[cfg(tokio_unstable)]
132        let worker_handle = {
133            let batch_kind = std::any::type_name::<T>();
134
135            // TODO: identify the unique part of the type name generically,
136            //       or make it an argument to this method
137            let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
138            let batch_kind = batch_kind.trim_end_matches("::Verifier");
139
140            tokio::task::Builder::new()
141                .name(&format!("{} batch", batch_kind))
142                .spawn(worker.run().instrument(span))
143                .expect("panic on error to match tokio::spawn")
144        };
145        #[cfg(not(tokio_unstable))]
146        let worker_handle = tokio::spawn(worker.run().instrument(span));
147
148        batch.register_worker(worker_handle);
149
150        batch
151    }
152
153    /// Creates a new `Batch` wrapping `service`, but returns the background worker.
154    ///
155    /// This is useful if you do not want to spawn directly onto the `tokio`
156    /// runtime but instead want to use your own executor. This will return the
157    /// `Batch` and the background `Worker` that you can then spawn.
158    pub fn pair(
159        service: T,
160        max_items_weight_in_batch: usize,
161        max_batches: impl Into<Option<usize>>,
162        max_latency: std::time::Duration,
163    ) -> (Self, Worker<T, Request>)
164    where
165        T: Send + 'static,
166        T::Error: Send + Sync,
167        Request: Send + 'static,
168    {
169        let (tx, rx) = mpsc::unbounded_channel();
170
171        // Clamp config to sensible values.
172        let max_items_weight_in_batch = max(max_items_weight_in_batch, 1);
173        let max_batches = max_batches
174            .into()
175            .unwrap_or_else(rayon::current_num_threads);
176        let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
177
178        // The semaphore bound limits the maximum number of concurrent requests
179        // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
180        // used their semaphore reservation in a `call` yet).
181        //
182        // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
183        // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
184        // Often there is only one verifier running, when that happens we want it to take all the cores.
185        //
186        // Requests with a request weight greater than 1 won't typically exhaust the number of available
187        // permits, but will still be bounded to the maximum possible number of concurrent requests.
188        let semaphore = Semaphore::new(max_items_weight_in_batch * max_batches_in_queue);
189        let semaphore = PollSemaphore::new(Arc::new(semaphore));
190
191        let (error_handle, worker) = Worker::new(
192            service,
193            rx,
194            max_items_weight_in_batch,
195            max_batches,
196            max_latency,
197            semaphore.clone(),
198        );
199
200        let batch = Batch {
201            tx,
202            semaphore,
203            permit: None,
204            error_handle,
205            worker_handle: Arc::new(Mutex::new(None)),
206        };
207
208        (batch, worker)
209    }
210
211    /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
212    ///
213    /// Only used when the task is spawned on the tokio runtime.
214    pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
215        *self
216            .worker_handle
217            .lock()
218            .expect("previous task panicked while holding the worker handle mutex") =
219            Some(worker_handle);
220    }
221
222    /// Returns the error from the batch worker's `error_handle`.
223    fn get_worker_error(&self) -> crate::BoxError {
224        self.error_handle.get_error_on_closed()
225    }
226}
227
228impl<T, Request: RequestWeight> Service<Request> for Batch<T, Request>
229where
230    T: Service<BatchControl<Request>>,
231    T::Future: Send + 'static,
232    T::Error: Into<crate::BoxError>,
233{
234    type Response = T::Response;
235    type Error = crate::BoxError;
236    type Future = ResponseFuture<T::Future>;
237
238    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239        // Check to see if the worker has returned or panicked.
240        //
241        // Correctness: Registers this task for wakeup when the worker finishes.
242        if let Some(worker_handle) = self
243            .worker_handle
244            .lock()
245            .expect("previous task panicked while holding the worker handle mutex")
246            .as_mut()
247        {
248            match Pin::new(worker_handle).poll(cx) {
249                Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
250                Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
251                    tracing::warn!(
252                        "batch task cancelled: {task_cancelled}\n\
253                         Is Zebra shutting down?"
254                    );
255
256                    return Poll::Ready(Err(task_cancelled.into()));
257                }
258                Poll::Ready(Err(task_panic)) => {
259                    std::panic::resume_unwind(task_panic.into_panic());
260                }
261                Poll::Pending => {}
262            }
263        }
264
265        // Check if the worker has set an error and closed its channels.
266        //
267        // Correctness: Registers this task for wakeup when the channel is closed.
268        let tx = self.tx.clone();
269        let closed = tx.closed();
270        pin!(closed);
271        if closed.poll(cx).is_ready() {
272            return Poll::Ready(Err(self.get_worker_error()));
273        }
274
275        // Poll to acquire a semaphore permit.
276        //
277        // CORRECTNESS
278        //
279        // If we acquire a permit, then there's enough buffer capacity to send a new request.
280        // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
281        // this task for wakeup when the next permit is available, or when the semaphore is closed.
282        //
283        // When `poll_ready()` is called multiple times, and channel capacity is 1,
284        // avoid deadlocks by dropping any previous permit before acquiring another one.
285        // This also stops tasks holding a permit after an error.
286        //
287        // Calling `poll_ready()` multiple times can make tasks lose their previous permit
288        // to another concurrent task.
289        self.permit = None;
290
291        let permit = ready!(self.semaphore.poll_acquire(cx));
292        if let Some(permit) = permit {
293            // Calling poll_ready() more than once will drop any previous permit,
294            // releasing its capacity back to the semaphore.
295            self.permit = Some(permit);
296        } else {
297            // The semaphore has been closed.
298            return Poll::Ready(Err(self.get_worker_error()));
299        }
300
301        Poll::Ready(Ok(()))
302    }
303
304    fn call(&mut self, request: Request) -> Self::Future {
305        tracing::trace!("sending request to buffer worker");
306        let _permit = self
307            .permit
308            .take()
309            .expect("poll_ready must be called before a batch request");
310
311        // get the current Span so that we can explicitly propagate it to the worker
312        // if we didn't do this, events on the worker related to this span wouldn't be counted
313        // towards that span since the worker would have no way of entering it.
314        let span = tracing::Span::current();
315
316        // If we've made it here, then a semaphore permit has already been
317        // acquired, so we can freely allocate a oneshot.
318        let (tx, rx) = oneshot::channel();
319
320        match self.tx.send(Message {
321            request,
322            tx,
323            span,
324            _permit,
325        }) {
326            Err(_) => ResponseFuture::failed(self.get_worker_error()),
327            Ok(_) => ResponseFuture::new(rx),
328        }
329    }
330}
331
332impl<T, Request: RequestWeight> Clone for Batch<T, Request>
333where
334    T: Service<BatchControl<Request>>,
335{
336    fn clone(&self) -> Self {
337        Self {
338            tx: self.tx.clone(),
339            semaphore: self.semaphore.clone(),
340            permit: None,
341            error_handle: self.error_handle.clone(),
342            worker_handle: self.worker_handle.clone(),
343        }
344    }
345}