Skip to content

Commit

Permalink
feat(curve): impl jacobian_madd_2007_bl algorithm and corresponding t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
moven0831 committed Jan 4, 2025
1 parent 4805191 commit ebc0865
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 1 deletion.
73 changes: 73 additions & 0 deletions mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ struct Jacobian {
BigInt z;
};

struct Affine {
BigInt x;
BigInt y;
};

Jacobian jacobian_add_2007_bl(
Jacobian a,
Jacobian b,
Expand Down Expand Up @@ -99,3 +104,71 @@ Jacobian jacobian_dbl_2009_l(
result.z = z3;
return result;
}

//http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl
Jacobian jacobian_madd_2007_bl(
Jacobian a,
Affine b,
BigInt p
) {
BigInt x1 = a.x;
BigInt y1 = a.y;
BigInt z1 = a.z;
BigInt x2 = b.x;
BigInt y2 = b.y;

// Z1Z1 = Z1^2
BigInt z1z1 = mont_mul_cios(z1, z1, p);

// U2 = X2*Z1Z1
BigInt u2 = mont_mul_cios(x2, z1z1, p);

// S2 = Y2*Z1*Z1Z1
BigInt temp_s2 = mont_mul_cios(y2, z1, p);
BigInt s2 = mont_mul_cios(temp_s2, z1z1, p);

// H = U2-X1
BigInt h = ff_sub(u2, x1, p);

// HH = H^2
BigInt hh = mont_mul_cios(h, h, p);

// I = 4*HH
BigInt i = ff_add(hh, hh, p); // *2
i = ff_add(i, i, p); // *4

// J = H*I
BigInt j = mont_mul_cios(h, i, p);

// r = 2*(S2-Y1)
BigInt s2_minus_y1 = ff_sub(s2, y1, p);
BigInt r = ff_add(s2_minus_y1, s2_minus_y1, p);

// V = X1*I
BigInt v = mont_mul_cios(x1, i, p);

// X3 = r^2-J-2*V
BigInt r2 = mont_mul_cios(r, r, p);
BigInt v2 = ff_add(v, v, p);
BigInt jv2 = ff_add(j, v2, p);
BigInt x3 = ff_sub(r2, jv2, p);

// Y3 = r*(V-X3)-2*Y1*J
BigInt v_minus_x3 = ff_sub(v, x3, p);
BigInt r_vmx3 = mont_mul_cios(r, v_minus_x3, p);
BigInt y1j = mont_mul_cios(y1, j, p);
BigInt y1j2 = ff_add(y1j, y1j, p);
BigInt y3 = ff_sub(r_vmx3, y1j2, p);

// Z3 = (Z1+H)^2-Z1Z1-HH
BigInt z1_plus_h = ff_add(z1, h, p);
BigInt z1_plus_h_squared = mont_mul_cios(z1_plus_h, z1_plus_h, p);
BigInt temp = ff_sub(z1_plus_h_squared, z1z1, p);
BigInt z3 = ff_sub(temp, hh, p);

Jacobian result;
result.x = x3;
result.y = y3;
result.z = z3;
return result;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// source: https://github.com/geometryxyz/msl-secp256k1

using namespace metal;
#include <metal_stdlib>
#include <metal_math>
#include "jacobian.metal"

kernel void run(
device BigInt* prime [[ buffer(0) ]],
device BigInt* a_xr [[ buffer(1) ]],
device BigInt* a_yr [[ buffer(2) ]],
device BigInt* a_zr [[ buffer(3) ]],
device BigInt* b_xr [[ buffer(4) ]],
device BigInt* b_yr [[ buffer(5) ]],
device BigInt* result_xr [[ buffer(6) ]],
device BigInt* result_yr [[ buffer(7) ]],
device BigInt* result_zr [[ buffer(8) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt x2; x2.limbs = b_xr->limbs;
BigInt y2; y2.limbs = b_yr->limbs;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Affine b; b.x = x2; b.y = y2;

Jacobian res = jacobian_madd_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
}
172 changes: 172 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// adapted from https://github.com/geometryxyz/msl-secp256k1

use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as GAffine, G1Projective as G};
use ark_ec::{AffineRepr, CurveGroup};
use metal::*;

use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_ff::{BigInt, PrimeField};
use ark_std::{rand::thread_rng, UniformRand};
use num_bigint::BigUint;

#[test]
#[serial_test::serial]
pub fn test_jacobian_add_2007_bl() {
let log_limb_size = 16;
let p: BigUint = BaseField::MODULUS.try_into().unwrap();

let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let rinv = res.0;
let n0 = res.1;
let nsafe = calc_nsafe(log_limb_size);

// Generate 2 random affine points
let (a, b) = {
let mut rng = thread_rng();
let base_point = GAffine::generator().into_group();

let s1 = ScalarField::rand(&mut rng);
let mut s2 = ScalarField::rand(&mut rng);

// Ensure s1 and s2 are different (if s1 == s2, we use pDBL instead of pADD)
while s1 == s2 {
s2 = ScalarField::rand(&mut rng);
}

(base_point * s1, base_point * s2)
};

// Compute the sum in projective form using Arkworks
let expected = a + b;

// Set B into Affine form
let b = b.into_affine();

let ax: BigUint = a.x.into();
let ay: BigUint = a.y.into();
let az: BigUint = a.z.into();
let bx: BigUint = b.x.into();
let by: BigUint = b.y.into();

let axr = (&ax * &r) % &p;
let ayr = (&ay * &r) % &p;
let azr = (&az * &r) % &p;
let bxr = (&bx * &r) % &p;
let byr = (&by * &r) % &p;

let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let axr_limbs = ark_ff::BigInt::<4>::try_from(axr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let ayr_limbs = ark_ff::BigInt::<4>::try_from(ayr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let azr_limbs = ark_ff::BigInt::<4>::try_from(azr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let bxr_limbs = ark_ff::BigInt::<4>::try_from(bxr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let byr_limbs = ark_ff::BigInt::<4>::try_from(byr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);

let device = get_default_device();
let prime_buf = create_buffer(&device, &p_limbs);
let axr_buf = create_buffer(&device, &axr_limbs);
let ayr_buf = create_buffer(&device, &ayr_limbs);
let azr_buf = create_buffer(&device, &azr_limbs);
let bxr_buf = create_buffer(&device, &bxr_limbs);
let byr_buf = create_buffer(&device, &byr_limbs);
let result_xr_buf = create_empty_buffer(&device, num_limbs);
let result_yr_buf = create_empty_buffer(&device, num_limbs);
let result_zr_buf = create_empty_buffer(&device, num_limbs);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
n0,
nsafe,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/curve",
"jacobian_madd_2007_bl.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&prime_buf), 0);
encoder.set_buffer(1, Some(&axr_buf), 0);
encoder.set_buffer(2, Some(&ayr_buf), 0);
encoder.set_buffer(3, Some(&azr_buf), 0);
encoder.set_buffer(4, Some(&bxr_buf), 0);
encoder.set_buffer(5, Some(&byr_buf), 0);
encoder.set_buffer(6, Some(&result_xr_buf), 0);
encoder.set_buffer(7, Some(&result_yr_buf), 0);
encoder.set_buffer(8, Some(&result_zr_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_xr_limbs: Vec<u32> = read_buffer(&result_xr_buf, num_limbs);
let result_yr_limbs: Vec<u32> = read_buffer(&result_yr_buf, num_limbs);
let result_zr_limbs: Vec<u32> = read_buffer(&result_zr_buf, num_limbs);

let result_xr: BigUint = BigInt::from_limbs(&result_xr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_yr: BigUint = BigInt::from_limbs(&result_yr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_zr: BigUint = BigInt::from_limbs(&result_zr_limbs, log_limb_size)
.try_into()
.unwrap();

let result_x = (result_xr * &rinv) % &p;
let result_y = (result_yr * &rinv) % &p;
let result_z = (result_zr * &rinv) % &p;

let result = G::new(result_x.into(), result_y.into(), result_z.into());
assert!(result == expected);
}
4 changes: 3 additions & 1 deletion mopro-msm/src/msm/metal_msm/tests/curve/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[cfg(test)]
pub mod jacobian_add_2007_b1_unsafe;
pub mod jacobian_add_2007_b1;
#[cfg(test)]
pub mod jacobian_dbl_2009_l;
#[cfg(test)]
pub mod jacobian_madd_2007_bl;

0 comments on commit ebc0865

Please sign in to comment.