From 32ba1c9e77245bce254c4558d188eb66d088f9eb Mon Sep 17 00:00:00 2001 From: Joey Yandle Date: Tue, 28 Nov 2023 13:16:04 -0500 Subject: [PATCH] Add custom serde for all structs (#75) * add custom serde for scalar point and field * use while let instead of loop with single match branch * impl custom serde with tests for ecdsa signature and public key types * add custom serde for schnorr sig * remove unnecessary context from ecsda signature, since it costs very little to create and having it there causes problems * remove hardwired std::fmt * add Display for structs and errors that were missing it * inc major version for release --- p256k1/Cargo.toml | 3 +- p256k1/src/ecdsa.rs | 111 ++++++++++++++++++++++++++++++------ p256k1/src/field.rs | 74 +++++++++++++++++++++++- p256k1/src/keys.rs | 128 +++++++++++++++++++++++++++++++++++++++++- p256k1/src/point.rs | 91 +++++++++++++++++++++++++++--- p256k1/src/scalar.rs | 71 ++++++++++++++++++++++- p256k1/src/schnorr.rs | 96 ++++++++++++++++++++++++++++--- 7 files changed, 535 insertions(+), 39 deletions(-) diff --git a/p256k1/Cargo.toml b/p256k1/Cargo.toml index 524543e..dca7da5 100644 --- a/p256k1/Cargo.toml +++ b/p256k1/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "p256k1" -version = "5.5.0" +version = "6.0.0" edition = "2021" authors = ["Joey Yandle "] license = "Apache-2.0" @@ -40,6 +40,7 @@ syn = { version = "2.0.10", features = ["full"] } [dev-dependencies] libc = "0.2" criterion = "0.4.0" +serde_json = "1.0" [[bench]] name = "point_bench" diff --git a/p256k1/src/ecdsa.rs b/p256k1/src/ecdsa.rs index 62c3d87..e189865 100644 --- a/p256k1/src/ecdsa.rs +++ b/p256k1/src/ecdsa.rs @@ -1,13 +1,17 @@ +use core::fmt::{Debug, Display, Formatter, Result as FmtResult}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use std::array::TryFromSliceError; use crate::_rename::{ secp256k1_ecdsa_sign, secp256k1_ecdsa_signature_parse_compact, secp256k1_ecdsa_signature_serialize_compact, secp256k1_ecdsa_verify, }; -use crate::bindings::secp256k1_ecdsa_signature; -use crate::context::Context; -use crate::errors::ConversionError; -use crate::scalar::Scalar; +use crate::{ + bindings::secp256k1_ecdsa_signature, context::Context, errors::ConversionError, scalar::Scalar, +}; pub use crate::keys::{Error as KeyError, PublicKey}; @@ -21,8 +25,8 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } @@ -36,23 +40,22 @@ impl From for Error { /** Signature is a wrapper around libsecp256k1's secp256k1_ecdsa_signature struct. */ +#[derive(Debug, Clone)] pub struct Signature { /// The wrapped libsecp256k1 signature pub signature: secp256k1_ecdsa_signature, - /// The context associated with the signature - pub context: Context, } impl Signature { /// Construct an ECDSA signature pub fn new(hash: &[u8], sec_key: &Scalar) -> Result { + let context = Context::default(); let mut sig = Self { signature: secp256k1_ecdsa_signature { data: [0; 64] }, - context: Context::default(), }; if unsafe { secp256k1_ecdsa_sign( - sig.context.context, + context.context, &mut sig.signature, hash.as_ptr(), sec_key.to_bytes().as_ptr(), @@ -68,9 +71,11 @@ impl Signature { /// Verify an ECDSA signature pub fn verify(&self, hash: &[u8], pub_key: &PublicKey) -> bool { + let context = Context::default(); + 1 == unsafe { secp256k1_ecdsa_verify( - self.context.context, + context.context, &self.signature, hash.as_ptr(), &pub_key.key, @@ -80,11 +85,12 @@ impl Signature { /// Returns the signature's deserialized underlying data pub fn to_bytes(&self) -> [u8; 64] { + let context = Context::default(); let mut bytes = [0u8; 64]; //Deserialize the signature's data unsafe { secp256k1_ecdsa_signature_serialize_compact( - self.context.context, + context.context, bytes.as_mut_ptr(), &self.signature, ); @@ -93,6 +99,63 @@ impl Signature { } } +impl Display for Signature { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", bs58::encode(self.to_bytes()).into_string()) + } +} + +impl Serialize for Signature { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct SignatureVisitor; + +impl<'de> Visitor<'de> for SignatureVisitor { + type Value = Signature; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents two scalars on the secp256k1 curve") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match Signature::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for Signature { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(SignatureVisitor) + } +} + impl TryFrom<&[u8]> for Signature { type Error = Error; /// Create an ECDSA signature given a slice of signed data. @@ -108,14 +171,14 @@ impl TryFrom<[u8; 64]> for Signature { /// Create an ECDSA signature given an array of signed data. /// Note it also serializes the data in compact (64 byte) format fn try_from(input: [u8; 64]) -> Result { + let context = Context::default(); let mut sig = Self { signature: secp256k1_ecdsa_signature { data: [0u8; 64] }, - context: Context::default(), }; //Attempt to serialize the data into the signature let parsed = unsafe { secp256k1_ecdsa_signature_parse_compact( - sig.context.context, + context.context, &mut sig.signature, input.as_ptr(), ) @@ -196,7 +259,6 @@ mod tests { let sig_from_struct = Signature { signature: secp256k1_ecdsa_signature { data: bytes }, - context: Context::default(), }; let sig_from_slice = Signature::try_from(bytes.as_slice()).unwrap(); let sig_from_array = Signature::try_from(bytes).unwrap(); @@ -215,7 +277,7 @@ mod tests { } #[test] - fn signature_serde() { + fn manual_serde() { // Generate random data bytes let mut rng = OsRng::default(); let mut bytes = [0u8; 64]; @@ -226,4 +288,21 @@ mod tests { assert_ne!(sig.signature.data, bytes); assert_eq!(sig.to_bytes(), bytes); } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let mut hash = [0u8; 32]; + rng.fill_bytes(&mut hash); + let private_key = Scalar::random(&mut rng); + let public_key = PublicKey::new(&private_key).expect("failed to create public key"); + let sig = Signature::new(&hash, &private_key).expect("failed to create sig"); + + assert!(sig.verify(&hash, &public_key)); + + let ssig = serde_json::to_string(&sig).expect("failed to serialize"); + let dsig: Signature = serde_json::from_str(&ssig).expect("failed to deserialize"); + + assert!(dsig.verify(&hash, &public_key)); + } } diff --git a/p256k1/src/field.rs b/p256k1/src/field.rs index 12c3d29..30bb079 100644 --- a/p256k1/src/field.rs +++ b/p256k1/src/field.rs @@ -9,6 +9,10 @@ use core::{ }; use num_traits::{One, Zero}; use rand_core::{CryptoRng, RngCore}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use crate::_rename::{ secp256k1_fe_add, secp256k1_fe_cmp_var, secp256k1_fe_get_b32, secp256k1_fe_inv, @@ -34,7 +38,13 @@ pub enum Error { Conversion(ConversionError), } -#[derive(Copy, Clone, Debug, serde::Serialize, serde::Deserialize)] +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self) + } +} + +#[derive(Copy, Clone, Debug)] /** Element is a wrapper around libsecp256k1's internal secp256k1_fe struct. It provides a field element, which is like a scalar but not necessarily reduced modulo the group order */ @@ -160,6 +170,58 @@ impl PartialEq for Element { impl Eq for Element {} +impl Serialize for Element { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct ElementVisitor; + +impl<'de> Visitor<'de> for ElementVisitor { + type Value = Element; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter + .write_str("an array of bytes which represents field element for the secp256k1 curve") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match Self::Value::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for Element { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(ElementVisitor) + } +} + impl Hash for Element { fn hash(&self, state: &mut H) { state.write(&self.to_bytes()[..]); @@ -608,4 +670,14 @@ mod tests { assert_eq!(a, c); assert_eq!(s, t); } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let x = Element::random(&mut rng); + let s = serde_json::to_string(&x).expect("failed to serialize"); + let y = serde_json::from_str(&s).expect("failed to deserialize"); + + assert_eq!(x, y); + } } diff --git a/p256k1/src/keys.rs b/p256k1/src/keys.rs index f2eff3e..b1e3700 100644 --- a/p256k1/src/keys.rs +++ b/p256k1/src/keys.rs @@ -1,5 +1,9 @@ use bs58; use core::fmt::{Debug, Display, Formatter, Result as FmtResult}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use std::array::TryFromSliceError; use crate::_rename::{ @@ -29,8 +33,8 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } @@ -100,6 +104,57 @@ impl Display for PublicKey { } } +impl Serialize for PublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct PublicKeyVisitor; + +impl<'de> Visitor<'de> for PublicKeyVisitor { + type Value = PublicKey; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents a ECDSA public key") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match PublicKey::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for PublicKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(PublicKeyVisitor) + } +} + impl TryFrom<&str> for PublicKey { type Error = Error; /// Create a pubkey from the passed byte slice @@ -180,6 +235,57 @@ impl Display for XOnlyPublicKey { } } +impl Serialize for XOnlyPublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct XOnlyPublicKeyVisitor; + +impl<'de> Visitor<'de> for XOnlyPublicKeyVisitor { + type Value = XOnlyPublicKey; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents a ECDSA public key") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match XOnlyPublicKey::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for XOnlyPublicKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(XOnlyPublicKeyVisitor) + } +} + impl TryFrom<&str> for XOnlyPublicKey { type Error = Error; /// Create a pubkey from the passed byte slice @@ -386,4 +492,22 @@ mod tests { assert_eq!(pubkey.to_bytes(), pubkey2.to_bytes()); assert_eq!(pubkey.to_bytes(), pubkey3.to_bytes()); } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let private_key = Scalar::random(&mut rng); + let public_key = PublicKey::new(&private_key).expect("failed to create public key"); + let ser = serde_json::to_string(&public_key).expect("failed to serialize"); + let deser: PublicKey = serde_json::from_str(&ser).expect("failed to deserialize"); + + assert_eq!(public_key.to_bytes(), deser.to_bytes()); + + let xonly_public_key = + XOnlyPublicKey::new(&private_key).expect("failed to create XOnlyPublicKey"); + let xoser = serde_json::to_string(&xonly_public_key).expect("failed to serialize"); + let xodeser: XOnlyPublicKey = serde_json::from_str(&xoser).expect("failed to deserialize"); + + assert_eq!(xonly_public_key.to_bytes(), xodeser.to_bytes()); + } } diff --git a/p256k1/src/point.rs b/p256k1/src/point.rs index b257bca..ee283c9 100644 --- a/p256k1/src/point.rs +++ b/p256k1/src/point.rs @@ -11,8 +11,18 @@ use core::{ }; use num_traits::Zero; use primitive_types::U256; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use std::os::raw::c_void; +use crate::_rename::{ + secp256k1_ecmult, secp256k1_ecmult_multi_var, secp256k1_fe_get_b32, secp256k1_fe_is_odd, + secp256k1_fe_normalize_var, secp256k1_fe_set_b32, secp256k1_ge_set_xo_var, + secp256k1_gej_add_var, secp256k1_gej_neg, secp256k1_gej_set_ge, secp256k1_scratch_space_create, + secp256k1_scratch_space_destroy, +}; use crate::{ bindings::{ secp256k1_callback, secp256k1_ecmult_multi_callback, secp256k1_fe, secp256k1_ge, @@ -26,13 +36,6 @@ use crate::{ traits::MultiMult, }; -use crate::_rename::{ - secp256k1_ecmult, secp256k1_ecmult_multi_var, secp256k1_fe_get_b32, secp256k1_fe_is_odd, - secp256k1_fe_normalize_var, secp256k1_fe_set_b32, secp256k1_ge_set_xo_var, - secp256k1_gej_add_var, secp256k1_gej_neg, secp256k1_gej_set_ge, secp256k1_scratch_space_create, - secp256k1_scratch_space_destroy, -}; - /// The secp256k1 base point pub const G: Point = Point { gej: secp256k1_gej { @@ -76,7 +79,13 @@ pub enum Error { LiftFailed, } -#[derive(Copy, Clone, serde::Serialize, serde::Deserialize)] +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self) + } +} + +#[derive(Copy, Clone)] /** Point is a wrapper around libsecp256k1's internal secp256k1_gej struct. It provides a point on the secp256k1 curve in Jacobian coordinates. This allows for extremely fast curve point operations, and avoids expensive conversions from byte buffers. */ @@ -337,6 +346,60 @@ impl PartialEq for Point { impl Eq for Point {} +impl Serialize for Point { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(self.compress().as_bytes()) + } +} + +struct PointVisitor; + +impl<'de> Visitor<'de> for PointVisitor { + type Value = Point; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents a point on the secp256k1 curve") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match Compressed::try_from(value) { + Ok(c) => match Point::try_from(&c) { + Ok(p) => Ok(p), + Err(e) => Err(E::custom(format!("{:?}", e))), + }, + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for Point { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(PointVisitor) + } +} + impl Hash for Point { fn hash(&self, state: &mut H) { state.write(&self.compress().data[..]); @@ -806,7 +869,7 @@ mod tests { } #[test] #[allow(non_snake_case)] - fn test_bip_340() { + fn bip_340() { let mut rng = OsRng::default(); for _ in 0..0xff { @@ -827,4 +890,14 @@ mod tests { } } } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let p = Point::from(Scalar::random(&mut rng)); + let s = serde_json::to_string(&p).expect("failed to serialize"); + let q = serde_json::from_str(&s).expect("failed to deserialize"); + + assert_eq!(p, q); + } } diff --git a/p256k1/src/scalar.rs b/p256k1/src/scalar.rs index 495bb18..25945f9 100644 --- a/p256k1/src/scalar.rs +++ b/p256k1/src/scalar.rs @@ -9,6 +9,10 @@ use core::{ }; use num_traits::{One, Zero}; use rand_core::{CryptoRng, RngCore}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use crate::_rename::{ secp256k1_ecmult, secp256k1_scalar_add, secp256k1_scalar_eq, secp256k1_scalar_get_b32, @@ -27,13 +31,13 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } -#[derive(Copy, Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Copy, Clone, Debug)] /** Scalar is a wrapper around libsecp256k1's internal secp256k1_scalar struct. It provides a scalar modulo the group order. Storing scalars in this format avoids unnecessary conversions from byte bffers, which provides a significant performance enhancement. */ @@ -169,6 +173,57 @@ impl PartialEq for Scalar { impl Eq for Scalar {} +impl Serialize for Scalar { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct ScalarVisitor; + +impl<'de> Visitor<'de> for ScalarVisitor { + type Value = Scalar; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents a scalar for the secp256k1 curve") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match Self::Value::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for Scalar { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(ScalarVisitor) + } +} + impl Hash for Scalar { fn hash(&self, state: &mut H) { state.write(&self.to_bytes()[..]); @@ -770,4 +825,14 @@ mod tests { assert_eq!(a, c); assert_eq!(s, t); } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let x = Scalar::random(&mut rng); + let s = serde_json::to_string(&x).expect("failed to serialize"); + let y = serde_json::from_str(&s).expect("failed to deserialize"); + + assert_eq!(x, y); + } } diff --git a/p256k1/src/schnorr.rs b/p256k1/src/schnorr.rs index d118049..4be3b3f 100644 --- a/p256k1/src/schnorr.rs +++ b/p256k1/src/schnorr.rs @@ -1,10 +1,17 @@ +use core::fmt::{Debug, Display, Formatter, Result as FmtResult}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use std::array::TryFromSliceError; use crate::_rename::{secp256k1_schnorrsig_sign32, secp256k1_schnorrsig_verify}; -use crate::context::Context; -use crate::errors::ConversionError; -use crate::keys::{Error as KeyError, KeyPair, XOnlyPublicKey}; -use crate::scalar::Scalar; +use crate::{ + context::Context, + errors::ConversionError, + keys::{Error as KeyError, KeyPair, XOnlyPublicKey}, + scalar::Scalar, +}; #[derive(Debug, Clone)] /// Errors in Schnorr signature operations @@ -18,8 +25,8 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } @@ -39,6 +46,7 @@ impl From for Error { /** Signature is a wrapper around libsecp256k1's secp256k1_schnorr_signature struct. */ +#[derive(Debug, Clone)] pub struct Signature { /// The wrapped libsecp256k1 signature pub data: [u8; 64], @@ -89,6 +97,63 @@ impl Signature { } } +impl Display for Signature { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", bs58::encode(self.to_bytes()).into_string()) + } +} + +impl Serialize for Signature { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +struct SignatureVisitor; + +impl<'de> Visitor<'de> for SignatureVisitor { + type Value = Signature; + + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { + formatter.write_str("an array of bytes which represents two scalars on the secp256k1 curve") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match Signature::try_from(value) { + Ok(s) => Ok(s), + Err(e) => Err(E::custom(format!("{:?}", e))), + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut v = Vec::new(); + + while let Ok(Some(x)) = seq.next_element() { + v.push(x); + } + + self.visit_bytes(&v) + } +} + +impl<'de> Deserialize<'de> for Signature { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_bytes(SignatureVisitor) + } +} + impl TryFrom<&[u8]> for Signature { type Error = Error; /// Create an Schnorr signature given a slice of signed data. @@ -108,7 +173,7 @@ impl From<[u8; 64]> for Signature { #[cfg(test)] mod tests { use super::*; - use rand_core::OsRng; + use rand_core::{OsRng, RngCore}; use sha2::{Digest, Sha256}; #[test] @@ -151,4 +216,21 @@ mod tests { let bad_msg_hash = hasher.finalize(); assert!(!sig.verify(&bad_msg_hash, &pub_key)); } + + #[test] + fn custom_serde() { + let mut rng = OsRng::default(); + let mut hash = [0u8; 32]; + rng.fill_bytes(&mut hash); + let private_key = Scalar::random(&mut rng); + let public_key = XOnlyPublicKey::new(&private_key).expect("failed to create public key"); + let sig = Signature::new(&hash, &private_key).expect("failed to create sig"); + + assert!(sig.verify(&hash, &public_key)); + + let ssig = serde_json::to_string(&sig).expect("failed to serialize"); + let dsig: Signature = serde_json::from_str(&ssig).expect("failed to deserialize"); + + assert!(dsig.verify(&hash, &public_key)); + } }