Skip to content

Commit

Permalink
Implement Prio3MultiHotHistogram that checks bounded number of 1s
Browse files Browse the repository at this point in the history
Add a reference implementation of the idea in issue #287, which
needs a new FLP 'MultiHotHistogram' and a new Prio3 type
'Prio3MultiHotHistogram' that checks a Client's measurement has a
bounded number of 1s.
  • Loading branch information
junyechen1996 authored and cjpatton committed Oct 16, 2023
1 parent 3d108f4 commit 1d960b2
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 1 deletion.
158 changes: 157 additions & 1 deletion poc/flp_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ class Histogram(Valid):

def __init__(self, length, chunk_length):
"""
Instantiate an instace of the `Histogram` circuit with the given
Instantiate an instance of the `Histogram` circuit with the given
length and chunk_length.
"""

Expand Down Expand Up @@ -646,6 +646,138 @@ def test_vec_set_type_param(self, test_vec):
return ['length', 'chunk_length']


class MultiHotHistogram(Valid):
"""
A validity circuit that checks each Client's measurement is a bit vector
with at most `max_count` number of 1s.
In order to check the Client measurement `meas` has at most `max_count` 1s,
we ask the Client to send an additional `bits_for_count` bits that
encode the number of 1s in the measurement, with an offset. Specifically:
- Let `bits_for_count = max_count.bit_length()`, i.e. the number of bits
to represent `max_count`.
- Let `offset = 2**bits_for_count - 1 - max_count`.
- Client will encode `count = offset + \sum_i meas_i` in
`bits_for_count` bits.
- We can naturally bound `count` as the following:
`0 <= count <= 2**bits_for_count - 1`, and therefore:
`-offset <= \sum_i meas_i <= max_count`.
- Since we also verify each `meas_i` is a bit, we can lower bound the
summation by 0. Therefore, we will be able to verify
`0 <= \sum_i meas_i <= max_count`.
"""
# Operational parameters
length = None # Set by the constructor
max_count = None # Set by constructor
chunk_length = None # Set by constructor

# Associated types
Measurement = list[Unsigned] # A vector of bits.
AggResult = list[Unsigned] # A vector of counts as unsigned integers.
Field = field.Field128

# Associated parameters
GADGETS = None # Set by constructor
GADGET_CALLS = None # Set by constructor
MEAS_LEN = None # Set by constructor
JOINT_RAND_LEN = 2
OUTPUT_LEN = None # Set by constructor

def __init__(self, length, max_count, chunk_length):
"""
Instantiate an instance of the `MultiHotHistogram` circuit with the
given length, max_count, and chunk_length.
"""
if length <= 0:
raise ValueError('invalid length')
if max_count <= 0 or max_count > length:
raise ValueError('invalid max_count')
if chunk_length <= 0:
raise ValueError('invalid chunk_length')

# Compute the number of bits to represent `max_count`.
self.bits_for_count = max_count.bit_length()
self.offset = self.Field((1 << self.bits_for_count) - 1 - max_count)
# Sanity check `offset + length` doesn't overflow field size,
# because in validity circuit, we will compute `offset + \sum_i meas_i`.
if self.Field.MODULUS - self.offset.as_unsigned() <= length:
raise ValueError('length and max_count are too large '
'for the current field size')

self.length = length
self.max_count = max_count
self.chunk_length = chunk_length
self.GADGETS = [ParallelSum(Mul(), chunk_length)]
# The number of bit entries are `length + bits_for_count`,
# so the number of gadget calls is equal to
# `ceil((length + bits_for_count) / chunk_length)`.
self.GADGET_CALLS = [
(length + self.bits_for_count + chunk_length - 1) // chunk_length
]
self.MEAS_LEN = self.length + self.bits_for_count
self.OUTPUT_LEN = self.length

def eval(self, meas, joint_rand, num_shares):
self.check_valid_eval(meas, joint_rand)

# Check that each bucket is one or zero.
range_check = self.Field(0)
r = joint_rand[0]
r_power = r
shares_inv = self.Field(num_shares).inv()
for i in range(self.GADGET_CALLS[0]):
inputs = [None] * (2 * self.chunk_length)
for j in range(self.chunk_length):
index = i * self.chunk_length + j
if index < len(meas):
meas_elem = meas[index]
else:
meas_elem = self.Field(0)

inputs[j * 2] = r_power * meas_elem
inputs[j * 2 + 1] = meas_elem - shares_inv

r_power *= r

range_check += self.GADGETS[0].eval(self.Field, inputs)

# Check that `offset` plus the sum of the buckets is equal to the
# value claimed by the Client.
count_check = self.offset * shares_inv
for i in range(self.length):
count_check += meas[i]
count_check -= self.Field.decode_from_bit_vector(meas[self.length:])

out = joint_rand[1] * range_check + \
joint_rand[1] ** 2 * count_check
return out

def encode(self, measurement):
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

encoded = list(map(self.Field, measurement))
# Encode the result of `offset + \sum_i measurement_i` into
# `bits_for_count` bits.
count = self.offset + sum(encoded, self.Field(0))
encoded += self.Field.encode_into_bit_vector(
count.as_unsigned(), self.bits_for_count
)
return encoded

def truncate(self, meas):
return meas[:self.length]

def decode(self, output, _num_measurements):
return [bucket_count.as_unsigned() for bucket_count in output]

def test_vec_set_type_param(self, test_vec):
test_vec['length'] = int(self.length)
test_vec['max_count'] = int(self.max_count)
test_vec['chunk_length'] = int(self.chunk_length)
return ['length', 'max_count', 'chunk_length']


class SumVec(Valid):
# Operational parameters
length = None # Set by constructor
Expand Down Expand Up @@ -883,6 +1015,30 @@ def test():
(flp.Field.rand_vec(4), False),
])

# MultiHotHistogram with length = 4, max_count = 2, chunk_length = 2.
flp = FlpGeneric(MultiHotHistogram(4, 2, 2))
# Successful cases:
cases = [
(flp.encode([0, 0, 0, 0]), True),
(flp.encode([0, 1, 0, 0]), True),
(flp.encode([0, 1, 1, 0]), True),
(flp.encode([1, 1, 0, 0]), True),
]
# Failure cases: too many number of 1s, should fail count check.
cases += [
(
[flp.Field(1)] * i +
[flp.Field(0)] * (flp.Valid.length - i) +
# Try to lie about the encoded count.
[flp.Field(0)] * flp.Valid.bits_for_count,
False
)
for i in range(flp.Valid.max_count + 1, flp.Valid.length + 1)
]
# Failure case: pass count check but fail bit check.
cases += [(flp.encode([flp.Field.MODULUS - 1, 1, 0, 0]), False)]
test_flp_generic(flp, cases)

# SumVec with length 2, bits 4, chunk len 1.
test_encode_truncate_decode_with_fft_fields(SumVec,
[[1, 2], [3, 4], [5, 6], [7, 8]],
Expand Down
47 changes: 47 additions & 0 deletions poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,30 @@ class Prio3SumVecWithMultiproofAndParams(cls):
return Prio3SumVecWithMultiproofAndParams


class Prio3MultiHotHistogram(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
# Private codepoint just for testing.
ID = 0xFFFFFFFF

# Operational parameters.
test_vec_name = 'Prio3MultiHotHistogram'

@classmethod
def with_params(Prio3MultiHotHistogram,
length: Unsigned,
max_count: Unsigned,
chunk_length: Unsigned):
class Prio3MultiHotHistogramWithParams(Prio3MultiHotHistogram):
Flp = flp_generic.FlpGeneric(flp_generic.MultiHotHistogram(
length, max_count, chunk_length
))
return Prio3MultiHotHistogramWithParams


##
# TESTS
#
Expand Down Expand Up @@ -710,6 +734,29 @@ def test_prio3sumvec_with_multiproof():
test_vec_instance=1,
)

# Prio3MultiHotHistogram with length = 4, max_count = 2, chunk_length = 2.
cls = Prio3MultiHotHistogram \
.with_params(4, 2, 2) \
.with_shares(num_shares)
assert cls.ID == 0xFFFFFFFF
test_vdaf(cls, None, [[0, 0, 0, 0]], [0, 0, 0, 0])
test_vdaf(cls, None, [[0, 1, 0, 0]], [0, 1, 0, 0])
test_vdaf(cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0])
test_vdaf(cls, None, [[0, 1, 1, 0], [0, 1, 0, 1]], [0, 2, 1, 1])
test_vdaf(
cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0], print_test_vec=TEST_VECTOR
)
# Prio3MultiHotHistogram with length = 11, max_count = 5, chunk_length = 3.
cls = Prio3MultiHotHistogram.with_params(11, 5, 3).with_shares(3)
test_vdaf(
cls,
None,
[[1] * 5 + [0] * 6],
[1] * 5 + [0] * 6,
print_test_vec=TEST_VECTOR,
test_vec_instance=1,
)

cls = TestPrio3Average.with_bits(3).with_shares(num_shares)
test_vdaf(cls, None, [1, 5, 1, 1, 4, 1, 3, 2], 2)

Expand Down

0 comments on commit 1d960b2

Please sign in to comment.