Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade 5 (Batch Circuit) #1444

Closed
wants to merge 11 commits into from
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions aggregator/src/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,42 @@ pub(crate) use rlc::{RlcConfig, POWS_OF_256};

pub use circuit::BatchCircuit;
pub use config::BatchCircuitConfig;
use halo2_base::halo2_proofs::halo2curves::bn256::{Fr, G1Affine};
use snark_verifier::Protocol;

/// Alias for a list of G1 points.
pub type PreprocessedPolyCommits = Vec<G1Affine>;
/// Alias for the transcript's initial state.
pub type TranscriptInitState = Fr;

/// Alias for the fixed part of the protocol which consists of the commitments to the preprocessed
/// polynomials and the initial state of the transcript.
#[derive(Clone)]
pub struct FixedProtocol {
/// The commitments to the preprocessed polynomials.
pub preprocessed: PreprocessedPolyCommits,
/// The initial state of the transcript.
pub init_state: TranscriptInitState,
}

impl From<Protocol<G1Affine>> for FixedProtocol {
fn from(protocol: Protocol<G1Affine>) -> Self {
Self {
preprocessed: protocol.preprocessed,
init_state: protocol
.transcript_initial_state
.expect("protocol transcript init state None"),
}
}
}

impl From<&Protocol<G1Affine>> for FixedProtocol {
fn from(protocol: &Protocol<G1Affine>) -> Self {
Self {
preprocessed: protocol.preprocessed.clone(),
init_state: protocol
.transcript_initial_state
.expect("protocol transcript init state None"),
}
}
}
156 changes: 138 additions & 18 deletions aggregator/src/aggregation/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ark_std::{end_timer, start_timer};
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
halo2curves::bn256::{Bn256, Fr, G1Affine},
plonk::{Circuit, ConstraintSystem, Error, Selector},
Expand All @@ -11,14 +12,20 @@ use snark_verifier::{
loader::halo2::{
halo2_ecc::{
ecc::EccChip,
fields::fp::FpConfig,
halo2_base::{AssignedValue, Context, ContextParams},
fields::{fp::FpConfig, FieldChip},
halo2_base::{
gates::{GateInstructions, RangeInstructions},
AssignedValue, Context, ContextParams,
QuantumCell::Existing,
},
},
Halo2Loader,
Halo2Loader, IntegerInstructions,
},
pcs::kzg::{Bdfg21, Kzg, KzgSuccinctVerifyingKey},
};
use snark_verifier_sdk::{aggregate, flatten_accumulator, CircuitExt, Snark, SnarkWitness};
use snark_verifier_sdk::{
aggregate_as_witness, flatten_accumulator, CircuitExt, Snark, SnarkWitness,
};
use std::{env, fs::File, rc::Rc};
use zkevm_circuits::util::Challenges;

Expand All @@ -30,8 +37,8 @@ use crate::{
core::{assign_batch_hashes, extract_proof_and_instances_with_pairing_check},
util::parse_hash_digest_cells,
witgen::{zstd_encode, MultiBlockProcessResult},
ConfigParams, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH, PI_CURRENT_STATE_ROOT,
PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT,
ConfigParams, FixedProtocol, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH,
PI_CURRENT_STATE_ROOT, PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT,
};

/// Batch circuit, the chunk aggregation routine below recursion circuit
Expand All @@ -55,14 +62,21 @@ pub struct BatchCircuit<const N_SNARKS: usize> {
// batch hash circuit for which the snarks are generated
// the chunks in this batch are also padded already
pub batch_hash: BatchHash<N_SNARKS>,

/// The SNARK protocol from the halo2-based inner circuit route.
pub halo2_protocol: FixedProtocol,
/// The SNARK protocol from the sp1-based inner circuit route.
pub sp1_protocol: FixedProtocol,
}

impl<const N_SNARKS: usize> BatchCircuit<N_SNARKS> {
pub fn new(
pub fn new<P: Into<FixedProtocol>>(
params: &ParamsKZG<Bn256>,
snarks_with_padding: &[Snark],
rng: impl Rng + Send,
batch_hash: BatchHash<N_SNARKS>,
halo2_protocol: P,
sp1_protocol: P,
) -> Result<Self, snark_verifier::Error> {
let timer = start_timer!(|| "generate aggregation circuit");

Expand Down Expand Up @@ -120,6 +134,8 @@ impl<const N_SNARKS: usize> BatchCircuit<N_SNARKS> {
flattened_instances,
as_proof: Value::known(as_proof),
batch_hash,
halo2_protocol: halo2_protocol.into(),
sp1_protocol: sp1_protocol.into(),
})
}

Expand Down Expand Up @@ -209,22 +225,21 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
let loader: Rc<Halo2Loader<G1Affine, EccChip<Fr, FpConfig<Fr, Fq>>>> =
Halo2Loader::new(ecc_chip, ctx);

//
// extract the assigned values for
// - instances which are the public inputs of each chunk (prefixed with 12
// instances from previous accumulators)
// - new accumulator
//
log::debug!("aggregation: chunk aggregation");
let (assigned_aggregation_instances, acc) = aggregate::<Kzg<Bn256, Bdfg21>>(
let (
assigned_aggregation_instances,
acc,
preprocessed_poly_sets,
transcript_init_states,
) = aggregate_as_witness::<Kzg<Bn256, Bdfg21>>(
&self.svk,
&loader,
&self.snarks_with_padding,
self.as_proof(),
);
for (i, e) in assigned_aggregation_instances[0].iter().enumerate() {
log::trace!("{}-th instance: {:?}", i, e.value)
}

// extract the following cells for later constraints
// - the accumulators
Expand All @@ -238,13 +253,118 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
.iter()
.flat_map(|instance_column| instance_column.iter().skip(ACC_LEN)),
);
for (i, e) in assigned_aggregation_instances[0].iter().enumerate() {
log::trace!("{}-th instance: {:?}", i, e.value)
}

loader
.ctx_mut()
.print_stats(&["snark aggregation [chunks -> batch]"]);
loader.ctx_mut().print_stats(&["snark aggregation"]);

let mut ctx = Rc::into_inner(loader).unwrap().into_ctx();
log::debug!("batching: assigning barycentric");

// We must ensure that the commitments to preprocessed polynomial and initial
// state of transcripts for every SNARK that is being aggregated belongs to the
// fixed set of values expected.
//
// First we load the constants.
let mut preprocessed_polys_halo2 = Vec::with_capacity(7);
let mut preprocessed_polys_sp1 = Vec::with_capacity(7);
for &preprocessed_poly in self.halo2_protocol.preprocessed.iter() {
preprocessed_polys_halo2.push(
config
.ecc_chip()
.assign_constant_point(&mut ctx, preprocessed_poly),
);
}
for &preprocessed_poly in self.sp1_protocol.preprocessed.iter() {
preprocessed_polys_sp1.push(
config
.ecc_chip()
.assign_constant_point(&mut ctx, preprocessed_poly),
);
}
let transcript_init_state_halo2 = config
.ecc_chip()
.field_chip()
.range()
.gate()
.assign_constant(&mut ctx, self.halo2_protocol.init_state)
.expect("IntegerInstructions::assign_constant infallible");
let transcript_init_state_sp1 = config
.ecc_chip()
.field_chip()
.range()
.gate()
.assign_constant(&mut ctx, self.sp1_protocol.init_state)
.expect("IntegerInstructions::assign_constant infallible");

// Commitments to the preprocessed polynomials.
//
// check_1: halo2-route
// check_2: sp1-route
//
// OR(check_1, check_2) == 1
let mut route_check = Vec::with_capacity(N_SNARKS);
for preprocessed_polys in preprocessed_poly_sets.iter() {
let mut preprocessed_check_1 =
config.flex_gate().load_constant(&mut ctx, Fr::ONE);
let mut preprocessed_check_2 =
config.flex_gate().load_constant(&mut ctx, Fr::ONE);
for ((commitment, comm_halo2), comm_sp1) in preprocessed_polys
.iter()
.zip_eq(preprocessed_polys_halo2.iter())
.zip_eq(preprocessed_polys_sp1.iter())
{
let check_1 =
config.ecc_chip().is_equal(&mut ctx, commitment, comm_halo2);
let check_2 =
config.ecc_chip().is_equal(&mut ctx, commitment, comm_sp1);
preprocessed_check_1 = config.flex_gate().and(
&mut ctx,
Existing(preprocessed_check_1),
Existing(check_1),
);
preprocessed_check_2 = config.flex_gate().and(
&mut ctx,
Existing(preprocessed_check_2),
Existing(check_2),
);
}
route_check.push(preprocessed_check_1);
let preprocessed_check = config.flex_gate().or(
&mut ctx,
Existing(preprocessed_check_1),
Existing(preprocessed_check_2),
);
config
.flex_gate()
.assert_is_const(&mut ctx, &preprocessed_check, Fr::ONE);
}

// Transcript initial state.
//
// If the SNARK belongs to halo2-route, the initial state is the halo2-initial
// state. Otherwise sp1-initial state.
for (transcript_init_state, &route) in
transcript_init_states.iter().zip_eq(route_check.iter())
{
let transcript_init_state = transcript_init_state
.expect("SNARK should have an initial state for transcript");
let init_state_expected = config.flex_gate().select(
&mut ctx,
Existing(transcript_init_state_halo2),
Existing(transcript_init_state_sp1),
Existing(route),
);
GateInstructions::assert_equal(
config.flex_gate(),
&mut ctx,
Existing(transcript_init_state),
Existing(init_state_expected),
);
}

ctx.print_stats(&["protocol check"]);

let barycentric = config.blob_consistency_config.assign_barycentric(
&mut ctx,
&self.batch_hash.blob_bytes,
Expand Down
7 changes: 7 additions & 0 deletions aggregator/src/tests/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ fn build_new_batch_circuit<const N_SNARKS: usize>(
})
.collect_vec()
};
let snark_protocol = real_snarks[0].protocol.clone();

// ==========================
// padded chunks
Expand All @@ -225,6 +226,8 @@ fn build_new_batch_circuit<const N_SNARKS: usize>(
[real_snarks, padded_snarks].concat().as_ref(),
rng,
batch_hash,
&snark_protocol,
&snark_protocol,
)
.unwrap()
}
Expand Down Expand Up @@ -293,6 +296,8 @@ fn build_batch_circuit_skip_encoding<const N_SNARKS: usize>() -> BatchCircuit<N_
})
.collect_vec()
};
let snark_protocol = real_snarks[0].protocol.clone();

// ==========================
// padded chunks
// ==========================
Expand All @@ -302,6 +307,8 @@ fn build_batch_circuit_skip_encoding<const N_SNARKS: usize>() -> BatchCircuit<N_
[real_snarks, padded_snarks].concat().as_ref(),
rng,
batch_hash,
&snark_protocol,
&snark_protocol,
)
.unwrap()
}
2 changes: 2 additions & 0 deletions prover/src/aggregator/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ impl<'params> Prover<'params> {
LayerId::Layer3.id(),
LayerId::Layer3.degree(),
batch_info,
&self.halo2_protocol,
&self.sp1_protocol,
&layer2_snarks,
output_dir,
)?;
Expand Down
Loading
Loading