From ae9527af07072ee55a3c460b4a0819a6045752de Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 14 Aug 2018 09:53:25 -0600 Subject: [PATCH] Do not rely on drop() to perform final accumulation step. --- src/circuit/blake2s.rs | 10 +- src/circuit/multieq.rs | 36 ++-- src/circuit/sha256.rs | 378 ++++++++++++++++++++--------------------- src/circuit/uint32.rs | 18 +- 4 files changed, 219 insertions(+), 223 deletions(-) diff --git a/src/circuit/blake2s.rs b/src/circuit/blake2s.rs index 93af8069..ef5084df 100644 --- a/src/circuit/blake2s.rs +++ b/src/circuit/blake2s.rs @@ -15,7 +15,7 @@ use super::uint32::{ UInt32 }; -use super::multieq::MultiEq; +use super::multieq::{MultiEq, multi_eq}; /* 2.1. Parameters @@ -202,9 +202,7 @@ fn blake2s_compression>( v[14] = v[14].xor(cs.namespace(|| "third xor"), &UInt32::constant(u32::max_value()))?; } - { - let mut cs = MultiEq::new(&mut cs); - + multi_eq::<_, _, _, Result<(), SynthesisError>>(&mut cs, |cs| { for i in 0..10 { let mut cs = cs.namespace(|| format!("round {}", i)); @@ -220,7 +218,9 @@ fn blake2s_compression>( mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?; mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?; } - } + + Ok(()) + })?; for i in 0..8 { let mut cs = cs.namespace(|| format!("h[{i}] ^ v[{i}] ^ v[{i} + 8]", i=i)); diff --git a/src/circuit/multieq.rs b/src/circuit/multieq.rs index 0f9c7556..28c67cc3 100644 --- a/src/circuit/multieq.rs +++ b/src/circuit/multieq.rs @@ -11,6 +11,24 @@ use bellman::{ Variable }; +pub fn multi_eq(cs: CS, f: F) -> R + where E: Engine, CS: ConstraintSystem, F: FnOnce(&mut MultiEq) -> R +{ + let mut cs = MultiEq { + cs: cs, + ops: 0, + bits_used: 0, + lhs: LinearCombination::zero(), + rhs: LinearCombination::zero() + }; + let tmp = f(&mut cs); + if cs.bits_used > 0 { + cs.accumulate(); + } + + tmp +} + pub struct MultiEq>{ cs: CS, ops: usize, @@ -20,16 +38,6 @@ pub struct MultiEq>{ } impl> MultiEq { - pub fn new(cs: CS) -> Self { - MultiEq { - cs: cs, - ops: 0, - bits_used: 0, - lhs: LinearCombination::zero(), - rhs: LinearCombination::zero() - } - } - fn accumulate(&mut self) { let ops = self.ops; @@ -68,14 +76,6 @@ impl> MultiEq { } } -impl> Drop for MultiEq { - fn drop(&mut self) { - if self.bits_used > 0 { - self.accumulate(); - } - } -} - impl> ConstraintSystem for MultiEq { type Root = Self; diff --git a/src/circuit/sha256.rs b/src/circuit/sha256.rs index 7b55fc89..cb18608f 100644 --- a/src/circuit/sha256.rs +++ b/src/circuit/sha256.rs @@ -1,5 +1,5 @@ use super::uint32::UInt32; -use super::multieq::MultiEq; +use super::multieq::{MultiEq, multi_eq}; use super::boolean::Boolean; use bellman::{ConstraintSystem, SynthesisError}; use pairing::Engine; @@ -94,212 +94,212 @@ fn sha256_compression_function( // We can save some constraints by combining some of // the constraints in different u32 additions - let mut cs = MultiEq::new(cs); - - for i in 16..64 { - let cs = &mut cs.namespace(|| format!("w extension {}", i)); - - // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) - let mut s0 = w[i-15].rotr(7); - s0 = s0.xor( - cs.namespace(|| "first xor for s0"), - &w[i-15].rotr(18) - )?; - s0 = s0.xor( - cs.namespace(|| "second xor for s0"), - &w[i-15].shr(3) - )?; - - // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) - let mut s1 = w[i-2].rotr(17); - s1 = s1.xor( - cs.namespace(|| "first xor for s1"), - &w[i-2].rotr(19) - )?; - s1 = s1.xor( - cs.namespace(|| "second xor for s1"), - &w[i-2].shr(10) - )?; - - let tmp = UInt32::addmany( - cs.namespace(|| "computation of w[i]"), - &[w[i-16].clone(), s0, w[i-7].clone(), s1] - )?; + multi_eq(cs, |cs| { + for i in 16..64 { + let cs = &mut cs.namespace(|| format!("w extension {}", i)); + + // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) + let mut s0 = w[i-15].rotr(7); + s0 = s0.xor( + cs.namespace(|| "first xor for s0"), + &w[i-15].rotr(18) + )?; + s0 = s0.xor( + cs.namespace(|| "second xor for s0"), + &w[i-15].shr(3) + )?; + + // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) + let mut s1 = w[i-2].rotr(17); + s1 = s1.xor( + cs.namespace(|| "first xor for s1"), + &w[i-2].rotr(19) + )?; + s1 = s1.xor( + cs.namespace(|| "second xor for s1"), + &w[i-2].shr(10) + )?; + + let tmp = UInt32::addmany( + cs.namespace(|| "computation of w[i]"), + &[w[i-16].clone(), s0, w[i-7].clone(), s1] + )?; + + // w[i] := w[i-16] + s0 + w[i-7] + s1 + w.push(tmp); + } - // w[i] := w[i-16] + s0 + w[i-7] + s1 - w.push(tmp); - } + assert_eq!(w.len(), 64); - assert_eq!(w.len(), 64); + enum Maybe { + Deferred(Vec), + Concrete(UInt32) + } - enum Maybe { - Deferred(Vec), - Concrete(UInt32) - } + impl Maybe { + fn compute( + self, + cs: M, + others: &[UInt32] + ) -> Result + where E: Engine, + CS: ConstraintSystem, + M: ConstraintSystem> + { + Ok(match self { + Maybe::Concrete(ref v) => { + return Ok(v.clone()) + }, + Maybe::Deferred(mut v) => { + v.extend(others.into_iter().cloned()); + UInt32::addmany( + cs, + &v + )? + } + }) + } + } - impl Maybe { - fn compute( - self, - cs: M, - others: &[UInt32] - ) -> Result - where E: Engine, - CS: ConstraintSystem, - M: ConstraintSystem> - { - Ok(match self { - Maybe::Concrete(ref v) => { - return Ok(v.clone()) - }, - Maybe::Deferred(mut v) => { - v.extend(others.into_iter().cloned()); - UInt32::addmany( - cs, - &v - )? - } - }) + let mut a = Maybe::Concrete(current_hash_value[0].clone()); + let mut b = current_hash_value[1].clone(); + let mut c = current_hash_value[2].clone(); + let mut d = current_hash_value[3].clone(); + let mut e = Maybe::Concrete(current_hash_value[4].clone()); + let mut f = current_hash_value[5].clone(); + let mut g = current_hash_value[6].clone(); + let mut h = current_hash_value[7].clone(); + + for i in 0..64 { + let cs = &mut cs.namespace(|| format!("compression round {}", i)); + + // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) + let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?; + let mut s1 = new_e.rotr(6); + s1 = s1.xor( + cs.namespace(|| "first xor for s1"), + &new_e.rotr(11) + )?; + s1 = s1.xor( + cs.namespace(|| "second xor for s1"), + &new_e.rotr(25) + )?; + + // ch := (e and f) xor ((not e) and g) + let ch = UInt32::sha256_ch( + cs.namespace(|| "ch"), + &new_e, + &f, + &g + )?; + + // temp1 := h + S1 + ch + k[i] + w[i] + let temp1 = vec![ + h.clone(), + s1, + ch, + UInt32::constant(ROUND_CONSTANTS[i]), + w[i].clone() + ]; + + // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) + let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?; + let mut s0 = new_a.rotr(2); + s0 = s0.xor( + cs.namespace(|| "first xor for s0"), + &new_a.rotr(13) + )?; + s0 = s0.xor( + cs.namespace(|| "second xor for s0"), + &new_a.rotr(22) + )?; + + // maj := (a and b) xor (a and c) xor (b and c) + let maj = UInt32::sha256_maj( + cs.namespace(|| "maj"), + &new_a, + &b, + &c + )?; + + // temp2 := S0 + maj + let temp2 = vec![s0, maj]; + + /* + h := g + g := f + f := e + e := d + temp1 + d := c + c := b + b := a + a := temp1 + temp2 + */ + + h = g; + g = f; + f = new_e; + e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::>()); + d = c; + c = b; + b = new_a; + a = Maybe::Deferred(temp1.iter().cloned().chain(temp2.iter().cloned()).collect::>()); } - } - let mut a = Maybe::Concrete(current_hash_value[0].clone()); - let mut b = current_hash_value[1].clone(); - let mut c = current_hash_value[2].clone(); - let mut d = current_hash_value[3].clone(); - let mut e = Maybe::Concrete(current_hash_value[4].clone()); - let mut f = current_hash_value[5].clone(); - let mut g = current_hash_value[6].clone(); - let mut h = current_hash_value[7].clone(); - - for i in 0..64 { - let cs = &mut cs.namespace(|| format!("compression round {}", i)); - - // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) - let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?; - let mut s1 = new_e.rotr(6); - s1 = s1.xor( - cs.namespace(|| "first xor for s1"), - &new_e.rotr(11) - )?; - s1 = s1.xor( - cs.namespace(|| "second xor for s1"), - &new_e.rotr(25) + /* + Add the compressed chunk to the current hash value: + h0 := h0 + a + h1 := h1 + b + h2 := h2 + c + h3 := h3 + d + h4 := h4 + e + h5 := h5 + f + h6 := h6 + g + h7 := h7 + h + */ + + let h0 = a.compute( + cs.namespace(|| "deferred h0 computation"), + &[current_hash_value[0].clone()] )?; - // ch := (e and f) xor ((not e) and g) - let ch = UInt32::sha256_ch( - cs.namespace(|| "ch"), - &new_e, - &f, - &g + let h1 = UInt32::addmany( + cs.namespace(|| "new h1"), + &[current_hash_value[1].clone(), b] )?; - // temp1 := h + S1 + ch + k[i] + w[i] - let temp1 = vec![ - h.clone(), - s1, - ch, - UInt32::constant(ROUND_CONSTANTS[i]), - w[i].clone() - ]; - - // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) - let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?; - let mut s0 = new_a.rotr(2); - s0 = s0.xor( - cs.namespace(|| "first xor for s0"), - &new_a.rotr(13) + let h2 = UInt32::addmany( + cs.namespace(|| "new h2"), + &[current_hash_value[2].clone(), c] )?; - s0 = s0.xor( - cs.namespace(|| "second xor for s0"), - &new_a.rotr(22) + + let h3 = UInt32::addmany( + cs.namespace(|| "new h3"), + &[current_hash_value[3].clone(), d] )?; - // maj := (a and b) xor (a and c) xor (b and c) - let maj = UInt32::sha256_maj( - cs.namespace(|| "maj"), - &new_a, - &b, - &c + let h4 = e.compute( + cs.namespace(|| "deferred h4 computation"), + &[current_hash_value[4].clone()] )?; - // temp2 := S0 + maj - let temp2 = vec![s0, maj]; + let h5 = UInt32::addmany( + cs.namespace(|| "new h5"), + &[current_hash_value[5].clone(), f] + )?; - /* - h := g - g := f - f := e - e := d + temp1 - d := c - c := b - b := a - a := temp1 + temp2 - */ + let h6 = UInt32::addmany( + cs.namespace(|| "new h6"), + &[current_hash_value[6].clone(), g] + )?; - h = g; - g = f; - f = new_e; - e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::>()); - d = c; - c = b; - b = new_a; - a = Maybe::Deferred(temp1.iter().cloned().chain(temp2.iter().cloned()).collect::>()); - } + let h7 = UInt32::addmany( + cs.namespace(|| "new h7"), + &[current_hash_value[7].clone(), h] + )?; - /* - Add the compressed chunk to the current hash value: - h0 := h0 + a - h1 := h1 + b - h2 := h2 + c - h3 := h3 + d - h4 := h4 + e - h5 := h5 + f - h6 := h6 + g - h7 := h7 + h - */ - - let h0 = a.compute( - cs.namespace(|| "deferred h0 computation"), - &[current_hash_value[0].clone()] - )?; - - let h1 = UInt32::addmany( - cs.namespace(|| "new h1"), - &[current_hash_value[1].clone(), b] - )?; - - let h2 = UInt32::addmany( - cs.namespace(|| "new h2"), - &[current_hash_value[2].clone(), c] - )?; - - let h3 = UInt32::addmany( - cs.namespace(|| "new h3"), - &[current_hash_value[3].clone(), d] - )?; - - let h4 = e.compute( - cs.namespace(|| "deferred h4 computation"), - &[current_hash_value[4].clone()] - )?; - - let h5 = UInt32::addmany( - cs.namespace(|| "new h5"), - &[current_hash_value[5].clone(), f] - )?; - - let h6 = UInt32::addmany( - cs.namespace(|| "new h6"), - &[current_hash_value[6].clone(), g] - )?; - - let h7 = UInt32::addmany( - cs.namespace(|| "new h7"), - &[current_hash_value[7].clone(), h] - )?; - - Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7]) + Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7]) + }) } #[cfg(test)] diff --git a/src/circuit/uint32.rs b/src/circuit/uint32.rs index fb0bfa92..16157748 100644 --- a/src/circuit/uint32.rs +++ b/src/circuit/uint32.rs @@ -419,7 +419,7 @@ mod test { use pairing::{Field}; use ::circuit::test::*; use bellman::{ConstraintSystem}; - use circuit::multieq::MultiEq; + use circuit::multieq::multi_eq; #[test] fn test_uint32_from_bits_be() { @@ -542,11 +542,9 @@ mod test { let mut expected = a.wrapping_add(b).wrapping_add(c); - let r = { - let mut cs = MultiEq::new(&mut cs); - let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap(); - r - }; + let r = multi_eq(&mut cs, |cs| { + UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap() + }); assert!(r.value == Some(expected)); @@ -584,11 +582,9 @@ mod test { let d_bit = UInt32::alloc(cs.namespace(|| "d_bit"), Some(d)).unwrap(); let r = a_bit.xor(cs.namespace(|| "xor"), &b_bit).unwrap(); - let r = { - let mut cs = MultiEq::new(&mut cs); - let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap(); - r - }; + let r = multi_eq(&mut cs, |cs| { + UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap() + }); assert!(cs.is_satisfied());