Skip to content

Commit

Permalink
[SDP-1311] Make Dashboard User E-mails case insensitive (#485)
Browse files Browse the repository at this point in the history
* SDP-1311 Make E-mails case insensitive

* SDP-1311 Address PR feedback
  • Loading branch information
marwen-abid authored Dec 4, 2024
1 parent 036419b commit f965c80
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- +migrate Up
UPDATE auth_users
SET
email = LOWER(TRIM(email));

-- +migrate Down
-- No down migration needed as email sanitization cannot be reversed
2 changes: 1 addition & 1 deletion stellar-auth/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (am *defaultAuthManager) AnyRolesInTokenUser(ctx context.Context, tokenStri
func (am *defaultAuthManager) CreateUser(ctx context.Context, user *User, password string) (*User, error) {
user, err := am.authenticator.CreateUser(ctx, user, password)
if err != nil {
return nil, fmt.Errorf("error creating user: %w", err)
return nil, fmt.Errorf("creating user: %w", err)
}

return user, nil
Expand Down
2 changes: 1 addition & 1 deletion stellar-auth/pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ func Test_AuthManager_CreateUser(t *testing.T) {

u, err := authManager.CreateUser(ctx, user, password)

assert.EqualError(t, err, "error creating user: unexpected error")
assert.EqualError(t, err, "creating user: unexpected error")
assert.Nil(t, u)
})

Expand Down
40 changes: 24 additions & 16 deletions stellar-auth/pkg/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ type authUser struct {
}

func (a *defaultAuthenticator) ValidateCredentials(ctx context.Context, email, password string) (*User, error) {
email = strings.TrimSpace(strings.ToLower(email))

const query = `
SELECT
u.id,
u.email,
u.first_name,
u.last_name,
u.encrypted_password
Expand Down Expand Up @@ -91,7 +94,7 @@ func (a *defaultAuthenticator) ValidateCredentials(ctx context.Context, email, p

return &User{
ID: au.ID,
Email: email,
Email: au.Email,
FirstName: au.FirstName,
LastName: au.LastName,
}, nil
Expand All @@ -100,28 +103,28 @@ func (a *defaultAuthenticator) ValidateCredentials(ctx context.Context, email, p
// CreateUser creates a user in the database. If a empty password is passed by parameter, a random password is generated,
// so the user can go through the ForgotPassword flow.
func (a *defaultAuthenticator) CreateUser(ctx context.Context, user *User, password string) (*User, error) {
if err := user.Validate(); err != nil {
return nil, fmt.Errorf("error validating user fields: %w", err)
if err := user.SanitizeAndValidate(); err != nil {
return nil, fmt.Errorf("validating user fields: %w", err)
}

// In case no password is passed we generate a random OTP (One Time Password)
if password == "" {
// Random length pasword
randomNumber, err := rand.Int(rand.Reader, big.NewInt(MaxPasswordLength-MinPasswordLength+1))
if err != nil {
return nil, fmt.Errorf("error generating random number in create user: %w", err)
return nil, fmt.Errorf("generating random number in create user: %w", err)
}

passwordLength := int(randomNumber.Int64() + MinPasswordLength)
password, err = utils.StringWithCharset(passwordLength, utils.PasswordCharset)
if err != nil {
return nil, fmt.Errorf("error generating random password string in create user: %w", err)
return nil, fmt.Errorf("generating random password string in create user: %w", err)
}
}

encryptedPassword, err := a.passwordEncrypter.Encrypt(ctx, password)
if err != nil {
return nil, fmt.Errorf("error encrypting password: %w", err)
return nil, fmt.Errorf("encrypting password: %w", err)
}

const query = `
Expand All @@ -138,7 +141,7 @@ func (a *defaultAuthenticator) CreateUser(ctx context.Context, user *User, passw
if pqError, ok := err.(*pq.Error); ok && pqError.Constraint == "auth_users_email_key" {
return nil, ErrUserEmailAlreadyExists
}
return nil, fmt.Errorf("error inserting user: %w", err)
return nil, fmt.Errorf("inserting user: %w", err)
}

user.ID = userID
Expand All @@ -148,6 +151,10 @@ func (a *defaultAuthenticator) CreateUser(ctx context.Context, user *User, passw
}

func (a *defaultAuthenticator) UpdateUser(ctx context.Context, ID, firstName, lastName, email, password string) error {
firstName = strings.TrimSpace(firstName)
lastName = strings.TrimSpace(lastName)
email = strings.TrimSpace(strings.ToLower(email))

if firstName == "" && lastName == "" && email == "" && password == "" {
return fmt.Errorf("provide at least one of these values: firstName, lastName, email or password")
}
Expand All @@ -174,7 +181,7 @@ func (a *defaultAuthenticator) UpdateUser(ctx context.Context, ID, firstName, la

if email != "" {
if err := utils.ValidateEmail(email); err != nil {
return fmt.Errorf("error validating email: %w", err)
return fmt.Errorf("validating email: %w", err)
}

fields = append(fields, "email = ?")
Expand All @@ -185,7 +192,7 @@ func (a *defaultAuthenticator) UpdateUser(ctx context.Context, ID, firstName, la
encryptedPassword, err := a.passwordEncrypter.Encrypt(ctx, password)
if err != nil {
if !errors.Is(err, ErrPasswordTooShort) {
return fmt.Errorf("error encrypting password: %w", err)
return fmt.Errorf("encrypting password: %w", err)
}
return err
}
Expand All @@ -199,12 +206,12 @@ func (a *defaultAuthenticator) UpdateUser(ctx context.Context, ID, firstName, la

res, err := a.dbConnectionPool.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("error updating user in the database: %w", err)
return fmt.Errorf("updating user in the database: %w", err)
}

numRowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("error getting the number of rows affected: %w", err)
return fmt.Errorf("getting the number of rows affected: %w", err)
}
if numRowsAffected == 0 {
return ErrNoRowsAffected
Expand Down Expand Up @@ -252,13 +259,14 @@ func (a *defaultAuthenticator) DeactivateUser(ctx context.Context, userID string
}

func (a *defaultAuthenticator) ForgotPassword(ctx context.Context, sqlExec db.SQLExecuter, email string) (string, error) {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return "", fmt.Errorf("error generating user reset password token: email cannot be empty")
return "", fmt.Errorf("generating user reset password token: email cannot be empty")
}

resetToken, err := utils.StringWithCharset(resetTokenLength, utils.DefaultCharset)
if err != nil {
return "", fmt.Errorf("error generating random reset token in forgot password: %w", err)
return "", fmt.Errorf("generating random reset token in forgot password: %w", err)
}

checkValidTokenQuery := `
Expand All @@ -274,7 +282,7 @@ func (a *defaultAuthenticator) ForgotPassword(ctx context.Context, sqlExec db.SQ
var hasValidToken bool
err = sqlExec.GetContext(ctx, &hasValidToken, checkValidTokenQuery, email)
if err != nil {
return "", fmt.Errorf("error checking if user has valid token: %w", err)
return "", fmt.Errorf("checking if user has valid token: %w", err)
}

if hasValidToken {
Expand All @@ -291,11 +299,11 @@ func (a *defaultAuthenticator) ForgotPassword(ctx context.Context, sqlExec db.SQ
`
result, err := sqlExec.ExecContext(ctx, q, email, resetToken)
if err != nil {
return "", fmt.Errorf("error inserting user reset password token in the database: %w", err)
return "", fmt.Errorf("inserting user reset password token in the database: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return "", fmt.Errorf("error getting rows affected inserting user reset password token in the database: %w", err)
return "", fmt.Errorf("getting rows affected inserting user reset password token in the database: %w", err)
}
if rowsAffected == 0 {
return "", ErrUserNotFound
Expand Down
97 changes: 84 additions & 13 deletions stellar-auth/pkg/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -135,6 +136,30 @@ func Test_DefaultAuthenticator_ValidateCredential(t *testing.T) {
assert.Equal(t, randUser.LastName, user.LastName)
})

t.Run("returns user successfully - case-insensitive", func(t *testing.T) {
encryptedPassword := "encryptedpassword"

passwordEncrypterMock.
On("Encrypt", ctx, mock.AnythingOfType("string")).
Return(encryptedPassword, nil).
Once()

randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false)

passwordEncrypterMock.
On("ComparePassword", ctx, randUser.EncryptedPassword, randUser.Password).
Return(true, nil).
Once()

uppercaseEmail := strings.ToUpper(randUser.Email)
user, err := authenticator.ValidateCredentials(ctx, uppercaseEmail, randUser.Password)
require.NoError(t, err)

assert.Equal(t, randUser.Email, user.Email)
assert.Equal(t, randUser.ID, user.ID)
assert.Equal(t, randUser.FirstName, user.FirstName)
assert.Equal(t, randUser.LastName, user.LastName)
})
passwordEncrypterMock.AssertExpectations(t)
}

Expand Down Expand Up @@ -163,27 +188,27 @@ func Test_DefaultAuthenticator_CreateUser(t *testing.T) {
u, err := authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, "error validating user fields: email is required")
assert.EqualError(t, err, "validating user fields: email is required")

user.Email = "invalid"
u, err = authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, `error validating user fields: email is invalid: the provided email "invalid" is not valid`)
assert.EqualError(t, err, `validating user fields: email is invalid: the provided email "invalid" is not valid`)

// First name
user.Email = "[email protected]"
u, err = authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, "error validating user fields: first name is required")
assert.EqualError(t, err, "validating user fields: first name is required")

// Last name
user.FirstName = "First"
u, err = authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, "error validating user fields: last name is required")
assert.EqualError(t, err, "validating user fields: last name is required")
})

t.Run("returns error when password is invalid", func(t *testing.T) {
Expand All @@ -203,7 +228,7 @@ func Test_DefaultAuthenticator_CreateUser(t *testing.T) {
u, err := authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, fmt.Sprintf("error encrypting password: password should have at least %d characters", MinPasswordLength))
assert.EqualError(t, err, fmt.Sprintf("encrypting password: password should have at least %d characters", MinPasswordLength))

passwordEncrypterMock.
On("Encrypt", ctx, password).
Expand All @@ -213,7 +238,7 @@ func Test_DefaultAuthenticator_CreateUser(t *testing.T) {
u, err = authenticator.CreateUser(ctx, user, password)

assert.Nil(t, u)
assert.EqualError(t, err, "error encrypting password: unexpected error")
assert.EqualError(t, err, "encrypting password: unexpected error")
})

t.Run("returns error when user is duplicated", func(t *testing.T) {
Expand Down Expand Up @@ -300,6 +325,35 @@ func Test_DefaultAuthenticator_CreateUser(t *testing.T) {
assert.Equal(t, "encryptedpassword", encryptedPassword)
})

t.Run("creates a new user correctly - case-insensitive", func(t *testing.T) {
user := &User{
Email: " [email protected]",
FirstName: " First",
LastName: " Last",
}

password := "mysecret"

passwordEncrypterMock.
On("Encrypt", ctx, password).
Return("encryptedpassword", nil).
Once()

u, err := authenticator.CreateUser(ctx, user, password)
require.NoError(t, err)

const query = "SELECT id, email, first_name, last_name FROM auth_users WHERE email = $1"

var newUser User
err = dbConnectionPool.QueryRowxContext(ctx, query, user.Email).Scan(&newUser.ID, &newUser.Email, &newUser.FirstName, &newUser.LastName)
require.NoError(t, err)

assert.Equal(t, u.ID, newUser.ID)
assert.Equal(t, strings.ToLower(strings.TrimSpace(u.Email)), newUser.Email)
assert.Equal(t, strings.TrimSpace(u.FirstName), newUser.FirstName)
assert.Equal(t, strings.TrimSpace(u.LastName), newUser.LastName)
})

passwordEncrypterMock.AssertExpectations(t)
}

Expand Down Expand Up @@ -504,7 +558,7 @@ func Test_DefaultAuthenticator_ForgotPassword(t *testing.T) {

t.Run("Should return an error if the email is empty", func(t *testing.T) {
resetToken, err := authenticator.ForgotPassword(ctx, dbConnectionPool, "")
assert.EqualError(t, err, "error generating user reset password token: email cannot be empty")
assert.EqualError(t, err, "generating user reset password token: email cannot be empty")
assert.Empty(t, resetToken)
})

Expand Down Expand Up @@ -590,6 +644,23 @@ func Test_DefaultAuthenticator_ForgotPassword(t *testing.T) {
assert.NotEmpty(t, resetToken)
})

t.Run("Should return reset token with a valid user - case-insensitive", func(t *testing.T) {
encryptedPassword := "encryptedpassword"

passwordEncrypterMock.
On("Encrypt", ctx, mock.AnythingOfType("string")).
Return(encryptedPassword, nil).
Once()

randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false)

uppercaseEmail := strings.ToUpper(randUser.Email)
resetToken, err := authenticator.ForgotPassword(ctx, dbConnectionPool, uppercaseEmail)
require.NoError(t, err)

assert.NotEmpty(t, resetToken)
})

passwordEncrypterMock.AssertExpectations(t)
}

Expand Down Expand Up @@ -753,7 +824,7 @@ func Test_DefaultAuthenticator_UpdateUser(t *testing.T) {

t.Run("returns error when email is invalid", func(t *testing.T) {
err := authenticator.UpdateUser(ctx, "user-id", "", "", "invalid", "")
assert.EqualError(t, err, `error validating email: the provided email "invalid" is not valid`)
assert.EqualError(t, err, `validating email: the provided email "invalid" is not valid`)
})

t.Run("returns error when password is too short", func(t *testing.T) {
Expand All @@ -777,7 +848,7 @@ func Test_DefaultAuthenticator_UpdateUser(t *testing.T) {
Once()

err := authenticator.UpdateUser(ctx, "user-id", "", "", "", "short")
assert.EqualError(t, err, "error encrypting password: unexpected error")
assert.EqualError(t, err, "encrypting password: unexpected error")
})

t.Run("updates first name successfully", func(t *testing.T) {
Expand Down Expand Up @@ -872,7 +943,7 @@ func Test_DefaultAuthenticator_UpdateUser(t *testing.T) {
})

t.Run("updates all fields successfully", func(t *testing.T) {
firstName, lastName, email, password := "FirstName", "LastName", "new_email@email.com", "newpassword"
firstName, lastName, email, password := "FirstName ", " LastName ", " new_EMail@email.com", "newpassword"

passwordEncrypterMock.
On("Encrypt", ctx, mock.AnythingOfType("string")).
Expand All @@ -893,9 +964,9 @@ func Test_DefaultAuthenticator_UpdateUser(t *testing.T) {

u := getUser(t, ctx, randUser.ID)

assert.Equal(t, firstName, u.FirstName)
assert.Equal(t, lastName, u.LastName)
assert.Equal(t, email, u.Email)
assert.Equal(t, "FirstName", u.FirstName)
assert.Equal(t, "LastName", u.LastName)
assert.Equal(t, "new_email@email.com", u.Email)
assert.Equal(t, "newpassowrdencrypted", u.EncryptedPassword)
})
}
Expand Down
2 changes: 2 additions & 0 deletions stellar-auth/pkg/auth/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"fmt"
"math/big"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -55,6 +56,7 @@ func randStringRunes(t *testing.T, n int) string {

func CreateRandomAuthUserFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, passwordEncrypter PasswordEncrypter, isAdmin bool, roles ...string) *RandomAuthUser {
randomSuffix := randStringRunes(t, 5)
randomSuffix = strings.TrimSpace(strings.ToLower(randomSuffix))
email := fmt.Sprintf("email%[email protected]", randomSuffix)
password := "password" + randomSuffix
firstName := "firstName" + randomSuffix
Expand Down
Loading

0 comments on commit f965c80

Please sign in to comment.