tower_batch_control/worker.rs
1//! Batch worker item handling and run loop implementation.
2
3use std::{
4 pin::Pin,
5 sync::{Arc, Mutex},
6};
7
8use futures::{
9 future::{BoxFuture, OptionFuture},
10 stream::FuturesUnordered,
11 FutureExt, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::{
15 sync::mpsc,
16 time::{sleep, Sleep},
17};
18use tokio_util::sync::PollSemaphore;
19use tower::{Service, ServiceExt};
20use tracing_futures::Instrument;
21
22use super::{
23 error::{Closed, ServiceError},
24 message::{self, Message},
25 BatchControl,
26};
27
28/// Task that handles processing the buffer. This type should not be used
29/// directly, instead `Buffer` requires an `Executor` that can accept this task.
30///
31/// The struct is `pub` in the private module and the type is *not* re-exported
32/// as part of the public API. This is the "sealed" pattern to include "private"
33/// types in public traits that are not meant for consumers of the library to
34/// implement (only call).
35#[pin_project(PinnedDrop)]
36#[derive(Debug)]
37pub struct Worker<T, Request>
38where
39 T: Service<BatchControl<Request>>,
40 T::Future: Send + 'static,
41 T::Error: Into<crate::BoxError>,
42{
43 // Batch management
44 //
45 /// A semaphore-bounded channel for receiving requests from the batch wrapper service.
46 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
47
48 /// The wrapped service that processes batches.
49 service: T,
50
51 /// The number of pending items sent to `service`, since the last batch flush.
52 pending_items: usize,
53
54 /// The timer for the pending batch, if it has any items.
55 ///
56 /// The timer is started when the first entry of a new batch is
57 /// submitted, so that the batch latency of all entries is at most
58 /// self.max_latency. However, we don't keep the timer running unless
59 /// there is a pending request to prevent wakeups on idle services.
60 pending_batch_timer: Option<Pin<Box<Sleep>>>,
61
62 /// The batches that the worker is concurrently executing.
63 concurrent_batches: FuturesUnordered<BoxFuture<'static, Result<T::Response, T::Error>>>,
64
65 // Errors and termination
66 //
67 /// An error that's populated on permanent service failure.
68 failed: Option<ServiceError>,
69
70 /// A shared error handle that's populated on permanent service failure.
71 error_handle: ErrorHandle,
72
73 /// A cloned copy of the wrapper service's semaphore, used to close the semaphore.
74 close: PollSemaphore,
75
76 // Config
77 //
78 /// The maximum number of items allowed in a batch.
79 max_items_in_batch: usize,
80
81 /// The maximum number of batches that are allowed to run concurrently.
82 max_concurrent_batches: usize,
83
84 /// The maximum delay before processing a batch with fewer than `max_items_in_batch`.
85 max_latency: std::time::Duration,
86}
87
88/// Get the error out
89#[derive(Debug)]
90pub(crate) struct ErrorHandle {
91 inner: Arc<Mutex<Option<ServiceError>>>,
92}
93
94impl<T, Request> Worker<T, Request>
95where
96 T: Service<BatchControl<Request>>,
97 T::Future: Send + 'static,
98 T::Error: Into<crate::BoxError>,
99{
100 /// Creates a new batch worker.
101 ///
102 /// See [`Batch::new()`](crate::Batch::new) for details.
103 pub(crate) fn new(
104 service: T,
105 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
106 max_items_in_batch: usize,
107 max_concurrent_batches: usize,
108 max_latency: std::time::Duration,
109 close: PollSemaphore,
110 ) -> (ErrorHandle, Worker<T, Request>) {
111 let error_handle = ErrorHandle {
112 inner: Arc::new(Mutex::new(None)),
113 };
114
115 let worker = Worker {
116 rx,
117 service,
118 pending_items: 0,
119 pending_batch_timer: None,
120 concurrent_batches: FuturesUnordered::new(),
121 failed: None,
122 error_handle: error_handle.clone(),
123 close,
124 max_items_in_batch,
125 max_concurrent_batches,
126 max_latency,
127 };
128
129 (error_handle, worker)
130 }
131
132 /// Process a single worker request.
133 async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
134 if let Some(ref error) = self.failed {
135 tracing::trace!(
136 ?error,
137 "notifying batch request caller about worker failure",
138 );
139 let _ = tx.send(Err(error.clone()));
140 return;
141 }
142
143 match self.service.ready().await {
144 Ok(svc) => {
145 let rsp = svc.call(req.into());
146 let _ = tx.send(Ok(rsp));
147
148 self.pending_items += 1;
149 }
150 Err(e) => {
151 self.failed(e.into());
152 let _ = tx.send(Err(self
153 .failed
154 .as_ref()
155 .expect("Worker::failed did not set self.failed?")
156 .clone()));
157 }
158 }
159 }
160
161 /// Tell the inner service to flush the current batch.
162 ///
163 /// Waits until the inner service is ready,
164 /// then stores a future which resolves when the batch finishes.
165 async fn flush_service(&mut self) {
166 if self.failed.is_some() {
167 tracing::trace!("worker failure: skipping flush");
168 return;
169 }
170
171 match self.service.ready().await {
172 Ok(ready_service) => {
173 let flush_future = ready_service.call(BatchControl::Flush);
174 self.concurrent_batches.push(flush_future.boxed());
175
176 // Now we have an empty batch.
177 self.pending_items = 0;
178 self.pending_batch_timer = None;
179 }
180 Err(error) => {
181 self.failed(error.into());
182 }
183 }
184 }
185
186 /// Is the current number of concurrent batches above the configured limit?
187 fn can_spawn_new_batches(&self) -> bool {
188 self.concurrent_batches.len() < self.max_concurrent_batches
189 }
190
191 /// Run loop for batch requests, which implements the batch policies.
192 ///
193 /// See [`Batch::new()`](crate::Batch::new) for details.
194 pub async fn run(mut self) {
195 loop {
196 // Wait on either a new message or the batch timer.
197 //
198 // If both are ready, end the batch now, because the timer has elapsed.
199 // If the timer elapses, any pending messages are preserved:
200 // https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety
201 tokio::select! {
202 biased;
203
204 batch_result = self.concurrent_batches.next(), if !self.concurrent_batches.is_empty() => match batch_result.expect("only returns None when empty") {
205 Ok(_response) => {
206 tracing::trace!(
207 pending_items = self.pending_items,
208 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
209 running_batches = self.concurrent_batches.len(),
210 "batch finished executing",
211 );
212 }
213 Err(error) => {
214 let error = error.into();
215 tracing::trace!(?error, "batch execution failed");
216 self.failed(error);
217 }
218 },
219
220 Some(()) = OptionFuture::from(self.pending_batch_timer.as_mut()), if self.pending_batch_timer.as_ref().is_some() => {
221 tracing::trace!(
222 pending_items = self.pending_items,
223 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
224 running_batches = self.concurrent_batches.len(),
225 "batch timer expired",
226 );
227
228 // TODO: use a batch-specific span to instrument this future.
229 self.flush_service().await;
230 },
231
232 maybe_msg = self.rx.recv(), if self.can_spawn_new_batches() => match maybe_msg {
233 Some(msg) => {
234 tracing::trace!(
235 pending_items = self.pending_items,
236 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
237 running_batches = self.concurrent_batches.len(),
238 "batch message received",
239 );
240
241 let span = msg.span;
242
243 self.process_req(msg.request, msg.tx)
244 // Apply the provided span to request processing.
245 .instrument(span)
246 .await;
247
248 // Check whether we have too many pending items.
249 if self.pending_items >= self.max_items_in_batch {
250 tracing::trace!(
251 pending_items = self.pending_items,
252 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
253 running_batches = self.concurrent_batches.len(),
254 "batch is full",
255 );
256
257 // TODO: use a batch-specific span to instrument this future.
258 self.flush_service().await;
259 } else if self.pending_items == 1 {
260 tracing::trace!(
261 pending_items = self.pending_items,
262 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
263 running_batches = self.concurrent_batches.len(),
264 "batch is new, starting timer",
265 );
266
267 // The first message in a new batch.
268 self.pending_batch_timer = Some(Box::pin(sleep(self.max_latency)));
269 } else {
270 tracing::trace!(
271 pending_items = self.pending_items,
272 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
273 running_batches = self.concurrent_batches.len(),
274 "waiting for full batch or batch timer",
275 );
276 }
277 }
278 None => {
279 tracing::trace!("batch channel closed and emptied, exiting worker task");
280
281 return;
282 }
283 },
284 }
285 }
286 }
287
288 /// Register an inner service failure.
289 ///
290 /// The underlying service failed when we called `poll_ready` on it with the given `error`. We
291 /// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
292 /// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
293 /// requests will also fail with the same error.
294 fn failed(&mut self, error: crate::BoxError) {
295 tracing::debug!(?error, "batch worker error");
296
297 // Note that we need to handle the case where some error_handle is concurrently trying to send us
298 // a request. We need to make sure that *either* the send of the request fails *or* it
299 // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
300 // case where we send errors to all outstanding requests, and *then* the caller sends its
301 // request. We do this by *first* exposing the error, *then* closing the channel used to
302 // send more requests (so the client will see the error when the send fails), and *then*
303 // sending the error to all outstanding requests.
304 let error = ServiceError::new(error);
305
306 let mut inner = self.error_handle.inner.lock().unwrap();
307
308 // Ignore duplicate failures
309 if inner.is_some() {
310 return;
311 }
312
313 *inner = Some(error.clone());
314 drop(inner);
315
316 tracing::trace!(
317 ?error,
318 "worker failure: waking pending requests so they can be failed",
319 );
320 self.rx.close();
321 self.close.close();
322
323 // We don't schedule any batches on an errored service
324 self.pending_batch_timer = None;
325
326 // By closing the mpsc::Receiver, we know that the run() loop will
327 // drain all pending requests. We just need to make sure that any
328 // requests that we receive before we've exhausted the receiver receive
329 // the error:
330 self.failed = Some(error);
331 }
332}
333
334impl ErrorHandle {
335 pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
336 self.inner
337 .lock()
338 .expect("previous task panicked while holding the error handle mutex")
339 .as_ref()
340 .map(|svc_err| svc_err.clone().into())
341 .unwrap_or_else(|| Closed::new().into())
342 }
343}
344
345impl Clone for ErrorHandle {
346 fn clone(&self) -> ErrorHandle {
347 ErrorHandle {
348 inner: self.inner.clone(),
349 }
350 }
351}
352
353#[pin_project::pinned_drop]
354impl<T, Request> PinnedDrop for Worker<T, Request>
355where
356 T: Service<BatchControl<Request>>,
357 T::Future: Send + 'static,
358 T::Error: Into<crate::BoxError>,
359{
360 fn drop(mut self: Pin<&mut Self>) {
361 tracing::trace!(
362 pending_items = self.pending_items,
363 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
364 running_batches = self.concurrent_batches.len(),
365 error = ?self.failed,
366 "dropping batch worker",
367 );
368
369 // Fail pending tasks
370 self.failed(Closed::new().into());
371
372 // Fail queued requests
373 while let Ok(msg) = self.rx.try_recv() {
374 let _ = msg
375 .tx
376 .send(Err(self.failed.as_ref().expect("just set failed").clone()));
377 }
378
379 // Clear any finished batches, ignoring any errors.
380 // Ignore any batches that are still executing, because we can't cancel them.
381 //
382 // now_or_never() can stop futures waking up, but that's ok here,
383 // because we're manually polling, then dropping the stream.
384 while let Some(Some(_)) = self
385 .as_mut()
386 .project()
387 .concurrent_batches
388 .next()
389 .now_or_never()
390 {}
391 }
392}