Skip to content

Commit

Permalink
refactor(metal): overload Bigint operators and improve code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
moven0831 committed Jan 5, 2025
1 parent 22efc19 commit 6f98450
Show file tree
Hide file tree
Showing 16 changed files with 108 additions and 104 deletions.
34 changes: 34 additions & 0 deletions mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,37 @@ bool bigint_wide_gte(

return true;
}

BigInt get_bn254_basefield_modulus() {
BigInt modulus;
modulus.limbs[0] = BN254_BASEFIELD_MODULUS_LIMB_0;
modulus.limbs[1] = BN254_BASEFIELD_MODULUS_LIMB_1;
modulus.limbs[2] = BN254_BASEFIELD_MODULUS_LIMB_2;
modulus.limbs[3] = BN254_BASEFIELD_MODULUS_LIMB_3;
modulus.limbs[4] = BN254_BASEFIELD_MODULUS_LIMB_4;
modulus.limbs[5] = BN254_BASEFIELD_MODULUS_LIMB_5;
modulus.limbs[6] = BN254_BASEFIELD_MODULUS_LIMB_6;
modulus.limbs[7] = BN254_BASEFIELD_MODULUS_LIMB_7;
modulus.limbs[8] = BN254_BASEFIELD_MODULUS_LIMB_8;
modulus.limbs[9] = BN254_BASEFIELD_MODULUS_LIMB_9;
modulus.limbs[10] = BN254_BASEFIELD_MODULUS_LIMB_10;
modulus.limbs[11] = BN254_BASEFIELD_MODULUS_LIMB_11;
modulus.limbs[12] = BN254_BASEFIELD_MODULUS_LIMB_12;
modulus.limbs[13] = BN254_BASEFIELD_MODULUS_LIMB_13;
modulus.limbs[14] = BN254_BASEFIELD_MODULUS_LIMB_14;
modulus.limbs[15] = BN254_BASEFIELD_MODULUS_LIMB_15;
return modulus;
}

// Overload Operators
constexpr BigInt operator+(const BigInt lhs, const BigInt rhs) {
return bigint_add_unsafe(lhs, rhs);
}

constexpr BigInt operator-(const BigInt lhs, const BigInt rhs) {
return bigint_sub(lhs, rhs);
}

constexpr bool operator>=(const BigInt lhs, const BigInt rhs) {
return bigint_gte(lhs, rhs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ kernel void run(
device BigInt* result [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
BigInt res = bigint_add_unsafe(a, b);
result->limbs = res.limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt res = a + b;
*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ kernel void run(
device BigIntWide* result [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigIntWide res = bigint_add_wide(a, b);
result->limbs = res.limbs;
*result = res;
}
10 changes: 4 additions & 6 deletions mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ kernel void run(
device BigInt* result [[ buffer(2) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
BigInt res = bigint_sub(a, b);
result->limbs = res.limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt res = a - b;
*result = res;
}
20 changes: 10 additions & 10 deletions mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl.metal
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ kernel void run(
device BigInt* result_zr [[ buffer(9) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt x2; x2.limbs = b_xr->limbs;
BigInt y2; y2.limbs = b_yr->limbs;
BigInt z2; z2.limbs = b_zr->limbs;
BigInt p = *prime;
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
BigInt x2 = *b_xr;
BigInt y2 = *b_yr;
BigInt z2 = *b_zr;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Jacobian b; b.x = x2; b.y = y2; b.z = z2;

Jacobian res = jacobian_add_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
*result_xr = res.x;
*result_yr = res.y;
*result_zr = res.z;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ kernel void run(
device BigInt* result_zr [[ buffer(6) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt p = *prime;
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;

Jacobian res = jacobian_dbl_2009_l(a, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
*result_xr = res.x;
*result_yr = res.y;
*result_zr = res.z;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ kernel void run(
device BigInt* result_zr [[ buffer(8) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt x2; x2.limbs = b_xr->limbs;
BigInt y2; y2.limbs = b_yr->limbs;
BigInt p = *prime;
BigInt x1 = *a_xr;
BigInt y1 = *a_yr;
BigInt z1 = *a_zr;
BigInt x2 = *b_xr;
BigInt y2 = *b_yr;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Affine b; b.x = x2; b.y = y2;

Jacobian res = jacobian_madd_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
*result_xr = res.x;
*result_yr = res.y;
*result_zr = res.z;
}
14 changes: 7 additions & 7 deletions mopro-msm/src/msm/metal_msm/shader/field/ff.metal
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ BigInt ff_add(
BigInt b,
BigInt p
) {
BigInt sum = bigint_add_unsafe(a, b);
BigInt sum = a + b;

BigInt res;
if (bigint_gte(sum, p)) {
if (sum >= p) {
// s = a + b - p
BigInt s = bigint_sub(sum, p);
BigInt s = sum - p;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = s.limbs[i];
}
Expand All @@ -34,17 +34,17 @@ BigInt ff_sub(
BigInt p
) {
// if a >= b
if (bigint_gte(a, b)) {
if (a >= b) {
// a - b
BigInt res = bigint_sub(a, b);
BigInt res = a - b;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = res.limbs[i];
}
return res;
} else {
// p - (b - a)
BigInt r = bigint_sub(b, a);
BigInt res = bigint_sub(p, r);
BigInt r = b - a;
BigInt res = p - r;
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = res.limbs[i];
}
Expand Down
11 changes: 4 additions & 7 deletions mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ kernel void run(
device BigInt* result [[ buffer(3) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;

BigInt res = ff_add(a, b, p);
result->limbs = res.limbs;
*result = res;
}
4 changes: 2 additions & 2 deletions mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ BigInt conditional_reduce(
BigInt x,
BigInt y
) {
if (bigint_gte(x, y)) {
return bigint_sub(x, y);
if (x >= y) {
return x - y;
}

return x;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ kernel void run(
device BigInt* result [[ buffer(3) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;

BigInt res = mont_mul_cios(a, b, p);
result->limbs = res.limbs;

*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@ kernel void run(
device BigInt* result [[ buffer(4) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;
array<uint, 1> cost_arr = *cost;

BigInt c = mont_mul_cios(a, a, p);
for (uint i = 1; i < cost_arr[0]; i ++) {
c = mont_mul_cios(c, a, p);
}
BigInt res = mont_mul_cios(c, b, p);
result->limbs = res.limbs;
*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ kernel void run(
device BigInt* result [[ buffer(3) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;

BigInt res = mont_mul_modified(a, b, p);
result->limbs = res.limbs;

*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@ kernel void run(
device BigInt* result [[ buffer(4) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;
array<uint, 1> cost_arr = *cost;

BigInt c = mont_mul_modified(a, a, p);
for (uint i = 1; i < cost_arr[0]; i ++) {
c = mont_mul_modified(c, a, p);
}
BigInt res = mont_mul_modified(c, b, p);
result->limbs = res.limbs;
*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ kernel void run(
device BigInt* result [[ buffer(3) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;

BigInt res = mont_mul_optimised(a, b, p);
result->limbs = res.limbs;

*result = res;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@ kernel void run(
device BigInt* result [[ buffer(4) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
BigInt a = *lhs;
BigInt b = *rhs;
BigInt p = *prime;
array<uint, 1> cost_arr = *cost;

BigInt c = mont_mul_optimised(a, a, p);
for (uint i = 1; i < cost_arr[0]; i ++) {
c = mont_mul_optimised(c, a, p);
}
BigInt res = mont_mul_optimised(c, b, p);
result->limbs = res.limbs;
*result = res;
}

0 comments on commit 6f98450

Please sign in to comment.