Skip to content

Commit

Permalink
Move encode() before eval() (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Oct 17, 2024
1 parent ac9aa5c commit dcadeaa
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 100 deletions.
88 changes: 44 additions & 44 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -3770,6 +3770,9 @@ class Count(Valid[int, int, F]):
def __init__(self, field: type[F]):
self.field = field

def encode(self, measurement: int) -> list[F]:
return [self.field(measurement)]

def eval(
self,
meas: list[F],
Expand All @@ -3779,9 +3782,6 @@ class Count(Valid[int, int, F]):
[meas[0], meas[0]])
return [squared - meas[0]]

def encode(self, measurement: int) -> list[F]:
return [self.field(measurement)]

def truncate(self, meas: list[F]) -> list[F]:
return meas

Expand Down Expand Up @@ -3845,6 +3845,18 @@ class Sum(Valid[int, int, F]):
self.MEAS_LEN = 2 * self.bits
self.EVAL_OUTPUT_LEN = 2 * self.bits + 1

def encode(self, measurement: int) -> list[F]:
encoded = []
encoded += self.field.encode_into_bit_vec(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vec(
measurement + self.offset.int(),
self.bits
)
return encoded

def eval(
self,
meas: list[F],
Expand All @@ -3862,18 +3874,6 @@ class Sum(Valid[int, int, F]):
out.append(range_check)
return out

def encode(self, measurement: int) -> list[F]:
encoded = []
encoded += self.field.encode_into_bit_vec(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vec(
measurement + self.offset.int(),
self.bits
)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return [self.field.decode_from_bit_vec(meas[:self.bits])]

Expand Down Expand Up @@ -3957,6 +3957,13 @@ class SumVec(Valid[list[int], list[int], F]):
self.OUTPUT_LEN = length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
encoded = []
for val in measurement:
encoded += self.field.encode_into_bit_vec(
val, self.bits)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -3988,13 +3995,6 @@ class SumVec(Valid[list[int], list[int], F]):

return [out]

def encode(self, measurement: list[int]) -> list[F]:
encoded = []
for val in measurement:
encoded += self.field.encode_into_bit_vec(
val, self.bits)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
truncated = []
for i in range(self.length):
Expand Down Expand Up @@ -4094,6 +4094,11 @@ class Histogram(Valid[int, list[int], F]):
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: int) -> list[F]:
encoded = [self.field(0)] * self.length
encoded[measurement] = self.field(1)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -4131,11 +4136,6 @@ class Histogram(Valid[int, list[int], F]):

return [range_check, sum_check]

def encode(self, measurement: int) -> list[F]:
encoded = [self.field(0)] * self.length
encoded[measurement] = self.field(1)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return meas

Expand Down Expand Up @@ -4232,6 +4232,23 @@ class MultihotCountVec(Valid[list[int], list[int], F]):
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))

encoded = []
encoded += count_vec
encoded += self.field.encode_into_bit_vec(
(self.offset + weight_reported).int(),
self.bits_for_weight)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -4273,23 +4290,6 @@ class MultihotCountVec(Valid[list[int], list[int], F]):

return [range_check, weight_check]

def encode(self, measurement: list[int]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))

encoded = []
encoded += count_vec
encoded += self.field.encode_into_bit_vec(
(self.offset + weight_reported).int(),
self.bits_for_weight)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return meas[:self.length]

Expand Down
112 changes: 56 additions & 56 deletions poc/vdaf_poc/flp_bbcggi19.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ class Count(Valid[int, int, F]):
def __init__(self, field: type[F]):
self.field = field

def encode(self, measurement: int) -> list[F]:
if measurement not in range(2): # REMOVE ME
raise ValueError('measurement out of range') # REMOVE ME
return [self.field(measurement)]

def eval(
self,
meas: list[F],
Expand All @@ -628,11 +633,6 @@ def eval(
[meas[0], meas[0]])
return [squared - meas[0]]

def encode(self, measurement: int) -> list[F]:
if measurement not in range(2): # REMOVE ME
raise ValueError('measurement out of range') # REMOVE ME
return [self.field(measurement)]

def truncate(self, meas: list[F]) -> list[F]:
if len(meas) != 1: # REMOVE ME
raise ValueError('incorrect measurement length') # REMOVE ME
Expand Down Expand Up @@ -675,6 +675,11 @@ def __init__(self,
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: int) -> list[F]:
encoded = [self.field(0)] * self.length
encoded[measurement] = self.field(1)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -714,11 +719,6 @@ def eval(

return [range_check, sum_check]

def encode(self, measurement: int) -> list[F]:
encoded = [self.field(0)] * self.length
encoded[measurement] = self.field(1)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return meas

Expand Down Expand Up @@ -818,6 +818,23 @@ def __init__(self,
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))

encoded = []
encoded += count_vec
encoded += self.field.encode_into_bit_vec(
(self.offset + weight_reported).int(),
self.bits_for_weight)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -861,23 +878,6 @@ def eval(

return [range_check, weight_check]

def encode(self, measurement: list[int]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))

encoded = []
encoded += count_vec
encoded += self.field.encode_into_bit_vec(
(self.offset + weight_reported).int(),
self.bits_for_weight)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return meas[:self.length]

Expand Down Expand Up @@ -936,6 +936,23 @@ def __init__(self,
self.OUTPUT_LEN = length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
# REMOVE ME
if len(measurement) != self.length:
raise ValueError('incorrect measurement length')

encoded = []
for val in measurement:
# REMOVE ME
if val not in range(2**self.bits):
raise ValueError(
'entry of measurement vector is out of range'
)

encoded += self.field.encode_into_bit_vec(
val, self.bits)
return encoded

def eval(
self,
meas: list[F],
Expand Down Expand Up @@ -969,23 +986,6 @@ def eval(

return [out]

def encode(self, measurement: list[int]) -> list[F]:
# REMOVE ME
if len(measurement) != self.length:
raise ValueError('incorrect measurement length')

encoded = []
for val in measurement:
# REMOVE ME
if val not in range(2**self.bits):
raise ValueError(
'entry of measurement vector is out of range'
)

encoded += self.field.encode_into_bit_vec(
val, self.bits)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
truncated = []
for i in range(self.length):
Expand Down Expand Up @@ -1057,6 +1057,18 @@ def __init__(self, field: type[F], max_measurement: int):
self.MEAS_LEN = 2 * self.bits
self.EVAL_OUTPUT_LEN = 2 * self.bits + 1

def encode(self, measurement: int) -> list[F]:
encoded = []
encoded += self.field.encode_into_bit_vec(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vec(
measurement + self.offset.int(),
self.bits
)
return encoded

def eval(
self,
meas: list[F],
Expand All @@ -1075,18 +1087,6 @@ def eval(
out.append(range_check)
return out

def encode(self, measurement: int) -> list[F]:
encoded = []
encoded += self.field.encode_into_bit_vec(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vec(
measurement + self.offset.int(),
self.bits
)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return [self.field.decode_from_bit_vec(meas[:self.bits])]

Expand Down

0 comments on commit dcadeaa

Please sign in to comment.