Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Implement 2FA in the fake server #103

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,21 @@ func (s *Server) handlePostAuthRefresh() gin.HandlerFunc {
}
}

func (s *Server) handlePostAuth2FA() gin.HandlerFunc {
return func(c *gin.Context) {
var req proton.Auth2FAReq

if err := c.BindJSON(&req); err != nil {
return
}

if err := s.b.UpgradeAuth(c.GetString("AuthUID"), req.TwoFactorCode); err != nil {
_ = c.AbortWithError(http.StatusUnauthorized, err)
return
}
}
}

func (s *Server) handleDeleteAuth() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.DeleteSession(c.GetString("UserID"), c.GetString("AuthUID")); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions server/backend/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"github.com/google/uuid"
)

// TODO: Add other options e.g. real 2FA time-based OTP.
type totp struct {
want *string
}

type account struct {
userID string
username string
Expand All @@ -17,6 +22,7 @@ type account struct {
userSettings proton.UserSettings
contacts map[string]*proton.Contact

totp totp
auth map[string]auth
authLock sync.RWMutex

Expand Down
49 changes: 46 additions & 3 deletions server/backend/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ func (b *Backend) NewAuth(username string, ephemeral, proof []byte, session stri
return proton.Auth{}, fmt.Errorf("invalid proof: %w", err)
}

authUID, auth := uuid.NewString(), newAuth(b.authLife)
var scope Scope

if acc.totp.want != nil {
scope = ScopeTOTP
} else {
scope = ScopeFull
}

authUID, auth := uuid.NewString(), newAuth(scope)

acc.authLock.Lock()
defer acc.authLock.Unlock()
Expand Down Expand Up @@ -83,7 +91,7 @@ func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) {
return proton.Auth{}, fmt.Errorf("invalid auth ref")
}

newAuth := newAuth(b.authLife)
newAuth := newAuth(auth.scope)

acc.auth[authUID] = newAuth

Expand All @@ -93,8 +101,43 @@ func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) {
return proton.Auth{}, fmt.Errorf("invalid auth")
}

func (b *Backend) VerifyAuth(authUID, authAcc string) (string, error) {
func (b *Backend) UpgradeAuth(authUID, totp string) error {
b.accLock.RLock()
defer b.accLock.RUnlock()

for _, acc := range b.accounts {
acc.authLock.Lock()
defer acc.authLock.Unlock()

auth, ok := acc.auth[authUID]
if !ok {
continue
}

if auth.scope != ScopeTOTP {
return fmt.Errorf("invalid scope")
} else if acc.totp.want == nil {
return fmt.Errorf("2FA not enabled")
} else if *acc.totp.want != totp {
return fmt.Errorf("invalid 2FA code")
}

auth.scope = ScopeFull

acc.auth[authUID] = auth

return nil
}

return fmt.Errorf("no such auth")
}

func (b *Backend) VerifyAuth(authUID, authAcc string, scope Scope) (string, error) {
return withAccAuth(b, authUID, authAcc, func(acc *account) (string, error) {
if acc.auth[authUID].scope != scope {
return "", fmt.Errorf("invalid scope")
}

return acc.userID, nil
})
}
Expand Down
9 changes: 8 additions & 1 deletion server/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func (b *Backend) SetAuthLife(authLife time.Duration) {
b.authLife = authLife
}

func (b *Backend) SetAuthTOTP(userID, totp string) error {
return b.withAcc(userID, func(acc *account) error {
acc.totp.want = &totp
return nil
})
}

func (b *Backend) SetMaxUpdatesPerEvent(max int) {
b.maxUpdatesPerEvent = max
}
Expand Down Expand Up @@ -556,7 +563,7 @@ func withAccAuth[T any](b *Backend, authUID, authAcc string, fn func(acc *accoun
}

if time.Since(val.creation) > b.authLife {
acc.auth[authUID] = auth{ref: val.ref, creation: val.creation}
acc.auth[authUID] = newAuthFromExpired(val)
} else if val.acc == authAcc {
return fn(acc)
}
Expand Down
2 changes: 1 addition & 1 deletion server/backend/contact.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin
ContactMetadata: proton.ContactMetadata{
ID: contactID,
Name: names[0].Value,
ContactEmails: []proton.ContactEmail{proton.ContactEmail{
ContactEmails: []proton.ContactEmail{{
ID: "1",
Name: names[0].Value,
Email: emails[0].Value,
Expand Down
28 changes: 24 additions & 4 deletions server/backend/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,42 @@ func (v *ID) FromString(s string) error {
return nil
}

type Scope int

// TODO: Add more scopes?
const (
ScopeNone Scope = iota
ScopeTOTP
ScopeFull
)

type auth struct {
acc string
ref string

scope Scope

creation time.Time
}

func newAuth(authLife time.Duration) auth {
func newAuth(scope Scope) auth {
return auth{
acc: uuid.NewString(),
ref: uuid.NewString(),

acc: uuid.NewString(),
ref: uuid.NewString(),
scope: scope,
creation: time.Now(),
}
}

func newAuthFromExpired(old auth) auth {
return auth{
acc: "",
ref: old.ref,
scope: old.scope,
creation: old.creation,
}
}

func (auth *auth) toAuth(userID, authUID string, proof []byte) proton.Auth {
return proton.Auth{
UserID: userID,
Expand Down
20 changes: 13 additions & 7 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
Expand All @@ -38,7 +39,7 @@ func initRouter(s *Server) {
}

// These routes require auth.
if core := core.Group("", s.requireAuth()); core != nil {
if core := core.Group("", s.requireAuth(backend.ScopeFull)); core != nil {
if users := core.Group("/users"); users != nil {
users.GET("", s.handleGetUsers())
}
Expand Down Expand Up @@ -77,7 +78,7 @@ func initRouter(s *Server) {
}

// All mail routes need authentication.
if mail := s.r.Group("/mail/v4", s.requireAuth()); mail != nil {
if mail := s.r.Group("/mail/v4", s.requireAuth(backend.ScopeFull)); mail != nil {
if settings := mail.Group("/settings"); settings != nil {
settings.GET("", s.handleGetMailSettings())
settings.PUT("/attachpublic", s.handlePutMailSettingsAttachPublicKey())
Expand Down Expand Up @@ -106,7 +107,7 @@ func initRouter(s *Server) {
}

// All contacts routes need authentication.
if contacts := s.r.Group("/contacts/v4", s.requireAuth()); contacts != nil {
if contacts := s.r.Group("/contacts/v4", s.requireAuth(backend.ScopeFull)); contacts != nil {
contacts.GET("", s.handleGetContacts())
contacts.POST("", s.handlePostContacts())
contacts.GET("/:contactID", s.handleGetContact())
Expand All @@ -115,7 +116,7 @@ func initRouter(s *Server) {
}

// All data routes need authentication.
if data := s.r.Group("/data/v1", s.requireAuth()); data != nil {
if data := s.r.Group("/data/v1", s.requireAuth(backend.ScopeFull)); data != nil {
if stats := data.Group("/stats"); stats != nil {
stats.POST("", s.handlePostDataStats())
stats.POST("/multiple", s.handlePostDataStatsMultiple())
Expand All @@ -128,8 +129,13 @@ func initRouter(s *Server) {
auth.POST("/info", s.handlePostAuthInfo())
auth.POST("/refresh", s.handlePostAuthRefresh())

// These routes require auth with only TOTP scope.
if auth := auth.Group("", s.requireAuth(backend.ScopeTOTP)); auth != nil {
auth.POST("/2fa", s.handlePostAuth2FA())
}

// These routes require auth.
if auth := auth.Group("", s.requireAuth()); auth != nil {
if auth := auth.Group("", s.requireAuth(backend.ScopeFull)); auth != nil {
auth.DELETE("", s.handleDeleteAuth())

if sessions := auth.Group("/sessions"); sessions != nil {
Expand Down Expand Up @@ -278,7 +284,7 @@ func (s *Server) handleOffline() gin.HandlerFunc {
}
}

func (s *Server) requireAuth() gin.HandlerFunc {
func (s *Server) requireAuth(scope backend.Scope) gin.HandlerFunc {
return func(c *gin.Context) {
authUID := c.Request.Header.Get("x-pm-uid")
if authUID == "" {
Expand All @@ -292,7 +298,7 @@ func (s *Server) requireAuth() gin.HandlerFunc {
return
}

userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1])
userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1], scope)
if err != nil {
c.AbortWithStatus(http.StatusUnauthorized)
return
Expand Down
5 changes: 4 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ func (s *Server) AddMessageCreatedEvent(userID, messageID string) error {
return s.b.AddMessageCreatedUpdate(userID, messageID)
}

// SetMaxUpdatesPerEvent
func (s *Server) SetMaxUpdatesPerEvent(max int) {
s.b.SetMaxUpdatesPerEvent(max)
}
Expand All @@ -213,6 +212,10 @@ func (s *Server) SetAuthLife(authLife time.Duration) {
s.b.SetAuthLife(authLife)
}

func (s *Server) SetAuthTOTP(userID, totp string) error {
return s.b.SetAuthTOTP(userID, totp)
}

func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) {
s.minAppVersion = minAppVersion
}
Expand Down
34 changes: 34 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ func TestServer_LoginLogout(t *testing.T) {
})
}

func TestServer_Login_2FA(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
userID, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)

// Set the expected 2FA code.
require.NoError(t, s.SetAuthTOTP(userID, "123123"))

// Create a new client.
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()

// Most requests should fail; we haven't provided the 2FA code.
must_fail(c.GetUser(ctx))

// Provide the 2FA code.
require.NoError(t, c.Auth2FA(ctx, proton.Auth2FAReq{
TwoFactorCode: "123123",
}))

// Now requests should succeed.
must(c.GetUser(ctx))
})
}

func TestServerMulti(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
_, _, err := s.CreateUser("user", []byte("pass"))
Expand Down Expand Up @@ -2251,6 +2277,14 @@ func must[T any](t T, err error) T {
return t
}

func must_fail[T any](t T, err error) T {
if err == nil {
panic(err)
}

return t
}

func elementsMatch[T comparable](want, got []T) bool {
if len(want) != len(got) {
return false
Expand Down