tower_batch_control/
worker.rs

1//! Batch worker item handling and run loop implementation.
2
3use std::{
4    pin::Pin,
5    sync::{Arc, Mutex},
6};
7
8use futures::{
9    future::{BoxFuture, OptionFuture},
10    stream::FuturesUnordered,
11    FutureExt, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::{
15    sync::mpsc,
16    time::{sleep, Sleep},
17};
18use tokio_util::sync::PollSemaphore;
19use tower::{Service, ServiceExt};
20use tracing_futures::Instrument;
21
22use super::{
23    error::{Closed, ServiceError},
24    message::{self, Message},
25    BatchControl,
26};
27
28/// Task that handles processing the buffer. This type should not be used
29/// directly, instead `Buffer` requires an `Executor` that can accept this task.
30///
31/// The struct is `pub` in the private module and the type is *not* re-exported
32/// as part of the public API. This is the "sealed" pattern to include "private"
33/// types in public traits that are not meant for consumers of the library to
34/// implement (only call).
35#[pin_project(PinnedDrop)]
36#[derive(Debug)]
37pub struct Worker<T, Request>
38where
39    T: Service<BatchControl<Request>>,
40    T::Future: Send + 'static,
41    T::Error: Into<crate::BoxError>,
42{
43    // Batch management
44    //
45    /// A semaphore-bounded channel for receiving requests from the batch wrapper service.
46    rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
47
48    /// The wrapped service that processes batches.
49    service: T,
50
51    /// The number of pending items sent to `service`, since the last batch flush.
52    pending_items: usize,
53
54    /// The timer for the pending batch, if it has any items.
55    ///
56    /// The timer is started when the first entry of a new batch is
57    /// submitted, so that the batch latency of all entries is at most
58    /// self.max_latency. However, we don't keep the timer running unless
59    /// there is a pending request to prevent wakeups on idle services.
60    pending_batch_timer: Option<Pin<Box<Sleep>>>,
61
62    /// The batches that the worker is concurrently executing.
63    concurrent_batches: FuturesUnordered<BoxFuture<'static, Result<T::Response, T::Error>>>,
64
65    // Errors and termination
66    //
67    /// An error that's populated on permanent service failure.
68    failed: Option<ServiceError>,
69
70    /// A shared error handle that's populated on permanent service failure.
71    error_handle: ErrorHandle,
72
73    /// A cloned copy of the wrapper service's semaphore, used to close the semaphore.
74    close: PollSemaphore,
75
76    // Config
77    //
78    /// The maximum number of items allowed in a batch.
79    max_items_in_batch: usize,
80
81    /// The maximum number of batches that are allowed to run concurrently.
82    max_concurrent_batches: usize,
83
84    /// The maximum delay before processing a batch with fewer than `max_items_in_batch`.
85    max_latency: std::time::Duration,
86}
87
88/// Get the error out
89#[derive(Debug)]
90pub(crate) struct ErrorHandle {
91    inner: Arc<Mutex<Option<ServiceError>>>,
92}
93
94impl<T, Request> Worker<T, Request>
95where
96    T: Service<BatchControl<Request>>,
97    T::Future: Send + 'static,
98    T::Error: Into<crate::BoxError>,
99{
100    /// Creates a new batch worker.
101    ///
102    /// See [`Batch::new()`](crate::Batch::new) for details.
103    pub(crate) fn new(
104        service: T,
105        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
106        max_items_in_batch: usize,
107        max_concurrent_batches: usize,
108        max_latency: std::time::Duration,
109        close: PollSemaphore,
110    ) -> (ErrorHandle, Worker<T, Request>) {
111        let error_handle = ErrorHandle {
112            inner: Arc::new(Mutex::new(None)),
113        };
114
115        let worker = Worker {
116            rx,
117            service,
118            pending_items: 0,
119            pending_batch_timer: None,
120            concurrent_batches: FuturesUnordered::new(),
121            failed: None,
122            error_handle: error_handle.clone(),
123            close,
124            max_items_in_batch,
125            max_concurrent_batches,
126            max_latency,
127        };
128
129        (error_handle, worker)
130    }
131
132    /// Process a single worker request.
133    async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
134        if let Some(ref error) = self.failed {
135            tracing::trace!(
136                ?error,
137                "notifying batch request caller about worker failure",
138            );
139            let _ = tx.send(Err(error.clone()));
140            return;
141        }
142
143        match self.service.ready().await {
144            Ok(svc) => {
145                let rsp = svc.call(req.into());
146                let _ = tx.send(Ok(rsp));
147
148                self.pending_items += 1;
149            }
150            Err(e) => {
151                self.failed(e.into());
152                let _ = tx.send(Err(self
153                    .failed
154                    .as_ref()
155                    .expect("Worker::failed did not set self.failed?")
156                    .clone()));
157            }
158        }
159    }
160
161    /// Tell the inner service to flush the current batch.
162    ///
163    /// Waits until the inner service is ready,
164    /// then stores a future which resolves when the batch finishes.
165    async fn flush_service(&mut self) {
166        if self.failed.is_some() {
167            tracing::trace!("worker failure: skipping flush");
168            return;
169        }
170
171        match self.service.ready().await {
172            Ok(ready_service) => {
173                let flush_future = ready_service.call(BatchControl::Flush);
174                self.concurrent_batches.push(flush_future.boxed());
175
176                // Now we have an empty batch.
177                self.pending_items = 0;
178                self.pending_batch_timer = None;
179            }
180            Err(error) => {
181                self.failed(error.into());
182            }
183        }
184    }
185
186    /// Is the current number of concurrent batches above the configured limit?
187    fn can_spawn_new_batches(&self) -> bool {
188        self.concurrent_batches.len() < self.max_concurrent_batches
189    }
190
191    /// Run loop for batch requests, which implements the batch policies.
192    ///
193    /// See [`Batch::new()`](crate::Batch::new) for details.
194    pub async fn run(mut self) {
195        loop {
196            // Wait on either a new message or the batch timer.
197            //
198            // If both are ready, end the batch now, because the timer has elapsed.
199            // If the timer elapses, any pending messages are preserved:
200            // https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety
201            tokio::select! {
202                biased;
203
204                batch_result = self.concurrent_batches.next(), if !self.concurrent_batches.is_empty() => match batch_result.expect("only returns None when empty") {
205                    Ok(_response) => {
206                        tracing::trace!(
207                            pending_items = self.pending_items,
208                            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
209                            running_batches = self.concurrent_batches.len(),
210                            "batch finished executing",
211                        );
212                    }
213                    Err(error) => {
214                        let error = error.into();
215                        tracing::trace!(?error, "batch execution failed");
216                        self.failed(error);
217                    }
218                },
219
220                Some(()) = OptionFuture::from(self.pending_batch_timer.as_mut()), if self.pending_batch_timer.as_ref().is_some() => {
221                    tracing::trace!(
222                        pending_items = self.pending_items,
223                        batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
224                        running_batches = self.concurrent_batches.len(),
225                        "batch timer expired",
226                    );
227
228                    // TODO: use a batch-specific span to instrument this future.
229                    self.flush_service().await;
230                },
231
232                maybe_msg = self.rx.recv(), if self.can_spawn_new_batches() => match maybe_msg {
233                    Some(msg) => {
234                        tracing::trace!(
235                            pending_items = self.pending_items,
236                            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
237                            running_batches = self.concurrent_batches.len(),
238                            "batch message received",
239                        );
240
241                        let span = msg.span;
242
243                        self.process_req(msg.request, msg.tx)
244                            // Apply the provided span to request processing.
245                            .instrument(span)
246                            .await;
247
248                        // Check whether we have too many pending items.
249                        if self.pending_items >= self.max_items_in_batch {
250                            tracing::trace!(
251                                pending_items = self.pending_items,
252                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
253                                running_batches = self.concurrent_batches.len(),
254                                "batch is full",
255                            );
256
257                            // TODO: use a batch-specific span to instrument this future.
258                            self.flush_service().await;
259                        } else if self.pending_items == 1 {
260                            tracing::trace!(
261                                pending_items = self.pending_items,
262                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
263                                running_batches = self.concurrent_batches.len(),
264                                "batch is new, starting timer",
265                            );
266
267                            // The first message in a new batch.
268                            self.pending_batch_timer = Some(Box::pin(sleep(self.max_latency)));
269                        } else {
270                            tracing::trace!(
271                                pending_items = self.pending_items,
272                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
273                                running_batches = self.concurrent_batches.len(),
274                                "waiting for full batch or batch timer",
275                            );
276                        }
277                    }
278                    None => {
279                        tracing::trace!("batch channel closed and emptied, exiting worker task");
280
281                        return;
282                    }
283                },
284            }
285        }
286    }
287
288    /// Register an inner service failure.
289    ///
290    /// The underlying service failed when we called `poll_ready` on it with the given `error`. We
291    /// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
292    /// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
293    /// requests will also fail with the same error.
294    fn failed(&mut self, error: crate::BoxError) {
295        tracing::debug!(?error, "batch worker error");
296
297        // Note that we need to handle the case where some error_handle is concurrently trying to send us
298        // a request. We need to make sure that *either* the send of the request fails *or* it
299        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
300        // case where we send errors to all outstanding requests, and *then* the caller sends its
301        // request. We do this by *first* exposing the error, *then* closing the channel used to
302        // send more requests (so the client will see the error when the send fails), and *then*
303        // sending the error to all outstanding requests.
304        let error = ServiceError::new(error);
305
306        let mut inner = self.error_handle.inner.lock().unwrap();
307
308        // Ignore duplicate failures
309        if inner.is_some() {
310            return;
311        }
312
313        *inner = Some(error.clone());
314        drop(inner);
315
316        tracing::trace!(
317            ?error,
318            "worker failure: waking pending requests so they can be failed",
319        );
320        self.rx.close();
321        self.close.close();
322
323        // We don't schedule any batches on an errored service
324        self.pending_batch_timer = None;
325
326        // By closing the mpsc::Receiver, we know that the run() loop will
327        // drain all pending requests. We just need to make sure that any
328        // requests that we receive before we've exhausted the receiver receive
329        // the error:
330        self.failed = Some(error);
331    }
332}
333
334impl ErrorHandle {
335    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
336        self.inner
337            .lock()
338            .expect("previous task panicked while holding the error handle mutex")
339            .as_ref()
340            .map(|svc_err| svc_err.clone().into())
341            .unwrap_or_else(|| Closed::new().into())
342    }
343}
344
345impl Clone for ErrorHandle {
346    fn clone(&self) -> ErrorHandle {
347        ErrorHandle {
348            inner: self.inner.clone(),
349        }
350    }
351}
352
353#[pin_project::pinned_drop]
354impl<T, Request> PinnedDrop for Worker<T, Request>
355where
356    T: Service<BatchControl<Request>>,
357    T::Future: Send + 'static,
358    T::Error: Into<crate::BoxError>,
359{
360    fn drop(mut self: Pin<&mut Self>) {
361        tracing::trace!(
362            pending_items = self.pending_items,
363            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
364            running_batches = self.concurrent_batches.len(),
365            error = ?self.failed,
366            "dropping batch worker",
367        );
368
369        // Fail pending tasks
370        self.failed(Closed::new().into());
371
372        // Fail queued requests
373        while let Ok(msg) = self.rx.try_recv() {
374            let _ = msg
375                .tx
376                .send(Err(self.failed.as_ref().expect("just set failed").clone()));
377        }
378
379        // Clear any finished batches, ignoring any errors.
380        // Ignore any batches that are still executing, because we can't cancel them.
381        //
382        // now_or_never() can stop futures waking up, but that's ok here,
383        // because we're manually polling, then dropping the stream.
384        while let Some(Some(_)) = self
385            .as_mut()
386            .project()
387            .concurrent_batches
388            .next()
389            .now_or_never()
390        {}
391    }
392}