diff --git a/webauthn/const.go b/webauthn/const.go index 9ab74f2c..7f08c3d0 100644 --- a/webauthn/const.go +++ b/webauthn/const.go @@ -5,7 +5,6 @@ import ( ) const ( - errFmtFieldEmpty = "the field '%s' must be configured but it is empty" errFmtFieldNotValidURI = "field '%s' is not a valid URI: %w" errFmtConfigValidate = "error occurred validating the configuration: %w" ) diff --git a/webauthn/login.go b/webauthn/login.go index 69ecb979..dd6bd7ab 100644 --- a/webauthn/login.go +++ b/webauthn/login.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net/http" + "net/url" "time" "github.com/go-webauthn/webauthn/protocol" @@ -70,6 +71,12 @@ func (webauthn *WebAuthn) beginLogin(userID []byte, allowedCredentials []protoco opt(&assertion.Response) } + if len(assertion.Response.RelyingPartyID) == 0 { + return nil, nil, fmt.Errorf("error generating assertion: the relying party id must be provided via the configuration or a functional option for a login") + } else if _, err = url.Parse(assertion.Response.RelyingPartyID); err != nil { + return nil, nil, fmt.Errorf("error generating assertion: the relying party id failed to validate as it's not a valid uri with error: %w", err) + } + if assertion.Response.Timeout == 0 { switch { case assertion.Response.UserVerification == protocol.VerificationDiscouraged: @@ -147,6 +154,13 @@ func WithAppIdExtension(appid string) LoginOption { } } +// WithLoginRelyingPartyID sets the Relying Party ID for this particular login. +func WithLoginRelyingPartyID(id string) LoginOption { + return func(cco *protocol.PublicKeyCredentialRequestOptions) { + cco.RelyingPartyID = id + } +} + // FinishLogin takes the response from the client and validate it against the user credentials and stored session data. func (webauthn *WebAuthn) FinishLogin(user User, session SessionData, response *http.Request) (*Credential, error) { parsedResponse, err := protocol.ParseCredentialRequestResponse(response) diff --git a/webauthn/login_test.go b/webauthn/login_test.go index 29fdbc52..a869cf9e 100644 --- a/webauthn/login_test.go +++ b/webauthn/login_test.go @@ -3,6 +3,9 @@ package webauthn import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/go-webauthn/webauthn/protocol" ) @@ -26,3 +29,82 @@ func TestLogin_FinishLoginFailure(t *testing.T) { t.Errorf("FinishLogin() credential = %v, want nil", credential) } } + +func TestWithLoginRelyingPartyID(t *testing.T) { + testCases := []struct { + name string + have *Config + opts []LoginOption + expectedID string + err string + }{ + { + name: "OptionDefinedInConfig", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: nil, + expectedID: "https://example.com", + }, + { + name: "OptionDefinedInConfigAndOpts", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: []LoginOption{WithLoginRelyingPartyID("https://a.example.com")}, + expectedID: "https://a.example.com", + }, + { + name: "OptionDefinedInConfigWithNoErrAndInOptsWithError", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: []LoginOption{WithLoginRelyingPartyID("---::~!!~@#M!@OIK#N!@IOK@@@@@@@@@@")}, + err: "error generating assertion: the relying party id failed to validate as it's not a valid uri with error: parse \"---::~!!~@\": first path segment in URL cannot contain colon", + }, + { + name: "OptionDefinedInOpts", + have: &Config{ + RPOrigins: []string{"https://example.com"}, + }, + opts: []LoginOption{WithLoginRelyingPartyID("https://example.com")}, + expectedID: "https://example.com", + }, + { + name: "OptionIDNotDefined", + have: &Config{ + RPOrigins: []string{"https://example.com"}, + }, + opts: nil, + err: "error generating assertion: the relying party id must be provided via the configuration or a functional option for a login", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w, err := New(tc.have) + assert.NoError(t, err) + + user := &defaultUser{ + credentials: []Credential{ + {}, + }, + } + + creation, _, err := w.BeginLogin(user, tc.opts...) + if tc.err != "" { + assert.EqualError(t, err, tc.err) + } else { + assert.NoError(t, err) + require.NotNil(t, creation) + assert.Equal(t, tc.expectedID, creation.Response.RelyingPartyID) + } + }) + } +} diff --git a/webauthn/registration.go b/webauthn/registration.go index 92566bd0..a0d6e3a6 100644 --- a/webauthn/registration.go +++ b/webauthn/registration.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net/http" + "net/url" "time" "github.com/go-webauthn/webauthn/protocol" @@ -69,6 +70,16 @@ func (webauthn *WebAuthn) BeginRegistration(user User, opts ...RegistrationOptio opt(&creation.Response) } + if len(creation.Response.RelyingParty.ID) == 0 { + return nil, nil, fmt.Errorf("error generating credential creation: the relying party id must be provided via the configuration or a functional option for a creation") + } else if _, err = url.Parse(creation.Response.RelyingParty.ID); err != nil { + return nil, nil, fmt.Errorf("error generating credential creation: the relying party id failed to validate as it's not a valid uri with error: %w", err) + } + + if len(creation.Response.RelyingParty.Name) == 0 { + return nil, nil, fmt.Errorf("error generating credential creation: the relying party display name must be provided via the configuration or a functional option for a creation") + } + if creation.Response.Timeout == 0 { switch { case creation.Response.AuthenticatorSelection.UserVerification == protocol.VerificationDiscouraged: @@ -176,6 +187,20 @@ func WithAppIdExcludeExtension(appid string) RegistrationOption { } } +// WithRegistrationRelyingPartyID sets the relying party id for the registration. +func WithRegistrationRelyingPartyID(id string) RegistrationOption { + return func(cco *protocol.PublicKeyCredentialCreationOptions) { + cco.RelyingParty.ID = id + } +} + +// WithRegistrationRelyingPartyName sets the relying party name for the registration. +func WithRegistrationRelyingPartyName(name string) RegistrationOption { + return func(cco *protocol.PublicKeyCredentialCreationOptions) { + cco.RelyingParty.Name = name + } +} + // FinishRegistration takes the response from the authenticator and client and verify the credential against the user's // credentials and session data. func (webauthn *WebAuthn) FinishRegistration(user User, session SessionData, response *http.Request) (*Credential, error) { diff --git a/webauthn/registration_test.go b/webauthn/registration_test.go index a2c597de..3aa7b961 100644 --- a/webauthn/registration_test.go +++ b/webauthn/registration_test.go @@ -4,10 +4,100 @@ import ( "encoding/json" "testing" - "github.com/go-webauthn/webauthn/protocol" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-webauthn/webauthn/protocol" ) +func TestWithRegistrationRelyingPartyID(t *testing.T) { + testCases := []struct { + name string + have *Config + opts []RegistrationOption + expectedID string + expectedName string + err string + }{ + { + name: "OptionDefinedInConfig", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: nil, + expectedID: "https://example.com", + expectedName: "Test Display Name", + }, + { + name: "OptionDefinedInConfigAndOpts", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: []RegistrationOption{WithRegistrationRelyingPartyID("https://a.example.com"), WithRegistrationRelyingPartyName("Test Display Name2")}, + expectedID: "https://a.example.com", + expectedName: "Test Display Name2", + }, + { + name: "OptionDefinedInConfigWithNoErrAndInOptsWithError", + have: &Config{ + RPID: "https://example.com", + RPDisplayName: "Test Display Name", + RPOrigins: []string{"https://example.com"}, + }, + opts: []RegistrationOption{WithRegistrationRelyingPartyID("---::~!!~@#M!@OIK#N!@IOK@@@@@@@@@@"), WithRegistrationRelyingPartyName("Test Display Name2")}, + err: "error generating credential creation: the relying party id failed to validate as it's not a valid uri with error: parse \"---::~!!~@\": first path segment in URL cannot contain colon", + }, + { + name: "OptionDefinedInOpts", + have: &Config{ + RPOrigins: []string{"https://example.com"}, + }, + opts: []RegistrationOption{WithRegistrationRelyingPartyID("https://example.com"), WithRegistrationRelyingPartyName("Test Display Name")}, + expectedID: "https://example.com", + expectedName: "Test Display Name", + }, + { + name: "OptionDisplayNameNotDefined", + have: &Config{ + RPOrigins: []string{"https://example.com"}, + }, + opts: []RegistrationOption{WithRegistrationRelyingPartyID("https://example.com")}, + err: "error generating credential creation: the relying party display name must be provided via the configuration or a functional option for a creation", + }, + { + name: "OptionIDNotDefined", + have: &Config{ + RPOrigins: []string{"https://example.com"}, + }, + opts: []RegistrationOption{WithRegistrationRelyingPartyName("Test Display Name")}, + err: "error generating credential creation: the relying party id must be provided via the configuration or a functional option for a creation", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + w, err := New(tc.have) + assert.NoError(t, err) + + user := &defaultUser{} + + creation, _, err := w.BeginRegistration(user, tc.opts...) + if tc.err != "" { + assert.EqualError(t, err, tc.err) + } else { + assert.NoError(t, err) + require.NotNil(t, creation) + assert.Equal(t, tc.expectedID, creation.Response.RelyingParty.ID) + assert.Equal(t, tc.expectedName, creation.Response.RelyingParty.Name) + } + }) + } +} + func TestRegistration_FinishRegistrationFailure(t *testing.T) { user := &defaultUser{ id: []byte("123"), diff --git a/webauthn/types.go b/webauthn/types.go index 161c45b4..a0dcc678 100644 --- a/webauthn/types.go +++ b/webauthn/types.go @@ -91,18 +91,12 @@ func (config *Config) validate() error { return nil } - if len(config.RPDisplayName) == 0 { - return fmt.Errorf(errFmtFieldEmpty, "RPDisplayName") - } - - if len(config.RPID) == 0 { - return fmt.Errorf(errFmtFieldEmpty, "RPID") - } - var err error - if _, err = url.Parse(config.RPID); err != nil { - return fmt.Errorf(errFmtFieldNotValidURI, "RPID", err) + if len(config.RPID) != 0 { + if _, err = url.Parse(config.RPID); err != nil { + return fmt.Errorf(errFmtFieldNotValidURI, "RPID", err) + } } defaultTimeoutConfig := defaultTimeout diff --git a/webauthn/user.go b/webauthn/types_test.go similarity index 68% rename from webauthn/user.go rename to webauthn/types_test.go index 045ed8f6..b154bd93 100644 --- a/webauthn/user.go +++ b/webauthn/types_test.go @@ -1,8 +1,8 @@ package webauthn -// TODO: move this to a _test.go file. type defaultUser struct { - id []byte + id []byte + credentials []Credential } var _ User = (*defaultUser)(nil) @@ -19,10 +19,6 @@ func (user *defaultUser) WebAuthnDisplayName() string { return "New User" } -func (user *defaultUser) WebAuthnIcon() string { - return "https://pics.com/avatar.png" -} - func (user *defaultUser) WebAuthnCredentials() []Credential { - return []Credential{} + return user.credentials }