diff --git a/docs/predicate/predicates.md b/docs/predicate/predicates.md index d5ec691f..7c7412cf 100644 --- a/docs/predicate/predicates.md +++ b/docs/predicate/predicates.md @@ -198,7 +198,83 @@ Examples: - did_components('did:example:123456?versionId=1', did(Method, ID, Path, Query, Fragment)). # Reconstruct a DID from its components. -- did_components(DID, did('example', '123456', null, 'versionId=1', _42)). +- did_components(DID, did('example', '123456', _, 'versionId=1', _42)). +``` + +## ecdsa_verify/4 + +ecdsa_verify/4 determines if a given signature is valid as per the ECDSA algorithm for the provided data, using the specified public key. + +The signature is as follows: + +```text +ecdsa_verify(+PubKey, +Data, +Signature, +Options), which is semi-deterministic. +``` + +Where: + +- PubKey is the 33\-byte compressed public key, as specified in section 4.3.6 of ANSI X9.62. + +- Data is the hash of the signed message, which can be either an atom or a list of bytes. + +- Signature represents the ASN.1 encoded signature corresponding to the Data. + +- Options are additional configurations for the verification process. Supported options include: encoding\(\+Format\) which specifies the encoding used for the data, and type\(\+Alg\) which chooses the algorithm within the ECDSA family \(see below for details\). + +For Format, the supported encodings are: + +- hex \(default\), the hexadecimal encoding represented as an atom. +- octet, the plain byte encoding depicted as a list of integers ranging from 0 to 255. + +For Alg, the supported algorithms are: + +- secp256r1 \(default\): Also known as P\-256 and prime256v1. +- secp256k1: The Koblitz elliptic curve used in Bitcoin's public\-key cryptography. + +Examples: + +```text +# Verify a signature for hexadecimal data using the ECDSA secp256r1 algorithm. +- ecdsa_verify([127, ...], '9b038f8ef6918cbb56040dfda401b56b...', [23, 56, ...], encoding(hex)) + +# Verify a signature for binary data using the ECDSA secp256k1 algorithm. +- ecdsa_verify([127, ...], [56, 90, ..], [23, 56, ...], [encoding(octet), type(secp256k1)]) +``` + +## eddsa_verify/4 + +eddsa_verify/4 determines if a given signature is valid as per the EdDSA algorithm for the provided data, using the specified public key. + +The signature is as follows: + +```text +eddsa_verify(+PubKey, +Data, +Signature, +Options) is semi-det +``` + +Where: + +- PubKey is the encoded public key as a list of bytes. +- Data is the message to verify, represented as either a hexadecimal atom or a list of bytes. It's important that the message isn't pre\-hashed since the Ed25519 algorithm processes messages in two passes when signing. +- Signature represents the signature corresponding to the data, provided as a list of bytes. +- Options are additional configurations for the verification process. Supported options include: encoding\(\+Format\) which specifies the encoding used for the Data, and type\(\+Alg\) which chooses the algorithm within the EdDSA family \(see below for details\). + +For Format, the supported encodings are: + +- hex \(default\), the hexadecimal encoding represented as an atom. +- octet, the plain byte encoding depicted as a list of integers ranging from 0 to 255. + +For Alg, the supported algorithms are: + +- ed25519 \(default\): The EdDSA signature scheme using SHA\-512 \(SHA\-2\) and Curve25519. + +Examples: + +```text +# Verify a signature for a given hexadecimal data. +- eddsa_verify([127, ...], '9b038f8ef6918cbb56040dfda401b56b...', [23, 56, ...], [encoding(hex), type(ed25519)]) + +# Verify a signature for binary data. +- eddsa_verify([127, ...], [56, 90, ..], [23, 56, ...], [encoding(octet), type(ed25519)]) ``` ## hex_bytes/2 diff --git a/go.mod b/go.mod index b6f05623..f0f3dc0e 100644 --- a/go.mod +++ b/go.mod @@ -101,6 +101,7 @@ require ( github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/dustinxie/ecc v0.0.0-20210511000915-959544187564 // indirect github.com/dvsekhvalnov/jose2go v1.5.0 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/fatih/color v1.15.0 // indirect diff --git a/go.sum b/go.sum index 8e4d28f7..9a12ab65 100644 --- a/go.sum +++ b/go.sum @@ -508,6 +508,8 @@ github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:Htrtb github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/dustinxie/ecc v0.0.0-20210511000915-959544187564 h1:I6KUy4CI6hHjqnyJLNCEi7YHVMkwwtfSr2k9splgdSM= +github.com/dustinxie/ecc v0.0.0-20210511000915-959544187564/go.mod h1:yekO+3ZShy19S+bsmnERmznGy9Rfg6dWWWpiGJjNAz8= github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= diff --git a/x/logic/interpreter/registry.go b/x/logic/interpreter/registry.go index 63833c2d..4c0f84ce 100644 --- a/x/logic/interpreter/registry.go +++ b/x/logic/interpreter/registry.go @@ -116,6 +116,8 @@ var registry = map[string]any{ "json_prolog/2": predicate.JSONProlog, "uri_encoded/3": predicate.URIEncoded, "read_string/3": predicate.ReadString, + "eddsa_verify/4": predicate.EDDSAVerify, + "ecdsa_verify/4": predicate.ECDSAVerify, } // RegistryNames is the list of the predicate names in the Registry. diff --git a/x/logic/predicate/address_test.go b/x/logic/predicate/address_test.go index 08fc6fdd..f97306ca 100644 --- a/x/logic/predicate/address_test.go +++ b/x/logic/predicate/address_test.go @@ -89,7 +89,7 @@ func TestBech32(t *testing.T) { }, { query: `bech32_address(-('okp4', ['8956',167,23,244,162,175,49,162,170,15,181,141,68,134,141,168,18,56,247,30]), Bech32).`, - wantError: fmt.Errorf("bech32_address/2: failed to convert term to bytes list: invalid term type in list engine.Atom, only integer allowed"), + wantError: fmt.Errorf("bech32_address/2: failed to convert term to bytes list: invalid term type in list at position 1: engine.Atom, only engine.Integer allowed"), wantSuccess: false, }, { diff --git a/x/logic/predicate/atom.go b/x/logic/predicate/atom.go index 14e343f0..4f79fc45 100644 --- a/x/logic/predicate/atom.go +++ b/x/logic/predicate/atom.go @@ -4,29 +4,31 @@ import ( "github.com/ichiban/prolog/engine" ) -// AtomPair are terms with principal functor (-)/2. -// For example, the term -(A, B) denotes the pair of elements A and B. -var AtomPair = engine.NewAtom("-") +var ( + // AtomPair are terms with principal functor (-)/2. + // For example, the term -(A, B) denotes the pair of elements A and B. + AtomPair = engine.NewAtom("-") -// AtomJSON are terms with principal functor json/1. -// It is used to represent json objects. -var AtomJSON = engine.NewAtom("json") + // AtomJSON are terms with principal functor json/1. + // It is used to represent json objects. + AtomJSON = engine.NewAtom("json") -// AtomAt are terms with principal functor (@)/1. -// It is used to represent special values in json objects. -var AtomAt = engine.NewAtom("@") + // AtomAt are terms with principal functor (@)/1. + // It is used to represent special values in json objects. + AtomAt = engine.NewAtom("@") -// AtomTrue is the term true. -var AtomTrue = engine.NewAtom("true") + // AtomTrue is the term true. + AtomTrue = engine.NewAtom("true") -// AtomFalse is the term false. -var AtomFalse = engine.NewAtom("false") + // AtomFalse is the term false. + AtomFalse = engine.NewAtom("false") -// AtomEmptyArray is the term []. -var AtomEmptyArray = engine.NewAtom("[]") + // AtomEmptyArray is the term []. + AtomEmptyArray = engine.NewAtom("[]") -// AtomNull is the term null. -var AtomNull = engine.NewAtom("null") + // AtomNull is the term null. + AtomNull = engine.NewAtom("null") +) // MakeNull returns the compound term @(null). // It is used to represent the null value in json objects. diff --git a/x/logic/predicate/crypto.go b/x/logic/predicate/crypto.go index 0f9bfa3c..74f8f514 100644 --- a/x/logic/predicate/crypto.go +++ b/x/logic/predicate/crypto.go @@ -4,10 +4,12 @@ import ( "context" "encoding/hex" "fmt" + "slices" + "strings" "github.com/ichiban/prolog/engine" - "github.com/cometbft/cometbft/crypto" + cometcrypto "github.com/cometbft/cometbft/crypto" "github.com/okp4/okp4d/x/logic/util" ) @@ -35,7 +37,7 @@ func SHAHash(vm *engine.VM, data, hash engine.Term, cont engine.Cont, env *engin var result []byte switch d := env.Resolve(data).(type) { case engine.Atom: - result = crypto.Sha256([]byte(d.String())) + result = cometcrypto.Sha256([]byte(d.String())) return engine.Unify(vm, hash, BytesToList(result), cont, env) default: return engine.Error(fmt.Errorf("sha_hash/2: invalid data type: %T, should be Atom", d)) @@ -97,3 +99,132 @@ func HexBytes(vm *engine.VM, hexa, bts engine.Term, cont engine.Cont, env *engin } }) } + +// EDDSAVerify determines if a given signature is valid as per the EdDSA algorithm for the provided data, using the +// specified public key. +// +// The signature is as follows: +// +// eddsa_verify(+PubKey, +Data, +Signature, +Options) is semi-det +// +// Where: +// - PubKey is the encoded public key as a list of bytes. +// - Data is the message to verify, represented as either a hexadecimal atom or a list of bytes. +// It's important that the message isn't pre-hashed since the Ed25519 algorithm processes +// messages in two passes when signing. +// - Signature represents the signature corresponding to the data, provided as a list of bytes. +// - Options are additional configurations for the verification process. Supported options include: +// encoding(+Format) which specifies the encoding used for the Data, and type(+Alg) which chooses the algorithm +// within the EdDSA family (see below for details). +// +// For Format, the supported encodings are: +// +// - hex (default), the hexadecimal encoding represented as an atom. +// - octet, the plain byte encoding depicted as a list of integers ranging from 0 to 255. +// +// For Alg, the supported algorithms are: +// +// - ed25519 (default): The EdDSA signature scheme using SHA-512 (SHA-2) and Curve25519. +// +// Examples: +// +// # Verify a signature for a given hexadecimal data. +// - eddsa_verify([127, ...], '9b038f8ef6918cbb56040dfda401b56b...', [23, 56, ...], [encoding(hex), type(ed25519)]) +// +// # Verify a signature for binary data. +// - eddsa_verify([127, ...], [56, 90, ..], [23, 56, ...], [encoding(octet), type(ed25519)]) +func EDDSAVerify(_ *engine.VM, key, data, sig, options engine.Term, cont engine.Cont, env *engine.Env) *engine.Promise { + return xVerify("eddsa_verify/4", key, data, sig, options, util.Ed25519, []util.Alg{util.Ed25519}, cont, env) +} + +// ECDSAVerify determines if a given signature is valid as per the ECDSA algorithm for the provided data, using the +// specified public key. +// +// The signature is as follows: +// +// ecdsa_verify(+PubKey, +Data, +Signature, +Options), which is semi-deterministic. +// +// Where: +// +// - PubKey is the 33-byte compressed public key, as specified in section 4.3.6 of ANSI X9.62. +// +// - Data is the hash of the signed message, which can be either an atom or a list of bytes. +// +// - Signature represents the ASN.1 encoded signature corresponding to the Data. +// +// - Options are additional configurations for the verification process. Supported options include: +// encoding(+Format) which specifies the encoding used for the data, and type(+Alg) which chooses the algorithm +// within the ECDSA family (see below for details). +// +// For Format, the supported encodings are: +// +// - hex (default), the hexadecimal encoding represented as an atom. +// - octet, the plain byte encoding depicted as a list of integers ranging from 0 to 255. +// +// For Alg, the supported algorithms are: +// +// - secp256r1 (default): Also known as P-256 and prime256v1. +// - secp256k1: The Koblitz elliptic curve used in Bitcoin's public-key cryptography. +// +// Examples: +// +// # Verify a signature for hexadecimal data using the ECDSA secp256r1 algorithm. +// - ecdsa_verify([127, ...], '9b038f8ef6918cbb56040dfda401b56b...', [23, 56, ...], encoding(hex)) +// +// # Verify a signature for binary data using the ECDSA secp256k1 algorithm. +// - ecdsa_verify([127, ...], [56, 90, ..], [23, 56, ...], [encoding(octet), type(secp256k1)]) +func ECDSAVerify(_ *engine.VM, key, data, sig, options engine.Term, cont engine.Cont, env *engine.Env) *engine.Promise { + return xVerify("ecdsa_verify/4", key, data, sig, options, util.Secp256r1, []util.Alg{util.Secp256r1, util.Secp256k1}, cont, env) +} + +// xVerify return `true` if the Signature can be verified as the signature for Data, using the given PubKey for a +// considered algorithm. +// This is a generic predicate implementation that can be used to verify any signature. +func xVerify(functor string, key, data, sig, options engine.Term, defaultAlgo util.Alg, + algos []util.Alg, cont engine.Cont, env *engine.Env, +) *engine.Promise { + typeOpt := engine.NewAtom("type") + return engine.Delay(func(ctx context.Context) *engine.Promise { + typeTerm, err := util.GetOptionWithDefault(typeOpt, options, engine.NewAtom(defaultAlgo.String()), env) + if err != nil { + return engine.Error(fmt.Errorf("%s: %w", functor, err)) + } + typeAtom, err := util.ResolveToAtom(env, typeTerm) + if err != nil { + return engine.Error(fmt.Errorf("%s: %w", functor, err)) + } + + if idx := slices.IndexFunc(algos, func(a util.Alg) bool { return a.String() == typeAtom.String() }); idx == -1 { + return engine.Error(fmt.Errorf("%s: invalid type: %s. Possible values: %s", + functor, + typeAtom.String(), + strings.Join(util.Map(algos, func(a util.Alg) string { return a.String() }), ", "))) + } + + decodedKey, err := TermToBytes(key, AtomEncoding.Apply(AtomOctet), env) + if err != nil { + return engine.Error(fmt.Errorf("%s: failed to decode public key: %w", functor, err)) + } + + decodedData, err := TermToBytes(data, options, env) + if err != nil { + return engine.Error(fmt.Errorf("%s: failed to decode data: %w", functor, err)) + } + + decodedSignature, err := TermToBytes(sig, AtomEncoding.Apply(AtomOctet), env) + if err != nil { + return engine.Error(fmt.Errorf("%s: failed to decode signature: %w", functor, err)) + } + + r, err := util.VerifySignature(util.Alg(typeAtom.String()), decodedKey, decodedData, decodedSignature) + if err != nil { + return engine.Error(fmt.Errorf("%s: failed to verify signature: %w", functor, err)) + } + + if !r { + return engine.Bool(false) + } + + return cont(env) + }) +} diff --git a/x/logic/predicate/crypto_test.go b/x/logic/predicate/crypto_test.go index fb43e5f6..80d3f303 100644 --- a/x/logic/predicate/crypto_test.go +++ b/x/logic/predicate/crypto_test.go @@ -103,14 +103,14 @@ H == [2252,222,43,46,219,165,107,244,8,96,31,183,33,254,155,92,51,141,16,238,66, }, { query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae', -[345,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`, +[45,38,180,107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`, wantSuccess: false, }, { query: `hex_bytes('2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae', [345,38,'hey',107,104,255,198,143,249,155,69,60,29,48,65,52,19,66,45,112,100,131,191,160,249,138,94,136,98,102,231,174]).`, wantSuccess: false, - wantError: fmt.Errorf("hex_bytes/2: failed convert list into bytes: invalid term type in list engine.Atom, only integer allowed"), + wantError: fmt.Errorf("hex_bytes/2: failed convert list into bytes: invalid integer value in list at position 1: 345 is out of byte range (0-255)"), }, } for nc, tc := range cases { @@ -171,3 +171,207 @@ H == [2252,222,43,46,219,165,107,244,8,96,31,183,33,254,155,92,51,141,16,238,66, } }) } + +func TestXVerify(t *testing.T) { + Convey("Given a test cases", t, func() { + cases := []struct { + program string + query string + wantResult []types.TermResults + wantError error + wantSuccess bool + }{ + // ed25519 + { // All good + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5a', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + eddsa_verify(PubKey, Msg, Sig, [encoding(octet), type(ed25519)]).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: true, + }, + { // All good with hex encoding + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5a', PubKey), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + eddsa_verify(PubKey, '9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Sig, encoding(hex)).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: true, + }, + { // Wrong Msg + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5a', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9e', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + eddsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantSuccess: false, + }, + { // Wrong public key + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5b5b', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + eddsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantSuccess: false, + wantError: fmt.Errorf("eddsa_verify/4: failed to verify signature: ed25519: bad public key length: 33"), + }, + { // Wrong signature + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5a', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff', Sig), + eddsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantSuccess: false, + }, + { // Unsupported algo + program: `verify :- + hex_bytes('53167ac3fc4b720daa45b04fc73fe752578fa23a10048422d6904b7f4f7bba5a', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + eddsa_verify(PubKey, Msg, Sig, [encoding(octet), type(foo)]).`, + query: `verify.`, + wantSuccess: false, + wantError: fmt.Errorf("eddsa_verify/4: invalid type: foo. Possible values: ed25519"), + }, + // ECDSA - secp256r1 + { + // All good + program: `verify :- + hex_bytes('0213c8426be471e55506f7ce4f7df557a42e310df09f92eb732ca3085e797cef9b', PubKey), + hex_bytes('e50c26e89f734b2ee12041ff27874c901891f74a0f0cf470333312a3034ce3be', Msg), + hex_bytes('30450220099e6f9dd218e0e304efa7a4224b0058a8e3aec73367ec239bee4ed8ed7d85db022100b504d3d0d2e879b04705c0e5a2b40b0521a5ab647ea207bd81134e1a4eb79e47', Sig), + ecdsa_verify(PubKey, Msg, Sig, [encoding(octet), type(secp256r1)]).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: true, + }, + { // Invalid secp signature + program: `verify :- + hex_bytes('0213c8426be471e55506f7ce4f7df557', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + ecdsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantSuccess: false, + wantError: fmt.Errorf("ecdsa_verify/4: failed to verify signature: failed to parse compressed public key (first 10 bytes): 0213c8426be471e55506"), + }, + { // Unsupported algo + program: `verify :- + hex_bytes('0213c8426be471e55506f7ce4f7df557a42e310df09f92eb732ca3085e797cef9b', PubKey), + hex_bytes('9b038f8ef6918cbb56040dfda401b56bb1ce79c472e7736e8677758c83367a9d', Msg), + hex_bytes('889bcfd331e8e43b5ebf430301dffb6ac9e2fce69f6227b43552fe3dc8cc1ee00c1cc53452a8712e9d5f80086dff8cf4999c1b93ed6c6e403c09334cb61ddd0b', Sig), + ecdsa_verify(PubKey, Msg, Sig, [encoding(octet), type(foo)]).`, + query: `verify.`, + wantSuccess: false, + wantError: fmt.Errorf("ecdsa_verify/4: invalid type: foo. Possible values: secp256r1, secp256k1"), + }, + { + // Wrong msg + program: `verify :- + hex_bytes('0213c8426be471e55506f7ce4f7df557a42e310df09f92eb732ca3085e797cef9b', PubKey), + hex_bytes('e50c26e89f734b2ee12041ff27874c901891f74a0f0cf470333312a3034ce3bf', Msg), + hex_bytes('30450220099e6f9dd218e0e304efa7a4224b0058a8e3aec73367ec239bee4ed8ed7d85db022100b504d3d0d2e879b04705c0e5a2b40b0521a5ab647ea207bd81134e1a4eb79e47', Sig), + ecdsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: false, + }, + { + // Wrong signature + program: `verify :- + hex_bytes('0213c8426be471e55506f7ce4f7df557a42e310df09f92eb732ca3085e797cef9b', PubKey), + hex_bytes('e50c26e89f734b2ee12041ff27874c901891f74a0f0cf470333312a3034ce3be', Msg), + hex_bytes('30450220099e6f9dd218e0e304efa7a4224b0058a8e3aec73367ec239bee4ed8ed7d85db022100b504d3d0d2e879b04705c0e5a2b40b0521a5ab647ea207bd81134e1a4eb79e48', Sig), + ecdsa_verify(PubKey, Msg, Sig, encoding(octet)).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: false, + }, + // ECDSA - secp256k1 + { + // All good + program: `verify :- + hex_bytes('026b5450187ee9c63ba9e42cb6018d8469c903aca116178e223de76e49fe63b71c', PubKey), + hex_bytes('dece063885d3648078f903b6a3e8989f649dc3368cd9c8d69755ed9dcb6a0995', Msg), + hex_bytes('304402201448201bb4408549b0997f4b9ad9ed36f3cf8bb9c433fc7f3ba48c6b6e39476e022053f7d056f7ffeab9a79f3a36bc2ba969ddd530a3a1495d1ed7bba00039820223', Sig), + ecdsa_verify(PubKey, Msg, Sig, [encoding(octet), type(secp256k1)]).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: true, + }, + { + // Wrong signature + program: `verify :- + hex_bytes('026b5450187ee9c63ba9e42cb6018d8469c903aca116178e223de76e49fe63b71c', PubKey), + hex_bytes('dece063885d3648078f903b6a3e8989f649dc3368cd9c8d69755ed9dcb6a0996', Msg), + hex_bytes('304402201448201bb4408549b0997f4b9ad9ed36f3cf8bb9c433fc7f3ba48c6b6e39476e022053f7d056f7ffeab9a79f3a36bc2ba969ddd530a3a1495d1ed7bba00039820223', Sig), + ecdsa_verify(PubKey, Msg, Sig, [encoding(octet), type(secp256k1)]).`, + query: `verify.`, + wantResult: []types.TermResults{{}}, + wantSuccess: false, + }, + } + for nc, tc := range cases { + Convey(fmt.Sprintf("Given the query #%d: %s", nc, tc.query), func() { + Convey("and a context", func() { + db := tmdb.NewMemDB() + stateStore := store.NewCommitMultiStore(db) + ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) + + Convey("and a vm", func() { + interpreter := testutil.NewLightInterpreterMust(ctx) + interpreter.Register2(engine.NewAtom("hex_bytes"), HexBytes) + interpreter.Register4(engine.NewAtom("eddsa_verify"), EDDSAVerify) + interpreter.Register4(engine.NewAtom("ecdsa_verify"), ECDSAVerify) + + err := interpreter.Compile(ctx, tc.program) + So(err, ShouldBeNil) + + Convey("When the predicate is called", func() { + sols, err := interpreter.QueryContext(ctx, tc.query) + + Convey("Then the error should be nil", func() { + So(err, ShouldBeNil) + So(sols, ShouldNotBeNil) + + Convey("and the bindings should be as expected", func() { + var got []types.TermResults + for sols.Next() { + m := types.TermResults{} + err := sols.Scan(m) + So(err, ShouldBeNil) + + got = append(got, m) + } + if tc.wantError != nil { + So(sols.Err(), ShouldBeError, tc.wantError.Error()) + } else { + So(sols.Err(), ShouldBeNil) + + if tc.wantSuccess { + So(len(got), ShouldBeGreaterThan, 0) + So(len(got), ShouldEqual, len(tc.wantResult)) + for iGot, resultGot := range got { + for varGot, termGot := range resultGot { + So(testutil.ReindexUnknownVariables(termGot), ShouldEqual, tc.wantResult[iGot][varGot]) + } + } + } else { + So(len(got), ShouldEqual, 0) + } + } + }) + }) + }) + }) + }) + }) + } + }) +} diff --git a/x/logic/predicate/did.go b/x/logic/predicate/did.go index 933859c4..b00f77f2 100644 --- a/x/logic/predicate/did.go +++ b/x/logic/predicate/did.go @@ -14,6 +14,9 @@ import ( // AtomDID is a term which represents a DID as a compound term `did(Method, ID, Path, Query, Fragment)`. var AtomDID = engine.NewAtom("did") +// DIDPrefix is the prefix for a DID. +const DIDPrefix = "did:" + // DIDComponents is a predicate which breaks down a DID into its components according to the [W3C DID] specification. // // The signature is as follows: @@ -34,10 +37,12 @@ var AtomDID = engine.NewAtom("did") // - did_components('did:example:123456?versionId=1', did(Method, ID, Path, Query, Fragment)). // // # Reconstruct a DID from its components. -// - did_components(DID, did('example', '123456', null, 'versionId=1', _42)). +// - did_components(DID, did('example', '123456', _, 'versionId=1', _42)). // // [W3C DID]: https://w3c.github.io/did-core // [DID syntax]: https://w3c.github.io/did-core/#did-syntax +// +//nolint:funlen func DIDComponents(vm *engine.VM, did, components engine.Term, cont engine.Cont, env *engine.Env) *engine.Promise { switch t1 := env.Resolve(did).(type) { case engine.Variable: @@ -69,34 +74,60 @@ func DIDComponents(vm *engine.VM, did, components engine.Term, cont engine.Cont, } buf := strings.Builder{} - buf.WriteString("did:") - if segment, ok := util.Resolve(env, t2.Arg(0)); ok { - buf.WriteString(url.PathEscape(segment.String())) - } - if segment, ok := util.Resolve(env, t2.Arg(1)); ok { - buf.WriteString(":") - buf.WriteString(url.PathEscape(segment.String())) + buf.WriteString(DIDPrefix) + + processors := []func(engine.Atom){ + func(segment engine.Atom) { + buf.WriteString(segment.String()) + }, + func(segment engine.Atom) { + buf.WriteString(":") + buf.WriteString(url.PathEscape(segment.String())) + }, + func(segment engine.Atom) { + for _, s := range strings.FieldsFunc(segment.String(), func(c rune) bool { return c == '/' }) { + buf.WriteString("/") + buf.WriteString(url.PathEscape(s)) + } + }, + func(segment engine.Atom) { + buf.WriteString("?") + buf.WriteString(url.PathEscape(segment.String())) + }, + func(segment engine.Atom) { + buf.WriteString("#") + buf.WriteString(url.PathEscape(segment.String())) + }, } - if segment, ok := util.Resolve(env, t2.Arg(2)); ok { - for _, s := range strings.FieldsFunc(segment.String(), func(c rune) bool { return c == '/' }) { - buf.WriteString("/") - buf.WriteString(url.PathEscape(s)) + + for i := 0; i < t2.Arity(); i++ { + if err := processSegment(t2, uint8(i), processors[i], env); err != nil { + return engine.Error(fmt.Errorf("did_components/2: %w", err)) } } - if segment, ok := util.Resolve(env, t2.Arg(3)); ok { - buf.WriteString("?") - buf.WriteString(url.PathEscape(segment.String())) - } - if segment, ok := util.Resolve(env, t2.Arg(4)); ok { - buf.WriteString("#") - buf.WriteString(url.PathEscape(segment.String())) - } + return engine.Unify(vm, did, engine.NewAtom(buf.String()), cont, env) default: return engine.Error(fmt.Errorf("did_components/2: cannot unify did with %T", t2)) } } +// processSegment processes a segment of a DID. +func processSegment(segments engine.Compound, segmentNumber uint8, fn func(segment engine.Atom), env *engine.Env) error { + term := env.Resolve(segments.Arg(int(segmentNumber))) + if _, ok := term.(engine.Variable); ok { + return nil + } + segment, err := util.ResolveToAtom(env, segments.Arg(int(segmentNumber))) + if err != nil { + return fmt.Errorf("failed to resolve atom at segment %d: %w", segmentNumber, err) + } + + fn(segment) + + return nil +} + // didToTerms converts a DID to a "tuple" of terms (either an Atom or a Variable), // or returns an error if the conversion fails. // The returned atoms are url decoded. diff --git a/x/logic/predicate/did_test.go b/x/logic/predicate/did_test.go index cd26ca18..3c46aade 100644 --- a/x/logic/predicate/did_test.go +++ b/x/logic/predicate/did_test.go @@ -90,6 +90,12 @@ func TestDID(t *testing.T) { wantResult: []types.TermResults{}, wantError: fmt.Errorf("did_components/2: invalid arity 1. Expected 5"), }, + { + query: `did_components(X,did(example,'123456','path with/space',5,test)).`, + wantResult: []types.TermResults{}, + wantError: fmt.Errorf( + "did_components/2: failed to resolve atom at segment 3: invalid term '%%!s(engine.Integer=5)' - expected engine.Atom but got engine.Integer"), //nolint:lll + }, { query: `did_components('did:example:123456',foo(X)).`, wantResult: []types.TermResults{}, diff --git a/x/logic/predicate/util.go b/x/logic/predicate/util.go index 4718f3ee..d932e9b5 100644 --- a/x/logic/predicate/util.go +++ b/x/logic/predicate/util.go @@ -1,6 +1,7 @@ package predicate import ( + "encoding/hex" "fmt" "sort" @@ -9,6 +10,18 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/okp4/okp4d/x/logic/types" + "github.com/okp4/okp4d/x/logic/util" +) + +var ( + // AtomEncoding is the term used to indicate the encoding type option. + AtomEncoding = engine.NewAtom("encoding") + + // AtomHex is the term used to indicate the hexadecimal encoding type option. + AtomHex = engine.NewAtom("hex") + + // AtomOctet is the term used to indicate the byte encoding type option. + AtomOctet = engine.NewAtom("octet") ) // SortBalances by coin denomination. @@ -64,17 +77,67 @@ func BytesToList(bt []byte) engine.Term { return engine.List(terms...) } +// TermToBytes try to convert a term to native golang []byte. +// By default, if no encoding options is given the term is considered as hexadecimal value. +// +// Options are: +// - encoding(+Format). +// where Format is the encoding format to use. Possible values are: +// -- `hex` (default): hexadecimal encoding represented as an atom. +// -- `octet`: plain bytes encoding represented as a list of integers between 0 and 255. +func TermToBytes(term, options engine.Term, env *engine.Env) ([]byte, error) { + encoding, err := util.GetOptionWithDefault(AtomEncoding, options, AtomHex, env) + if err != nil { + return nil, err + } + + switch enc := env.Resolve(encoding).(type) { + case engine.Atom: + switch enc { + case AtomOctet: + v := env.Resolve(term) + if c, ok := v.(engine.Compound); ok && util.IsList(c) { + iter := engine.ListIterator{List: v, Env: env} + return ListToBytes(iter, env) + } + return nil, fmt.Errorf("term should be a List, given %T", term) + case AtomHex: + v := env.Resolve(term) + if atom, ok := v.(engine.Atom); ok { + src := []byte(atom.String()) + result := make([]byte, hex.DecodedLen(len(src))) + _, err := hex.Decode(result, src) + return result, err + } + return nil, fmt.Errorf("invalid term type: %T, should be an atom", term) + default: + return nil, fmt.Errorf("invalid encoding option: %s, valid values are '%s' or '%s'", enc, AtomHex, AtomOctet) + } + default: + return nil, fmt.Errorf("invalid term '%s' - expected engine.Atom but got %T", encoding, encoding) + } +} + func ListToBytes(terms engine.ListIterator, env *engine.Env) ([]byte, error) { bt := make([]byte, 0) + index := 0 + for terms.Next() { term := env.Resolve(terms.Current()) + index++ + switch t := term.(type) { case engine.Integer: - bt = append(bt, byte(t)) + if t >= 0 && t <= 255 { + bt = append(bt, byte(t)) + } else { + return nil, fmt.Errorf("invalid integer value in list at position %d: %d is out of byte range (0-255)", index, t) + } default: - return nil, fmt.Errorf("invalid term type in list %T, only integer allowed", term) + return nil, fmt.Errorf("invalid term type in list at position %d: %T, only engine.Integer allowed", index, term) } } + return bt, nil } diff --git a/x/logic/predicate/util_test.go b/x/logic/predicate/util_test.go index b975fc3a..8f521f76 100644 --- a/x/logic/predicate/util_test.go +++ b/x/logic/predicate/util_test.go @@ -1,3 +1,4 @@ +//nolint:lll package predicate import ( @@ -79,3 +80,109 @@ func TestExtractJsonTerm(t *testing.T) { } }) } + +func TestTermToBytes(t *testing.T) { + Convey("Given a test cases", t, func() { + cases := []struct { + term engine.Term + options engine.Term + result []byte + wantSuccess bool + wantError error + }{ + { // If no option, by default, given term is in hexadecimal format. + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: nil, + result: []byte{72, 101, 121, 32, 33, 32, 89, 111, 117, 32, 119, 97, 110, 116, 32, 116, 111, 32, 115, 101, 101, 32, 116, 104, 105, 115, 32, 116, 101, 120, 116, 44, 32, 119, 111, 110, 100, 101, 114, 102, 117, 108, 33}, + wantSuccess: true, + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("hex")), + result: []byte{72, 101, 121, 32, 33, 32, 89, 111, 117, 32, 119, 97, 110, 116, 32, 116, 111, 32, 115, 101, 101, 32, 116, 104, 105, 115, 32, 116, 101, 120, 116, 44, 32, 119, 111, 110, 100, 101, 114, 102, 117, 108, 33}, + wantSuccess: true, + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("octet")), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("term should be a List, given engine.Atom"), + }, + { + term: engine.List(engine.Integer(72), engine.Integer(101), engine.Integer(121), engine.Integer(32), engine.Integer(33), engine.Integer(32), engine.Integer(89), engine.Integer(111), engine.Integer(117), engine.Integer(32), engine.Integer(119), engine.Integer(97), engine.Integer(110), engine.Integer(116), engine.Integer(32), engine.Integer(116), engine.Integer(111), engine.Integer(32), engine.Integer(115), engine.Integer(101), engine.Integer(101), engine.Integer(32), engine.Integer(116), engine.Integer(104), engine.Integer(105), engine.Integer(115), engine.Integer(32), engine.Integer(116), engine.Integer(101), engine.Integer(120), engine.Integer(116), engine.Integer(44), engine.Integer(32), engine.Integer(119), engine.Integer(111), engine.Integer(110), engine.Integer(100), engine.Integer(101), engine.Integer(114), engine.Integer(102), engine.Integer(117), engine.Integer(108), engine.Integer(33)), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("octet")), + result: []byte{72, 101, 121, 32, 33, 32, 89, 111, 117, 32, 119, 97, 110, 116, 32, 116, 111, 32, 115, 101, 101, 32, 116, 104, 105, 115, 32, 116, 101, 120, 116, 44, 32, 119, 111, 110, 100, 101, 114, 102, 117, 108, 33}, + wantSuccess: true, + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("foo")), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("invalid encoding option: foo, valid values are 'hex' or 'octet'"), + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("foo"), engine.NewAtom("bar")), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("invalid arity for compound 'encoding': 2 but expected 1"), + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("encoding").Apply(engine.Integer(10)), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("invalid term '%%!s(engine.Integer=10)' - expected engine.Atom but got engine.Integer"), + }, + { + term: engine.NewAtom("foo").Apply(engine.NewAtom("bar")), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("octet")), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("term should be a List, given *engine.compound"), + }, + { + term: engine.NewAtom("foo").Apply(engine.NewAtom("bar")), + options: engine.NewAtom("encoding").Apply(engine.NewAtom("hex")), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("invalid term type: *engine.compound, should be an atom"), + }, + { + term: engine.NewAtom("486579202120596f752077616e7420746f20736565207468697320746578742c20776f6e64657266756c21"), + options: engine.NewAtom("foo"), + result: nil, + wantSuccess: false, + wantError: fmt.Errorf("invalid term 'foo' - expected engine.Compound but got engine.Atom"), + }, + } + for nc, tc := range cases { + Convey(fmt.Sprintf("Given the term #%d: %s", nc, tc.term), func() { + Convey("when check try convert", func() { + env := engine.Env{} + result, err := TermToBytes(tc.term, tc.options, &env) + + if tc.wantSuccess { + Convey("then no error should be thrown", func() { + So(err, ShouldBeNil) + + Convey("and result should be as expected", func() { + So(result, ShouldResemble, tc.result) + }) + }) + } else { + Convey("then error should occurs", func() { + So(err, ShouldNotBeNil) + + Convey("and should be as expected", func() { + So(err, ShouldResemble, tc.wantError) + }) + }) + } + }) + }) + } + }) +} diff --git a/x/logic/util/crypto.go b/x/logic/util/crypto.go new file mode 100644 index 00000000..f32b0ffa --- /dev/null +++ b/x/logic/util/crypto.go @@ -0,0 +1,62 @@ +package util + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "fmt" + + "github.com/dustinxie/ecc" +) + +// Alg is the type of algorithm supported by the crypto util functions. +type Alg string + +// String returns the string representation of the algorithm. +func (a Alg) String() string { + return string(a) +} + +const ( + Secp256k1 Alg = "secp256k1" + Secp256r1 Alg = "secp256r1" + Ed25519 Alg = "ed25519" +) + +// VerifySignature verifies the signature of the given message with the given public key using the given algorithm. +func VerifySignature(alg Alg, pubKey []byte, msg, sig []byte) (_ bool, err error) { + defer func() { + if recoveredErr := recover(); recoveredErr != nil { + err = fmt.Errorf("%s", recoveredErr) + } + }() + + switch alg { + case Ed25519: + return ed25519.Verify(pubKey, msg, sig), nil + case Secp256r1: + return verifySignatureWithCurve(elliptic.P256(), pubKey, msg, sig) + case Secp256k1: + return verifySignatureWithCurve(ecc.P256k1(), pubKey, msg, sig) + default: + return false, fmt.Errorf("algo %s not supported", alg) + } +} + +// verifySignatureWithCurve verifies the ASN1 signature of the given message with the given +// public key (in compressed form specified in section 4.3.6 of ANSI X9.62.) using the given +// elliptic curve. +func verifySignatureWithCurve(curve elliptic.Curve, pubKey, msg, sig []byte) (bool, error) { + x, y := ecc.UnmarshalCompressed(curve, pubKey) + if x == nil || y == nil { + return false, fmt.Errorf("failed to parse compressed public key (first 10 bytes): %x", pubKey[:10]) + } + + pk := &ecdsa.PublicKey{ + Curve: curve, + X: x, + Y: y, + } + + return ecc.VerifyASN1(pk, msg, sig), nil +} diff --git a/x/logic/util/prolog.go b/x/logic/util/prolog.go index dada7a76..e55ebe38 100644 --- a/x/logic/util/prolog.go +++ b/x/logic/util/prolog.go @@ -1,23 +1,34 @@ package util import ( + "fmt" "strings" "github.com/ichiban/prolog/engine" ) +var ( + // AtomDot is the term used to represent the dot in a list. + AtomDot = engine.NewAtom(".") + + // AtomEmpty is the term used to represent empty. + AtomEmpty = engine.NewAtom("") +) + // StringToTerm converts a string to a term. func StringToTerm(s string) engine.Term { return engine.NewAtom(s) } -// Resolve resolves a term and returns the resolved term and a boolean indicating whether the term is instantiated. -func Resolve(env *engine.Env, t engine.Term) (engine.Atom, bool) { +// ResolveToAtom resolves a term and attempts to convert it into an engine.Atom if possible. +// If conversion fails, the function returns the empty atom and the error. +func ResolveToAtom(env *engine.Env, t engine.Term) (engine.Atom, error) { switch t := env.Resolve(t).(type) { case engine.Atom: - return t, true + return t, nil default: - return engine.NewAtom(""), false + return AtomEmpty, + fmt.Errorf("invalid term '%s' - expected engine.Atom but got %T", t, t) } } @@ -39,3 +50,67 @@ func PredicateMatches(this string) func(string) bool { return strings.Split(this, "/")[0] == that } } + +// IsList returns true if the given compound is a list. +func IsList(compound engine.Compound) bool { + return compound.Functor() == AtomDot && compound.Arity() == 2 +} + +// GetOption returns the value of the first option with the given name in the given options. +// An option is a compound with the given name as functor and one argument which is +// a term, for instance `opt(v)`. +// The options are either a list of options or an option. +// If no option is found nil is returned. +func GetOption(name engine.Atom, options engine.Term, env *engine.Env) (engine.Term, error) { + extractOption := func(term engine.Term) (engine.Term, error) { + switch v := term.(type) { + case engine.Compound: + if v.Functor() == name { + if v.Arity() != 1 { + return nil, fmt.Errorf("invalid arity for compound '%s': %d but expected 1", name, v.Arity()) + } + + return v.Arg(0), nil + } + return nil, nil + case nil: + return nil, nil + default: + return nil, fmt.Errorf("invalid term '%s' - expected engine.Compound but got %T", term, v) + } + } + + resolvedTerm := env.Resolve(options) + + compound, ok := resolvedTerm.(engine.Compound) + if ok && IsList(compound) { + iter := engine.ListIterator{List: compound, Env: env} + + for iter.Next() { + opt := env.Resolve(iter.Current()) + + term, err := extractOption(opt) + if err != nil { + return nil, err + } + + if term != nil { + return term, nil + } + } + return nil, nil + } + + return extractOption(resolvedTerm) +} + +// GetOptionWithDefault returns the value of the first option with the given name in the given options, or the given +// default value if no option is found. +func GetOptionWithDefault(name engine.Atom, options engine.Term, defaultValue engine.Term, env *engine.Env) (engine.Term, error) { + if term, err := GetOption(name, options, env); err != nil { + return nil, err + } else if term != nil { + return term, nil + } + return defaultValue, nil +} diff --git a/x/logic/util/prolog_test.go b/x/logic/util/prolog_test.go new file mode 100644 index 00000000..c9a21375 --- /dev/null +++ b/x/logic/util/prolog_test.go @@ -0,0 +1,169 @@ +package util + +import ( + "fmt" + "testing" + + "github.com/ichiban/prolog/engine" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestGetOption(t *testing.T) { + Convey("Given a test cases", t, func() { + cases := []struct { + option engine.Atom + options engine.Term + wantResult engine.Term + wantError error + }{ + { + option: engine.NewAtom("foo"), + options: nil, + wantResult: nil, + wantError: nil, + }, + { + option: engine.NewAtom("foo"), + options: engine.NewAtom("foo").Apply(engine.NewAtom("bar")), + wantResult: engine.NewAtom("bar"), + wantError: nil, + }, + { + option: engine.NewAtom("bar"), + options: engine.NewAtom("foo").Apply(engine.NewAtom("bar")), + wantResult: nil, + wantError: nil, + }, + { + option: engine.NewAtom("foo"), + options: engine.List(engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.NewAtom("bar"), + wantError: nil, + }, + { + option: engine.NewAtom("bar"), + options: engine.List(engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: nil, + wantError: nil, + }, + { + option: engine.NewAtom("foo"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.NewAtom("hoo")), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.NewAtom("bar"), + wantError: nil, + }, + { + option: engine.NewAtom("foo"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("foo").Apply(engine.NewAtom("bar1")), + engine.NewAtom("hey").Apply(engine.NewAtom("hoo")), + engine.NewAtom("foo").Apply(engine.NewAtom("bar1"))), + wantResult: engine.NewAtom("bar1"), + wantError: nil, + }, + { + option: engine.NewAtom("hey"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.NewAtom("hoo")), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.NewAtom("hoo"), + wantError: nil, + }, + { + option: engine.NewAtom("hey"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.NewAtom("jo").Apply(engine.NewAtom("bi"))), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + wantError: nil, + }, + { + option: engine.NewAtom("hey"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.NewAtom("jo").Apply(engine.NewAtom("bi"))), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + wantError: nil, + }, + { + option: engine.NewAtom("hey"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.List(engine.NewAtom("bi"), engine.NewAtom("bar"))), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: engine.List(engine.NewAtom("bi"), engine.NewAtom("bar")), + wantError: nil, + }, + { + option: engine.NewAtom("hey"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.List(engine.NewAtom("hey").Apply(engine.NewAtom("joe"))), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: nil, + wantError: nil, + }, + { + option: engine.NewAtom("foo"), + options: engine.NewAtom("foo"), + wantResult: nil, + wantError: fmt.Errorf("invalid term 'foo' - expected engine.Compound but got engine.Atom"), + }, + { + option: engine.NewAtom("foo"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey"), + engine.NewAtom("foo").Apply(engine.NewAtom("bar"))), + wantResult: nil, + wantError: fmt.Errorf("invalid term 'hey' - expected engine.Compound but got engine.Atom"), + }, + { + option: engine.NewAtom("foo"), + options: engine.List( + engine.NewAtom("jo").Apply(engine.NewAtom("bi")), + engine.NewAtom("hey").Apply(engine.NewAtom("hoo")), + engine.NewAtom("foo").Apply(engine.NewAtom("bar1"), engine.NewAtom("bar2"))), + wantResult: nil, + wantError: fmt.Errorf("invalid arity for compound 'foo': 2 but expected 1"), + }, + } + for nc, tc := range cases { + Convey(fmt.Sprintf("Given the term option #%d: %s", nc, tc.option), func() { + Convey("when getting option", func() { + env := engine.Env{} + result, err := GetOption(tc.option, tc.options, &env) + + if tc.wantError == nil { + Convey("then no error should be thrown", func() { + So(err, ShouldBeNil) + + Convey("and result should be as expected", func() { + So(result, ShouldEqual, tc.wantResult) + }) + }) + } else { + Convey("then atom returned should be the empty one", func() { + So(result, ShouldEqual, tc.wantResult) + }) + Convey("then error should occurs", func() { + So(err, ShouldNotBeNil) + + Convey("and should be as expected", func() { + So(err, ShouldBeError, tc.wantError) + }) + }) + } + }) + }) + } + }) +} diff --git a/x/logic/util/slice.go b/x/logic/util/slice.go new file mode 100644 index 00000000..b6144de1 --- /dev/null +++ b/x/logic/util/slice.go @@ -0,0 +1,10 @@ +package util + +// Map applies the given function to each element of the given slice and returns a new slice with the results. +func Map[T, M any](s []T, f func(T) M) []M { + m := make([]M, len(s)) + for i, v := range s { + m[i] = f(v) + } + return m +}