zebra_scan/service/scan_task/
commands.rs

1//! Types and method implementations for [`ScanTaskCommand`]
2
3use std::collections::{HashMap, HashSet};
4
5use color_eyre::{eyre::eyre, Report};
6use tokio::sync::{
7    mpsc::{error::TrySendError, Receiver, Sender},
8    oneshot,
9};
10
11use sapling_crypto::zip32::DiversifiableFullViewingKey;
12use zebra_chain::{block::Height, parameters::Network};
13use zebra_node_services::scan_service::response::ScanResult;
14use zebra_state::SaplingScanningKey;
15
16use crate::scan::sapling_key_to_dfvk;
17
18use super::ScanTask;
19
20const RESULTS_SENDER_BUFFER_SIZE: usize = 100;
21
22#[derive(Debug)]
23/// Commands that can be sent to [`ScanTask`]
24pub enum ScanTaskCommand {
25    /// Start scanning for new viewing keys
26    RegisterKeys {
27        /// New keys to start scanning for
28        keys: Vec<(String, Option<u32>)>,
29        /// Returns the set of keys the scanner accepted.
30        rsp_tx: oneshot::Sender<Vec<SaplingScanningKey>>,
31    },
32
33    /// Stop scanning for deleted viewing keys
34    RemoveKeys {
35        /// Notify the caller once the key is removed (so the caller can wait before clearing results)
36        done_tx: oneshot::Sender<()>,
37
38        /// Key hashes that are to be removed
39        keys: Vec<String>,
40    },
41
42    /// Start sending results for key hashes to `result_sender`
43    SubscribeResults {
44        /// Key hashes to send the results of to result channel
45        keys: HashSet<String>,
46
47        /// Returns the result receiver once the subscribed keys have been added.
48        rsp_tx: oneshot::Sender<Receiver<ScanResult>>,
49    },
50}
51
52impl ScanTask {
53    /// Accepts the scan task's `parsed_key` collection and a reference to the command channel receiver
54    ///
55    /// Processes messages in the scan task channel, updating `parsed_keys` if required.
56    ///
57    /// Returns newly registered keys for scanning.
58    pub fn process_messages(
59        cmd_receiver: &mut tokio::sync::mpsc::Receiver<ScanTaskCommand>,
60        registered_keys: &mut HashMap<SaplingScanningKey, DiversifiableFullViewingKey>,
61        network: &Network,
62    ) -> Result<
63        (
64            HashMap<SaplingScanningKey, (DiversifiableFullViewingKey, Height)>,
65            HashMap<SaplingScanningKey, Sender<ScanResult>>,
66            Vec<(Receiver<ScanResult>, oneshot::Sender<Receiver<ScanResult>>)>,
67        ),
68        Report,
69    > {
70        use tokio::sync::mpsc::error::TryRecvError;
71
72        let mut new_keys = HashMap::new();
73        let mut new_result_senders = HashMap::new();
74        let mut new_result_receivers = Vec::new();
75        let sapling_activation_height = network.sapling_activation_height();
76
77        loop {
78            let cmd = match cmd_receiver.try_recv() {
79                Ok(cmd) => cmd,
80
81                Err(TryRecvError::Empty) => break,
82                Err(TryRecvError::Disconnected) => {
83                    // Return early if the sender has been dropped.
84                    return Err(eyre!("command channel disconnected"));
85                }
86            };
87
88            match cmd {
89                ScanTaskCommand::RegisterKeys { keys, rsp_tx } => {
90                    // Determine what keys we pass to the scanner.
91                    let keys: Vec<_> = keys
92                        .into_iter()
93                        .filter_map(|key| {
94                            // Don't accept keys that:
95                            // 1. the scanner already has, and
96                            // 2. were already submitted.
97                            if registered_keys.contains_key(&key.0)
98                                && !new_keys.contains_key(&key.0)
99                            {
100                                return None;
101                            }
102
103                            let birth_height = if let Some(height) = key.1 {
104                                match Height::try_from(height) {
105                                    Ok(height) => height,
106                                    // Don't accept the key if its birth height is not a valid height.
107                                    Err(_) => return None,
108                                }
109                            } else {
110                                // Use the Sapling activation height if the key has no birth height.
111                                sapling_activation_height
112                            };
113
114                            sapling_key_to_dfvk(&key.0, network)
115                                .ok()
116                                .map(|parsed| (key.0, (parsed, birth_height)))
117                        })
118                        .collect();
119
120                    // Send the accepted keys back.
121                    let _ = rsp_tx.send(keys.iter().map(|key| key.0.clone()).collect());
122
123                    new_keys.extend(keys.clone());
124
125                    registered_keys.extend(keys.into_iter().map(|(key, (dfvk, _))| (key, dfvk)));
126                }
127
128                ScanTaskCommand::RemoveKeys { done_tx, keys } => {
129                    for key in keys {
130                        registered_keys.remove(&key);
131                        new_keys.remove(&key);
132                    }
133
134                    // Ignore send errors for the done notification, caller is expected to use a timeout.
135                    let _ = done_tx.send(());
136                }
137
138                ScanTaskCommand::SubscribeResults { rsp_tx, keys } => {
139                    let keys: Vec<_> = keys
140                        .into_iter()
141                        .filter(|key| registered_keys.contains_key(key))
142                        .collect();
143
144                    if keys.is_empty() {
145                        continue;
146                    }
147
148                    let (result_sender, result_receiver) =
149                        tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE);
150
151                    new_result_receivers.push((result_receiver, rsp_tx));
152
153                    for key in keys {
154                        new_result_senders.insert(key, result_sender.clone());
155                    }
156                }
157            }
158        }
159
160        Ok((new_keys, new_result_senders, new_result_receivers))
161    }
162
163    /// Sends a command to the scan task
164    pub fn send(
165        &mut self,
166        command: ScanTaskCommand,
167    ) -> Result<(), tokio::sync::mpsc::error::TrySendError<ScanTaskCommand>> {
168        self.cmd_sender.try_send(command)
169    }
170
171    /// Sends a message to the scan task to remove the provided viewing keys.
172    ///
173    /// Returns a oneshot channel receiver to notify the caller when the keys have been removed.
174    pub fn remove_keys(
175        &mut self,
176        keys: Vec<String>,
177    ) -> Result<oneshot::Receiver<()>, TrySendError<ScanTaskCommand>> {
178        let (done_tx, done_rx) = oneshot::channel();
179
180        self.send(ScanTaskCommand::RemoveKeys { keys, done_tx })?;
181
182        Ok(done_rx)
183    }
184
185    /// Sends a message to the scan task to start scanning for the provided viewing keys.
186    pub fn register_keys(
187        &mut self,
188        keys: Vec<(String, Option<u32>)>,
189    ) -> Result<oneshot::Receiver<Vec<String>>, TrySendError<ScanTaskCommand>> {
190        let (rsp_tx, rsp_rx) = oneshot::channel();
191
192        self.send(ScanTaskCommand::RegisterKeys { keys, rsp_tx })?;
193
194        Ok(rsp_rx)
195    }
196
197    /// Sends a message to the scan task to start sending the results for the provided viewing keys to a channel.
198    ///
199    /// Returns the channel receiver.
200    pub fn subscribe(
201        &mut self,
202        keys: HashSet<SaplingScanningKey>,
203    ) -> Result<oneshot::Receiver<Receiver<ScanResult>>, TrySendError<ScanTaskCommand>> {
204        let (rsp_tx, rsp_rx) = oneshot::channel();
205
206        self.send(ScanTaskCommand::SubscribeResults { keys, rsp_tx })?;
207
208        Ok(rsp_rx)
209    }
210}