From 6effaf908c7b7eb585f43bef2ec1aedca3e4bfb5 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Tue, 7 Jan 2025 14:29:50 -0800 Subject: [PATCH] Poplar1: Check for trailing bits in agg param --- draft-irtf-cfrg-vdaf.md | 11 ++++++++++ poc/tests/test_vdaf_poplar1.py | 39 ++++++++++++++++++++++++++++++++++ poc/vdaf_poc/vdaf_poplar1.py | 11 ++++++++++ 3 files changed, 61 insertions(+) diff --git a/draft-irtf-cfrg-vdaf.md b/draft-irtf-cfrg-vdaf.md index d740cc10..e1ef279d 100644 --- a/draft-irtf-cfrg-vdaf.md +++ b/draft-irtf-cfrg-vdaf.md @@ -5252,8 +5252,19 @@ Decoding involves the following procedure: ~~~ python prefixes = [] + +last_byte_mask = 0 +leftover_bits = (level + 1) % 8 +if leftover_bits > 0: + for bit_index in range(8 - leftover_bits, 8): + last_byte_mask |= 1 << bit_index; + last_byte_mask ^= 255 + bytes_per_prefix = ((level + 1) + 7) // 8 for chunk in itertools.batched(encoded_prefixes, bytes_per_prefix): + if chunk[-1] & last_byte_mask > 0: + raise ValueError('trailing bits in prefix') + prefix = [] for i in range(level + 1): byte_index = i // 8 diff --git a/poc/tests/test_vdaf_poplar1.py b/poc/tests/test_vdaf_poplar1.py index 7d887188..85c98e2f 100644 --- a/poc/tests/test_vdaf_poplar1.py +++ b/poc/tests/test_vdaf_poplar1.py @@ -285,6 +285,45 @@ def test_aggregation_parameter_encoding(self) -> None: self.assertEqual(want, cls.decode_agg_param( cls.encode_agg_param(want))) + def test_aggregation_parameter_encoding_clear_trailing_bits(self) -> None: + cls = Poplar1(256) + + # Set the first bit of the first prefix, which should be cleared. + malformed = bytearray(cls.encode_agg_param( + (6, ( + (False,) * 7, + (True,) * 7, + )))) + malformed[6] |= 1 + with self.assertRaises(ValueError): + cls.decode_agg_param(malformed) + + # Set the first bit of the second prefix, which should be cleared. + malformed = bytearray(cls.encode_agg_param( + (6, ( + (False,) * 7, + (True,) * 7, + )))) + malformed[7] |= 1 + with self.assertRaises(ValueError): + cls.decode_agg_param(malformed) + + # Try a longer prefix. + malformed = bytearray(cls.encode_agg_param( + (110, ( + (False,) * 111, + )))) + malformed[19] |= 1 + with self.assertRaises(ValueError): + cls.decode_agg_param(malformed) + + # Try setting each bit following the first level. + for level in range(1, 8): + malformed = bytearray(cls.encode_agg_param((0, ((True,),)))) + malformed[6] |= 1 << (7 - level) + with self.assertRaises(ValueError): + cls.decode_agg_param(malformed) + def test_generate_test_vectors(self) -> None: # Generate test vectors. cls = Poplar1(4) diff --git a/poc/vdaf_poc/vdaf_poplar1.py b/poc/vdaf_poc/vdaf_poplar1.py index d5a7bc63..6b0c6a06 100644 --- a/poc/vdaf_poc/vdaf_poplar1.py +++ b/poc/vdaf_poc/vdaf_poplar1.py @@ -458,8 +458,19 @@ def decode_agg_param(self, encoded: bytes) -> Poplar1AggParam: # before de-indenting, to avoid warnings from xml2rfc. # =================================================================== prefixes = [] + + last_byte_mask = 0 + leftover_bits = (level + 1) % 8 + if leftover_bits > 0: + for bit_index in range(8 - leftover_bits, 8): + last_byte_mask |= 1 << bit_index + last_byte_mask ^= 255 + bytes_per_prefix = ((level + 1) + 7) // 8 for chunk in itertools.batched(encoded_prefixes, bytes_per_prefix): + if chunk[-1] & last_byte_mask > 0: + raise ValueError('trailing bits in prefix') + prefix = [] for i in range(level + 1): byte_index = i // 8