zebra_test/
transcript.rs

1//! A [`Service`] implementation based on a fixed transcript.
2
3use std::{
4    fmt::Debug,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use color_eyre::{
10    eyre::{eyre, Report, WrapErr},
11    section::Section,
12    section::SectionExt,
13};
14use futures::future::{ready, Ready};
15use tower::{Service, ServiceExt};
16
17type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
18
19/// An error-checking function: is the value an expected error?
20///
21/// If the checked error is the expected error, the function should return `Ok(())`.
22/// Otherwise, it should just return the checked error, wrapped inside `Err`.
23pub type ErrorChecker = fn(Option<BoxError>) -> Result<(), BoxError>;
24
25/// An expected error in a transcript.
26#[derive(Debug, Clone)]
27pub enum ExpectedTranscriptError {
28    /// Match any error
29    Any,
30    /// Use a validator function to check for matching errors
31    Exact(Arc<ErrorChecker>),
32}
33
34impl ExpectedTranscriptError {
35    /// Convert the `verifier` function into an exact error checker
36    pub fn exact(verifier: ErrorChecker) -> Self {
37        ExpectedTranscriptError::Exact(verifier.into())
38    }
39
40    /// Check the actual error `e` against this expected error.
41    #[track_caller]
42    fn check(&self, e: BoxError) -> Result<(), Report> {
43        match self {
44            ExpectedTranscriptError::Any => Ok(()),
45            ExpectedTranscriptError::Exact(checker) => checker(Some(e)),
46        }
47        .map_err(ErrorCheckerError)
48        .wrap_err("service returned an error but it didn't match the expected error")
49    }
50
51    fn mock(&self) -> Report {
52        match self {
53            ExpectedTranscriptError::Any => eyre!("mock error"),
54            ExpectedTranscriptError::Exact(checker) => {
55                checker(None).map_err(|e| eyre!(e)).expect_err(
56                    "transcript should correctly produce the expected mock error when passed None",
57                )
58            }
59        }
60    }
61}
62
63#[derive(Debug, thiserror::Error)]
64#[error("ErrorChecker Error: {0}")]
65struct ErrorCheckerError(BoxError);
66
67/// A transcript: a list of requests and expected results.
68#[must_use]
69pub struct Transcript<R, S, I>
70where
71    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
72{
73    messages: I,
74}
75
76impl<R, S, I> From<I> for Transcript<R, S, I::IntoIter>
77where
78    I: IntoIterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
79{
80    fn from(messages: I) -> Self {
81        Self {
82            messages: messages.into_iter(),
83        }
84    }
85}
86
87impl<R, S, I> Transcript<R, S, I>
88where
89    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
90    R: Debug,
91    S: Debug + Eq,
92{
93    /// Check this transcript against the responses from the `to_check` service
94    pub async fn check<C>(mut self, mut to_check: C) -> Result<(), Report>
95    where
96        C: Service<R, Response = S>,
97        C::Error: Into<BoxError>,
98    {
99        for (req, expected_rsp) in &mut self.messages {
100            // These unwraps could propagate errors with the correct
101            // bound on C::Error
102            let fut = to_check
103                .ready()
104                .await
105                .map_err(Into::into)
106                .map_err(|e| eyre!(e))
107                .expect("expected service to not fail during execution of transcript");
108
109            let response = fut.call(req).await;
110
111            match (response, expected_rsp) {
112                (Ok(rsp), Ok(expected_rsp)) => {
113                    if rsp != expected_rsp {
114                        Err(eyre!(
115                            "response doesn't match transcript's expected response"
116                        ))
117                        .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))
118                        .with_section(|| format!("{rsp:?}").header("Found Response:"))?;
119                    }
120                }
121                (Ok(rsp), Err(error_checker)) => {
122                    let error = Err(eyre!("received a response when an error was expected"))
123                        .with_section(|| format!("{rsp:?}").header("Found Response:"));
124
125                    let error = match std::panic::catch_unwind(|| error_checker.mock()) {
126                        Ok(expected_err) => error
127                            .with_section(|| format!("{expected_err:?}").header("Expected Error:")),
128                        Err(pi) => {
129                            let payload = pi
130                                .downcast_ref::<String>()
131                                .cloned()
132                                .or_else(|| pi.downcast_ref::<&str>().map(ToString::to_string))
133                                .unwrap_or_else(|| "<non string panic payload>".into());
134
135                            error
136                                .section(payload.header("Panic:"))
137                                .wrap_err("ErrorChecker panicked when producing expected response")
138                        }
139                    };
140
141                    error?;
142                }
143                (Err(e), Ok(expected_rsp)) => {
144                    Err(eyre!("received an error when a response was expected"))
145                        .with_error(|| ErrorCheckerError(e.into()))
146                        .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))?
147                }
148                (Err(e), Err(error_checker)) => {
149                    error_checker.check(e.into())?;
150                    continue;
151                }
152            }
153        }
154        Ok(())
155    }
156}
157
158impl<R, S, I> Service<R> for Transcript<R, S, I>
159where
160    R: Debug + Eq,
161    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
162{
163    type Response = S;
164    type Error = Report;
165    type Future = Ready<Result<S, Report>>;
166
167    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        Poll::Ready(Ok(()))
169    }
170
171    #[track_caller]
172    fn call(&mut self, request: R) -> Self::Future {
173        if let Some((expected_request, response)) = self.messages.next() {
174            match response {
175                Ok(response) => {
176                    if request == expected_request {
177                        ready(Ok(response))
178                    } else {
179                        ready(
180                            Err(eyre!("received unexpected request"))
181                                .with_section(|| {
182                                    format!("{expected_request:?}").header("Expected Request:")
183                                })
184                                .with_section(|| format!("{request:?}").header("Found Request:")),
185                        )
186                    }
187                }
188                Err(check_fn) => ready(Err(check_fn.mock())),
189            }
190        } else {
191            ready(Err(eyre!("Got request after transcript ended")))
192        }
193    }
194}