tower_fallback/
future.rs

1//! Future types for the `Fallback` middleware.
2
3use std::{
4    fmt::Debug,
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures_core::ready;
11use pin_project::pin_project;
12use tower::Service;
13
14use crate::BoxedError;
15
16/// Future that completes either with the first service's successful response, or
17/// with the second service's response.
18#[pin_project]
19pub struct ResponseFuture<S1, S2, Request>
20where
21    S1: Service<Request>,
22    S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
23    S2::Error: Into<BoxedError>,
24{
25    #[pin]
26    state: ResponseState<S1, S2, Request>,
27}
28
29#[pin_project(project_replace = __ResponseStateProjectionOwned, project = ResponseStateProj)]
30enum ResponseState<S1, S2, Request>
31where
32    S1: Service<Request>,
33    S2: Service<Request>,
34    S2::Error: Into<BoxedError>,
35{
36    PollResponse1 {
37        #[pin]
38        fut: S1::Future,
39        req: Request,
40        svc2: S2,
41    },
42    PollReady2 {
43        req: Request,
44        svc2: S2,
45    },
46    PollResponse2 {
47        #[pin]
48        fut: S2::Future,
49    },
50    // Placeholder value to swap into the pin projection of the enum so we can take ownership of the fields.
51    Tmp,
52}
53
54impl<S1, S2, Request> ResponseFuture<S1, S2, Request>
55where
56    S1: Service<Request>,
57    S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
58    S2::Error: Into<BoxedError>,
59{
60    pub(crate) fn new(fut: S1::Future, req: Request, svc2: S2) -> Self {
61        ResponseFuture {
62            state: ResponseState::PollResponse1 { fut, req, svc2 },
63        }
64    }
65}
66
67impl<S1, S2, Request> Future for ResponseFuture<S1, S2, Request>
68where
69    S1: Service<Request>,
70    S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
71    S2::Error: Into<BoxedError>,
72{
73    type Output = Result<<S1 as Service<Request>>::Response, BoxedError>;
74
75    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76        let mut this = self.project();
77        // CORRECTNESS
78        //
79        // The current task must be scheduled for wakeup every time we return
80        // `Poll::Pending`.
81        //
82        // This loop ensures that the task is scheduled as required, because it
83        // only returns Pending when a future or service returns Pending.
84        loop {
85            match this.state.as_mut().project() {
86                ResponseStateProj::PollResponse1 { fut, .. } => match ready!(fut.poll(cx)) {
87                    Ok(rsp) => return Poll::Ready(Ok(rsp)),
88                    Err(_) => {
89                        tracing::debug!("got error from svc1, retrying on svc2");
90                        if let __ResponseStateProjectionOwned::PollResponse1 { req, svc2, .. } =
91                            this.state.as_mut().project_replace(ResponseState::Tmp)
92                        {
93                            this.state.set(ResponseState::PollReady2 { req, svc2 });
94                        } else {
95                            unreachable!();
96                        }
97                    }
98                },
99                ResponseStateProj::PollReady2 { svc2, .. } => match ready!(svc2.poll_ready(cx)) {
100                    Err(e) => return Poll::Ready(Err(e.into())),
101                    Ok(()) => {
102                        if let __ResponseStateProjectionOwned::PollReady2 { mut svc2, req } =
103                            this.state.as_mut().project_replace(ResponseState::Tmp)
104                        {
105                            this.state.set(ResponseState::PollResponse2 {
106                                fut: svc2.call(req),
107                            });
108                        } else {
109                            unreachable!();
110                        }
111                    }
112                },
113                ResponseStateProj::PollResponse2 { fut } => {
114                    return fut.poll(cx).map_err(Into::into)
115                }
116                ResponseStateProj::Tmp => unreachable!(),
117            }
118        }
119    }
120}
121
122impl<S1, S2, Request> Debug for ResponseFuture<S1, S2, Request>
123where
124    S1: Service<Request>,
125    S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
126    Request: Debug,
127    S1::Future: Debug,
128    S2: Debug,
129    S2::Future: Debug,
130    S2::Error: Into<BoxedError>,
131{
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        f.debug_struct("ResponseFuture")
134            .field("state", &self.state)
135            .finish()
136    }
137}
138
139impl<S1, S2, Request> Debug for ResponseState<S1, S2, Request>
140where
141    S1: Service<Request>,
142    S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
143    Request: Debug,
144    S1::Future: Debug,
145    S2: Debug,
146    S2::Future: Debug,
147    S2::Error: Into<BoxedError>,
148{
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        match self {
151            ResponseState::PollResponse1 { fut, req, svc2 } => f
152                .debug_struct("ResponseState::PollResponse1")
153                .field("fut", fut)
154                .field("req", req)
155                .field("svc2", svc2)
156                .finish(),
157            ResponseState::PollReady2 { req, svc2 } => f
158                .debug_struct("ResponseState::PollReady2")
159                .field("req", req)
160                .field("svc2", svc2)
161                .finish(),
162            ResponseState::PollResponse2 { fut } => f
163                .debug_struct("ResponseState::PollResponse2")
164                .field("fut", fut)
165                .finish(),
166            ResponseState::Tmp => unreachable!(),
167        }
168    }
169}