1use std::{
29    collections::{HashMap, HashSet},
30    pin::Pin,
31    task::{Context, Poll},
32    time::Duration,
33};
34
35use futures::{
36    future::TryFutureExt,
37    ready,
38    stream::{FuturesUnordered, Stream},
39    FutureExt,
40};
41use pin_project::{pin_project, pinned_drop};
42use thiserror::Error;
43use tokio::{sync::oneshot, task::JoinHandle};
44use tower::{Service, ServiceExt};
45use tracing_futures::Instrument;
46
47use zebra_chain::{
48    block::Height,
49    transaction::{self, UnminedTxId, VerifiedUnminedTx},
50    transparent,
51};
52use zebra_consensus::transaction as tx;
53use zebra_network::{self as zn, PeerSocketAddr};
54use zebra_node_services::mempool::Gossip;
55use zebra_state::{self as zs, CloneError};
56
57use crate::components::{
58    mempool::crawler::RATE_LIMIT_DELAY,
59    sync::{BLOCK_DOWNLOAD_TIMEOUT, BLOCK_VERIFY_TIMEOUT},
60};
61
62use super::MempoolError;
63
64type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
65
66pub(crate) const TRANSACTION_DOWNLOAD_TIMEOUT: Duration = BLOCK_DOWNLOAD_TIMEOUT;
72
73pub(crate) const TRANSACTION_VERIFY_TIMEOUT: Duration = BLOCK_VERIFY_TIMEOUT;
81
82pub const MAX_INBOUND_CONCURRENCY: usize = 25;
106
107#[derive(Copy, Clone, Debug, Eq, PartialEq)]
109struct CancelDownloadAndVerify;
110
111#[derive(Error, Debug, Clone)]
113#[allow(dead_code)]
114pub enum TransactionDownloadVerifyError {
115    #[error("transaction is already in state")]
116    InState,
117
118    #[error("error in state service: {0}")]
119    StateError(#[source] CloneError),
120
121    #[error("error downloading transaction: {0}")]
122    DownloadFailed(#[source] CloneError),
123
124    #[error("transaction download / verification was cancelled")]
125    Cancelled,
126
127    #[error("transaction did not pass consensus validation: {error}")]
128    Invalid {
129        error: zebra_consensus::error::TransactionError,
130        advertiser_addr: Option<PeerSocketAddr>,
131    },
132}
133
134#[pin_project(PinnedDrop)]
136#[derive(Debug)]
137pub struct Downloads<ZN, ZV, ZS>
138where
139    ZN: Service<zn::Request, Response = zn::Response, Error = BoxError> + Send + Clone + 'static,
140    ZN::Future: Send,
141    ZV: Service<tx::Request, Response = tx::Response, Error = BoxError> + Send + Clone + 'static,
142    ZV::Future: Send,
143    ZS: Service<zs::Request, Response = zs::Response, Error = BoxError> + Send + Clone + 'static,
144    ZS::Future: Send,
145{
146    network: ZN,
150
151    verifier: ZV,
153
154    state: ZS,
156
157    #[pin]
160    pending: FuturesUnordered<
161        JoinHandle<
162            Result<
163                Result<
164                    (
165                        VerifiedUnminedTx,
166                        Vec<transparent::OutPoint>,
167                        Option<Height>,
168                        Option<oneshot::Sender<Result<(), BoxError>>>,
169                    ),
170                    (TransactionDownloadVerifyError, UnminedTxId),
171                >,
172                tokio::time::error::Elapsed,
173            >,
174        >,
175    >,
176
177    cancel_handles: HashMap<UnminedTxId, (oneshot::Sender<CancelDownloadAndVerify>, Gossip)>,
180}
181
182impl<ZN, ZV, ZS> Stream for Downloads<ZN, ZV, ZS>
183where
184    ZN: Service<zn::Request, Response = zn::Response, Error = BoxError> + Send + Clone + 'static,
185    ZN::Future: Send,
186    ZV: Service<tx::Request, Response = tx::Response, Error = BoxError> + Send + Clone + 'static,
187    ZV::Future: Send,
188    ZS: Service<zs::Request, Response = zs::Response, Error = BoxError> + Send + Clone + 'static,
189    ZS::Future: Send,
190{
191    type Item = Result<
192        Result<
193            (
194                VerifiedUnminedTx,
195                Vec<transparent::OutPoint>,
196                Option<Height>,
197                Option<oneshot::Sender<Result<(), BoxError>>>,
198            ),
199            (UnminedTxId, TransactionDownloadVerifyError),
200        >,
201        tokio::time::error::Elapsed,
202    >;
203
204    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
205        let this = self.project();
206        let item = if let Some(join_result) = ready!(this.pending.poll_next(cx)) {
216            let result = join_result.expect("transaction download and verify tasks must not panic");
217            let result = match result {
218                Ok(Ok((tx, spent_mempool_outpoints, tip_height, rsp_tx))) => {
219                    this.cancel_handles.remove(&tx.transaction.id);
220                    Ok(Ok((tx, spent_mempool_outpoints, tip_height, rsp_tx)))
221                }
222                Ok(Err((e, hash))) => {
223                    this.cancel_handles.remove(&hash);
224                    Ok(Err((hash, e)))
225                }
226                Err(elapsed) => Err(elapsed),
227            };
228
229            Some(result)
230        } else {
231            None
232        };
233
234        Poll::Ready(item)
235    }
236
237    fn size_hint(&self) -> (usize, Option<usize>) {
238        self.pending.size_hint()
239    }
240}
241
242impl<ZN, ZV, ZS> Downloads<ZN, ZV, ZS>
243where
244    ZN: Service<zn::Request, Response = zn::Response, Error = BoxError> + Send + Clone + 'static,
245    ZN::Future: Send,
246    ZV: Service<tx::Request, Response = tx::Response, Error = BoxError> + Send + Clone + 'static,
247    ZV::Future: Send,
248    ZS: Service<zs::Request, Response = zs::Response, Error = BoxError> + Send + Clone + 'static,
249    ZS::Future: Send,
250{
251    pub fn new(network: ZN, verifier: ZV, state: ZS) -> Self {
261        Self {
262            network,
263            verifier,
264            state,
265            pending: FuturesUnordered::new(),
266            cancel_handles: HashMap::new(),
267        }
268    }
269
270    #[instrument(skip(self, gossiped_tx), fields(txid = %gossiped_tx.id()))]
274    #[allow(clippy::unwrap_in_result)]
275    pub fn download_if_needed_and_verify(
276        &mut self,
277        gossiped_tx: Gossip,
278        mut rsp_tx: Option<oneshot::Sender<Result<(), BoxError>>>,
279    ) -> Result<(), MempoolError> {
280        let txid = gossiped_tx.id();
281
282        if self.cancel_handles.contains_key(&txid) {
283            debug!(
284                ?txid,
285                queue_len = self.pending.len(),
286                ?MAX_INBOUND_CONCURRENCY,
287                "transaction id already queued for inbound download: ignored transaction"
288            );
289            metrics::gauge!("mempool.currently.queued.transactions",)
290                .set(self.pending.len() as f64);
291
292            return Err(MempoolError::AlreadyQueued);
293        }
294
295        if self.pending.len() >= MAX_INBOUND_CONCURRENCY {
296            debug!(
297                ?txid,
298                queue_len = self.pending.len(),
299                ?MAX_INBOUND_CONCURRENCY,
300                "too many transactions queued for inbound download: ignored transaction"
301            );
302            metrics::gauge!("mempool.currently.queued.transactions",)
303                .set(self.pending.len() as f64);
304
305            return Err(MempoolError::FullQueue);
306        }
307
308        let (cancel_tx, mut cancel_rx) = oneshot::channel::<CancelDownloadAndVerify>();
310
311        let network = self.network.clone();
312        let verifier = self.verifier.clone();
313        let mut state = self.state.clone();
314
315        let gossiped_tx_req = gossiped_tx.clone();
316
317        let fut = async move {
318            Self::transaction_in_best_chain(&mut state, txid).await?;
320
321            trace!(?txid, "transaction is not in best chain");
322
323            let (tip_height, next_height) = match state.oneshot(zs::Request::Tip).await {
324                Ok(zs::Response::Tip(None)) => Ok((None, Height(0))),
325                Ok(zs::Response::Tip(Some((height, _hash)))) => {
326                    let next_height =
327                        (height + 1).expect("valid heights are far below the maximum");
328                    Ok((Some(height), next_height))
329                }
330                Ok(_) => unreachable!("wrong response"),
331                Err(e) => Err(TransactionDownloadVerifyError::StateError(e.into())),
332            }?;
333
334            trace!(?txid, ?next_height, "got next height");
335
336            let (tx, advertiser_addr) = match gossiped_tx {
337                Gossip::Id(txid) => {
338                    let req = zn::Request::TransactionsById(std::iter::once(txid).collect());
339
340                    let tx = match network
341                        .oneshot(req)
342                        .await
343                        .map_err(CloneError::from)
344                        .map_err(TransactionDownloadVerifyError::DownloadFailed)?
345                    {
346                        zn::Response::Transactions(mut txs) => txs.pop().ok_or_else(|| {
347                            TransactionDownloadVerifyError::DownloadFailed(
348                                BoxError::from("no transactions returned").into(),
349                            )
350                        })?,
351                        _ => unreachable!("wrong response to transaction request"),
352                    };
353
354                    let (tx, advertiser_addr) = tx.available().expect(
355                        "unexpected missing tx status: single tx failures should be errors",
356                    );
357
358                    metrics::counter!(
359                        "mempool.downloaded.transactions.total",
360                        "version" => format!("{}",tx.transaction.version()),
361                    ).increment(1);
362                    (tx, advertiser_addr)
363                }
364                Gossip::Tx(tx) => {
365                    metrics::counter!(
366                        "mempool.pushed.transactions.total",
367                        "version" => format!("{}",tx.transaction.version()),
368                    ).increment(1);
369                    (tx, None)
370                }
371            };
372
373            trace!(?txid, "got tx");
374
375            let result = verifier
376                .oneshot(tx::Request::Mempool {
377                    transaction: tx.clone(),
378                    height: next_height,
379                })
380                .map_ok(|rsp| {
381                    let tx::Response::Mempool { transaction, spent_mempool_outpoints } = rsp else {
382                        panic!("unexpected non-mempool response to mempool request")
383                    };
384
385                    (transaction, spent_mempool_outpoints, tip_height)
386                })
387                .await;
388
389            trace!(?txid, result = ?result.as_ref().map(|_tx| ()), "verified transaction for the mempool");
391
392            result.map_err(|e| TransactionDownloadVerifyError::Invalid { error: e.into(), advertiser_addr } )
393        }
394        .map_ok(|(tx, spent_mempool_outpoints, tip_height)| {
395            metrics::counter!(
396                "mempool.verified.transactions.total",
397                "version" => format!("{}", tx.transaction.transaction.version()),
398            ).increment(1);
399            (tx, spent_mempool_outpoints, tip_height)
400        })
401        .map_err(move |e| (e, txid))
404        .inspect(move |result| {
405            let result = result.as_ref().map(|_tx| txid);
407            debug!("mempool transaction result: {result:?}");
408        })
409        .in_current_span();
410
411        let task = tokio::spawn(async move {
412            let fut = tokio::time::timeout(RATE_LIMIT_DELAY, fut);
413
414            let result = tokio::select! {
416                biased;
417                _ = &mut cancel_rx => {
418                    trace!("task cancelled prior to completion");
419                    metrics::counter!("mempool.cancelled.verify.tasks.total").increment(1);
420                    if let Some(rsp_tx) = rsp_tx.take() {
421                        let _ = rsp_tx.send(Err("verification cancelled".into()));
422                    }
423
424                    Ok(Err((TransactionDownloadVerifyError::Cancelled, txid)))
425                }
426                verification = fut => {
427                    verification
428                        .inspect_err(|_elapsed| {
429                            if let Some(rsp_tx) = rsp_tx.take() {
430                                let _ = rsp_tx.send(Err("timeout waiting for verification result".into()));
431                            }
432                        })
433                        .map(|inner_result| {
434                            match inner_result {
435                                Ok((transaction, spent_mempool_outpoints, tip_height)) => Ok((transaction, spent_mempool_outpoints, tip_height, rsp_tx)),
436                                Err((tx_verifier_error, tx_id)) => {
437                                    if let Some(rsp_tx) = rsp_tx.take() {
438                                        let error_msg = format!(
439                                            "failed to validate tx: {tx_id}, error: {tx_verifier_error}"
440                                        );
441                                        let _ = rsp_tx.send(Err(error_msg.into()));
442                                    };
443
444                                    Err((tx_verifier_error, tx_id))
445                                }
446                            }
447                        })
448                },
449            };
450
451            result
452        });
453
454        self.pending.push(task);
455        assert!(
456            self.cancel_handles
457                .insert(txid, (cancel_tx, gossiped_tx_req))
458                .is_none(),
459            "transactions are only queued once"
460        );
461
462        debug!(
463            ?txid,
464            queue_len = self.pending.len(),
465            ?MAX_INBOUND_CONCURRENCY,
466            "queued transaction hash for download"
467        );
468        metrics::gauge!("mempool.currently.queued.transactions",).set(self.pending.len() as f64);
469        metrics::counter!("mempool.queued.transactions.total").increment(1);
470
471        Ok(())
472    }
473
474    pub fn cancel(&mut self, mined_ids: &HashSet<transaction::Hash>) {
477        let removed_txids: Vec<UnminedTxId> = self
480            .cancel_handles
481            .keys()
482            .filter(|txid| mined_ids.contains(&txid.mined_id()))
483            .cloned()
484            .collect();
485
486        for txid in removed_txids {
487            if let Some(handle) = self.cancel_handles.remove(&txid) {
488                let _ = handle.0.send(CancelDownloadAndVerify);
489            }
490        }
491    }
492
493    pub fn cancel_all(&mut self) {
496        let _ = std::mem::take(&mut self.pending);
498        for (_hash, cancel) in self.cancel_handles.drain() {
502            let _ = cancel.0.send(CancelDownloadAndVerify);
503        }
504        assert!(self.pending.is_empty());
505        assert!(self.cancel_handles.is_empty());
506        metrics::gauge!("mempool.currently.queued.transactions",).set(self.pending.len() as f64);
507    }
508
509    #[allow(dead_code)]
511    pub fn in_flight(&self) -> usize {
512        self.pending.len()
513    }
514
515    pub fn transaction_requests(&self) -> impl Iterator<Item = &Gossip> {
517        self.cancel_handles.iter().map(|(_tx_id, (_handle, tx))| tx)
518    }
519
520    async fn transaction_in_best_chain(
522        state: &mut ZS,
523        txid: UnminedTxId,
524    ) -> Result<(), TransactionDownloadVerifyError> {
525        match state
526            .ready()
527            .await
528            .map_err(CloneError::from)
529            .map_err(TransactionDownloadVerifyError::StateError)?
530            .call(zs::Request::Transaction(txid.mined_id()))
531            .await
532        {
533            Ok(zs::Response::Transaction(None)) => Ok(()),
534            Ok(zs::Response::Transaction(Some(_))) => Err(TransactionDownloadVerifyError::InState),
535            Ok(_) => unreachable!("wrong response"),
536            Err(e) => Err(TransactionDownloadVerifyError::StateError(e.into())),
537        }?;
538
539        Ok(())
540    }
541}
542
543#[pinned_drop]
544impl<ZN, ZV, ZS> PinnedDrop for Downloads<ZN, ZV, ZS>
545where
546    ZN: Service<zn::Request, Response = zn::Response, Error = BoxError> + Send + Clone + 'static,
547    ZN::Future: Send,
548    ZV: Service<tx::Request, Response = tx::Response, Error = BoxError> + Send + Clone + 'static,
549    ZV::Future: Send,
550    ZS: Service<zs::Request, Response = zs::Response, Error = BoxError> + Send + Clone + 'static,
551    ZS::Future: Send,
552{
553    fn drop(mut self: Pin<&mut Self>) {
554        self.cancel_all();
555
556        metrics::gauge!("mempool.currently.queued.transactions").set(0 as f64);
557    }
558}