zebra_network/peer/client/
tests.rs
1#![cfg_attr(feature = "proptest-impl", allow(dead_code))]
5
6use std::{
7 net::{Ipv4Addr, SocketAddrV4},
8 sync::Arc,
9 time::Duration,
10};
11
12use chrono::Utc;
13use futures::{
14 channel::{mpsc, oneshot},
15 future::{self, AbortHandle, Future, FutureExt},
16};
17use tokio::{
18 sync::broadcast::{self, error::TryRecvError},
19 task::JoinHandle,
20};
21
22use zebra_chain::block::Height;
23
24use crate::{
25 constants,
26 peer::{
27 error::SharedPeerError, CancelHeartbeatTask, Client, ClientRequest, ConnectionInfo,
28 ErrorSlot,
29 },
30 peer_set::InventoryChange,
31 protocol::{
32 external::{types::Version, AddrInVersion},
33 types::{Nonce, PeerServices},
34 },
35 BoxError, VersionMessage,
36};
37
38#[cfg(test)]
39mod vectors;
40
41const MAX_PEER_CONNECTION_TIME: Duration = Duration::from_secs(10);
43
44pub struct ClientTestHarness {
46 client_request_receiver: Option<mpsc::Receiver<ClientRequest>>,
47 shutdown_receiver: Option<oneshot::Receiver<CancelHeartbeatTask>>,
48 #[allow(dead_code)]
49 inv_receiver: Option<broadcast::Receiver<InventoryChange>>,
50 error_slot: ErrorSlot,
51 remote_version: Version,
52 connection_aborter: AbortHandle,
53 heartbeat_aborter: AbortHandle,
54}
55
56impl ClientTestHarness {
57 pub fn build() -> ClientTestHarnessBuilder {
60 ClientTestHarnessBuilder {
61 version: None,
62 connection_task: None,
63 heartbeat_task: None,
64 }
65 }
66
67 pub fn remote_version(&self) -> Version {
69 self.remote_version
70 }
71
72 pub fn wants_connection_heartbeats(&mut self) -> bool {
79 let receive_result = self
80 .shutdown_receiver
81 .as_mut()
82 .expect("heartbeat shutdown receiver endpoint has been dropped")
83 .try_recv();
84
85 match receive_result {
86 Ok(None) => true,
87 Ok(Some(CancelHeartbeatTask)) | Err(oneshot::Canceled) => false,
88 }
89 }
90
91 pub fn drop_heartbeat_shutdown_receiver(&mut self) {
93 let hearbeat_future = self
94 .shutdown_receiver
95 .take()
96 .expect("unexpected test failure: heartbeat shutdown receiver endpoint has already been dropped");
97
98 std::mem::drop(hearbeat_future);
99 }
100
101 pub fn close_outbound_client_request_receiver(&mut self) {
106 self.client_request_receiver
107 .as_mut()
108 .expect("request receiver endpoint has been dropped")
109 .close();
110 }
111
112 pub fn drop_outbound_client_request_receiver(&mut self) {
116 self.client_request_receiver
117 .take()
118 .expect("request receiver endpoint has already been dropped");
119 }
120
121 pub(crate) fn try_to_receive_outbound_client_request(&mut self) -> ReceiveRequestAttempt {
125 let receive_result = self
126 .client_request_receiver
127 .as_mut()
128 .expect("request receiver endpoint has been dropped")
129 .try_next();
130
131 match receive_result {
132 Ok(Some(request)) => ReceiveRequestAttempt::Request(request),
133 Ok(None) => ReceiveRequestAttempt::Closed,
134 Err(_) => ReceiveRequestAttempt::Empty,
135 }
136 }
137
138 #[allow(dead_code)]
144 pub fn drop_inventory_change_receiver(&mut self) {
145 self.inv_receiver
146 .take()
147 .expect("inventory change receiver endpoint has already been dropped");
148 }
149
150 #[allow(dead_code)]
156 #[allow(clippy::unwrap_in_result)]
157 pub(crate) fn try_to_receive_inventory_change(&mut self) -> Option<InventoryChange> {
158 let receive_result = self
159 .inv_receiver
160 .as_mut()
161 .expect("inventory change receiver endpoint has been dropped")
162 .try_recv();
163
164 match receive_result {
165 Ok(change) => Some(change),
166 Err(TryRecvError::Empty) => None,
167 Err(TryRecvError::Closed) => None,
168 Err(TryRecvError::Lagged(skipped_messages)) => unreachable!(
169 "unexpected lagged inventory receiver in tests, skipped {} messages",
170 skipped_messages,
171 ),
172 }
173 }
174
175 pub fn current_error(&self) -> Option<SharedPeerError> {
177 self.error_slot.try_get_error()
178 }
179
180 pub fn set_error(&self, error: impl Into<SharedPeerError>) {
186 self.error_slot
187 .try_update_error(error.into())
188 .expect("unexpected earlier error in error slot")
189 }
190
191 pub async fn stop_connection_task(&self) {
193 self.connection_aborter.abort();
194
195 tokio::task::yield_now().await;
197 }
198
199 pub async fn stop_heartbeat_task(&self) {
201 self.heartbeat_aborter.abort();
202
203 tokio::task::yield_now().await;
205 }
206}
207
208pub(crate) enum ReceiveRequestAttempt {
212 Closed,
214
215 Empty,
217
218 Request(ClientRequest),
220}
221
222impl ReceiveRequestAttempt {
223 pub fn is_closed(&self) -> bool {
226 matches!(self, ReceiveRequestAttempt::Closed)
227 }
228
229 pub fn is_empty(&self) -> bool {
231 matches!(self, ReceiveRequestAttempt::Empty)
232 }
233
234 #[allow(dead_code)]
236 pub fn request(self) -> Option<ClientRequest> {
237 match self {
238 ReceiveRequestAttempt::Request(request) => Some(request),
239 ReceiveRequestAttempt::Closed | ReceiveRequestAttempt::Empty => None,
240 }
241 }
242}
243
244pub struct ClientTestHarnessBuilder<C = future::Ready<()>, H = future::Ready<()>> {
250 connection_task: Option<C>,
251 heartbeat_task: Option<H>,
252 version: Option<Version>,
253}
254
255impl<C, H> ClientTestHarnessBuilder<C, H>
256where
257 C: Future<Output = ()> + Send + 'static,
258 H: Future<Output = ()> + Send + 'static,
259{
260 pub fn with_version(mut self, version: Version) -> Self {
262 self.version = Some(version);
263 self
264 }
265
266 pub fn with_connection_task<NewC>(
268 self,
269 connection_task: NewC,
270 ) -> ClientTestHarnessBuilder<NewC, H> {
271 ClientTestHarnessBuilder {
272 connection_task: Some(connection_task),
273 heartbeat_task: self.heartbeat_task,
274 version: self.version,
275 }
276 }
277
278 pub fn with_heartbeat_task<NewH>(
280 self,
281 heartbeat_task: NewH,
282 ) -> ClientTestHarnessBuilder<C, NewH> {
283 ClientTestHarnessBuilder {
284 connection_task: self.connection_task,
285 heartbeat_task: Some(heartbeat_task),
286 version: self.version,
287 }
288 }
289
290 pub fn finish(self) -> (Client, ClientTestHarness) {
292 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
293 let (client_request_sender, client_request_receiver) = mpsc::channel(1);
294 let (inv_sender, inv_receiver) = broadcast::channel(5);
295
296 let error_slot = ErrorSlot::default();
297 let remote_version = self.version.unwrap_or(Version(0));
298
299 let (connection_task, connection_aborter) =
300 Self::spawn_background_task_or_fallback(self.connection_task);
301 let (heartbeat_task, heartbeat_aborter) =
302 Self::spawn_background_task_or_fallback_with_result(self.heartbeat_task);
303
304 let negotiated_version =
305 std::cmp::min(remote_version, constants::CURRENT_NETWORK_PROTOCOL_VERSION);
306
307 let remote = VersionMessage {
308 version: remote_version,
309 services: PeerServices::default(),
310 timestamp: Utc::now(),
311 address_recv: AddrInVersion::new(
312 SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1),
313 PeerServices::default(),
314 ),
315 address_from: AddrInVersion::new(
316 SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2),
317 PeerServices::default(),
318 ),
319 nonce: Nonce::default(),
320 user_agent: "client test harness".to_string(),
321 start_height: Height(0),
322 relay: true,
323 };
324
325 let connection_info = Arc::new(ConnectionInfo {
326 connected_addr: crate::peer::ConnectedAddr::Isolated,
327 remote,
328 negotiated_version,
329 });
330
331 let client = Client {
332 connection_info,
333 shutdown_tx: Some(shutdown_sender),
334 server_tx: client_request_sender,
335 inv_collector: inv_sender,
336 error_slot: error_slot.clone(),
337 connection_task,
338 heartbeat_task,
339 };
340
341 let harness = ClientTestHarness {
342 client_request_receiver: Some(client_request_receiver),
343 shutdown_receiver: Some(shutdown_receiver),
344 inv_receiver: Some(inv_receiver),
345 error_slot,
346 remote_version,
347 connection_aborter,
348 heartbeat_aborter,
349 };
350
351 (client, harness)
352 }
353
354 fn spawn_background_task_or_fallback<T>(task_future: Option<T>) -> (JoinHandle<()>, AbortHandle)
359 where
360 T: Future<Output = ()> + Send + 'static,
361 {
362 match task_future {
363 Some(future) => Self::spawn_background_task(future),
364 None => Self::spawn_background_task(tokio::time::sleep(MAX_PEER_CONNECTION_TIME)),
365 }
366 }
367
368 fn spawn_background_task<T>(task_future: T) -> (JoinHandle<()>, AbortHandle)
370 where
371 T: Future<Output = ()> + Send + 'static,
372 {
373 let (task, abort_handle) = future::abortable(task_future);
374 let task_handle = tokio::spawn(task.map(|_result| ()));
375
376 (task_handle, abort_handle)
377 }
378
379 fn spawn_background_task_or_fallback_with_result<T>(
386 task_future: Option<T>,
387 ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
388 where
389 T: Future<Output = ()> + Send + 'static,
390 {
391 match task_future {
392 Some(future) => Self::spawn_background_task_with_result(future),
393 None => Self::spawn_background_task_with_result(tokio::time::sleep(
394 MAX_PEER_CONNECTION_TIME,
395 )),
396 }
397 }
398
399 fn spawn_background_task_with_result<T>(
401 task_future: T,
402 ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
403 where
404 T: Future<Output = ()> + Send + 'static,
405 {
406 let (task, abort_handle) = future::abortable(task_future);
407 let task_handle = tokio::spawn(task.map(|_result| Ok(())));
408
409 (task_handle, abort_handle)
410 }
411}