tower_batch_control/worker.rs
1//! Batch worker item handling and run loop implementation.
2
3use std::{
4 pin::Pin,
5 sync::{Arc, Mutex},
6};
7
8use futures::{
9 future::{BoxFuture, OptionFuture},
10 stream::FuturesUnordered,
11 FutureExt, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::{
15 sync::mpsc,
16 time::{sleep, Sleep},
17};
18use tokio_util::sync::PollSemaphore;
19use tower::{Service, ServiceExt};
20use tracing_futures::Instrument;
21
22use crate::RequestWeight;
23
24use super::{
25 error::{Closed, ServiceError},
26 message::{self, Message},
27 BatchControl,
28};
29
30/// Task that handles processing the buffer. This type should not be used
31/// directly, instead `Buffer` requires an `Executor` that can accept this task.
32///
33/// The struct is `pub` in the private module and the type is *not* re-exported
34/// as part of the public API. This is the "sealed" pattern to include "private"
35/// types in public traits that are not meant for consumers of the library to
36/// implement (only call).
37#[pin_project(PinnedDrop)]
38#[derive(Debug)]
39pub struct Worker<T, Request: RequestWeight>
40where
41 T: Service<BatchControl<Request>>,
42 T::Future: Send + 'static,
43 T::Error: Into<crate::BoxError>,
44{
45 // Batch management
46 //
47 /// A semaphore-bounded channel for receiving requests from the batch wrapper service.
48 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
49
50 /// The wrapped service that processes batches.
51 service: T,
52
53 /// The total weight of pending requests sent to `service`, since the last batch flush.
54 pending_items_weight: usize,
55
56 /// The timer for the pending batch, if it has any items.
57 ///
58 /// The timer is started when the first entry of a new batch is
59 /// submitted, so that the batch latency of all entries is at most
60 /// self.max_latency. However, we don't keep the timer running unless
61 /// there is a pending request to prevent wakeups on idle services.
62 pending_batch_timer: Option<Pin<Box<Sleep>>>,
63
64 /// The batches that the worker is concurrently executing.
65 concurrent_batches: FuturesUnordered<BoxFuture<'static, Result<T::Response, T::Error>>>,
66
67 // Errors and termination
68 //
69 /// An error that's populated on permanent service failure.
70 failed: Option<ServiceError>,
71
72 /// A shared error handle that's populated on permanent service failure.
73 error_handle: ErrorHandle,
74
75 /// A cloned copy of the wrapper service's semaphore, used to close the semaphore.
76 close: PollSemaphore,
77
78 // Config
79 //
80 /// The maximum weight of pending items in a batch before it should be flushed and
81 /// pending items should be added to a new batch.
82 max_items_weight_in_batch: usize,
83
84 /// The maximum number of batches that are allowed to run concurrently.
85 max_concurrent_batches: usize,
86
87 /// The maximum delay before processing a batch with items that have a total weight
88 /// that is less than `max_items_weight_in_batch`.
89 max_latency: std::time::Duration,
90}
91
92/// Get the error out
93#[derive(Debug)]
94pub(crate) struct ErrorHandle {
95 inner: Arc<Mutex<Option<ServiceError>>>,
96}
97
98impl<T, Request: RequestWeight> Worker<T, Request>
99where
100 T: Service<BatchControl<Request>>,
101 T::Future: Send + 'static,
102 T::Error: Into<crate::BoxError>,
103{
104 /// Creates a new batch worker.
105 ///
106 /// See [`Batch::new()`](crate::Batch::new) for details.
107 pub(crate) fn new(
108 service: T,
109 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
110 max_items_weight_in_batch: usize,
111 max_concurrent_batches: usize,
112 max_latency: std::time::Duration,
113 close: PollSemaphore,
114 ) -> (ErrorHandle, Worker<T, Request>) {
115 let error_handle = ErrorHandle {
116 inner: Arc::new(Mutex::new(None)),
117 };
118
119 let worker = Worker {
120 rx,
121 service,
122 pending_items_weight: 0,
123 pending_batch_timer: None,
124 concurrent_batches: FuturesUnordered::new(),
125 failed: None,
126 error_handle: error_handle.clone(),
127 close,
128 max_items_weight_in_batch,
129 max_concurrent_batches,
130 max_latency,
131 };
132
133 (error_handle, worker)
134 }
135
136 /// Process a single worker request.
137 async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
138 if let Some(ref error) = self.failed {
139 tracing::trace!(
140 ?error,
141 "notifying batch request caller about worker failure",
142 );
143 let _ = tx.send(Err(error.clone()));
144 return;
145 }
146
147 match self.service.ready().await {
148 Ok(svc) => {
149 self.pending_items_weight += req.request_weight();
150 let rsp = svc.call(req.into());
151 let _ = tx.send(Ok(rsp));
152 }
153 Err(e) => {
154 self.failed(e.into());
155 let _ = tx.send(Err(self
156 .failed
157 .as_ref()
158 .expect("Worker::failed did not set self.failed?")
159 .clone()));
160 }
161 }
162 }
163
164 /// Tell the inner service to flush the current batch.
165 ///
166 /// Waits until the inner service is ready,
167 /// then stores a future which resolves when the batch finishes.
168 async fn flush_service(&mut self) {
169 if self.failed.is_some() {
170 tracing::trace!("worker failure: skipping flush");
171 return;
172 }
173
174 match self.service.ready().await {
175 Ok(ready_service) => {
176 let flush_future = ready_service.call(BatchControl::Flush);
177 self.concurrent_batches.push(flush_future.boxed());
178
179 // Now we have an empty batch.
180 self.pending_items_weight = 0;
181 self.pending_batch_timer = None;
182 }
183 Err(error) => {
184 self.failed(error.into());
185 }
186 }
187 }
188
189 /// Is the current number of concurrent batches above the configured limit?
190 fn can_spawn_new_batches(&self) -> bool {
191 self.concurrent_batches.len() < self.max_concurrent_batches
192 }
193
194 /// Run loop for batch requests, which implements the batch policies.
195 ///
196 /// See [`Batch::new()`](crate::Batch::new) for details.
197 pub async fn run(mut self) {
198 loop {
199 // Wait on either a new message or the batch timer.
200 //
201 // If both are ready, end the batch now, because the timer has elapsed.
202 // If the timer elapses, any pending messages are preserved:
203 // https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety
204 tokio::select! {
205 biased;
206
207 batch_result = self.concurrent_batches.next(), if !self.concurrent_batches.is_empty() => match batch_result.expect("only returns None when empty") {
208 Ok(_response) => {
209 tracing::trace!(
210 pending_items_weight = self.pending_items_weight,
211 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
212 running_batches = self.concurrent_batches.len(),
213 "batch finished executing",
214 );
215 }
216 Err(error) => {
217 let error = error.into();
218 tracing::trace!(?error, "batch execution failed");
219 self.failed(error);
220 }
221 },
222
223 Some(()) = OptionFuture::from(self.pending_batch_timer.as_mut()), if self.pending_batch_timer.as_ref().is_some() => {
224 tracing::trace!(
225 pending_items_weight = self.pending_items_weight,
226 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
227 running_batches = self.concurrent_batches.len(),
228 "batch timer expired",
229 );
230
231 // TODO: use a batch-specific span to instrument this future.
232 self.flush_service().await;
233 },
234
235 maybe_msg = self.rx.recv(), if self.can_spawn_new_batches() => match maybe_msg {
236 Some(msg) => {
237 tracing::trace!(
238 pending_items_weight = self.pending_items_weight,
239 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
240 running_batches = self.concurrent_batches.len(),
241 "batch message received",
242 );
243
244 let span = msg.span;
245 let is_new_batch = self.pending_items_weight == 0;
246
247 self.process_req(msg.request, msg.tx)
248 // Apply the provided span to request processing.
249 .instrument(span)
250 .await;
251
252 // Check whether we have too many pending items.
253 if self.pending_items_weight >= self.max_items_weight_in_batch {
254 tracing::trace!(
255 pending_items_weight = self.pending_items_weight,
256 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
257 running_batches = self.concurrent_batches.len(),
258 "batch is full",
259 );
260
261 // TODO: use a batch-specific span to instrument this future.
262 self.flush_service().await;
263 } else if is_new_batch {
264 tracing::trace!(
265 pending_items_weight = self.pending_items_weight,
266 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
267 running_batches = self.concurrent_batches.len(),
268 "batch is new, starting timer",
269 );
270
271 // The first message in a new batch.
272 self.pending_batch_timer = Some(Box::pin(sleep(self.max_latency)));
273 } else {
274 tracing::trace!(
275 pending_items_weight = self.pending_items_weight,
276 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
277 running_batches = self.concurrent_batches.len(),
278 "waiting for full batch or batch timer",
279 );
280 }
281 }
282 None => {
283 tracing::trace!("batch channel closed and emptied, exiting worker task");
284
285 return;
286 }
287 },
288 }
289 }
290 }
291
292 /// Register an inner service failure.
293 ///
294 /// The underlying service failed when we called `poll_ready` on it with the given `error`. We
295 /// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
296 /// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
297 /// requests will also fail with the same error.
298 fn failed(&mut self, error: crate::BoxError) {
299 tracing::debug!(?error, "batch worker error");
300
301 // Note that we need to handle the case where some error_handle is concurrently trying to send us
302 // a request. We need to make sure that *either* the send of the request fails *or* it
303 // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
304 // case where we send errors to all outstanding requests, and *then* the caller sends its
305 // request. We do this by *first* exposing the error, *then* closing the channel used to
306 // send more requests (so the client will see the error when the send fails), and *then*
307 // sending the error to all outstanding requests.
308 let error = ServiceError::new(error);
309
310 let mut inner = self.error_handle.inner.lock().unwrap();
311
312 // Ignore duplicate failures
313 if inner.is_some() {
314 return;
315 }
316
317 *inner = Some(error.clone());
318 drop(inner);
319
320 tracing::trace!(
321 ?error,
322 "worker failure: waking pending requests so they can be failed",
323 );
324 self.rx.close();
325 self.close.close();
326
327 // We don't schedule any batches on an errored service
328 self.pending_batch_timer = None;
329
330 // By closing the mpsc::Receiver, we know that the run() loop will
331 // drain all pending requests. We just need to make sure that any
332 // requests that we receive before we've exhausted the receiver receive
333 // the error:
334 self.failed = Some(error);
335 }
336}
337
338impl ErrorHandle {
339 pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
340 self.inner
341 .lock()
342 .expect("previous task panicked while holding the error handle mutex")
343 .as_ref()
344 .map(|svc_err| svc_err.clone().into())
345 .unwrap_or_else(|| Closed::new().into())
346 }
347}
348
349impl Clone for ErrorHandle {
350 fn clone(&self) -> ErrorHandle {
351 ErrorHandle {
352 inner: self.inner.clone(),
353 }
354 }
355}
356
357#[pin_project::pinned_drop]
358impl<T, Request: RequestWeight> PinnedDrop for Worker<T, Request>
359where
360 T: Service<BatchControl<Request>>,
361 T::Future: Send + 'static,
362 T::Error: Into<crate::BoxError>,
363{
364 fn drop(mut self: Pin<&mut Self>) {
365 tracing::trace!(
366 pending_items_weight = self.pending_items_weight,
367 batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
368 running_batches = self.concurrent_batches.len(),
369 error = ?self.failed,
370 "dropping batch worker",
371 );
372
373 // Fail pending tasks
374 self.failed(Closed::new().into());
375
376 // Fail queued requests
377 while let Ok(msg) = self.rx.try_recv() {
378 let _ = msg
379 .tx
380 .send(Err(self.failed.as_ref().expect("just set failed").clone()));
381 }
382
383 // Clear any finished batches, ignoring any errors.
384 // Ignore any batches that are still executing, because we can't cancel them.
385 //
386 // now_or_never() can stop futures waking up, but that's ok here,
387 // because we're manually polling, then dropping the stream.
388 while let Some(Some(_)) = self
389 .as_mut()
390 .project()
391 .concurrent_batches
392 .next()
393 .now_or_never()
394 {}
395 }
396}