Skip to content

Commit

Permalink
Poplar1: Check for trailing bits in agg param
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton committed Jan 8, 2025
1 parent c0e76dc commit 6effaf9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
11 changes: 11 additions & 0 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -5252,8 +5252,19 @@ Decoding involves the following procedure:

~~~ python
prefixes = []

last_byte_mask = 0
leftover_bits = (level + 1) % 8
if leftover_bits > 0:
for bit_index in range(8 - leftover_bits, 8):
last_byte_mask |= 1 << bit_index;
last_byte_mask ^= 255

bytes_per_prefix = ((level + 1) + 7) // 8
for chunk in itertools.batched(encoded_prefixes, bytes_per_prefix):
if chunk[-1] & last_byte_mask > 0:
raise ValueError('trailing bits in prefix')

prefix = []
for i in range(level + 1):
byte_index = i // 8
Expand Down
39 changes: 39 additions & 0 deletions poc/tests/test_vdaf_poplar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,45 @@ def test_aggregation_parameter_encoding(self) -> None:
self.assertEqual(want, cls.decode_agg_param(
cls.encode_agg_param(want)))

def test_aggregation_parameter_encoding_clear_trailing_bits(self) -> None:
cls = Poplar1(256)

# Set the first bit of the first prefix, which should be cleared.
malformed = bytearray(cls.encode_agg_param(
(6, (
(False,) * 7,
(True,) * 7,
))))
malformed[6] |= 1
with self.assertRaises(ValueError):
cls.decode_agg_param(malformed)

# Set the first bit of the second prefix, which should be cleared.
malformed = bytearray(cls.encode_agg_param(
(6, (
(False,) * 7,
(True,) * 7,
))))
malformed[7] |= 1
with self.assertRaises(ValueError):
cls.decode_agg_param(malformed)

# Try a longer prefix.
malformed = bytearray(cls.encode_agg_param(
(110, (
(False,) * 111,
))))
malformed[19] |= 1
with self.assertRaises(ValueError):
cls.decode_agg_param(malformed)

# Try setting each bit following the first level.
for level in range(1, 8):
malformed = bytearray(cls.encode_agg_param((0, ((True,),))))
malformed[6] |= 1 << (7 - level)
with self.assertRaises(ValueError):
cls.decode_agg_param(malformed)

def test_generate_test_vectors(self) -> None:
# Generate test vectors.
cls = Poplar1(4)
Expand Down
11 changes: 11 additions & 0 deletions poc/vdaf_poc/vdaf_poplar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,19 @@ def decode_agg_param(self, encoded: bytes) -> Poplar1AggParam:
# before de-indenting, to avoid warnings from xml2rfc.
# ===================================================================
prefixes = []

last_byte_mask = 0
leftover_bits = (level + 1) % 8
if leftover_bits > 0:
for bit_index in range(8 - leftover_bits, 8):
last_byte_mask |= 1 << bit_index
last_byte_mask ^= 255

bytes_per_prefix = ((level + 1) + 7) // 8
for chunk in itertools.batched(encoded_prefixes, bytes_per_prefix):
if chunk[-1] & last_byte_mask > 0:
raise ValueError('trailing bits in prefix')

prefix = []
for i in range(level + 1):
byte_index = i // 8
Expand Down

0 comments on commit 6effaf9

Please sign in to comment.