diff --git a/zcash_client_backend/src/scanning.rs b/zcash_client_backend/src/scanning.rs index 70b6fbf023..3554587c36 100644 --- a/zcash_client_backend/src/scanning.rs +++ b/zcash_client_backend/src/scanning.rs @@ -377,7 +377,7 @@ pub(crate) fn scan_block_with_runner< sapling_keys: &[(&AccountId, SK)], sapling_nullifiers: &[(AccountId, sapling::Nullifier)], prior_block_metadata: Option<&BlockMetadata>, - mut batch_runner: Option<&mut TaggedBatchRunner>, + mut sapling_batch_runner: Option<&mut TaggedBatchRunner>, ) -> Result, ScanError> { if let Some(scan_error) = check_hash_continuity(&block, prior_block_metadata) { return Err(scan_error); @@ -489,42 +489,21 @@ pub(crate) fn scan_block_with_runner< let tx_index = u16::try_from(tx.index).expect("Cannot fit more than 2^16 transactions in a block"); - // Check for spent notes. The comparison against known-unspent nullifiers is done - // in constant time. - // TODO: However, this is O(|nullifiers| * |notes|); does using - // constant-time operations here really make sense? - let mut shielded_spends = vec![]; - let mut sapling_unlinked_nullifiers = Vec::with_capacity(tx.spends.len()); - for (index, spend) in tx.spends.into_iter().enumerate() { - let spend_nf = spend - .nf() - .expect("Could not deserialize nullifier for spend from protobuf representation."); - - // Find the first tracked nullifier that matches this spend, and produce - // a WalletShieldedSpend if there is a match, in constant time. - let spend = sapling_nullifiers - .iter() - .map(|&(account, nf)| CtOption::new(account, nf.ct_eq(&spend_nf))) - .fold(CtOption::new(AccountId::ZERO, 0.into()), |first, next| { - CtOption::conditional_select(&next, &first, first.is_some()) - }) - .map(|account| WalletSaplingSpend::from_parts(index, spend_nf, account)); - - if spend.is_some().into() { - shielded_spends.push(spend.unwrap()); - } else { - // This nullifier didn't match any we are currently tracking; save it in - // case it matches an earlier block range we haven't scanned yet. - sapling_unlinked_nullifiers.push(spend_nf); - } - } + let (sapling_spends, sapling_unlinked_nullifiers) = check_nullifiers( + &tx.spends, + sapling_nullifiers, + |spend| { + spend.nf().expect( + "Could not deserialize nullifier for spend from protobuf representation.", + ) + }, + WalletSaplingSpend::from_parts, + ); sapling_nullifier_map.push((txid, tx_index, sapling_unlinked_nullifiers)); // Collect the set of accounts that were spent from in this transaction - let spent_from_accounts: HashSet<_> = shielded_spends - .iter() - .map(|spend| spend.account()) - .collect(); + let spent_from_accounts: HashSet<_> = + sapling_spends.iter().map(|spend| spend.account()).collect(); // We keep track of the number of outputs and actions here because tx.outputs // and tx.actions end up being moved. @@ -549,7 +528,7 @@ pub(crate) fn scan_block_with_runner< }) .collect::>(); - let decrypted: Vec<_> = if let Some(runner) = batch_runner.as_mut() { + let decrypted: Vec<_> = if let Some(runner) = sapling_batch_runner.as_mut() { let sapling_keys = sapling_keys .iter() .flat_map(|(a, k)| { @@ -643,11 +622,11 @@ pub(crate) fn scan_block_with_runner< } } - if !(shielded_spends.is_empty() && shielded_outputs.is_empty()) { + if !(sapling_spends.is_empty() && shielded_outputs.is_empty()) { wtxs.push(WalletTx { txid, index: tx_index as usize, - sapling_spends: shielded_spends, + sapling_spends, sapling_outputs: shielded_outputs, }); } @@ -699,6 +678,48 @@ pub(crate) fn scan_block_with_runner< )) } +fn check_nullifiers< + Spend, + Nf: ConstantTimeEq + Copy, + WS, + FS: Fn(&Spend) -> Nf, + FWS: Fn(usize, Nf, AccountId) -> WS, +>( + spends: &[Spend], + nullifiers: &[(AccountId, Nf)], + extract_nf: FS, + construct_wallet_spend: FWS, +) -> (Vec, Vec) { + // Check for spent notes. The comparison against known-unspent nullifiers is done + // in constant time. + // TODO: However, this is O(|nullifiers| * |notes|); does using + // constant-time operations here really make sense? + let mut shielded_spends = vec![]; + let mut unlinked_nullifiers = Vec::with_capacity(spends.len()); + for (index, spend) in spends.iter().enumerate() { + let spend_nf = extract_nf(spend); + + // Find the first tracked nullifier that matches this spend, and produce + // a WalletShieldedSpend if there is a match, in constant time. + let spend = nullifiers + .iter() + .map(|&(account, nf)| CtOption::new(account, nf.ct_eq(&spend_nf))) + .fold(CtOption::new(AccountId::ZERO, 0.into()), |first, next| { + CtOption::conditional_select(&next, &first, first.is_some()) + }) + .map(|account| construct_wallet_spend(index, spend_nf, account)); + + if spend.is_some().into() { + shielded_spends.push(spend.unwrap()); + } else { + // This nullifier didn't match any we are currently tracking; save it in + // case it matches an earlier block range we haven't scanned yet. + unlinked_nullifiers.push(spend_nf); + } + } + (shielded_spends, unlinked_nullifiers) +} + #[cfg(test)] mod tests { use group::{