From a19102387291e5e393cf8b035dc554078668e139 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 25 Nov 2024 20:10:30 -0500 Subject: [PATCH] Restore transcript consistency --- spartan_parallel/src/dense_mlpoly.rs | 46 +++++++++++++++++--- spartan_parallel/src/lib.rs | 7 ++++ spartan_parallel/src/nizk/bullet.rs | 1 + spartan_parallel/src/nizk/mod.rs | 60 +++++++++++++++++++++++---- spartan_parallel/src/product_tree.rs | 5 +-- spartan_parallel/src/r1csproof.rs | 7 ++++ spartan_parallel/src/sparse_mlpoly.rs | 19 +++------ spartan_parallel/src/sumcheck.rs | 48 ++++++++++++++++++++- 8 files changed, 162 insertions(+), 31 deletions(-) diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index a94bf6d7..fe5ad3f3 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -370,19 +370,34 @@ impl PolyEvalProof { pub fn verify( &self, - _transcript: &mut Transcript, - _r: &[S], // point at which the polynomial is evaluated + transcript: &mut Transcript, + r: &[S], // point at which the polynomial is evaluated ) -> Result<(), ProofVerifyError> { + >::append_protocol_name( + transcript, + PolyEvalProof::::protocol_name(), + ); + + // compute L and R + let eq = EqPolynomial::new(r.to_vec()); + let (L, R) = eq.compute_factored_evals(); + + let _ = self + .proof + .verify(R.len(), transcript, &R); + // TODO: Alternative PCS Verification Ok(()) } pub fn verify_plain( &self, - _transcript: &mut Transcript, - _r: &[S], // point at which the polynomial is evaluated + transcript: &mut Transcript, + r: &[S], // point at which the polynomial is evaluated _Zr: &S, // evaluation \widetilde{Z}(r) ) -> Result<(), ProofVerifyError> { + self.verify(transcript, r); + // TODO: Alternative PCS Verification Ok(()) } @@ -758,6 +773,7 @@ impl PolyEvalProof { } let mut proof_list = Vec::new(); + for i in 0..LZ_list.len() { let L = &L_list[i]; let L_size = L.len(); @@ -781,8 +797,10 @@ impl PolyEvalProof { &Zc_list[i], blind_Zr, ); + proof_list.push(PolyEvalProof { proof }); } + proof_list } @@ -801,6 +819,7 @@ impl PolyEvalProof { // We need one proof per poly size let mut index_map: HashMap<(usize, usize), usize> = HashMap::new(); + let mut LZ_list: Vec = Vec::new(); let mut L_list = Vec::new(); let mut R_list = Vec::new(); @@ -815,7 +834,11 @@ impl PolyEvalProof { if let Some(index) = index_map.get(&(num_proofs, num_inputs)) { c = c * c_base; let _L = &L_list[*index]; + + let LZ = S::field_zero(); + LZ_list[*index] = LZ_list[*index] + c * LZ; } else { + index_map.insert((num_proofs, num_inputs), LZ_list.len()); let num_vars_q = num_proofs.log_2(); let num_vars_y = num_inputs.log_2(); // pad or trim rq and ry to correct length @@ -837,11 +860,24 @@ impl PolyEvalProof { eq.compute_factored_evals() }; // compute a weighted sum of commitments and L + let LZ = S::field_zero(); L_list.push(L); - R_list.push(R); + R_list.push(R); + LZ_list.push(LZ); } } + assert_eq!(LZ_list.len(), proof_list.len()); + + // Verify proofs + for i in 0..LZ_list.len() { + let R = &R_list[i]; + + proof_list[i] + .proof + .verify(R.len(), transcript, R)?; + } + Ok(()) } diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 8658077f..21850a32 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -1846,6 +1846,7 @@ impl SNARK { &block_w3_prover, &block_w3_shifted_prover, ]; + let (block_r1cs_sat_proof, block_challenges) = { let (proof, block_challenges) = { R1CSProof::prove( @@ -1867,6 +1868,7 @@ impl SNARK { (proof, block_challenges) }; + // Final evaluation on BLOCK let (block_inst_evals_bound_rp, block_inst_evals_list, block_r1cs_eval_proof_list) = { let [rp, _, rx, ry] = block_challenges; @@ -1881,6 +1883,7 @@ impl SNARK { for r in &inst_evals_list { S::append_field_to_transcript(b"ABCr_claim", transcript, *r); } + // Sample random combinations of A, B, C for inst_evals_bound_rp check in the Verifier // The random values are not used by the prover, but need to be appended to the transcript let _: S = transcript.challenge_scalar(b"challenge_c0"); @@ -1901,6 +1904,7 @@ impl SNARK { transcript, &mut random_tape, ); + let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len())); @@ -2864,6 +2868,7 @@ impl SNARK { &block_w3_verifier, &block_w3_shifted_verifier, ]; + let block_challenges = self.block_r1cs_sat_proof.verify( block_num_instances, block_max_num_proofs, @@ -2883,6 +2888,7 @@ impl SNARK { for r in &self.block_inst_evals_list { S::append_field_to_transcript(b"ABCr_claim", transcript, *r); } + // Sample random combinations of A, B, C for inst_evals_bound_rp check let c0: S = transcript.challenge_scalar(b"challenge_c0"); let c1: S = transcript.challenge_scalar(b"challenge_c1"); @@ -2908,6 +2914,7 @@ impl SNARK { transcript, )?; } + // Permute block_inst_evals_list to the correct order for RP evaluation let _ABC_evals: Vec = (0..block_num_instances) .map(|i| ABC_evals[block_index[i]]) diff --git a/spartan_parallel/src/nizk/bullet.rs b/spartan_parallel/src/nizk/bullet.rs index a9d4370e..549c924d 100644 --- a/spartan_parallel/src/nizk/bullet.rs +++ b/spartan_parallel/src/nizk/bullet.rs @@ -56,6 +56,7 @@ impl BulletReductionProof { let (blind_L, blind_R) = blinds_iter.next().unwrap(); let u: S = transcript.challenge_scalar(b"u"); + let u_inv = u.invert().unwrap(); for i in 0..n { diff --git a/spartan_parallel/src/nizk/mod.rs b/spartan_parallel/src/nizk/mod.rs index 4bb6a96b..25728835 100644 --- a/spartan_parallel/src/nizk/mod.rs +++ b/spartan_parallel/src/nizk/mod.rs @@ -44,7 +44,14 @@ impl KnowledgeProof { KnowledgeProof { z1, z2 } } - pub fn verify(&self, _transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + pub fn verify(&self, transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + >::append_protocol_name( + transcript, + KnowledgeProof::::protocol_name(), + ); + + let c: S = transcript.challenge_scalar(b"c"); + // TODO: Alternative PCS Verification Ok(()) } @@ -81,7 +88,14 @@ impl EqualityProof { EqualityProof { z } } - pub fn verify(&self, _transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + pub fn verify(&self, transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + >::append_protocol_name( + transcript, + EqualityProof::::protocol_name(), + ); + + let c: S = transcript.challenge_scalar(b"c"); + // TODO: Alternative PCS Verification Ok(()) } @@ -136,7 +150,14 @@ impl ProductProof { true } - pub fn verify(&self, _transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + pub fn verify(&self, transcript: &mut Transcript) -> Result<(), ProofVerifyError> { + >::append_protocol_name( + transcript, + ProductProof::::protocol_name(), + ); + + let c: S = transcript.challenge_scalar(b"c"); + // TODO: Alternative PCS Verification Ok(()) } @@ -183,6 +204,7 @@ impl DotProductProof { let _dotproduct_a_d = DotProductProof::compute_dotproduct(a_vec, &d_vec); + S::append_field_vector_to_transcript(b"a", transcript, a_vec); let c: S = transcript.challenge_scalar(b"c"); let z = (0..d_vec.len()) @@ -201,7 +223,8 @@ impl DotProductProof { DotProductProof::::protocol_name(), ); S::append_field_vector_to_transcript(b"a", transcript, a); - let _c: S = transcript.challenge_scalar(b"c"); + let c: S = transcript.challenge_scalar(b"c"); + let _dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, a); // TODO: Alternative PCS Verification @@ -275,10 +298,33 @@ impl DotProductProofLog { pub fn verify( &self, - _n: usize, - _transcript: &mut Transcript, - _a: &[S], + n: usize, + transcript: &mut Transcript, + a: &[S], ) -> Result<(), ProofVerifyError> { + assert_eq!(a.len(), n); + + >::append_protocol_name( + transcript, + DotProductProofLog::::protocol_name(), + ); + + S::append_field_vector_to_transcript(b"a", transcript, a); + + // sample a random base and scale the generator used for + // the output of the inner product + let r: S = transcript.challenge_scalar(b"r"); + + // BulletReductionProof - verification_scalars + let mut m = a.len(); + while m != 1 { + m /= 2; + + let u: S = transcript.challenge_scalar(b"u"); + } + + let c: S = transcript.challenge_scalar(b"c"); + // TODO: Alternative PCS Verification Ok(()) } diff --git a/spartan_parallel/src/product_tree.rs b/spartan_parallel/src/product_tree.rs index c3b2f05a..c12ad115 100644 --- a/spartan_parallel/src/product_tree.rs +++ b/spartan_parallel/src/product_tree.rs @@ -322,6 +322,7 @@ impl ProductCircuitEvalProofBatched { // produce a fresh set of coeffs and a joint claim let coeff_vec = transcript.challenge_vector(b"rand_coeffs_next_layer", claims_to_verify.len()); + let claim = (0..claims_to_verify.len()) .map(|i| claims_to_verify[i] * coeff_vec[i]) .sum(); @@ -407,7 +408,7 @@ impl ProductCircuitEvalProofBatched { .map(|i| claims_to_verify[i] * coeff_vec[i]) .sum(); - let (_claim_last, rand_prod) = self.proof[i].verify(claim, num_rounds, 3, transcript); + let (claim_last, rand_prod) = self.proof[i].verify(claim, num_rounds, 3, transcript); let claims_prod_left = &self.proof[i].claims_prod_left; let claims_prod_right = &self.proof[i].claims_prod_right; @@ -446,9 +447,7 @@ impl ProductCircuitEvalProofBatched { } } - /* TODO: IMPORTANT, DEBUG, CHECK FAIL assert_eq!(claim_expected, claim_last); - */ // produce a random challenge let r_layer = transcript.challenge_scalar(b"challenge_r_layer"); diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index fa787af0..0e794b0b 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -251,6 +251,7 @@ impl R1CSProof { transcript, random_tape, ); + assert_eq!(poly_tau_p.len(), 1); assert_eq!(poly_tau_q.len(), 1); assert_eq!(poly_tau_x.len(), 1); @@ -464,6 +465,7 @@ impl R1CSProof { } } } + let proof_eval_vars_at_ry_list = PolyEvalProof::prove_batched_instances_disjoint_rounds( &poly_list, &num_proofs_list, @@ -752,6 +754,11 @@ impl R1CSProof { timer_commit_opening.stop(); + // verify proof that expected_claim_post_phase2 == claim_post_phase2 + self.proof_eq_sc_phase2.verify( + transcript, + )?; + Ok([rp, rq_rev, rx, [rw, ry].concat()]) } } diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 89d15eb0..140c830b 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -831,19 +831,17 @@ impl HashLayerProof { let eval_init_addr = IdentityPolynomial::new(rand_mem.len()).evaluate(rand_mem); let eval_init_val = EqPolynomial::new(r.to_vec()).evaluate(rand_mem); let hash_init_at_rand_mem = - hash_func(&eval_init_addr, &eval_init_val, &S::field_zero()) - *r_multiset_check; // verify the claim_last of init chunk - /* TODO: IMPORTANT, DEBUG, CHECK FAIL - assert_eq!(&hash_init_at_rand_mem, claim_init); - */ + hash_func(&eval_init_addr, &eval_init_val, &S::field_zero()) - *r_multiset_check; + + // verify the claim_last of init chunk + assert_eq!(&hash_init_at_rand_mem, claim_init); // read for i in 0..eval_ops_addr.len() { let hash_read_at_rand_ops = hash_func(&eval_ops_addr[i], &eval_ops_val[i], &eval_read_ts[i]) - *r_multiset_check; // verify the claim_last of init chunk - /* TODO: IMPORTANT, DEBUG, CHECK FAIL assert_eq!(&hash_read_at_rand_ops, &claim_read[i]); - */ } // write: shares addr, val component; only decommit write_ts @@ -852,9 +850,7 @@ impl HashLayerProof { let hash_write_at_rand_ops = hash_func(&eval_ops_addr[i], &eval_ops_val[i], &eval_write_ts) - *r_multiset_check; // verify the claim_last of init chunk - /* TODO: IMPORTANT, DEBUG, CHECK FAIL assert_eq!(&hash_write_at_rand_ops, &claim_write[i]); - */ } // audit: shares addr and val with init @@ -862,9 +858,7 @@ impl HashLayerProof { let eval_audit_val = eval_init_val; let hash_audit_at_rand_mem = hash_func(&eval_audit_addr, &eval_audit_val, eval_audit_ts) - *r_multiset_check; - /* TODO: IMPORTANT, DEBUG, CHECK FAIL assert_eq!(&hash_audit_at_rand_mem, claim_audit); // verify the last step of the sum-check for audit - */ Ok(()) } @@ -905,11 +899,9 @@ impl HashLayerProof { let claim_col_ops_val = claims_dotp[3 * i + 1]; let claim_val = claims_dotp[3 * i + 2]; - /* TODO: IMPORTANT, DEBUG, CHECK FAIL assert_eq!(claim_row_ops_val, eval_row_ops_val[i]); assert_eq!(claim_col_ops_val, eval_col_ops_val[i]); - assert_eq!(claim_val, eval_val_vec[i]);\ - */ + assert_eq!(claim_val, eval_val_vec[i]); } // verify addr-timestamps using comm_comb_ops at rand_ops @@ -1170,7 +1162,6 @@ impl ProductLayerProof { transcript, ProductLayerProof::::protocol_name(), ); - let timer = Timer::new("verify_prod_proof"); let num_instances = eval.len(); diff --git a/spartan_parallel/src/sumcheck.rs b/spartan_parallel/src/sumcheck.rs index d72b3f01..68c97aba 100644 --- a/spartan_parallel/src/sumcheck.rs +++ b/spartan_parallel/src/sumcheck.rs @@ -57,6 +57,9 @@ impl SumcheckInstanceProof { // derive the verifier's challenge for the next round let r_i = transcript.challenge_scalar(b"challenge_nextround"); + // scalar_debug + // println!("=> SumcheckInstanceProof-verify, challenge round {:?} - {:?}", i, r_i); + r.push(r_i); // evaluate the claimed degree-ell polynomial at r_i @@ -80,14 +83,51 @@ impl ZKSumcheckInstanceProof { pub fn verify( &self, num_rounds: usize, - _degree_bound: usize, + degree_bound: usize, transcript: &mut Transcript, ) -> Result, ProofVerifyError> { let mut r: Vec = Vec::new(); - for _i in 0..num_rounds { + for i in 0..num_rounds { // derive the verifier's challenge for the next round let r_i = transcript.challenge_scalar(b"challenge_nextround"); + + // verify the proof of sum-check and evals + let res = { + // produce two weights + let w: Vec = transcript.challenge_vector(b"combine_two_claims_to_one", 2); + + let a = { + // the vector to use to decommit for sum-check test + let a_sc = { + let mut a = vec![S::field_one(); degree_bound + 1]; + a[0] = a[0] + S::field_one(); + a + }; + + // the vector to use to decommit for evaluation + let a_eval = { + let mut a = vec![S::field_one(); degree_bound + 1]; + for j in 1..a.len() { + a[j] = a[j - 1] * r_i; + } + a + }; + + // take weighted sum of the two vectors using w + assert_eq!(a_sc.len(), a_eval.len()); + (0..a_sc.len()) + .map(|i| w[0] * a_sc[i] + w[1] * a_eval[i]) + .collect::>() + }; + + self.proofs[i] + .verify( + transcript, + &a, + ) + .is_ok() + }; r.push(r_i); } @@ -291,6 +331,10 @@ impl SumcheckInstanceProof { //derive the verifier's challenge for the next round let r_j = transcript.challenge_scalar(b"challenge_nextround"); + + // scalar_debug + // println!("=> prove_cubic_batched, challenge round {:?} - {:?}", _j, r_j); + r.push(r_j); // bound all tables to the verifier's challenege