1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
//! Wrapper service for batching items to an underlying service.

use std::{
    cmp::max,
    fmt,
    future::Future,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};

use futures_core::ready;
use tokio::{
    pin,
    sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
    task::JoinHandle,
};
use tokio_util::sync::PollSemaphore;
use tower::Service;
use tracing::{info_span, Instrument};

use super::{
    future::ResponseFuture,
    message::Message,
    worker::{ErrorHandle, Worker},
    BatchControl,
};

/// The maximum number of batches in the queue.
///
/// This avoids having very large queues on machines with hundreds or thousands of cores.
pub const QUEUE_BATCH_LIMIT: usize = 64;

/// Allows batch processing of requests.
///
/// See the crate documentation for more details.
pub struct Batch<T, Request>
where
    T: Service<BatchControl<Request>>,
{
    // Batch management
    //
    /// A custom-bounded channel for sending requests to the batch worker.
    ///
    /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
    /// channel, we use tokio's semaphore separately to implement the bound.
    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,

    /// A semaphore used to bound the channel.
    ///
    /// When the buffer's channel is full, we want to exert backpressure in
    /// `poll_ready`, so that callers such as load balancers could choose to call
    /// another service rather than waiting for buffer capacity.
    ///
    /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
    /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
    /// own bounded MPSC on top of the unbounded channel, using a semaphore to
    /// limit how many items are in the channel.
    semaphore: PollSemaphore,

    /// A semaphore permit that allows this service to send one message on `tx`.
    permit: Option<OwnedSemaphorePermit>,

    // Errors
    //
    /// An error handle shared between all service clones for the same worker.
    error_handle: ErrorHandle,

    /// A worker task handle shared between all service clones for the same worker.
    ///
    /// Only used when the worker is spawned on the tokio runtime.
    worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}

impl<T, Request> fmt::Debug for Batch<T, Request>
where
    T: Service<BatchControl<Request>>,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let name = std::any::type_name::<Self>();
        f.debug_struct(name)
            .field("tx", &self.tx)
            .field("semaphore", &self.semaphore)
            .field("permit", &self.permit)
            .field("error_handle", &self.error_handle)
            .field("worker_handle", &self.worker_handle)
            .finish()
    }
}

impl<T, Request> Batch<T, Request>
where
    T: Service<BatchControl<Request>>,
    T::Future: Send + 'static,
    T::Error: Into<crate::BoxError>,
{
    /// Creates a new `Batch` wrapping `service`.
    ///
    /// The wrapper is responsible for telling the inner service when to flush a
    /// batch of requests. These parameters control this policy:
    ///
    /// * `max_items_in_batch` gives the maximum number of items per batch.
    /// * `max_batches` is an upper bound on the number of batches in the queue,
    ///   and the number of concurrently executing batches.
    ///   If this is `None`, we use the current number of [`rayon`] threads.
    ///   The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
    /// * `max_latency` gives the maximum latency for a batch item to start verifying.
    ///
    /// The default Tokio executor is used to run the given service, which means
    /// that this method must be called while on the Tokio runtime.
    pub fn new(
        service: T,
        max_items_in_batch: usize,
        max_batches: impl Into<Option<usize>>,
        max_latency: std::time::Duration,
    ) -> Self
    where
        T: Send + 'static,
        T::Future: Send,
        T::Response: Send,
        T::Error: Send + Sync,
        Request: Send + 'static,
    {
        let (mut batch, worker) = Self::pair(service, max_items_in_batch, max_batches, max_latency);

        let span = info_span!("batch worker", kind = std::any::type_name::<T>());

        #[cfg(tokio_unstable)]
        let worker_handle = {
            let batch_kind = std::any::type_name::<T>();

            // TODO: identify the unique part of the type name generically,
            //       or make it an argument to this method
            let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
            let batch_kind = batch_kind.trim_end_matches("::Verifier");

            tokio::task::Builder::new()
                .name(&format!("{} batch", batch_kind))
                .spawn(worker.run().instrument(span))
        };
        #[cfg(not(tokio_unstable))]
        let worker_handle = tokio::spawn(worker.run().instrument(span));

        batch.register_worker(worker_handle);

        batch
    }

    /// Creates a new `Batch` wrapping `service`, but returns the background worker.
    ///
    /// This is useful if you do not want to spawn directly onto the `tokio`
    /// runtime but instead want to use your own executor. This will return the
    /// `Batch` and the background `Worker` that you can then spawn.
    pub fn pair(
        service: T,
        max_items_in_batch: usize,
        max_batches: impl Into<Option<usize>>,
        max_latency: std::time::Duration,
    ) -> (Self, Worker<T, Request>)
    where
        T: Send + 'static,
        T::Error: Send + Sync,
        Request: Send + 'static,
    {
        let (tx, rx) = mpsc::unbounded_channel();

        // Clamp config to sensible values.
        let max_items_in_batch = max(max_items_in_batch, 1);
        let max_batches = max_batches
            .into()
            .unwrap_or_else(rayon::current_num_threads);
        let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);

        // The semaphore bound limits the maximum number of concurrent requests
        // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
        // used their semaphore reservation in a `call` yet).
        //
        // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
        // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
        // Often there is only one verifier running, when that happens we want it to take all the cores.
        let semaphore = Semaphore::new(max_items_in_batch * max_batches_in_queue);
        let semaphore = PollSemaphore::new(Arc::new(semaphore));

        let (error_handle, worker) = Worker::new(
            service,
            rx,
            max_items_in_batch,
            max_batches,
            max_latency,
            semaphore.clone(),
        );

        let batch = Batch {
            tx,
            semaphore,
            permit: None,
            error_handle,
            worker_handle: Arc::new(Mutex::new(None)),
        };

        (batch, worker)
    }

    /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
    ///
    /// Only used when the task is spawned on the tokio runtime.
    pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
        *self
            .worker_handle
            .lock()
            .expect("previous task panicked while holding the worker handle mutex") =
            Some(worker_handle);
    }

    /// Returns the error from the batch worker's `error_handle`.
    fn get_worker_error(&self) -> crate::BoxError {
        self.error_handle.get_error_on_closed()
    }
}

impl<T, Request> Service<Request> for Batch<T, Request>
where
    T: Service<BatchControl<Request>>,
    T::Future: Send + 'static,
    T::Error: Into<crate::BoxError>,
{
    type Response = T::Response;
    type Error = crate::BoxError;
    type Future = ResponseFuture<T::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // Check to see if the worker has returned or panicked.
        //
        // Correctness: Registers this task for wakeup when the worker finishes.
        if let Some(worker_handle) = self
            .worker_handle
            .lock()
            .expect("previous task panicked while holding the worker handle mutex")
            .as_mut()
        {
            match Pin::new(worker_handle).poll(cx) {
                Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
                Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
                    tracing::warn!(
                        "batch task cancelled: {task_cancelled}\n\
                         Is Zebra shutting down?"
                    );

                    return Poll::Ready(Err(task_cancelled.into()));
                }
                Poll::Ready(Err(task_panic)) => {
                    std::panic::resume_unwind(task_panic.into_panic());
                }
                Poll::Pending => {}
            }
        }

        // Check if the worker has set an error and closed its channels.
        //
        // Correctness: Registers this task for wakeup when the channel is closed.
        let tx = self.tx.clone();
        let closed = tx.closed();
        pin!(closed);
        if closed.poll(cx).is_ready() {
            return Poll::Ready(Err(self.get_worker_error()));
        }

        // Poll to acquire a semaphore permit.
        //
        // CORRECTNESS
        //
        // If we acquire a permit, then there's enough buffer capacity to send a new request.
        // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
        // this task for wakeup when the next permit is available, or when the semaphore is closed.
        //
        // When `poll_ready()` is called multiple times, and channel capacity is 1,
        // avoid deadlocks by dropping any previous permit before acquiring another one.
        // This also stops tasks holding a permit after an error.
        //
        // Calling `poll_ready()` multiple times can make tasks lose their previous permit
        // to another concurrent task.
        self.permit = None;

        let permit = ready!(self.semaphore.poll_acquire(cx));
        if let Some(permit) = permit {
            // Calling poll_ready() more than once will drop any previous permit,
            // releasing its capacity back to the semaphore.
            self.permit = Some(permit);
        } else {
            // The semaphore has been closed.
            return Poll::Ready(Err(self.get_worker_error()));
        }

        Poll::Ready(Ok(()))
    }

    fn call(&mut self, request: Request) -> Self::Future {
        tracing::trace!("sending request to buffer worker");
        let _permit = self
            .permit
            .take()
            .expect("poll_ready must be called before a batch request");

        // get the current Span so that we can explicitly propagate it to the worker
        // if we didn't do this, events on the worker related to this span wouldn't be counted
        // towards that span since the worker would have no way of entering it.
        let span = tracing::Span::current();

        // If we've made it here, then a semaphore permit has already been
        // acquired, so we can freely allocate a oneshot.
        let (tx, rx) = oneshot::channel();

        match self.tx.send(Message {
            request,
            tx,
            span,
            _permit,
        }) {
            Err(_) => ResponseFuture::failed(self.get_worker_error()),
            Ok(_) => ResponseFuture::new(rx),
        }
    }
}

impl<T, Request> Clone for Batch<T, Request>
where
    T: Service<BatchControl<Request>>,
{
    fn clone(&self) -> Self {
        Self {
            tx: self.tx.clone(),
            semaphore: self.semaphore.clone(),
            permit: None,
            error_handle: self.error_handle.clone(),
            worker_handle: self.worker_handle.clone(),
        }
    }
}