From e11c78808d4e92941c1a03729fc8a4138a92b6c6 Mon Sep 17 00:00:00 2001 From: Albert Liu Date: Mon, 2 Oct 2023 11:27:54 -0700 Subject: [PATCH] POC implementation for mutliproof proposal in #177 * Support multiproofs in Prio3 * Add new Prio3SumVec variant, i.e. Prio3SumVecWithMultiproof, with configuration (field size, number of proofs) * Add with_field class methods to introduce new SumVec with configurable field size --- poc/flp_generic.py | 33 +++++-- poc/vdaf_prio3.py | 222 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 196 insertions(+), 59 deletions(-) diff --git a/poc/flp_generic.py b/poc/flp_generic.py index 4d335f7a..96675c0a 100644 --- a/poc/flp_generic.py +++ b/poc/flp_generic.py @@ -120,6 +120,12 @@ def check_valid_eval(self, meas, joint_rand): if len(joint_rand) != self.JOINT_RAND_LEN: raise ERR_INPUT + @classmethod + def with_field(cls, field: field.FftField): + class _ValidWithField(cls): + Field = field + return _ValidWithField + class ProveGadget: def __init__(self, Field, wire_seeds, g, g_calls): @@ -829,6 +835,22 @@ def decode(self, output, num_measurements): return total // num_measurements +# Test encoding, truncation, then decoding. +def test_encode_truncate_decode(flp, measurements): + for measurement in measurements: + assert measurement == flp.decode( + flp.truncate(flp.encode(measurement)), 1) + + +def test_encode_truncate_decode_with_fft_fields(cls, measurements, *args): + for f in [field.Field64, field.Field96, field.Field128]: + cls_with_field = cls.with_field(f) + assert cls_with_field.Field == f + obj = cls_with_field(*args) + assert isinstance(obj, cls) + test_encode_truncate_decode(FlpGeneric(obj), measurements) + + def test(): flp = FlpGeneric(Count()) test_flp_generic(flp, [ @@ -848,9 +870,7 @@ def test(): (flp.encode(2 ** 10 - 1), True), (flp.Field.rand_vec(10), False), ]) - # Roundtrip test with no proof generated. - for meas in [0, 100, 2 ** 10 - 1]: - assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1) + test_encode_truncate_decode(flp, [0, 100, 2 ** 10 - 1]) flp = FlpGeneric(Histogram(4, 2)) test_flp_generic(flp, [ @@ -864,10 +884,9 @@ def test(): ]) # SumVec with length 2, bits 4, chunk len 1. - flp = FlpGeneric(SumVec(2, 4, 1)) - # Roundtrip test with no proof generated. - for meas in [[1, 2], [3, 4], [5, 6], [7, 8]]: - assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1) + test_encode_truncate_decode_with_fft_fields(SumVec, + [[1, 2], [3, 4], [5, 6], [7, 8]], + 2, 4, 1) flp = FlpGeneric(TestMultiGadget()) test_flp_generic(flp, [ diff --git a/poc/vdaf_prio3.py b/poc/vdaf_prio3.py index fc6fef07..69b4c726 100644 --- a/poc/vdaf_prio3.py +++ b/poc/vdaf_prio3.py @@ -7,6 +7,7 @@ import xof from common import (ERR_INPUT, ERR_VERIFY, TEST_VECTOR, Unsigned, byte, concat, front, vec_add, vec_sub, zeros) +from field import FftField, Field64, Field128 from vdaf import Vdaf, test_vdaf USAGE_MEAS_SHARE = 1 @@ -32,6 +33,9 @@ class Prio3(Vdaf): ROUNDS = 1 SHARES = None # A number between `[2, 256)` set later + # Operational parameters + PROOFS = 1 # Number of proofs, in range `[1, 256)` + # Types required by `Vdaf` Measurement = Flp.Measurement PublicShare = Optional[list[bytes]] # joint randomness parts @@ -73,7 +77,7 @@ def is_valid(agg_param, previous_agg_params): def prep_init(Prio3, verify_key, agg_id, _agg_param, nonce, public_share, input_share): k_joint_rand_parts = public_share - (meas_share, proof_share, k_blind) = \ + (meas_share, proofs_share, k_blind) = \ Prio3.expand_input_share(agg_id, input_share) out_share = Prio3.Flp.truncate(meas_share) @@ -86,18 +90,27 @@ def prep_init(Prio3, verify_key, agg_id, _agg_param, k_joint_rand_parts[agg_id] = k_joint_rand_part k_corrected_joint_rand = Prio3.joint_rand_seed( k_joint_rand_parts) - joint_rand = Prio3.joint_rand(k_corrected_joint_rand) + joint_rands = Prio3.joint_rands(k_corrected_joint_rand) # Query the measurement and proof share. - query_rand = Prio3.query_rand(verify_key, nonce) - verifier_share = Prio3.Flp.query(meas_share, - proof_share, - query_rand, - joint_rand, - Prio3.SHARES) + query_rands = Prio3.query_rands(verify_key, nonce) + verifiers_share = [] + for _ in range(Prio3.PROOFS): + proof_share, proofs_share = front( + Prio3.Flp.PROOF_LEN, proofs_share) + query_rand, query_rands = front( + Prio3.Flp.QUERY_RAND_LEN, query_rands) + if Prio3.Flp.JOINT_RAND_LEN > 0: + joint_rand, joint_rands = front( + Prio3.Flp.JOINT_RAND_LEN, joint_rands) + verifiers_share += Prio3.Flp.query(meas_share, + proof_share, + query_rand, + joint_rand, + Prio3.SHARES) prep_state = (out_share, k_corrected_joint_rand) - prep_share = (verifier_share, k_joint_rand_part) + prep_share = (verifiers_share, k_joint_rand_part) return (prep_state, prep_share) @classmethod @@ -115,16 +128,19 @@ def prep_next(Prio3, prep, prep_msg): @classmethod def prep_shares_to_prep(Prio3, _agg_param, prep_shares): # Unshard the verifier shares into the verifier message. - verifier = Prio3.Flp.Field.zeros(Prio3.Flp.VERIFIER_LEN) + verifiers = Prio3.Flp.Field.zeros( + Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS) k_joint_rand_parts = [] - for (verifier_share, k_joint_rand_part) in prep_shares: - verifier = vec_add(verifier, verifier_share) + for (verifiers_share, k_joint_rand_part) in prep_shares: + verifiers = vec_add(verifiers, verifiers_share) if Prio3.Flp.JOINT_RAND_LEN > 0: k_joint_rand_parts.append(k_joint_rand_part) - # Verify that the proof is well-formed and the input is valid. - if not Prio3.Flp.decide(verifier): - raise ERR_VERIFY # proof verifier check failed + # Verify that each proof is well-formed and accepts the measurement. + for _ in range(Prio3.PROOFS): + verifier, verifiers = front(Prio3.Flp.VERIFIER_LEN, verifiers) + if not Prio3.Flp.decide(verifier): + raise ERR_VERIFY # proof verifier check failed # Combine the joint randomness parts computed by the # Aggregators into the true joint randomness seed. This is @@ -158,7 +174,7 @@ def shard_without_joint_rand(Prio3, meas, seeds): k_helper_seeds[i] for i in range(0, (Prio3.SHARES-1) * 2, 2) ] - k_helper_proof_shares = [ + k_helper_proofs_shares = [ k_helper_seeds[i] for i in range(1, (Prio3.SHARES-1) * 2, 2) ] @@ -172,13 +188,17 @@ def shard_without_joint_rand(Prio3, meas, seeds): Prio3.helper_meas_share(j+1, k_helper_meas_shares[j]), ) - # Generate the proof and shard it into proof shares. - prove_rand = Prio3.prove_rand(k_prove) - leader_proof_share = Prio3.Flp.prove(meas, prove_rand, []) + # Generate the proofs and shard it into proof(s) shares. + prove_rands = Prio3.prove_rands(k_prove) + leader_proofs_share = [] + for _ in range(Prio3.PROOFS): + prove_rand, prove_rands = front( + Prio3.Flp.PROVE_RAND_LEN, prove_rands) + leader_proofs_share += Prio3.Flp.prove(meas, prove_rand, []) for j in range(Prio3.SHARES-1): - leader_proof_share = vec_sub( - leader_proof_share, - Prio3.helper_proof_share(j+1, k_helper_proof_shares[j]), + leader_proofs_share = vec_sub( + leader_proofs_share, + Prio3.helper_proofs_share(j+1, k_helper_proofs_shares[j]), ) # Each Aggregator's input share contains its measurement share @@ -186,13 +206,13 @@ def shard_without_joint_rand(Prio3, meas, seeds): input_shares = [] input_shares.append(( leader_meas_share, - leader_proof_share, + leader_proofs_share, None, )) for j in range(Prio3.SHARES-1): input_shares.append(( k_helper_meas_shares[j], - k_helper_proof_shares[j], + k_helper_proofs_shares[j], None, )) return (None, input_shares) @@ -204,7 +224,7 @@ def shard_with_joint_rand(Prio3, meas, nonce, seeds): k_helper_seeds[i] for i in range(0, (Prio3.SHARES-1) * 3, 3) ] - k_helper_proof_shares = [ + k_helper_proofs_shares = [ k_helper_seeds[i] for i in range(1, (Prio3.SHARES-1) * 3, 3) ] @@ -230,14 +250,21 @@ def shard_with_joint_rand(Prio3, meas, nonce, seeds): 0, k_leader_blind, leader_meas_share, nonce)) # Generate the proof and shard it into proof shares. - prove_rand = Prio3.prove_rand(k_prove) - joint_rand = Prio3.joint_rand( + prove_rands = Prio3.prove_rands(k_prove) + joint_rands = Prio3.joint_rands( Prio3.joint_rand_seed(k_joint_rand_parts)) - leader_proof_share = Prio3.Flp.prove(meas, prove_rand, joint_rand) + leader_proofs_share = [] + for _ in range(Prio3.PROOFS): + prove_rand, prove_rands = front( + Prio3.Flp.PROVE_RAND_LEN, prove_rands) + joint_rand, joint_rands = front( + Prio3.Flp.JOINT_RAND_LEN, joint_rands) + leader_proofs_share += Prio3.Flp.prove(meas, + prove_rand, joint_rand) for j in range(Prio3.SHARES-1): - leader_proof_share = vec_sub( - leader_proof_share, - Prio3.helper_proof_share(j+1, k_helper_proof_shares[j]), + leader_proofs_share = vec_sub( + leader_proofs_share, + Prio3.helper_proofs_share(j+1, k_helper_proofs_shares[j]), ) # Each Aggregator's input share contains its measurement share, @@ -246,13 +273,13 @@ def shard_with_joint_rand(Prio3, meas, nonce, seeds): input_shares = [] input_shares.append(( leader_meas_share, - leader_proof_share, + leader_proofs_share, k_leader_blind, )) for j in range(Prio3.SHARES-1): input_shares.append(( k_helper_meas_shares[j], - k_helper_proof_shares[j], + k_helper_proofs_shares[j], k_helper_blinds[j], )) return (k_joint_rand_parts, input_shares) @@ -268,41 +295,41 @@ def helper_meas_share(Prio3, agg_id, k_share): ) @classmethod - def helper_proof_share(Prio3, agg_id, k_share): + def helper_proofs_share(Prio3, agg_id, k_share): return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, k_share, Prio3.domain_separation_tag(USAGE_PROOF_SHARE), byte(agg_id), - Prio3.Flp.PROOF_LEN, + Prio3.Flp.PROOF_LEN * Prio3.PROOFS, ) @classmethod def expand_input_share(Prio3, agg_id, input_share): - (meas_share, proof_share, k_blind) = input_share + (meas_share, proofs_share, k_blind) = input_share if agg_id > 0: meas_share = Prio3.helper_meas_share(agg_id, meas_share) - proof_share = Prio3.helper_proof_share(agg_id, proof_share) - return (meas_share, proof_share, k_blind) + proofs_share = Prio3.helper_proofs_share(agg_id, proofs_share) + return (meas_share, proofs_share, k_blind) @classmethod - def prove_rand(Prio3, k_prove): + def prove_rands(Prio3, k_prove): return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, k_prove, Prio3.domain_separation_tag(USAGE_PROVE_RANDOMNESS), b'', - Prio3.Flp.PROVE_RAND_LEN, + Prio3.Flp.PROVE_RAND_LEN * Prio3.PROOFS, ) @classmethod - def query_rand(Prio3, verify_key, nonce): + def query_rands(Prio3, verify_key, nonce): return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, verify_key, Prio3.domain_separation_tag(USAGE_QUERY_RANDOMNESS), nonce, - Prio3.Flp.QUERY_RAND_LEN, + Prio3.Flp.QUERY_RAND_LEN * Prio3.PROOFS, ) @classmethod @@ -323,14 +350,15 @@ def joint_rand_seed(Prio3, k_joint_rand_parts): ) @classmethod - def joint_rand(Prio3, k_joint_rand_seed): + def joint_rands(Prio3, k_joint_rand_seed): """Derive the joint randomness from its seed.""" + binder = b'' if Prio3.PROOFS == 1 else byte(Prio3.PROOFS) return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, k_joint_rand_seed, Prio3.domain_separation_tag(USAGE_JOINT_RANDOMNESS), - b'', - Prio3.Flp.JOINT_RAND_LEN, + binder, + Prio3.Flp.JOINT_RAND_LEN * Prio3.PROOFS, ) @classmethod @@ -367,14 +395,15 @@ def test_vec_set_type_param(Prio3, test_vec): @classmethod def test_vec_encode_input_share(Prio3, input_share): - (meas_share, proof_share, k_blind) = input_share + (meas_share, proofs_share, k_blind) = input_share encoded = bytes() - if type(meas_share) == list and type(proof_share) == list: # leader + if type(meas_share) == list and type(proofs_share) == list: # Leader + assert len(proofs_share) == Prio3.Flp.PROOF_LEN * Prio3.PROOFS encoded += Prio3.Flp.Field.encode_vec(meas_share) - encoded += Prio3.Flp.Field.encode_vec(proof_share) - elif type(meas_share) == bytes and type(proof_share) == bytes: # helper + encoded += Prio3.Flp.Field.encode_vec(proofs_share) + elif type(meas_share) == bytes and type(proofs_share) == bytes: # Helper encoded += meas_share - encoded += proof_share + encoded += proofs_share if k_blind != None: # joint randomness used encoded += k_blind return encoded @@ -392,9 +421,10 @@ def test_vec_encode_agg_share(Prio3, agg_share): @classmethod def test_vec_encode_prep_share(Prio3, prep_share): - (verifier_share, k_joint_rand_part) = prep_share + (verifiers_share, k_joint_rand_part) = prep_share encoded = bytes() - encoded += Prio3.Flp.Field.encode_vec(verifier_share) + assert len(verifiers_share) == Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS + encoded += Prio3.Flp.Field.encode_vec(verifiers_share) if k_joint_rand_part != None: # joint randomness used encoded += k_joint_rand_part return encoded @@ -483,6 +513,43 @@ class Prio3HistogramWithLength(Prio3Histogram): return Prio3HistogramWithLength +# Experimental multiproof variant of Prio3SumVec +class Prio3SumVecWithMultiproof(Prio3SumVec): + # Operational parameters. + test_vec_name = 'Prio3SumVecWithMultiproof' + + @staticmethod + def is_recommended(valid_cls, + num_proofs: Unsigned, + f: FftField) -> bool: + # TODO(issue#177) Decide how many proofs to use. + if f == Field64: + # the upper bound is due to the fact + # we encode it using one byte in `joint_rands` + return 2 <= num_proofs < 256 + elif f == Field128: + return 1 <= num_proofs < 256 + return False + + @classmethod + def with_params(cls, + length: Unsigned, + bits: Unsigned, + chunk_length: Unsigned, + num_proofs: Unsigned, + field: FftField): + valid_cls = flp_generic.SumVec.with_field(field) + if not cls.is_recommended(valid_cls, num_proofs, field): + raise ValueError("parameters not recommended") + + class Prio3SumVecWithMultiproofAndParams(cls): + # Associated parameters. + ID = 0xFFFFFFFF + PROOFS = num_proofs + Flp = flp_generic.FlpGeneric(valid_cls(length, bits, chunk_length)) + return Prio3SumVecWithMultiproofAndParams + + ## # TESTS # @@ -506,6 +573,55 @@ class TestPrio3AverageWithBits(TestPrio3Average): return TestPrio3AverageWithBits +def _test_prio3sumvec(num_proofs: Unsigned, field: FftField): + valid_cls = flp_generic.SumVec.with_field(field) + assert Prio3SumVecWithMultiproof.is_recommended( + valid_cls, num_proofs, field) + + cls = Prio3SumVecWithMultiproof \ + .with_params(10, 8, 9, num_proofs, field) \ + .with_shares(2) + + assert cls.ID == 0xFFFFFFFF + assert cls.PROOFS == num_proofs + + test_vdaf( + cls, + None, + [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], + [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] + ) + test_vdaf( + cls, + None, + [ + list(range(10)), + [1] * 10, + [255] * 10 + ], + list(range(256, 266)), + print_test_vec=TEST_VECTOR, + ) + cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) + test_vdaf( + cls, + None, + [ + [10000, 32000, 9], + [19342, 19615, 3061], + [15986, 24671, 23910] + ], + [45328, 76286, 26980], + print_test_vec=TEST_VECTOR, + test_vec_instance=1, + ) + + +def test_prio3sumvec_with_multiproof(): + for n in range(2, 5): + _test_prio3sumvec(num_proofs=n, field=Field64) + + if __name__ == '__main__': num_shares = 2 # Must be in range `[2, 256)` @@ -601,3 +717,5 @@ class TestPrio3AverageWithBits(TestPrio3Average): # otherwise. assert cls.is_valid(None, set([])) assert not cls.is_valid(None, set([None])) + + test_prio3sumvec_with_multiproof()