Skip to content

Commit

Permalink
feat(curve): add pADD, pMADD, pDBL algo in Jacobian Coordinates
Browse files Browse the repository at this point in the history
* wip: rust host code for metal ec op

* chore(mont): ignore heavy test

* test(mont): change to suitable limb size

* refactor(curve): use mont_cios for Jacobian Coordinates

* feat(curve): impl jacobian_madd_2007_bl algorithm and corresponding tests

* feat(tests): add jacobian_add_2007_b1 test and refactor jacobian_dbl_2009_l test
  • Loading branch information
moven0831 authored Jan 4, 2025
1 parent dadc88c commit c225fcf
Show file tree
Hide file tree
Showing 11 changed files with 651 additions and 51 deletions.
122 changes: 98 additions & 24 deletions mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// source: https://github.com/geometryxyz/msl-secp256k1
// algorithms: https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html

using namespace metal;
#include <metal_stdlib>
Expand All @@ -11,7 +12,12 @@ struct Jacobian {
BigInt z;
};

Jacobian jacobian_add_2007_bl_unsafe(
struct Affine {
BigInt x;
BigInt y;
};

Jacobian jacobian_add_2007_bl(
Jacobian a,
Jacobian b,
BigInt p
Expand All @@ -23,36 +29,36 @@ Jacobian jacobian_add_2007_bl_unsafe(
BigInt y2 = b.y;
BigInt z2 = b.z;

BigInt z1z1 = mont_mul_optimised(z1, z1, p);
BigInt z2z2 = mont_mul_optimised(z2, z2, p);
BigInt u1 = mont_mul_optimised(x1, z2z2, p);
BigInt u2 = mont_mul_optimised(x2, z1z1, p);
BigInt y1z2 = mont_mul_optimised(y1, z2, p);
BigInt s1 = mont_mul_optimised(y1z2, z2z2, p);
BigInt z1z1 = mont_mul_cios(z1, z1, p);
BigInt z2z2 = mont_mul_cios(z2, z2, p);
BigInt u1 = mont_mul_cios(x1, z2z2, p);
BigInt u2 = mont_mul_cios(x2, z1z1, p);
BigInt y1z2 = mont_mul_cios(y1, z2, p);
BigInt s1 = mont_mul_cios(y1z2, z2z2, p);

BigInt y2z1 = mont_mul_optimised(y2, z1, p);
BigInt s2 = mont_mul_optimised(y2z1, z1z1, p);
BigInt y2z1 = mont_mul_cios(y2, z1, p);
BigInt s2 = mont_mul_cios(y2z1, z1z1, p);
BigInt h = ff_sub(u2, u1, p);
BigInt h2 = ff_add(h, h, p);
BigInt i = mont_mul_optimised(h2, h2, p);
BigInt j = mont_mul_optimised(h, i, p);
BigInt i = mont_mul_cios(h2, h2, p);
BigInt j = mont_mul_cios(h, i, p);

BigInt s2s1 = ff_sub(s2, s1, p);
BigInt r = ff_add(s2s1, s2s1, p);
BigInt v = mont_mul_optimised(u1, i, p);
BigInt v = mont_mul_cios(u1, i, p);
BigInt v2 = ff_add(v, v, p);
BigInt r2 = mont_mul_optimised(r, r, p);
BigInt r2 = mont_mul_cios(r, r, p);
BigInt jv2 = ff_add(j, v2, p);
BigInt x3 = ff_sub(r2, jv2, p);

BigInt vx3 = ff_sub(v, x3, p);
BigInt rvx3 = mont_mul_optimised(r, vx3, p);
BigInt rvx3 = mont_mul_cios(r, vx3, p);
BigInt s12 = ff_add(s1, s1, p);
BigInt s12j = mont_mul_optimised(s12, j, p);
BigInt s12j = mont_mul_cios(s12, j, p);
BigInt y3 = ff_sub(rvx3, s12j, p);

BigInt z1z2 = mont_mul_optimised(z1, z2, p);
BigInt z1z2h = mont_mul_optimised(z1z2, h, p);
BigInt z1z2 = mont_mul_cios(z1, z2, p);
BigInt z1z2h = mont_mul_cios(z1z2, h, p);
BigInt z3 = ff_add(z1z2h, z1z2h, p);

Jacobian result;
Expand All @@ -70,26 +76,26 @@ Jacobian jacobian_dbl_2009_l(
BigInt y = pt.y;
BigInt z = pt.z;

BigInt a = mont_mul_optimised(x, x, p);
BigInt b = mont_mul_optimised(y, y, p);
BigInt c = mont_mul_optimised(b, b, p);
BigInt a = mont_mul_cios(x, x, p);
BigInt b = mont_mul_cios(y, y, p);
BigInt c = mont_mul_cios(b, b, p);
BigInt x1b = ff_add(x, b, p);
BigInt x1b2 = mont_mul_optimised(x1b, x1b, p);
BigInt x1b2 = mont_mul_cios(x1b, x1b, p);
BigInt ac = ff_add(a, c, p);
BigInt x1b2ac = ff_sub(x1b2, ac, p);
BigInt d = ff_add(x1b2ac, x1b2ac, p);
BigInt a2 = ff_add(a, a, p);
BigInt e = ff_add(a2, a, p);
BigInt f = mont_mul_optimised(e, e, p);
BigInt f = mont_mul_cios(e, e, p);
BigInt d2 = ff_add(d, d, p);
BigInt x3 = ff_sub(f, d2, p);
BigInt c2 = ff_add(c, c, p);
BigInt c4 = ff_add(c2, c2, p);
BigInt c8 = ff_add(c4, c4, p);
BigInt dx3 = ff_sub(d, x3, p);
BigInt edx3 = mont_mul_optimised(e, dx3, p);
BigInt edx3 = mont_mul_cios(e, dx3, p);
BigInt y3 = ff_sub(edx3, c8, p);
BigInt y1z1 = mont_mul_optimised(y, z, p);
BigInt y1z1 = mont_mul_cios(y, z, p);
BigInt z3 = ff_add(y1z1, y1z1, p);

Jacobian result;
Expand All @@ -98,3 +104,71 @@ Jacobian jacobian_dbl_2009_l(
result.z = z3;
return result;
}

//http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl
Jacobian jacobian_madd_2007_bl(
Jacobian a,
Affine b,
BigInt p
) {
BigInt x1 = a.x;
BigInt y1 = a.y;
BigInt z1 = a.z;
BigInt x2 = b.x;
BigInt y2 = b.y;

// Z1Z1 = Z1^2
BigInt z1z1 = mont_mul_cios(z1, z1, p);

// U2 = X2*Z1Z1
BigInt u2 = mont_mul_cios(x2, z1z1, p);

// S2 = Y2*Z1*Z1Z1
BigInt temp_s2 = mont_mul_cios(y2, z1, p);
BigInt s2 = mont_mul_cios(temp_s2, z1z1, p);

// H = U2-X1
BigInt h = ff_sub(u2, x1, p);

// HH = H^2
BigInt hh = mont_mul_cios(h, h, p);

// I = 4*HH
BigInt i = ff_add(hh, hh, p); // *2
i = ff_add(i, i, p); // *4

// J = H*I
BigInt j = mont_mul_cios(h, i, p);

// r = 2*(S2-Y1)
BigInt s2_minus_y1 = ff_sub(s2, y1, p);
BigInt r = ff_add(s2_minus_y1, s2_minus_y1, p);

// V = X1*I
BigInt v = mont_mul_cios(x1, i, p);

// X3 = r^2-J-2*V
BigInt r2 = mont_mul_cios(r, r, p);
BigInt v2 = ff_add(v, v, p);
BigInt jv2 = ff_add(j, v2, p);
BigInt x3 = ff_sub(r2, jv2, p);

// Y3 = r*(V-X3)-2*Y1*J
BigInt v_minus_x3 = ff_sub(v, x3, p);
BigInt r_vmx3 = mont_mul_cios(r, v_minus_x3, p);
BigInt y1j = mont_mul_cios(y1, j, p);
BigInt y1j2 = ff_add(y1j, y1j, p);
BigInt y3 = ff_sub(r_vmx3, y1j2, p);

// Z3 = (Z1+H)^2-Z1Z1-HH
BigInt z1_plus_h = ff_add(z1, h, p);
BigInt z1_plus_h_squared = mont_mul_cios(z1_plus_h, z1_plus_h, p);
BigInt temp = ff_sub(z1_plus_h_squared, z1z1, p);
BigInt z3 = ff_sub(temp, hh, p);

Jacobian result;
result.x = x3;
result.y = y3;
result.z = z3;
return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ kernel void run(
Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Jacobian b; b.x = x2; b.y = y2; b.z = z2;

Jacobian res = jacobian_add_2007_bl_unsafe(a, b, p);
Jacobian res = jacobian_add_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// source: https://github.com/geometryxyz/msl-secp256k1

using namespace metal;
#include <metal_stdlib>
#include <metal_math>
#include "jacobian.metal"

kernel void run(
device BigInt* prime [[ buffer(0) ]],
device BigInt* a_xr [[ buffer(1) ]],
device BigInt* a_yr [[ buffer(2) ]],
device BigInt* a_zr [[ buffer(3) ]],
device BigInt* b_xr [[ buffer(4) ]],
device BigInt* b_yr [[ buffer(5) ]],
device BigInt* result_xr [[ buffer(6) ]],
device BigInt* result_yr [[ buffer(7) ]],
device BigInt* result_zr [[ buffer(8) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt x2; x2.limbs = b_xr->limbs;
BigInt y2; y2.limbs = b_yr->limbs;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Affine b; b.x = x2; b.y = y2;

Jacobian res = jacobian_madd_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
}
176 changes: 176 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// adapted from https://github.com/geometryxyz/msl-secp256k1

use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as GAffine, G1Projective as G};
use ark_ec::AffineRepr;
use metal::*;

use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_ff::{BigInt, PrimeField};
use ark_std::{rand::thread_rng, UniformRand};
use num_bigint::BigUint;

#[test]
#[serial_test::serial]
pub fn test_jacobian_add_2007_bl() {
let log_limb_size = 16;
let p: BigUint = BaseField::MODULUS.try_into().unwrap();

let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let rinv = res.0;
let n0 = res.1;
let nsafe = calc_nsafe(log_limb_size);

// Generate 2 random affine points
let (a, b) = {
let mut rng = thread_rng();
let base_point = GAffine::generator().into_group();

let s1 = ScalarField::rand(&mut rng);
let mut s2 = ScalarField::rand(&mut rng);

// Ensure s1 and s2 are different (if s1 == s2, we use pDBL instead of pADD)
while s1 == s2 {
s2 = ScalarField::rand(&mut rng);
}

(base_point * s1, base_point * s2)
};

// Compute the sum in projective form using Arkworks
let expected = a + b;

let ax: BigUint = a.x.into();
let ay: BigUint = a.y.into();
let az: BigUint = a.z.into();
let bx: BigUint = b.x.into();
let by: BigUint = b.y.into();
let bz: BigUint = b.z.into();

let axr = (&ax * &r) % &p;
let ayr = (&ay * &r) % &p;
let azr = (&az * &r) % &p;
let bxr = (&bx * &r) % &p;
let byr = (&by * &r) % &p;
let bzr = (&bz * &r) % &p;

let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let axr_limbs = ark_ff::BigInt::<4>::try_from(axr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let ayr_limbs = ark_ff::BigInt::<4>::try_from(ayr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let azr_limbs = ark_ff::BigInt::<4>::try_from(azr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let bxr_limbs = ark_ff::BigInt::<4>::try_from(bxr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let byr_limbs = ark_ff::BigInt::<4>::try_from(byr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let bzr_limbs = ark_ff::BigInt::<4>::try_from(bzr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);

let device = get_default_device();
let prime_buf = create_buffer(&device, &p_limbs);
let axr_buf = create_buffer(&device, &axr_limbs);
let ayr_buf = create_buffer(&device, &ayr_limbs);
let azr_buf = create_buffer(&device, &azr_limbs);
let bxr_buf = create_buffer(&device, &bxr_limbs);
let byr_buf = create_buffer(&device, &byr_limbs);
let bzr_buf = create_buffer(&device, &bzr_limbs);
let result_xr_buf = create_empty_buffer(&device, num_limbs);
let result_yr_buf = create_empty_buffer(&device, num_limbs);
let result_zr_buf = create_empty_buffer(&device, num_limbs);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
n0,
nsafe,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/curve",
"jacobian_add_2007_bl.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&prime_buf), 0);
encoder.set_buffer(1, Some(&axr_buf), 0);
encoder.set_buffer(2, Some(&ayr_buf), 0);
encoder.set_buffer(3, Some(&azr_buf), 0);
encoder.set_buffer(4, Some(&bxr_buf), 0);
encoder.set_buffer(5, Some(&byr_buf), 0);
encoder.set_buffer(6, Some(&bzr_buf), 0);
encoder.set_buffer(7, Some(&result_xr_buf), 0);
encoder.set_buffer(8, Some(&result_yr_buf), 0);
encoder.set_buffer(9, Some(&result_zr_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_xr_limbs: Vec<u32> = read_buffer(&result_xr_buf, num_limbs);
let result_yr_limbs: Vec<u32> = read_buffer(&result_yr_buf, num_limbs);
let result_zr_limbs: Vec<u32> = read_buffer(&result_zr_buf, num_limbs);

let result_xr: BigUint = BigInt::from_limbs(&result_xr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_yr: BigUint = BigInt::from_limbs(&result_yr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_zr: BigUint = BigInt::from_limbs(&result_zr_limbs, log_limb_size)
.try_into()
.unwrap();

let result_x = (result_xr * &rinv) % &p;
let result_y = (result_yr * &rinv) % &p;
let result_z = (result_zr * &rinv) % &p;

let result = G::new(result_x.into(), result_y.into(), result_z.into());
assert!(result == expected);
}
Loading

0 comments on commit c225fcf

Please sign in to comment.