diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal index f6455f8..9ded858 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal @@ -129,3 +129,37 @@ bool bigint_wide_gte( return true; } + +BigInt get_bn254_basefield_modulus() { + BigInt modulus; + modulus.limbs[0] = BN254_BASEFIELD_MODULUS_LIMB_0; + modulus.limbs[1] = BN254_BASEFIELD_MODULUS_LIMB_1; + modulus.limbs[2] = BN254_BASEFIELD_MODULUS_LIMB_2; + modulus.limbs[3] = BN254_BASEFIELD_MODULUS_LIMB_3; + modulus.limbs[4] = BN254_BASEFIELD_MODULUS_LIMB_4; + modulus.limbs[5] = BN254_BASEFIELD_MODULUS_LIMB_5; + modulus.limbs[6] = BN254_BASEFIELD_MODULUS_LIMB_6; + modulus.limbs[7] = BN254_BASEFIELD_MODULUS_LIMB_7; + modulus.limbs[8] = BN254_BASEFIELD_MODULUS_LIMB_8; + modulus.limbs[9] = BN254_BASEFIELD_MODULUS_LIMB_9; + modulus.limbs[10] = BN254_BASEFIELD_MODULUS_LIMB_10; + modulus.limbs[11] = BN254_BASEFIELD_MODULUS_LIMB_11; + modulus.limbs[12] = BN254_BASEFIELD_MODULUS_LIMB_12; + modulus.limbs[13] = BN254_BASEFIELD_MODULUS_LIMB_13; + modulus.limbs[14] = BN254_BASEFIELD_MODULUS_LIMB_14; + modulus.limbs[15] = BN254_BASEFIELD_MODULUS_LIMB_15; + return modulus; +} + +// Overload Operators +constexpr BigInt operator+(const BigInt lhs, const BigInt rhs) { + return bigint_add_unsafe(lhs, rhs); +} + +constexpr BigInt operator-(const BigInt lhs, const BigInt rhs) { + return bigint_sub(lhs, rhs); +} + +constexpr bool operator>=(const BigInt lhs, const BigInt rhs) { + return bigint_gte(lhs, rhs); +} diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal index 7791963..ce54857 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal @@ -11,10 +11,8 @@ kernel void run( device BigInt* result [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - BigInt res = bigint_add_unsafe(a, b); - result->limbs = res.limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt res = a + b; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal index 30ffd35..f150055 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal @@ -11,10 +11,8 @@ kernel void run( device BigIntWide* result [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; + BigInt a = *lhs; + BigInt b = *rhs; BigIntWide res = bigint_add_wide(a, b); - result->limbs = res.limbs; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal index 552f012..9621d1c 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal @@ -11,10 +11,8 @@ kernel void run( device BigInt* result [[ buffer(2) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - BigInt res = bigint_sub(a, b); - result->limbs = res.limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt res = a - b; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal index 38ecd6f..313c9c9 100644 --- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal @@ -18,19 +18,19 @@ kernel void run( device BigInt* result_zr [[ buffer(9) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt p; p.limbs = prime->limbs; - BigInt x1; x1.limbs = a_xr->limbs; - BigInt y1; y1.limbs = a_yr->limbs; - BigInt z1; z1.limbs = a_zr->limbs; - BigInt x2; x2.limbs = b_xr->limbs; - BigInt y2; y2.limbs = b_yr->limbs; - BigInt z2; z2.limbs = b_zr->limbs; + BigInt p = *prime; + BigInt x1 = *a_xr; + BigInt y1 = *a_yr; + BigInt z1 = *a_zr; + BigInt x2 = *b_xr; + BigInt y2 = *b_yr; + BigInt z2 = *b_zr; Jacobian a; a.x = x1; a.y = y1; a.z = z1; Jacobian b; b.x = x2; b.y = y2; b.z = z2; Jacobian res = jacobian_add_2007_bl(a, b, p); - result_xr->limbs = res.x.limbs; - result_yr->limbs = res.y.limbs; - result_zr->limbs = res.z.limbs; + *result_xr = res.x; + *result_yr = res.y; + *result_zr = res.z; } diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal index c6dcad1..567368a 100644 --- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal @@ -15,15 +15,15 @@ kernel void run( device BigInt* result_zr [[ buffer(6) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt p; p.limbs = prime->limbs; - BigInt x1; x1.limbs = a_xr->limbs; - BigInt y1; y1.limbs = a_yr->limbs; - BigInt z1; z1.limbs = a_zr->limbs; + BigInt p = *prime; + BigInt x1 = *a_xr; + BigInt y1 = *a_yr; + BigInt z1 = *a_zr; Jacobian a; a.x = x1; a.y = y1; a.z = z1; Jacobian res = jacobian_dbl_2009_l(a, p); - result_xr->limbs = res.x.limbs; - result_yr->limbs = res.y.limbs; - result_zr->limbs = res.z.limbs; + *result_xr = res.x; + *result_yr = res.y; + *result_zr = res.z; } diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal index b122cc0..523acdd 100644 --- a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal @@ -17,18 +17,18 @@ kernel void run( device BigInt* result_zr [[ buffer(8) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt p; p.limbs = prime->limbs; - BigInt x1; x1.limbs = a_xr->limbs; - BigInt y1; y1.limbs = a_yr->limbs; - BigInt z1; z1.limbs = a_zr->limbs; - BigInt x2; x2.limbs = b_xr->limbs; - BigInt y2; y2.limbs = b_yr->limbs; + BigInt p = *prime; + BigInt x1 = *a_xr; + BigInt y1 = *a_yr; + BigInt z1 = *a_zr; + BigInt x2 = *b_xr; + BigInt y2 = *b_yr; Jacobian a; a.x = x1; a.y = y1; a.z = z1; Affine b; b.x = x2; b.y = y2; Jacobian res = jacobian_madd_2007_bl(a, b, p); - result_xr->limbs = res.x.limbs; - result_yr->limbs = res.y.limbs; - result_zr->limbs = res.z.limbs; + *result_xr = res.x; + *result_yr = res.y; + *result_zr = res.z; } diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal index 43a05d3..d46bcfb 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal @@ -10,12 +10,12 @@ BigInt ff_add( BigInt b, BigInt p ) { - BigInt sum = bigint_add_unsafe(a, b); + BigInt sum = a + b; BigInt res; - if (bigint_gte(sum, p)) { + if (sum >= p) { // s = a + b - p - BigInt s = bigint_sub(sum, p); + BigInt s = sum - p; for (uint i = 0; i < NUM_LIMBS; i ++) { res.limbs[i] = s.limbs[i]; } @@ -34,17 +34,17 @@ BigInt ff_sub( BigInt p ) { // if a >= b - if (bigint_gte(a, b)) { + if (a >= b) { // a - b - BigInt res = bigint_sub(a, b); + BigInt res = a - b; for (uint i = 0; i < NUM_LIMBS; i ++) { res.limbs[i] = res.limbs[i]; } return res; } else { // p - (b - a) - BigInt r = bigint_sub(b, a); - BigInt res = bigint_sub(p, r); + BigInt r = b - a; + BigInt res = p - r; for (uint i = 0; i < NUM_LIMBS; i ++) { res.limbs[i] = res.limbs[i]; } diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal index b3539ba..a10a17e 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal @@ -12,13 +12,10 @@ kernel void run( device BigInt* result [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; BigInt res = ff_add(a, b, p); - result->limbs = res.limbs; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal index 1783ae0..721e7d7 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal @@ -9,8 +9,8 @@ BigInt conditional_reduce( BigInt x, BigInt y ) { - if (bigint_gte(x, y)) { - return bigint_sub(x, y); + if (x >= y) { + return x - y; } return x; diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal index ec0c2ee..140474d 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios.metal @@ -12,14 +12,10 @@ kernel void run( device BigInt* result [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; BigInt res = mont_mul_cios(a, b, p); - result->limbs = res.limbs; - + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal index 74c675e..6b49d7c 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal @@ -11,12 +11,9 @@ kernel void run( device BigInt* result [[ buffer(4) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; array cost_arr = *cost; BigInt c = mont_mul_cios(a, a, p); @@ -24,5 +21,5 @@ kernel void run( c = mont_mul_cios(c, a, p); } BigInt res = mont_mul_cios(c, b, p); - result->limbs = res.limbs; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal index 77020d1..62d70f0 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal @@ -12,14 +12,10 @@ kernel void run( device BigInt* result [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; BigInt res = mont_mul_modified(a, b, p); - result->limbs = res.limbs; - + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified_benchmarks.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified_benchmarks.metal index 3d3c275..dc7047e 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified_benchmarks.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified_benchmarks.metal @@ -11,12 +11,9 @@ kernel void run( device BigInt* result [[ buffer(4) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; array cost_arr = *cost; BigInt c = mont_mul_modified(a, a, p); @@ -24,5 +21,5 @@ kernel void run( c = mont_mul_modified(c, a, p); } BigInt res = mont_mul_modified(c, b, p); - result->limbs = res.limbs; + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal index ffb0844..eaa95fc 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal @@ -12,14 +12,10 @@ kernel void run( device BigInt* result [[ buffer(3) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; BigInt res = mont_mul_optimised(a, b, p); - result->limbs = res.limbs; - + *result = res; } diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised_benchmarks.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised_benchmarks.metal index 18057eb..d1ce801 100644 --- a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised_benchmarks.metal +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised_benchmarks.metal @@ -11,12 +11,9 @@ kernel void run( device BigInt* result [[ buffer(4) ]], uint gid [[ thread_position_in_grid ]] ) { - BigInt a; - BigInt b; - BigInt p; - a.limbs = lhs->limbs; - b.limbs = rhs->limbs; - p.limbs = prime->limbs; + BigInt a = *lhs; + BigInt b = *rhs; + BigInt p = *prime; array cost_arr = *cost; BigInt c = mont_mul_optimised(a, a, p); @@ -24,5 +21,5 @@ kernel void run( c = mont_mul_optimised(c, a, p); } BigInt res = mont_mul_optimised(c, b, p); - result->limbs = res.limbs; + *result = res; }