Skip to content

Commit

Permalink
refactor to replace v1 by v2 and cleanup suffix (#791)
Browse files Browse the repository at this point in the history
Previously to quick verify idea and avoid massive change, there are new
functionality with suffix `_v2`. After experiment with good result, long
time ago all logic already stick to v2 version and no longer use v1.
This PR clean up all leftover v1 version, do renaming and file
replacement without modify existing logic.

In summary
- `sumcheck/src/prover_v2.rs` ->  `sumcheck/src/prover.rs`
- `multilinear_extensions/src/virtual_poly_v2.rs` ->
`multilinear_extensions/src/virtual_poly.rs`
- clean up all `V2` suffix

This addressed previous out-dated PR
#162, and as a preparation for
#788,
#702
  • Loading branch information
hero78119 authored Jan 2, 2025
1 parent f91ec06 commit 0c39f7f
Show file tree
Hide file tree
Showing 24 changed files with 497 additions and 1,646 deletions.
2 changes: 1 addition & 1 deletion ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::SmallField;

use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use multilinear_extensions::virtual_poly::ArcMultilinearExtension;

use crate::{
circuit_builder::CircuitBuilder,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use ff_ext::ExtensionField;
use generic_static::StaticTypeMap;
use goldilocks::{GoldilocksExt2, SmallField};
use itertools::{Itertools, chain, enumerate, izip};
use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension};
use multilinear_extensions::{mle::IntoMLEs, virtual_poly::ArcMultilinearExtension};
use rand::thread_rng;
use std::{
cmp::max,
Expand Down
13 changes: 6 additions & 7 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension},
util::ceil_log2,
virtual_poly::build_eq_x_r_vec,
virtual_poly_v2::ArcMultilinearExtension,
virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec},
};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverStateV2},
structs::{IOPProverMessage, IOPProverState},
};
use transcript::{ForkableTranscript, Transcript};

Expand Down Expand Up @@ -583,7 +582,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}

tracing::debug!("main sel sumcheck start");
let (main_sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys(
let (main_sel_sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
transcript,
Expand Down Expand Up @@ -1029,7 +1028,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha);
}

let (same_r_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys(
let (same_r_sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
transcript,
Expand Down Expand Up @@ -1241,7 +1240,7 @@ impl TowerProver {
layer_polys
.iter()
.all(|f| {
f.evaluations().len() == (1 << (log_num_fanin * round))
f.evaluations().len() == 1 << (log_num_fanin * round)
})
);

Expand Down Expand Up @@ -1287,7 +1286,7 @@ impl TowerProver {
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
// shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys`
// This is likely a bug in the tracing-flame crate.
let (sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys(
let (sumcheck_proofs, state) = IOPProverState::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
transcript,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme};
use multilinear_extensions::{
mle::IntoMLE, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension,
mle::IntoMLE, util::ceil_log2, virtual_poly::ArcMultilinearExtension,
};
use transcript::{BasicTranscript, BasicTranscriptWithStat, StatisticRecorder, Transcript};

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use multilinear_extensions::{
mle::{DenseMultilinearExtension, FieldType, IntoMLE},
op_mle_xa_b, op_mle3_range,
util::ceil_log2,
virtual_poly_v2::ArcMultilinearExtension,
virtual_poly::ArcMultilinearExtension,
};
use rayon::{
iter::{
Expand Down Expand Up @@ -415,7 +415,7 @@ mod tests {
commutative_op_mle_pair,
mle::{FieldType, IntoMLE},
util::ceil_log2,
virtual_poly_v2::ArcMultilinearExtension,
virtual_poly::ArcMultilinearExtension,
};

use crate::{
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
&VPAuxInfo {
// + 1 from sel_non_lc_zero_sumcheck
max_degree: SEL_DEGREE.max(cs.max_non_lc_degree + 1),
num_variables: log2_num_instances,
max_num_variables: log2_num_instances,
phantom: PhantomData,
},
transcript,
Expand Down Expand Up @@ -634,7 +634,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
},
&VPAuxInfo {
max_degree: SEL_DEGREE,
num_variables: expected_max_rounds,
max_num_variables: expected_max_rounds,
phantom: PhantomData,
},
transcript,
Expand Down Expand Up @@ -904,7 +904,7 @@ impl TowerVerify {
},
&VPAuxInfo {
max_degree: NUM_FANIN + 1, // + 1 for eq
num_variables: (round + 1) * log2_num_fanin,
max_num_variables: (round + 1) * log2_num_fanin,
phantom: PhantomData,
},
transcript,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use ff_ext::ExtensionField;
use itertools::{Itertools, chain};
use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::{
mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension,
mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension,
};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/uint/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ mod tests {
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::{
mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension,
mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension,
};

type E = GoldilocksExt2; // 18446744069414584321
Expand Down
19 changes: 9 additions & 10 deletions ceno_zkvm/src/virtual_polys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use ff_ext::ExtensionField;
use itertools::Itertools;
use multilinear_extensions::{
util::ceil_log2,
virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2},
virtual_poly::{ArcMultilinearExtension, VirtualPolynomial},
};

use crate::{expression::Expression, utils::transpose};

pub struct VirtualPolynomials<'a, E: ExtensionField> {
num_threads: usize,
polys: Vec<VirtualPolynomialV2<'a, E>>,
polys: Vec<VirtualPolynomial<'a, E>>,
/// a storage to keep thread based mles, specific to multi-thread logic
thread_based_mles_storage: HashMap<usize, Vec<ArcMultilinearExtension<'a, E>>>,
}
Expand All @@ -26,7 +26,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
VirtualPolynomials {
num_threads,
polys: (0..num_threads)
.map(|_| VirtualPolynomialV2::new(max_num_variables - ceil_log2(num_threads)))
.map(|_| VirtualPolynomial::new(max_num_variables - ceil_log2(num_threads)))
.collect_vec(),
thread_based_mles_storage: HashMap::new(),
}
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
});
}

pub fn get_batched_polys(self) -> Vec<VirtualPolynomialV2<'a, E>> {
pub fn get_batched_polys(self) -> Vec<VirtualPolynomial<'a, E>> {
self.polys
}

Expand Down Expand Up @@ -174,10 +174,9 @@ mod tests {
use itertools::Itertools;
use multilinear_extensions::{
mle::IntoMLE,
virtual_poly::VPAuxInfo,
virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2},
virtual_poly::{ArcMultilinearExtension, VPAuxInfo, VirtualPolynomial},
};
use sumcheck::structs::{IOPProverStateV2, IOPVerifierState};
use sumcheck::structs::{IOPProverState, IOPVerifierState};
use transcript::BasicTranscript as Transcript;

use crate::{
Expand Down Expand Up @@ -284,7 +283,7 @@ mod tests {
virtual_polys.add_mle_list(f2.iter().collect(), E::ONE);
virtual_polys.add_mle_list(f3.iter().collect(), E::ONE);

let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys(
let (sumcheck_proofs, _) = IOPProverState::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
&mut transcript,
Expand All @@ -296,13 +295,13 @@ mod tests {
&sumcheck_proofs,
&VPAuxInfo {
max_degree: 3,
num_variables: max_num_vars,
max_num_variables: max_num_vars,
phantom: std::marker::PhantomData,
},
&mut transcript,
);

let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars);
let mut verifier_poly = VirtualPolynomial::new(max_num_vars);
verifier_poly.add_mle_list(f1.to_vec(), E::ONE);
verifier_poly.add_mle_list(f2.to_vec(), E::ONE);
verifier_poly.add_mle_list(f3.to_vec(), E::ONE);
Expand Down
2 changes: 1 addition & 1 deletion mpcs/benches/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use mpcs::{

use multilinear_extensions::{
mle::{DenseMultilinearExtension, MultilinearExtension},
virtual_poly_v2::ArcMultilinearExtension,
virtual_poly::ArcMultilinearExtension,
};
use transcript::{BasicTranscript, Transcript};

Expand Down
2 changes: 1 addition & 1 deletion mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ mod commit_phase;
use commit_phase::{batch_commit_phase, commit_phase, simple_batch_commit_phase};
mod encoding;
pub use encoding::{coset_fft, fft, fft_root_table};
use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use multilinear_extensions::virtual_poly::ArcMultilinearExtension;

mod query_phase;
// This sumcheck module is different from the mpcs::sumcheck module, in that
Expand Down
4 changes: 2 additions & 2 deletions mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ pub use basefold::{
EncodingScheme, RSCode, RSCodeDefaultSpec, coset_fft, fft, fft_root_table, one_level_eval_hc,
one_level_interp_hc,
};
use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
use multilinear_extensions::virtual_poly::ArcMultilinearExtension;

fn validate_input<E: ExtensionField>(
function: &str,
Expand Down Expand Up @@ -377,7 +377,7 @@ pub mod test_util {
use multilinear_extensions::mle::DenseMultilinearExtension;
#[cfg(test)]
use multilinear_extensions::{
mle::MultilinearExtension, virtual_poly_v2::ArcMultilinearExtension,
mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension,
};
use rand::rngs::OsRng;
#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion multilinear_extensions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
pub mod mle;
pub mod util;
pub mod virtual_poly;
pub mod virtual_poly_v2;

#[cfg(test)]
mod test;
25 changes: 0 additions & 25 deletions multilinear_extensions/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,6 @@ fn test_virtual_polynomial_additions() {
}
}

#[test]
fn test_virtual_polynomial_mul_by_mle() {
let mut rng = test_rng();
for nv in 2..5 {
for num_products in 2..5 {
let base: Vec<E> = (0..nv).map(|_| E::random(&mut rng)).collect();

let (a, _a_sum) = VirtualPolynomial::<E>::random(nv, (2, 3), num_products, &mut rng);
let (b, _b_sum) = DenseMultilinearExtension::<E>::random_mle_list(nv, 1, &mut rng);
let b_mle = b[0].clone();
let coeff = Goldilocks::random(&mut rng);
let b_vp = VirtualPolynomial::new_from_mle(b_mle.clone(), coeff);

let mut c = a.clone();

c.mul_by_mle(b_mle, coeff);

assert_eq!(
a.evaluate(base.as_ref()) * b_vp.evaluate(base.as_ref()),
c.evaluate(base.as_ref())
);
}
}
}

#[test]
fn test_eq_xr() {
let mut rng = test_rng();
Expand Down
Loading

0 comments on commit 0c39f7f

Please sign in to comment.