tower_batch_control/
worker.rs

1//! Batch worker item handling and run loop implementation.
2
3use std::{
4    pin::Pin,
5    sync::{Arc, Mutex},
6};
7
8use futures::{
9    future::{BoxFuture, OptionFuture},
10    stream::FuturesUnordered,
11    FutureExt, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::{
15    sync::mpsc,
16    time::{sleep, Sleep},
17};
18use tokio_util::sync::PollSemaphore;
19use tower::{Service, ServiceExt};
20use tracing_futures::Instrument;
21
22use crate::RequestWeight;
23
24use super::{
25    error::{Closed, ServiceError},
26    message::{self, Message},
27    BatchControl,
28};
29
30/// Task that handles processing the buffer. This type should not be used
31/// directly, instead `Buffer` requires an `Executor` that can accept this task.
32///
33/// The struct is `pub` in the private module and the type is *not* re-exported
34/// as part of the public API. This is the "sealed" pattern to include "private"
35/// types in public traits that are not meant for consumers of the library to
36/// implement (only call).
37#[pin_project(PinnedDrop)]
38#[derive(Debug)]
39pub struct Worker<T, Request: RequestWeight>
40where
41    T: Service<BatchControl<Request>>,
42    T::Future: Send + 'static,
43    T::Error: Into<crate::BoxError>,
44{
45    // Batch management
46    //
47    /// A semaphore-bounded channel for receiving requests from the batch wrapper service.
48    rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
49
50    /// The wrapped service that processes batches.
51    service: T,
52
53    /// The total weight of pending requests sent to `service`, since the last batch flush.
54    pending_items_weight: usize,
55
56    /// The timer for the pending batch, if it has any items.
57    ///
58    /// The timer is started when the first entry of a new batch is
59    /// submitted, so that the batch latency of all entries is at most
60    /// self.max_latency. However, we don't keep the timer running unless
61    /// there is a pending request to prevent wakeups on idle services.
62    pending_batch_timer: Option<Pin<Box<Sleep>>>,
63
64    /// The batches that the worker is concurrently executing.
65    concurrent_batches: FuturesUnordered<BoxFuture<'static, Result<T::Response, T::Error>>>,
66
67    // Errors and termination
68    //
69    /// An error that's populated on permanent service failure.
70    failed: Option<ServiceError>,
71
72    /// A shared error handle that's populated on permanent service failure.
73    error_handle: ErrorHandle,
74
75    /// A cloned copy of the wrapper service's semaphore, used to close the semaphore.
76    close: PollSemaphore,
77
78    // Config
79    //
80    /// The maximum weight of pending items in a batch before it should be flushed and
81    /// pending items should be added to a new batch.
82    max_items_weight_in_batch: usize,
83
84    /// The maximum number of batches that are allowed to run concurrently.
85    max_concurrent_batches: usize,
86
87    /// The maximum delay before processing a batch with items that have a total weight
88    /// that is less than `max_items_weight_in_batch`.
89    max_latency: std::time::Duration,
90}
91
92/// Get the error out
93#[derive(Debug)]
94pub(crate) struct ErrorHandle {
95    inner: Arc<Mutex<Option<ServiceError>>>,
96}
97
98impl<T, Request: RequestWeight> Worker<T, Request>
99where
100    T: Service<BatchControl<Request>>,
101    T::Future: Send + 'static,
102    T::Error: Into<crate::BoxError>,
103{
104    /// Creates a new batch worker.
105    ///
106    /// See [`Batch::new()`](crate::Batch::new) for details.
107    pub(crate) fn new(
108        service: T,
109        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
110        max_items_weight_in_batch: usize,
111        max_concurrent_batches: usize,
112        max_latency: std::time::Duration,
113        close: PollSemaphore,
114    ) -> (ErrorHandle, Worker<T, Request>) {
115        let error_handle = ErrorHandle {
116            inner: Arc::new(Mutex::new(None)),
117        };
118
119        let worker = Worker {
120            rx,
121            service,
122            pending_items_weight: 0,
123            pending_batch_timer: None,
124            concurrent_batches: FuturesUnordered::new(),
125            failed: None,
126            error_handle: error_handle.clone(),
127            close,
128            max_items_weight_in_batch,
129            max_concurrent_batches,
130            max_latency,
131        };
132
133        (error_handle, worker)
134    }
135
136    /// Process a single worker request.
137    async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
138        if let Some(ref error) = self.failed {
139            tracing::trace!(
140                ?error,
141                "notifying batch request caller about worker failure",
142            );
143            let _ = tx.send(Err(error.clone()));
144            return;
145        }
146
147        match self.service.ready().await {
148            Ok(svc) => {
149                self.pending_items_weight += req.request_weight();
150                let rsp = svc.call(req.into());
151                let _ = tx.send(Ok(rsp));
152            }
153            Err(e) => {
154                self.failed(e.into());
155                let _ = tx.send(Err(self
156                    .failed
157                    .as_ref()
158                    .expect("Worker::failed did not set self.failed?")
159                    .clone()));
160            }
161        }
162    }
163
164    /// Tell the inner service to flush the current batch.
165    ///
166    /// Waits until the inner service is ready,
167    /// then stores a future which resolves when the batch finishes.
168    async fn flush_service(&mut self) {
169        if self.failed.is_some() {
170            tracing::trace!("worker failure: skipping flush");
171            return;
172        }
173
174        match self.service.ready().await {
175            Ok(ready_service) => {
176                let flush_future = ready_service.call(BatchControl::Flush);
177                self.concurrent_batches.push(flush_future.boxed());
178
179                // Now we have an empty batch.
180                self.pending_items_weight = 0;
181                self.pending_batch_timer = None;
182            }
183            Err(error) => {
184                self.failed(error.into());
185            }
186        }
187    }
188
189    /// Is the current number of concurrent batches above the configured limit?
190    fn can_spawn_new_batches(&self) -> bool {
191        self.concurrent_batches.len() < self.max_concurrent_batches
192    }
193
194    /// Run loop for batch requests, which implements the batch policies.
195    ///
196    /// See [`Batch::new()`](crate::Batch::new) for details.
197    pub async fn run(mut self) {
198        loop {
199            // Wait on either a new message or the batch timer.
200            //
201            // If both are ready, end the batch now, because the timer has elapsed.
202            // If the timer elapses, any pending messages are preserved:
203            // https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety
204            tokio::select! {
205                biased;
206
207                batch_result = self.concurrent_batches.next(), if !self.concurrent_batches.is_empty() => match batch_result.expect("only returns None when empty") {
208                    Ok(_response) => {
209                        tracing::trace!(
210                            pending_items_weight = self.pending_items_weight,
211                            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
212                            running_batches = self.concurrent_batches.len(),
213                            "batch finished executing",
214                        );
215                    }
216                    Err(error) => {
217                        let error = error.into();
218                        tracing::trace!(?error, "batch execution failed");
219                        self.failed(error);
220                    }
221                },
222
223                Some(()) = OptionFuture::from(self.pending_batch_timer.as_mut()), if self.pending_batch_timer.as_ref().is_some() => {
224                    tracing::trace!(
225                        pending_items_weight = self.pending_items_weight,
226                        batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
227                        running_batches = self.concurrent_batches.len(),
228                        "batch timer expired",
229                    );
230
231                    // TODO: use a batch-specific span to instrument this future.
232                    self.flush_service().await;
233                },
234
235                maybe_msg = self.rx.recv(), if self.can_spawn_new_batches() => match maybe_msg {
236                    Some(msg) => {
237                        tracing::trace!(
238                            pending_items_weight = self.pending_items_weight,
239                            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
240                            running_batches = self.concurrent_batches.len(),
241                            "batch message received",
242                        );
243
244                        let span = msg.span;
245                        let is_new_batch = self.pending_items_weight == 0;
246
247                        self.process_req(msg.request, msg.tx)
248                            // Apply the provided span to request processing.
249                            .instrument(span)
250                            .await;
251
252                        // Check whether we have too many pending items.
253                        if self.pending_items_weight >= self.max_items_weight_in_batch {
254                            tracing::trace!(
255                                pending_items_weight = self.pending_items_weight,
256                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
257                                running_batches = self.concurrent_batches.len(),
258                                "batch is full",
259                            );
260
261                            // TODO: use a batch-specific span to instrument this future.
262                            self.flush_service().await;
263                        } else if is_new_batch {
264                            tracing::trace!(
265                                pending_items_weight = self.pending_items_weight,
266                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
267                                running_batches = self.concurrent_batches.len(),
268                                "batch is new, starting timer",
269                            );
270
271                            // The first message in a new batch.
272                            self.pending_batch_timer = Some(Box::pin(sleep(self.max_latency)));
273                        } else {
274                            tracing::trace!(
275                                pending_items_weight = self.pending_items_weight,
276                                batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
277                                running_batches = self.concurrent_batches.len(),
278                                "waiting for full batch or batch timer",
279                            );
280                        }
281                    }
282                    None => {
283                        tracing::trace!("batch channel closed and emptied, exiting worker task");
284
285                        return;
286                    }
287                },
288            }
289        }
290    }
291
292    /// Register an inner service failure.
293    ///
294    /// The underlying service failed when we called `poll_ready` on it with the given `error`. We
295    /// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
296    /// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
297    /// requests will also fail with the same error.
298    fn failed(&mut self, error: crate::BoxError) {
299        tracing::debug!(?error, "batch worker error");
300
301        // Note that we need to handle the case where some error_handle is concurrently trying to send us
302        // a request. We need to make sure that *either* the send of the request fails *or* it
303        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
304        // case where we send errors to all outstanding requests, and *then* the caller sends its
305        // request. We do this by *first* exposing the error, *then* closing the channel used to
306        // send more requests (so the client will see the error when the send fails), and *then*
307        // sending the error to all outstanding requests.
308        let error = ServiceError::new(error);
309
310        let mut inner = self.error_handle.inner.lock().unwrap();
311
312        // Ignore duplicate failures
313        if inner.is_some() {
314            return;
315        }
316
317        *inner = Some(error.clone());
318        drop(inner);
319
320        tracing::trace!(
321            ?error,
322            "worker failure: waking pending requests so they can be failed",
323        );
324        self.rx.close();
325        self.close.close();
326
327        // We don't schedule any batches on an errored service
328        self.pending_batch_timer = None;
329
330        // By closing the mpsc::Receiver, we know that the run() loop will
331        // drain all pending requests. We just need to make sure that any
332        // requests that we receive before we've exhausted the receiver receive
333        // the error:
334        self.failed = Some(error);
335    }
336}
337
338impl ErrorHandle {
339    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
340        self.inner
341            .lock()
342            .expect("previous task panicked while holding the error handle mutex")
343            .as_ref()
344            .map(|svc_err| svc_err.clone().into())
345            .unwrap_or_else(|| Closed::new().into())
346    }
347}
348
349impl Clone for ErrorHandle {
350    fn clone(&self) -> ErrorHandle {
351        ErrorHandle {
352            inner: self.inner.clone(),
353        }
354    }
355}
356
357#[pin_project::pinned_drop]
358impl<T, Request: RequestWeight> PinnedDrop for Worker<T, Request>
359where
360    T: Service<BatchControl<Request>>,
361    T::Future: Send + 'static,
362    T::Error: Into<crate::BoxError>,
363{
364    fn drop(mut self: Pin<&mut Self>) {
365        tracing::trace!(
366            pending_items_weight = self.pending_items_weight,
367            batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
368            running_batches = self.concurrent_batches.len(),
369            error = ?self.failed,
370            "dropping batch worker",
371        );
372
373        // Fail pending tasks
374        self.failed(Closed::new().into());
375
376        // Fail queued requests
377        while let Ok(msg) = self.rx.try_recv() {
378            let _ = msg
379                .tx
380                .send(Err(self.failed.as_ref().expect("just set failed").clone()));
381        }
382
383        // Clear any finished batches, ignoring any errors.
384        // Ignore any batches that are still executing, because we can't cancel them.
385        //
386        // now_or_never() can stop futures waking up, but that's ok here,
387        // because we're manually polling, then dropping the stream.
388        while let Some(Some(_)) = self
389            .as_mut()
390            .project()
391            .concurrent_batches
392            .next()
393            .now_or_never()
394        {}
395    }
396}