diff --git a/internal/serve/httphandler/login_handler_test.go b/internal/serve/httphandler/login_handler_test.go index 3553ab35..9e9c6fb6 100644 --- a/internal/serve/httphandler/login_handler_test.go +++ b/internal/serve/httphandler/login_handler_test.go @@ -1,6 +1,8 @@ package httphandler import ( + "bytes" + "encoding/json" "errors" "io" "net/http" @@ -100,6 +102,12 @@ func Test_LoginHandler_validateRequest(t *testing.T) { } } +func requestToJSON(t *testing.T, req interface{}) io.Reader { + body, err := json.Marshal(req) + require.NoError(t, err) + return bytes.NewReader(body) +} + // TODO: tests with reCaptcha enabled and disabled func Test_LoginHandler_ServeHTTP(t *testing.T) { r := chi.NewRouter() diff --git a/internal/serve/httphandler/mfa_handler.go b/internal/serve/httphandler/mfa_handler.go index 77149f8b..be7fb374 100644 --- a/internal/serve/httphandler/mfa_handler.go +++ b/internal/serve/httphandler/mfa_handler.go @@ -33,51 +33,61 @@ type MFAHandler struct { const DeviceIDHeader = "Device-ID" +func (h MFAHandler) validateRequest(req MFARequest, deviceID string) *httperror.HTTPError { + lv := validators.NewValidator() + + lv.Check(req.MFACode != "", "mfa_code", "MFA Code is required") + lv.Check(h.ReCAPTCHADisabled || req.ReCAPTCHAToken != "", "recaptcha_token", "reCAPTCHA token is required") + + lv.Check(deviceID != "", DeviceIDHeader, DeviceIDHeader+" header is required") + + if lv.HasErrors() { + return httperror.BadRequest("", nil, lv.Errors) + } + + return nil +} + func (h MFAHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ctx := req.Context() + // Step 1: Decode and validate the incoming request var reqBody MFARequest if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { log.Ctx(ctx).Errorf("decoding the request body: %s", err.Error()) httperror.BadRequest("", err, nil).Render(rw) return } + deviceID := req.Header.Get(DeviceIDHeader) + if httpErr := h.validateRequest(reqBody, deviceID); httpErr != nil { + httpErr.Render(rw) + return + } - // validating reCAPTCHA Token + // Step 2: Run the reCAPTCHA validation if it is enabled if !h.ReCAPTCHADisabled { - isValid, recaptchaErr := h.ReCAPTCHAValidator.IsTokenValid(ctx, reqBody.ReCAPTCHAToken) - if recaptchaErr != nil { - httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", recaptchaErr, nil).Render(rw) + isValid, err := h.ReCAPTCHAValidator.IsTokenValid(ctx, reqBody.ReCAPTCHAToken) + if err != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", err, nil).Render(rw) return } if !isValid { - log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with email") + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with device ID %s", deviceID) httperror.BadRequest("reCAPTCHA token invalid", nil, nil).Render(rw) return } } - if reqBody.MFACode == "" { - extras := map[string]interface{}{"mfa_code": "MFA Code is required"} - httperror.BadRequest("Request invalid", nil, extras).Render(rw) - return - } - - deviceID := req.Header.Get(DeviceIDHeader) - if deviceID == "" { - httperror.BadRequest("Device-ID header is required", nil, nil).Render(rw) - return - } - + // Step 3: Authenticate the user with the MFA code token, err := h.AuthManager.AuthenticateMFA(ctx, deviceID, reqBody.MFACode, reqBody.RememberMe) if err != nil { if errors.Is(err, auth.ErrMFACodeInvalid) { httperror.Unauthorized("", err, nil).Render(rw) - return + } else { + log.Ctx(ctx).Errorf("authenticating user: %s", err.Error()) + httperror.InternalError(ctx, "Cannot authenticate user", err, nil).Render(rw) } - log.Ctx(ctx).Errorf("authenticating user: %s", err.Error()) - httperror.InternalError(ctx, "Cannot authenticate user", err, nil).Render(rw) return } diff --git a/internal/serve/httphandler/mfa_handler_test.go b/internal/serve/httphandler/mfa_handler_test.go index dcf7dac0..d73414ed 100644 --- a/internal/serve/httphandler/mfa_handler_test.go +++ b/internal/serve/httphandler/mfa_handler_test.go @@ -1,32 +1,96 @@ package httphandler import ( - "bytes" - "encoding/json" "errors" - "io" "net/http" "net/http/httptest" "strings" "testing" - "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stellar/stellar-disbursement-platform-backend/db" "github.com/stellar/stellar-disbursement-platform-backend/db/dbtest" "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" ) const mfaEndpoint = "/mfa" +func Test_MFAHandler_validateRequest(t *testing.T) { + type Req struct { + body MFARequest + deviceID string + } + testCases := []struct { + name string + handler MFAHandler + req Req + expected *httperror.HTTPError + }{ + { + name: "🔴 invalid body and headers fields", + expected: httperror.BadRequest("", nil, map[string]interface{}{ + "mfa_code": "MFA Code is required", + "recaptcha_token": "reCAPTCHA token is required", + "Device-ID": "Device-ID header is required", + }), + }, + { + name: "🔴 invalid body fields with reCAPTCHA disabled", + handler: MFAHandler{ + ReCAPTCHADisabled: true, + }, + expected: httperror.BadRequest("", nil, map[string]interface{}{ + "mfa_code": "MFA Code is required", + "Device-ID": "Device-ID header is required", + }), + }, + { + name: "🟢 valid request with reCAPTCHA enabled", + req: Req{ + body: MFARequest{ + MFACode: "123456", + ReCAPTCHAToken: "XyZ", + }, + deviceID: "safari-xyz", + }, + expected: nil, + }, + { + name: "🟢 valid request with reCAPTCHA disabled", + req: Req{ + body: MFARequest{ + MFACode: "123456", + }, + deviceID: "safari-xyz", + }, + handler: MFAHandler{ + ReCAPTCHADisabled: true, + }, + expected: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.handler.validateRequest(tc.req.body, tc.req.deviceID) + if tc.expected == nil { + require.Nil(t, err) + } else { + require.Equal(t, tc.expected, err) + } + }) + } +} + func Test_MFAHandler_ServeHTTP(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() - dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) require.NoError(t, outerErr) defer dbConnectionPool.Close() @@ -34,274 +98,198 @@ func Test_MFAHandler_ServeHTTP(t *testing.T) { models, outerErr := data.NewModels(dbConnectionPool) require.NoError(t, outerErr) - authenticatorMock := &auth.AuthenticatorMock{} - jwtManagerMock := &auth.JWTManagerMock{} - roleManagerMock := &auth.RoleManagerMock{} - reCAPTCHAValidatorMock := &validators.ReCAPTCHAValidatorMock{} - mfaManagerMock := &auth.MFAManagerMock{} - authManager := auth.NewAuthManager( - auth.WithCustomAuthenticatorOption(authenticatorMock), - auth.WithCustomJWTManagerOption(jwtManagerMock), - auth.WithCustomRoleManagerOption(roleManagerMock), - auth.WithCustomMFAManagerOption(mfaManagerMock), - ) - - mfaHandler := MFAHandler{ - AuthManager: authManager, - ReCAPTCHAValidator: reCAPTCHAValidatorMock, - Models: models, - ReCAPTCHADisabled: false, - } - deviceID := "safari-xyz" - t.Run("Test handler with invalid body", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, nil) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusBadRequest, rw.Code) - }) - - t.Run("Test handler with unexpected reCAPTCHA error", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(false, errors.New("unexpected error")). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusInternalServerError, rw.Code) - require.Contains(t, rw.Body.String(), "Cannot validate reCAPTCHA token") - }) - - t.Run("Test handler with invalid reCAPTCHA token", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(false, nil). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusBadRequest, rw.Code) - require.Contains(t, rw.Body.String(), "reCAPTCHA token invalid") - }) - - t.Run("Test Device ID header is empty", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusBadRequest, rw.Code) - require.Contains(t, rw.Body.String(), "Device-ID header is required") - }) - - t.Run("Test MFA code is empty", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - body := MFARequest{ReCAPTCHAToken: "token"} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusBadRequest, rw.Code) - require.Contains(t, rw.Body.String(), "MFA Code is required") - }) - - t.Run("Test MFA code is invalid", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - mfaManagerMock. - On("ValidateMFACode", mock.Anything, deviceID, "123456"). - Return("", auth.ErrMFACodeInvalid). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - req.Header.Set(DeviceIDHeader, deviceID) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusUnauthorized, rw.Code) - require.Contains(t, rw.Body.String(), "Not authorized.") - }) - - t.Run("Test MFA validation failed", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - mfaManagerMock. - On("ValidateMFACode", mock.Anything, deviceID, "123456"). - Return("", errors.New("weird error happened")). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - req.Header.Set(DeviceIDHeader, deviceID) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusInternalServerError, rw.Code) - require.Contains(t, rw.Body.String(), "Cannot authenticate user") - }) - - t.Run("Test MFA remember me failed", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - mfaManagerMock. - On("ValidateMFACode", mock.Anything, deviceID, "123456"). - Return("userID", nil). - Once() - - mfaManagerMock. - On("RememberDevice", mock.Anything, deviceID, "123456"). - Return(errors.New("weird error happened")). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - req.Header.Set(DeviceIDHeader, deviceID) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusInternalServerError, rw.Code) - require.Contains(t, rw.Body.String(), "Cannot authenticate user") - }) - - t.Run("Test MFA get user failed", func(t *testing.T) { - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - mfaManagerMock. - On("ValidateMFACode", mock.Anything, deviceID, "123456"). - Return("userID", nil). - Once() - - mfaManagerMock. - On("RememberDevice", mock.Anything, deviceID, "123456"). - Return(nil). - Once() - - authenticatorMock. - On("GetUser", mock.Anything, "userID"). - Return(nil, errors.New("weird error happened")). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - req.Header.Set(DeviceIDHeader, deviceID) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusInternalServerError, rw.Code) - require.Contains(t, rw.Body.String(), "Cannot authenticate user") - }) - - t.Run("Test MFA validation successful", func(t *testing.T) { - buf := new(strings.Builder) - log.DefaultLogger.SetOutput(buf) - log.SetLevel(log.InfoLevel) - - reCAPTCHAValidatorMock. - On("IsTokenValid", mock.Anything, "token"). - Return(true, nil). - Once() - - mfaManagerMock. - On("ValidateMFACode", mock.Anything, deviceID, "123456"). - Return("userID", nil). - Once() - - mfaManagerMock. - On("RememberDevice", mock.Anything, deviceID, "123456"). - Return(nil). - Once() - - user := &auth.User{ - ID: "user-id", - Email: "email@email.com", - } - - authenticatorMock. - On("GetUser", mock.Anything, "userID"). - Return(user, nil). - Once() - - roleManagerMock. - On("GetUserRoles", mock.Anything, user). - Return([]string{"role1"}, nil). - Once() - - jwtManagerMock. - On("GenerateToken", mock.Anything, user, mock.AnythingOfType("time.Time")). - Return("token123", nil). - On("ValidateToken", mock.Anything, "token123"). - Return(true, nil). - On("GetUserFromToken", mock.Anything, "token123"). - Return(user, nil). - Once() - - body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} - - req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) - req.Header.Set(DeviceIDHeader, deviceID) - rw := httptest.NewRecorder() - - mfaHandler.ServeHTTP(rw, req) - - require.Equal(t, http.StatusOK, rw.Code) - require.JSONEq(t, `{"token": "token123"}`, rw.Body.String()) - - // validate logs - require.Contains(t, buf.String(), "[UserLogin] - Logged in user with account ID user-id") - }) - - authenticatorMock.AssertExpectations(t) - reCAPTCHAValidatorMock.AssertExpectations(t) -} + testCases := []struct { + name string + ReCAPTCHADisabled bool + prepareMocks func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) + reqBody string + deviceID string + wantStatusCode int + wantResponseBody string + }{ + { + name: "🔴[400] invalid body", + reqBody: "invalid json", + wantStatusCode: http.StatusBadRequest, + wantResponseBody: `{"error":"The request was invalid in some way."}`, + }, + { + name: "🔴[400] missing [mfa_code,recaptcha_token,Device-ID]", + reqBody: "{}", + deviceID: "", + wantStatusCode: http.StatusBadRequest, + wantResponseBody: `{ + "error":"The request was invalid in some way.", + "extras": { + "mfa_code": "MFA Code is required", + "recaptcha_token": "reCAPTCHA token is required", + "Device-ID": "Device-ID header is required" + } + }`, + }, + { + name: "🔴[400](ReCAPTCHADisabled=true) missing [mfa_code,Device-ID]", + ReCAPTCHADisabled: true, + reqBody: "{}", + deviceID: "", + wantStatusCode: http.StatusBadRequest, + wantResponseBody: `{ + "error": "The request was invalid in some way.", + "extras": { + "mfa_code": "MFA Code is required", + "Device-ID": "Device-ID header is required" + } + }`, + }, + { + name: "🔴[500] when reCAPTCHA validator throws an unexpected error", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(false, errors.New("unexpected error")). + Once() + }, + wantStatusCode: http.StatusInternalServerError, + wantResponseBody: `{"error": "Cannot validate reCAPTCHA token"}`, + }, + { + name: "🔴[400] when reCAPTCHA token is deemed invalid", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(false, nil). + Once() + }, + wantStatusCode: http.StatusBadRequest, + wantResponseBody: `{"error": "reCAPTCHA token invalid"}`, + }, + { + name: "🔴[401] when mfa_code is invalid", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + authManagerMock. + On("AuthenticateMFA", mock.Anything, deviceID, "123456", mock.AnythingOfType("bool")). + Return("", auth.ErrMFACodeInvalid). + Once() + }, + wantStatusCode: http.StatusUnauthorized, + wantResponseBody: `{"error": "Not authorized."}`, + }, + { + name: "🔴[500] when the MFA validation returns an unexpedted error", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + authManagerMock. + On("AuthenticateMFA", mock.Anything, deviceID, "123456", mock.AnythingOfType("bool")). + Return("", errors.New("unexpected error")). + Once() + }, + wantStatusCode: http.StatusInternalServerError, + wantResponseBody: `{"error": "Cannot authenticate user"}`, + }, + { + name: "🔴[500] when GetUserID returns an unexpedted error", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + authManagerMock. + On("AuthenticateMFA", mock.Anything, deviceID, "123456", mock.AnythingOfType("bool")). + Return("token", nil). + Once() + authManagerMock. + On("GetUserID", mock.Anything, "token"). + Return("", errors.New("unexpected error")). + Once() + }, + wantStatusCode: http.StatusInternalServerError, + wantResponseBody: `{"error": "Cannot get user ID"}`, + }, + { + name: "🟢[200](ReCAPTCHADisabled=false) successfully validate MFA", + reqBody: `{"mfa_code":"123456","recaptcha_token":"token"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + authManagerMock. + On("AuthenticateMFA", mock.Anything, deviceID, "123456", mock.AnythingOfType("bool")). + Return("token", nil). + Once() + authManagerMock. + On("GetUserID", mock.Anything, "token"). + Return("user_id", nil). + Once() + }, + wantStatusCode: http.StatusOK, + wantResponseBody: `{"token": "token"}`, + }, + { + name: "🟢[200](ReCAPTCHADisabled=true) successfully validate MFA", + ReCAPTCHADisabled: true, + reqBody: `{"mfa_code":"123456"}`, + deviceID: deviceID, + prepareMocks: func(t *testing.T, reCAPTCHAValidatorMock *validators.ReCAPTCHAValidatorMock, authManagerMock *auth.AuthManagerMock) { + authManagerMock. + On("AuthenticateMFA", mock.Anything, deviceID, "123456", mock.AnythingOfType("bool")). + Return("token", nil). + Once() + authManagerMock. + On("GetUserID", mock.Anything, "token"). + Return("user_id", nil). + Once() + }, + wantStatusCode: http.StatusOK, + wantResponseBody: `{"token": "token"}`, + }, + } -func requestToJSON(t *testing.T, req interface{}) io.Reader { - body, err := json.Marshal(req) - require.NoError(t, err) - return bytes.NewReader(body) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reCAPTCHAValidatorMock := &validators.ReCAPTCHAValidatorMock{} + authManager := auth.NewAuthManagerMock(t) + if tc.prepareMocks != nil { + tc.prepareMocks(t, reCAPTCHAValidatorMock, authManager) + } + + mfaHandler := MFAHandler{ + AuthManager: authManager, + ReCAPTCHAValidator: reCAPTCHAValidatorMock, + Models: models, + ReCAPTCHADisabled: tc.ReCAPTCHADisabled, + } + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, strings.NewReader(tc.reqBody)) + if tc.deviceID != "" { + req.Header.Set(DeviceIDHeader, tc.deviceID) + } + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + assert.Equal(t, tc.wantStatusCode, rw.Code) + assert.JSONEq(t, tc.wantResponseBody, rw.Body.String()) + }) + } }