-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(curve): impl jacobian_madd_2007_bl algorithm and corresponding t…
…ests
- Loading branch information
Showing
4 changed files
with
282 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// source: https://github.com/geometryxyz/msl-secp256k1 | ||
|
||
using namespace metal; | ||
#include <metal_stdlib> | ||
#include <metal_math> | ||
#include "jacobian.metal" | ||
|
||
kernel void run( | ||
device BigInt* prime [[ buffer(0) ]], | ||
device BigInt* a_xr [[ buffer(1) ]], | ||
device BigInt* a_yr [[ buffer(2) ]], | ||
device BigInt* a_zr [[ buffer(3) ]], | ||
device BigInt* b_xr [[ buffer(4) ]], | ||
device BigInt* b_yr [[ buffer(5) ]], | ||
device BigInt* result_xr [[ buffer(6) ]], | ||
device BigInt* result_yr [[ buffer(7) ]], | ||
device BigInt* result_zr [[ buffer(8) ]], | ||
uint gid [[ thread_position_in_grid ]] | ||
) { | ||
BigInt p; p.limbs = prime->limbs; | ||
BigInt x1; x1.limbs = a_xr->limbs; | ||
BigInt y1; y1.limbs = a_yr->limbs; | ||
BigInt z1; z1.limbs = a_zr->limbs; | ||
BigInt x2; x2.limbs = b_xr->limbs; | ||
BigInt y2; y2.limbs = b_yr->limbs; | ||
|
||
Jacobian a; a.x = x1; a.y = y1; a.z = z1; | ||
Affine b; b.x = x2; b.y = y2; | ||
|
||
Jacobian res = jacobian_madd_2007_bl(a, b, p); | ||
result_xr->limbs = res.x.limbs; | ||
result_yr->limbs = res.y.limbs; | ||
result_zr->limbs = res.z.limbs; | ||
} |
172 changes: 172 additions & 0 deletions
172
mopro-msm/src/msm/metal_msm/tests/curve/jacobian_madd_2007_bl.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
// adapted from https://github.com/geometryxyz/msl-secp256k1 | ||
|
||
use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as GAffine, G1Projective as G}; | ||
use ark_ec::{AffineRepr, CurveGroup}; | ||
use metal::*; | ||
|
||
use crate::msm::metal_msm::host::gpu::{ | ||
create_buffer, create_empty_buffer, get_default_device, read_buffer, | ||
}; | ||
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; | ||
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; | ||
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; | ||
use ark_ff::{BigInt, PrimeField}; | ||
use ark_std::{rand::thread_rng, UniformRand}; | ||
use num_bigint::BigUint; | ||
|
||
#[test] | ||
#[serial_test::serial] | ||
pub fn test_jacobian_add_2007_bl() { | ||
let log_limb_size = 16; | ||
let p: BigUint = BaseField::MODULUS.try_into().unwrap(); | ||
|
||
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; | ||
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; | ||
|
||
let r = calc_mont_radix(num_limbs, log_limb_size); | ||
let res = calc_rinv_and_n0(&p, &r, log_limb_size); | ||
let rinv = res.0; | ||
let n0 = res.1; | ||
let nsafe = calc_nsafe(log_limb_size); | ||
|
||
// Generate 2 random affine points | ||
let (a, b) = { | ||
let mut rng = thread_rng(); | ||
let base_point = GAffine::generator().into_group(); | ||
|
||
let s1 = ScalarField::rand(&mut rng); | ||
let mut s2 = ScalarField::rand(&mut rng); | ||
|
||
// Ensure s1 and s2 are different (if s1 == s2, we use pDBL instead of pADD) | ||
while s1 == s2 { | ||
s2 = ScalarField::rand(&mut rng); | ||
} | ||
|
||
(base_point * s1, base_point * s2) | ||
}; | ||
|
||
// Compute the sum in projective form using Arkworks | ||
let expected = a + b; | ||
|
||
// Set B into Affine form | ||
let b = b.into_affine(); | ||
|
||
let ax: BigUint = a.x.into(); | ||
let ay: BigUint = a.y.into(); | ||
let az: BigUint = a.z.into(); | ||
let bx: BigUint = b.x.into(); | ||
let by: BigUint = b.y.into(); | ||
|
||
let axr = (&ax * &r) % &p; | ||
let ayr = (&ay * &r) % &p; | ||
let azr = (&az * &r) % &p; | ||
let bxr = (&bx * &r) % &p; | ||
let byr = (&by * &r) % &p; | ||
|
||
let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size); | ||
let axr_limbs = ark_ff::BigInt::<4>::try_from(axr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let ayr_limbs = ark_ff::BigInt::<4>::try_from(ayr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let azr_limbs = ark_ff::BigInt::<4>::try_from(azr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let bxr_limbs = ark_ff::BigInt::<4>::try_from(bxr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let byr_limbs = ark_ff::BigInt::<4>::try_from(byr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
|
||
let device = get_default_device(); | ||
let prime_buf = create_buffer(&device, &p_limbs); | ||
let axr_buf = create_buffer(&device, &axr_limbs); | ||
let ayr_buf = create_buffer(&device, &ayr_limbs); | ||
let azr_buf = create_buffer(&device, &azr_limbs); | ||
let bxr_buf = create_buffer(&device, &bxr_limbs); | ||
let byr_buf = create_buffer(&device, &byr_limbs); | ||
let result_xr_buf = create_empty_buffer(&device, num_limbs); | ||
let result_yr_buf = create_empty_buffer(&device, num_limbs); | ||
let result_zr_buf = create_empty_buffer(&device, num_limbs); | ||
|
||
let command_queue = device.new_command_queue(); | ||
let command_buffer = command_queue.new_command_buffer(); | ||
|
||
let compute_pass_descriptor = ComputePassDescriptor::new(); | ||
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); | ||
|
||
write_constants( | ||
"../mopro-msm/src/msm/metal_msm/shader", | ||
num_limbs, | ||
log_limb_size, | ||
n0, | ||
nsafe, | ||
); | ||
let library_path = compile_metal( | ||
"../mopro-msm/src/msm/metal_msm/shader/curve", | ||
"jacobian_madd_2007_bl.metal", | ||
); | ||
let library = device.new_library_with_file(library_path).unwrap(); | ||
let kernel = library.get_function("run", None).unwrap(); | ||
|
||
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); | ||
pipeline_state_descriptor.set_compute_function(Some(&kernel)); | ||
|
||
let pipeline_state = device | ||
.new_compute_pipeline_state_with_function( | ||
pipeline_state_descriptor.compute_function().unwrap(), | ||
) | ||
.unwrap(); | ||
|
||
encoder.set_compute_pipeline_state(&pipeline_state); | ||
encoder.set_buffer(0, Some(&prime_buf), 0); | ||
encoder.set_buffer(1, Some(&axr_buf), 0); | ||
encoder.set_buffer(2, Some(&ayr_buf), 0); | ||
encoder.set_buffer(3, Some(&azr_buf), 0); | ||
encoder.set_buffer(4, Some(&bxr_buf), 0); | ||
encoder.set_buffer(5, Some(&byr_buf), 0); | ||
encoder.set_buffer(6, Some(&result_xr_buf), 0); | ||
encoder.set_buffer(7, Some(&result_yr_buf), 0); | ||
encoder.set_buffer(8, Some(&result_zr_buf), 0); | ||
|
||
let thread_group_count = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
let thread_group_size = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); | ||
encoder.end_encoding(); | ||
|
||
command_buffer.commit(); | ||
command_buffer.wait_until_completed(); | ||
|
||
let result_xr_limbs: Vec<u32> = read_buffer(&result_xr_buf, num_limbs); | ||
let result_yr_limbs: Vec<u32> = read_buffer(&result_yr_buf, num_limbs); | ||
let result_zr_limbs: Vec<u32> = read_buffer(&result_zr_buf, num_limbs); | ||
|
||
let result_xr: BigUint = BigInt::from_limbs(&result_xr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
let result_yr: BigUint = BigInt::from_limbs(&result_yr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
let result_zr: BigUint = BigInt::from_limbs(&result_zr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
|
||
let result_x = (result_xr * &rinv) % &p; | ||
let result_y = (result_yr * &rinv) % &p; | ||
let result_z = (result_zr * &rinv) % &p; | ||
|
||
let result = G::new(result_x.into(), result_y.into(), result_z.into()); | ||
assert!(result == expected); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
#[cfg(test)] | ||
pub mod jacobian_add_2007_b1_unsafe; | ||
pub mod jacobian_add_2007_b1; | ||
#[cfg(test)] | ||
pub mod jacobian_dbl_2009_l; | ||
#[cfg(test)] | ||
pub mod jacobian_madd_2007_bl; |