From 119529f807d1b9881d8117692bcbf16f7fd4b4ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5rten=20Blankfors?= Date: Thu, 21 Mar 2024 13:27:13 +0100 Subject: [PATCH] feat: Implement missing common traits for public types It is recommended to eagerly implement some traits, listed in https://rust-lang.github.io/api-guidelines/interoperability.html#c-common-traits While many central types already implement some of these traits, this commit fills in the gap with the remaining ones. --- p256k1/src/ecdsa.rs | 56 +++++++++++- p256k1/src/errors.rs | 9 ++ p256k1/src/field.rs | 38 +++++++++ p256k1/src/keys.rs | 192 +++++++++++++++++++++++++++++++++++++++--- p256k1/src/point.rs | 2 + p256k1/src/scalar.rs | 2 + p256k1/src/schnorr.rs | 56 +++++++++++- 7 files changed, 340 insertions(+), 15 deletions(-) diff --git a/p256k1/src/ecdsa.rs b/p256k1/src/ecdsa.rs index e189865..2c8c60a 100644 --- a/p256k1/src/ecdsa.rs +++ b/p256k1/src/ecdsa.rs @@ -3,7 +3,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, Serializer, }; -use std::array::TryFromSliceError; +use std::{array::TryFromSliceError, hash::Hash}; use crate::_rename::{ secp256k1_ecdsa_sign, secp256k1_ecdsa_signature_parse_compact, @@ -25,12 +25,15 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } + impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } +impl std::error::Error for Error {} + impl From for Error { fn from(e: TryFromSliceError) -> Self { Error::TryFrom(e.to_string()) @@ -105,6 +108,32 @@ impl Display for Signature { } } +impl PartialEq for Signature { + fn eq(&self, other: &Self) -> bool { + self.to_bytes().eq(&other.to_bytes()) + } +} + +impl Eq for Signature {} + +impl PartialOrd for Signature { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Signature { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.to_bytes().cmp(&other.to_bytes()) + } +} + +impl Hash for Signature { + fn hash(&self, state: &mut H) { + state.write(&self.to_bytes()) + } +} + impl Serialize for Signature { fn serialize(&self, serializer: S) -> Result where @@ -197,7 +226,7 @@ mod tests { use super::*; use rand_core::{OsRng, RngCore}; use sha2::{Digest, Sha256}; - use std::thread; + use std::{collections::HashSet, thread}; #[test] fn signature_generation() { @@ -305,4 +334,27 @@ mod tests { assert!(dsig.verify(&hash, &public_key)); } + + #[test] + fn hash() { + let msg = [0u8; 32]; + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let signatures = private_keys.map(|pk| Signature::new(&msg, &pk).unwrap()); + + let signatures_hash_set: HashSet<_> = signatures.into(); + + assert_eq!(signatures_hash_set.len(), 5); + } + + #[test] + fn sort() { + let msg = [0u8; 32]; + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let mut signatures = private_keys.map(|pk| Signature::new(&msg, &pk).unwrap()); + signatures.sort(); + + for idx in 0..4 { + assert!(signatures[idx] < signatures[idx + 1]); + } + } } diff --git a/p256k1/src/errors.rs b/p256k1/src/errors.rs index 18e5110..39d374a 100644 --- a/p256k1/src/errors.rs +++ b/p256k1/src/errors.rs @@ -31,3 +31,12 @@ pub enum ConversionError { /// Error converting a base58-related value Base58(Base58Error), } + +impl Display for ConversionError { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for Base58Error {} +impl std::error::Error for ConversionError {} diff --git a/p256k1/src/field.rs b/p256k1/src/field.rs index 0aba451..05f296e 100644 --- a/p256k1/src/field.rs +++ b/p256k1/src/field.rs @@ -13,6 +13,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, Serializer, }; +use std::cmp::Ordering; use crate::_rename::{ secp256k1_fe_add, secp256k1_fe_cmp_var, secp256k1_fe_get_b32, secp256k1_fe_inv, @@ -44,6 +45,8 @@ impl Display for Error { } } +impl std::error::Error for Error {} + #[derive(Copy, Clone, Debug)] /** Element is a wrapper around libsecp256k1's internal secp256k1_fe struct. It provides a field element, which is like a scalar but not necessarily reduced modulo the group order @@ -170,6 +173,23 @@ impl PartialEq for Element { impl Eq for Element {} +impl PartialOrd for Element { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Element { + fn cmp(&self, other: &Self) -> Ordering { + match unsafe { secp256k1_fe_cmp_var(&self.fe, &other.fe) } { + -1 => Ordering::Less, + 0 => Ordering::Equal, + 1 => Ordering::Greater, + _ => panic!("secp256k1_fe_cmp_var returned unexpected result"), // Unreachable + } + } +} + impl Serialize for Element { fn serialize(&self, serializer: S) -> Result where @@ -657,6 +677,24 @@ mod tests { } } + #[test] + fn cmp() { + let left = Element::from(1); + let right = Element::from(2); + + assert!(left < right); + assert!(right > left); + } + + #[test] + fn sort() { + let sorted = [1, 2, 3, 4, 5].map(Element::from); + let mut unsorted = [4, 2, 3, 1, 5].map(Element::from); + unsorted.sort(); + + assert_eq!(unsorted, sorted); + } + #[test] fn base58() { let mut rng = OsRng::default(); diff --git a/p256k1/src/keys.rs b/p256k1/src/keys.rs index 4a9d449..37233c4 100644 --- a/p256k1/src/keys.rs +++ b/p256k1/src/keys.rs @@ -4,7 +4,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, Serializer, }; -use std::array::TryFromSliceError; +use std::{array::TryFromSliceError, hash::Hash}; use crate::_rename::{ secp256k1_ec_pubkey_create, secp256k1_ec_pubkey_parse, secp256k1_ec_pubkey_serialize, @@ -33,12 +33,15 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } + impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } +impl std::error::Error for Error {} + impl From for Error { fn from(e: TryFromSliceError) -> Self { Error::TryFrom(e.to_string()) @@ -104,6 +107,32 @@ impl Display for PublicKey { } } +impl PartialEq for PublicKey { + fn eq(&self, other: &Self) -> bool { + self.to_bytes().eq(&other.to_bytes()) + } +} + +impl Eq for PublicKey {} + +impl PartialOrd for PublicKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PublicKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.to_bytes().cmp(&other.to_bytes()) + } +} + +impl Hash for PublicKey { + fn hash(&self, state: &mut H) { + state.write(&self.to_bytes()) + } +} + impl Serialize for PublicKey { fn serialize(&self, serializer: S) -> Result where @@ -155,6 +184,14 @@ impl<'de> Deserialize<'de> for PublicKey { } } +impl TryFrom<&Scalar> for PublicKey { + type Error = Error; + + fn try_from(value: &Scalar) -> Result { + Self::new(value) + } +} + impl TryFrom<&str> for PublicKey { type Error = Error; /// Create a pubkey from the passed byte slice @@ -235,6 +272,32 @@ impl Display for XOnlyPublicKey { } } +impl PartialEq for XOnlyPublicKey { + fn eq(&self, other: &Self) -> bool { + self.to_bytes().eq(&other.to_bytes()) + } +} + +impl Eq for XOnlyPublicKey {} + +impl PartialOrd for XOnlyPublicKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for XOnlyPublicKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.to_bytes().cmp(&other.to_bytes()) + } +} + +impl Hash for XOnlyPublicKey { + fn hash(&self, state: &mut H) { + state.write(&self.to_bytes()) + } +} + impl Serialize for XOnlyPublicKey { fn serialize(&self, serializer: S) -> Result where @@ -286,6 +349,14 @@ impl<'de> Deserialize<'de> for XOnlyPublicKey { } } +impl TryFrom<&Scalar> for XOnlyPublicKey { + type Error = Error; + + fn try_from(value: &Scalar) -> Result { + Self::new(value) + } +} + impl TryFrom<&str> for XOnlyPublicKey { type Error = Error; /// Create a pubkey from the passed byte slice @@ -366,6 +437,14 @@ impl KeyPair { } } +impl TryFrom<&Scalar> for KeyPair { + type Error = Error; + + fn try_from(value: &Scalar) -> Result { + Self::new(value) + } +} + impl Debug for KeyPair { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("KeyPair") @@ -374,6 +453,36 @@ impl Debug for KeyPair { } } +impl PartialEq for KeyPair { + fn eq(&self, other: &Self) -> bool { + self == other + } +} + +impl Eq for KeyPair {} + +impl PartialOrd for KeyPair { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for KeyPair { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let p1: PublicKey = self.into(); + let p2: PublicKey = other.into(); + + p1.cmp(&p2) + } +} + +impl Hash for KeyPair { + fn hash(&self, state: &mut H) { + let p1: PublicKey = self.into(); + p1.hash(state); + } +} + impl From<&KeyPair> for Scalar { fn from(kp: &KeyPair) -> Scalar { let mut bytes = [0u8; 32]; @@ -419,6 +528,8 @@ impl From<&KeyPair> for XOnlyPublicKey { #[cfg(test)] mod tests { + use std::collections::HashSet; + use super::*; use crate::point::Point; use rand_core::OsRng; @@ -432,12 +543,12 @@ mod tests { //Serialize with try_from and deseriailze with to_bytes let pubkey2 = PublicKey::try_from(pubkey.to_bytes().as_slice()).unwrap(); - assert_eq!(pubkey2.to_bytes(), pubkey.to_bytes()); + assert_eq!(pubkey2, pubkey); //Serialize again with str let pubkey3 = PublicKey::try_from(format!("{}", &pubkey).as_str()).unwrap(); - assert_eq!(pubkey3.to_bytes(), pubkey.to_bytes()); - assert_eq!(pubkey3.to_bytes(), pubkey2.to_bytes()); + assert_eq!(pubkey3, pubkey); + assert_eq!(pubkey3, pubkey2); } #[test] @@ -449,12 +560,12 @@ mod tests { //Serialize with try_from and deseriailze with to_bytes let xopubkey2 = XOnlyPublicKey::try_from(xopubkey.to_bytes().as_slice()).unwrap(); - assert_eq!(xopubkey2.to_bytes(), xopubkey.to_bytes()); + assert_eq!(xopubkey2, xopubkey); //Serialize again with str let xopubkey3 = XOnlyPublicKey::try_from(format!("{}", &xopubkey).as_str()).unwrap(); - assert_eq!(xopubkey3.to_bytes(), xopubkey.to_bytes()); - assert_eq!(xopubkey3.to_bytes(), xopubkey2.to_bytes()); + assert_eq!(xopubkey3, xopubkey); + assert_eq!(xopubkey3, xopubkey2); } #[test] @@ -485,12 +596,12 @@ mod tests { let pubkey2 = PublicKey::new(&scalar).unwrap(); let pubkey3 = PublicKey::new(&seckey).unwrap(); - assert_eq!(scalar.to_bytes(), seckey.to_bytes()); + assert_eq!(scalar, seckey); assert_eq!(xopubkey.to_bytes(), point.x().to_bytes()); assert_eq!(xopubkey2.to_bytes(), point.x().to_bytes()); assert_eq!(xopubkey3.to_bytes(), point.x().to_bytes()); - assert_eq!(pubkey.to_bytes(), pubkey2.to_bytes()); - assert_eq!(pubkey.to_bytes(), pubkey3.to_bytes()); + assert_eq!(pubkey, pubkey2); + assert_eq!(pubkey, pubkey3); } #[test] @@ -501,13 +612,70 @@ mod tests { let ser = serde_json::to_string(&public_key).expect("failed to serialize"); let deser: PublicKey = serde_json::from_str(&ser).expect("failed to deserialize"); - assert_eq!(public_key.to_bytes(), deser.to_bytes()); + assert_eq!(public_key, deser); let xonly_public_key = XOnlyPublicKey::new(&private_key).expect("failed to create XOnlyPublicKey"); let xoser = serde_json::to_string(&xonly_public_key).expect("failed to serialize"); let xodeser: XOnlyPublicKey = serde_json::from_str(&xoser).expect("failed to deserialize"); - assert_eq!(xonly_public_key.to_bytes(), xodeser.to_bytes()); + assert_eq!(xonly_public_key, xodeser); + } + + #[test] + fn pubkey_hash() { + hash_test::(); + } + + #[test] + fn pubkey_sort() { + sort_test::(); + } + + #[test] + fn xonlykey_hash() { + hash_test::(); + } + + #[test] + fn xonlykey_sort() { + sort_test::(); + } + + #[test] + fn keypair_hash() { + hash_test::(); + } + + #[test] + fn keypair_sort() { + sort_test::(); + } + + fn hash_test() + where + K: for<'a> TryFrom<&'a Scalar> + Hash + Eq, + for<'a> >::Error: Debug, + { + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let public_keys = private_keys.map(|pk| K::try_from(&pk).unwrap()); + + let public_keys_hash_set: HashSet<_> = public_keys.into(); + + assert_eq!(public_keys_hash_set.len(), 5); + } + + fn sort_test() + where + K: for<'a> TryFrom<&'a Scalar> + Hash + Ord, + for<'a> >::Error: Debug, + { + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let mut public_keys = private_keys.map(|pk| K::try_from(&pk).unwrap()); + public_keys.sort(); + + for idx in 0..4 { + assert!(public_keys[idx] < public_keys[idx + 1]); + } } } diff --git a/p256k1/src/point.rs b/p256k1/src/point.rs index 0d7248b..4740548 100644 --- a/p256k1/src/point.rs +++ b/p256k1/src/point.rs @@ -85,6 +85,8 @@ impl Display for Error { } } +impl std::error::Error for Error {} + #[derive(Copy, Clone)] /** Point is a wrapper around libsecp256k1's internal secp256k1_gej struct. It provides a point on the secp256k1 curve in Jacobian coordinates. This allows for extremely fast curve point operations, and avoids expensive conversions from byte buffers. diff --git a/p256k1/src/scalar.rs b/p256k1/src/scalar.rs index 2e4f84f..47069e7 100644 --- a/p256k1/src/scalar.rs +++ b/p256k1/src/scalar.rs @@ -37,6 +37,8 @@ impl Display for Error { } } +impl std::error::Error for Error {} + #[derive(Copy, Clone, Debug)] /** Scalar is a wrapper around libsecp256k1's internal secp256k1_scalar struct. It provides a scalar modulo the group order. Storing scalars in this format avoids unnecessary conversions from byte bffers, which provides a significant performance enhancement. diff --git a/p256k1/src/schnorr.rs b/p256k1/src/schnorr.rs index 4be3b3f..40313fc 100644 --- a/p256k1/src/schnorr.rs +++ b/p256k1/src/schnorr.rs @@ -3,7 +3,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, Serializer, }; -use std::array::TryFromSliceError; +use std::{array::TryFromSliceError, hash::Hash}; use crate::_rename::{secp256k1_schnorrsig_sign32, secp256k1_schnorrsig_verify}; use crate::{ @@ -25,12 +25,15 @@ pub enum Error { /// Error converting a scalar Conversion(ConversionError), } + impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{:?}", self) } } +impl std::error::Error for Error {} + impl From for Error { fn from(e: TryFromSliceError) -> Self { Error::TryFrom(e.to_string()) @@ -103,6 +106,32 @@ impl Display for Signature { } } +impl PartialEq for Signature { + fn eq(&self, other: &Self) -> bool { + self.to_bytes().eq(&other.to_bytes()) + } +} + +impl Eq for Signature {} + +impl PartialOrd for Signature { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Signature { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.to_bytes().cmp(&other.to_bytes()) + } +} + +impl Hash for Signature { + fn hash(&self, state: &mut H) { + state.write(&self.to_bytes()) + } +} + impl Serialize for Signature { fn serialize(&self, serializer: S) -> Result where @@ -172,6 +201,8 @@ impl From<[u8; 64]> for Signature { #[cfg(test)] mod tests { + use std::collections::HashSet; + use super::*; use rand_core::{OsRng, RngCore}; use sha2::{Digest, Sha256}; @@ -233,4 +264,27 @@ mod tests { assert!(dsig.verify(&hash, &public_key)); } + + #[test] + fn hash() { + let msg = [0u8; 32]; + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let signatures = private_keys.map(|pk| Signature::new(&msg, &pk).unwrap()); + + let signatures_hash_set: HashSet<_> = signatures.into(); + + assert_eq!(signatures_hash_set.len(), 5); + } + + #[test] + fn sort() { + let msg = [0u8; 32]; + let private_keys = [1, 2, 3, 4, 5].map(Scalar::from); + let mut signatures = private_keys.map(|pk| Signature::new(&msg, &pk).unwrap()); + signatures.sort(); + + for idx in 0..4 { + assert!(signatures[idx] < signatures[idx + 1]); + } + } }