1use std::{
4 fmt::Debug,
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures_core::ready;
11use pin_project::pin_project;
12use tower::Service;
13
14use crate::BoxedError;
15
16#[pin_project]
19pub struct ResponseFuture<S1, S2, Request>
20where
21 S1: Service<Request>,
22 S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
23 S2::Error: Into<BoxedError>,
24{
25 #[pin]
26 state: ResponseState<S1, S2, Request>,
27}
28
29#[pin_project(project_replace = __ResponseStateProjectionOwned, project = ResponseStateProj)]
30enum ResponseState<S1, S2, Request>
31where
32 S1: Service<Request>,
33 S2: Service<Request>,
34 S2::Error: Into<BoxedError>,
35{
36 PollResponse1 {
37 #[pin]
38 fut: S1::Future,
39 req: Request,
40 svc2: S2,
41 },
42 PollReady2 {
43 req: Request,
44 svc2: S2,
45 },
46 PollResponse2 {
47 #[pin]
48 fut: S2::Future,
49 },
50 Tmp,
52}
53
54impl<S1, S2, Request> ResponseFuture<S1, S2, Request>
55where
56 S1: Service<Request>,
57 S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
58 S2::Error: Into<BoxedError>,
59{
60 pub(crate) fn new(fut: S1::Future, req: Request, svc2: S2) -> Self {
61 ResponseFuture {
62 state: ResponseState::PollResponse1 { fut, req, svc2 },
63 }
64 }
65}
66
67impl<S1, S2, Request> Future for ResponseFuture<S1, S2, Request>
68where
69 S1: Service<Request>,
70 S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
71 S2::Error: Into<BoxedError>,
72{
73 type Output = Result<<S1 as Service<Request>>::Response, BoxedError>;
74
75 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76 let mut this = self.project();
77 loop {
85 match this.state.as_mut().project() {
86 ResponseStateProj::PollResponse1 { fut, .. } => match ready!(fut.poll(cx)) {
87 Ok(rsp) => return Poll::Ready(Ok(rsp)),
88 Err(_) => {
89 tracing::debug!("got error from svc1, retrying on svc2");
90 if let __ResponseStateProjectionOwned::PollResponse1 { req, svc2, .. } =
91 this.state.as_mut().project_replace(ResponseState::Tmp)
92 {
93 this.state.set(ResponseState::PollReady2 { req, svc2 });
94 } else {
95 unreachable!();
96 }
97 }
98 },
99 ResponseStateProj::PollReady2 { svc2, .. } => match ready!(svc2.poll_ready(cx)) {
100 Err(e) => return Poll::Ready(Err(e.into())),
101 Ok(()) => {
102 if let __ResponseStateProjectionOwned::PollReady2 { mut svc2, req } =
103 this.state.as_mut().project_replace(ResponseState::Tmp)
104 {
105 this.state.set(ResponseState::PollResponse2 {
106 fut: svc2.call(req),
107 });
108 } else {
109 unreachable!();
110 }
111 }
112 },
113 ResponseStateProj::PollResponse2 { fut } => {
114 return fut.poll(cx).map_err(Into::into)
115 }
116 ResponseStateProj::Tmp => unreachable!(),
117 }
118 }
119 }
120}
121
122impl<S1, S2, Request> Debug for ResponseFuture<S1, S2, Request>
123where
124 S1: Service<Request>,
125 S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
126 Request: Debug,
127 S1::Future: Debug,
128 S2: Debug,
129 S2::Future: Debug,
130 S2::Error: Into<BoxedError>,
131{
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("ResponseFuture")
134 .field("state", &self.state)
135 .finish()
136 }
137}
138
139impl<S1, S2, Request> Debug for ResponseState<S1, S2, Request>
140where
141 S1: Service<Request>,
142 S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
143 Request: Debug,
144 S1::Future: Debug,
145 S2: Debug,
146 S2::Future: Debug,
147 S2::Error: Into<BoxedError>,
148{
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 ResponseState::PollResponse1 { fut, req, svc2 } => f
152 .debug_struct("ResponseState::PollResponse1")
153 .field("fut", fut)
154 .field("req", req)
155 .field("svc2", svc2)
156 .finish(),
157 ResponseState::PollReady2 { req, svc2 } => f
158 .debug_struct("ResponseState::PollReady2")
159 .field("req", req)
160 .field("svc2", svc2)
161 .finish(),
162 ResponseState::PollResponse2 { fut } => f
163 .debug_struct("ResponseState::PollResponse2")
164 .field("fut", fut)
165 .finish(),
166 ResponseState::Tmp => unreachable!(),
167 }
168 }
169}