zebra_test/
transcript.rs
1use std::{
4 fmt::Debug,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use color_eyre::{
10 eyre::{eyre, Report, WrapErr},
11 section::Section,
12 section::SectionExt,
13};
14use futures::future::{ready, Ready};
15use tower::{Service, ServiceExt};
16
17type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
18
19pub type ErrorChecker = fn(Option<BoxError>) -> Result<(), BoxError>;
24
25#[derive(Debug, Clone)]
27pub enum ExpectedTranscriptError {
28 Any,
30 Exact(Arc<ErrorChecker>),
32}
33
34impl ExpectedTranscriptError {
35 pub fn exact(verifier: ErrorChecker) -> Self {
37 ExpectedTranscriptError::Exact(verifier.into())
38 }
39
40 #[track_caller]
42 fn check(&self, e: BoxError) -> Result<(), Report> {
43 match self {
44 ExpectedTranscriptError::Any => Ok(()),
45 ExpectedTranscriptError::Exact(checker) => checker(Some(e)),
46 }
47 .map_err(ErrorCheckerError)
48 .wrap_err("service returned an error but it didn't match the expected error")
49 }
50
51 fn mock(&self) -> Report {
52 match self {
53 ExpectedTranscriptError::Any => eyre!("mock error"),
54 ExpectedTranscriptError::Exact(checker) => {
55 checker(None).map_err(|e| eyre!(e)).expect_err(
56 "transcript should correctly produce the expected mock error when passed None",
57 )
58 }
59 }
60 }
61}
62
63#[derive(Debug, thiserror::Error)]
64#[error("ErrorChecker Error: {0}")]
65struct ErrorCheckerError(BoxError);
66
67#[must_use]
69pub struct Transcript<R, S, I>
70where
71 I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
72{
73 messages: I,
74}
75
76impl<R, S, I> From<I> for Transcript<R, S, I::IntoIter>
77where
78 I: IntoIterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
79{
80 fn from(messages: I) -> Self {
81 Self {
82 messages: messages.into_iter(),
83 }
84 }
85}
86
87impl<R, S, I> Transcript<R, S, I>
88where
89 I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
90 R: Debug,
91 S: Debug + Eq,
92{
93 pub async fn check<C>(mut self, mut to_check: C) -> Result<(), Report>
95 where
96 C: Service<R, Response = S>,
97 C::Error: Into<BoxError>,
98 {
99 for (req, expected_rsp) in &mut self.messages {
100 let fut = to_check
103 .ready()
104 .await
105 .map_err(Into::into)
106 .map_err(|e| eyre!(e))
107 .expect("expected service to not fail during execution of transcript");
108
109 let response = fut.call(req).await;
110
111 match (response, expected_rsp) {
112 (Ok(rsp), Ok(expected_rsp)) => {
113 if rsp != expected_rsp {
114 Err(eyre!(
115 "response doesn't match transcript's expected response"
116 ))
117 .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))
118 .with_section(|| format!("{rsp:?}").header("Found Response:"))?;
119 }
120 }
121 (Ok(rsp), Err(error_checker)) => {
122 let error = Err(eyre!("received a response when an error was expected"))
123 .with_section(|| format!("{rsp:?}").header("Found Response:"));
124
125 let error = match std::panic::catch_unwind(|| error_checker.mock()) {
126 Ok(expected_err) => error
127 .with_section(|| format!("{expected_err:?}").header("Expected Error:")),
128 Err(pi) => {
129 let payload = pi
130 .downcast_ref::<String>()
131 .cloned()
132 .or_else(|| pi.downcast_ref::<&str>().map(ToString::to_string))
133 .unwrap_or_else(|| "<non string panic payload>".into());
134
135 error
136 .section(payload.header("Panic:"))
137 .wrap_err("ErrorChecker panicked when producing expected response")
138 }
139 };
140
141 error?;
142 }
143 (Err(e), Ok(expected_rsp)) => {
144 Err(eyre!("received an error when a response was expected"))
145 .with_error(|| ErrorCheckerError(e.into()))
146 .with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))?
147 }
148 (Err(e), Err(error_checker)) => {
149 error_checker.check(e.into())?;
150 continue;
151 }
152 }
153 }
154 Ok(())
155 }
156}
157
158impl<R, S, I> Service<R> for Transcript<R, S, I>
159where
160 R: Debug + Eq,
161 I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
162{
163 type Response = S;
164 type Error = Report;
165 type Future = Ready<Result<S, Report>>;
166
167 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 #[track_caller]
172 fn call(&mut self, request: R) -> Self::Future {
173 if let Some((expected_request, response)) = self.messages.next() {
174 match response {
175 Ok(response) => {
176 if request == expected_request {
177 ready(Ok(response))
178 } else {
179 ready(
180 Err(eyre!("received unexpected request"))
181 .with_section(|| {
182 format!("{expected_request:?}").header("Expected Request:")
183 })
184 .with_section(|| format!("{request:?}").header("Found Request:")),
185 )
186 }
187 }
188 Err(check_fn) => ready(Err(check_fn.mock())),
189 }
190 } else {
191 ready(Err(eyre!("Got request after transcript ended")))
192 }
193 }
194}