1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
//! A [`Service`] implementation based on a fixed transcript.

use std::{
    fmt::Debug,
    sync::Arc,
    task::{Context, Poll},
};

use color_eyre::{
    eyre::{eyre, Report, WrapErr},
    section::Section,
    section::SectionExt,
};
use futures::future::{ready, Ready};
use tower::{Service, ServiceExt};

type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;

/// An error-checking function: is the value an expected error?
///
/// If the checked error is the expected error, the function should return `Ok(())`.
/// Otherwise, it should just return the checked error, wrapped inside `Err`.
pub type ErrorChecker = fn(Option<BoxError>) -> Result<(), BoxError>;

/// An expected error in a transcript.
#[derive(Debug, Clone)]
pub enum ExpectedTranscriptError {
    /// Match any error
    Any,
    /// Use a validator function to check for matching errors
    Exact(Arc<ErrorChecker>),
}

impl ExpectedTranscriptError {
    /// Convert the `verifier` function into an exact error checker
    pub fn exact(verifier: ErrorChecker) -> Self {
        ExpectedTranscriptError::Exact(verifier.into())
    }

    /// Check the actual error `e` against this expected error.
    #[track_caller]
    fn check(&self, e: BoxError) -> Result<(), Report> {
        match self {
            ExpectedTranscriptError::Any => Ok(()),
            ExpectedTranscriptError::Exact(checker) => checker(Some(e)),
        }
        .map_err(ErrorCheckerError)
        .wrap_err("service returned an error but it didn't match the expected error")
    }

    fn mock(&self) -> Report {
        match self {
            ExpectedTranscriptError::Any => eyre!("mock error"),
            ExpectedTranscriptError::Exact(checker) => {
                checker(None).map_err(|e| eyre!(e)).expect_err(
                    "transcript should correctly produce the expected mock error when passed None",
                )
            }
        }
    }
}

#[derive(Debug, thiserror::Error)]
#[error("ErrorChecker Error: {0}")]
struct ErrorCheckerError(BoxError);

/// A transcript: a list of requests and expected results.
#[must_use]
pub struct Transcript<R, S, I>
where
    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
    messages: I,
}

impl<R, S, I> From<I> for Transcript<R, S, I::IntoIter>
where
    I: IntoIterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
    fn from(messages: I) -> Self {
        Self {
            messages: messages.into_iter(),
        }
    }
}

impl<R, S, I> Transcript<R, S, I>
where
    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
    R: Debug,
    S: Debug + Eq,
{
    /// Check this transcript against the responses from the `to_check` service
    pub async fn check<C>(mut self, mut to_check: C) -> Result<(), Report>
    where
        C: Service<R, Response = S>,
        C::Error: Into<BoxError>,
    {
        for (req, expected_rsp) in &mut self.messages {
            // These unwraps could propagate errors with the correct
            // bound on C::Error
            let fut = to_check
                .ready()
                .await
                .map_err(Into::into)
                .map_err(|e| eyre!(e))
                .expect("expected service to not fail during execution of transcript");

            let response = fut.call(req).await;

            match (response, expected_rsp) {
                (Ok(rsp), Ok(expected_rsp)) => {
                    if rsp != expected_rsp {
                        Err(eyre!(
                            "response doesn't match transcript's expected response"
                        ))
                        .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))
                        .with_section(|| format!("{rsp:?}").header("Found Response:"))?;
                    }
                }
                (Ok(rsp), Err(error_checker)) => {
                    let error = Err(eyre!("received a response when an error was expected"))
                        .with_section(|| format!("{rsp:?}").header("Found Response:"));

                    let error = match std::panic::catch_unwind(|| error_checker.mock()) {
                        Ok(expected_err) => error
                            .with_section(|| format!("{expected_err:?}").header("Expected Error:")),
                        Err(pi) => {
                            let payload = pi
                                .downcast_ref::<String>()
                                .cloned()
                                .or_else(|| pi.downcast_ref::<&str>().map(ToString::to_string))
                                .unwrap_or_else(|| "<non string panic payload>".into());

                            error
                                .section(payload.header("Panic:"))
                                .wrap_err("ErrorChecker panicked when producing expected response")
                        }
                    };

                    error?;
                }
                (Err(e), Ok(expected_rsp)) => {
                    Err(eyre!("received an error when a response was expected"))
                        .with_error(|| ErrorCheckerError(e.into()))
                        .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))?
                }
                (Err(e), Err(error_checker)) => {
                    error_checker.check(e.into())?;
                    continue;
                }
            }
        }
        Ok(())
    }
}

impl<R, S, I> Service<R> for Transcript<R, S, I>
where
    R: Debug + Eq,
    I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
    type Response = S;
    type Error = Report;
    type Future = Ready<Result<S, Report>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    #[track_caller]
    fn call(&mut self, request: R) -> Self::Future {
        if let Some((expected_request, response)) = self.messages.next() {
            match response {
                Ok(response) => {
                    if request == expected_request {
                        ready(Ok(response))
                    } else {
                        ready(
                            Err(eyre!("received unexpected request"))
                                .with_section(|| {
                                    format!("{expected_request:?}").header("Expected Request:")
                                })
                                .with_section(|| format!("{request:?}").header("Found Request:")),
                        )
                    }
                }
                Err(check_fn) => ready(Err(check_fn.mock())),
            }
        } else {
            ready(Err(eyre!("Got request after transcript ended")))
        }
    }
}