Skip to content

Commit

Permalink
Feat adoptctxx (#101)
Browse files Browse the repository at this point in the history
* draft new context functions

* add TODO and context management functions
  • Loading branch information
matoszz authored Dec 13, 2024
1 parent a003a04 commit d5d4653
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sessions/TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Update Session data to use defined types instead of string map
- Create type assertions for Session data
- Refactor functions to use newly defined types + assertions
- Phase in new_context.go functions over context.go
92 changes: 92 additions & 0 deletions sessions/new_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package sessions

import (
"context"

"github.com/pkg/errors"

"github.com/theopenlane/utils/contextx"
"golang.org/x/oauth2"
)

// NewContextWithToken returns a copy of ctx that stores the Token
func NewContextWithToken(ctx context.Context, token *oauth2.Token) context.Context {
return contextx.With(ctx, token)
}

// NewOhAuthTokenFromContext returns the Token from the ctx
func NewOhAuthTokenFromContext(ctx context.Context) (*oauth2.Token, error) {
token, ok := contextx.From[*oauth2.Token](ctx)
if !ok {
return nil, errors.New("context missing Token")
}

return token, nil
}

// NewUserIDFromContext returns the user ID from the ctx
// this function assumes the session data is stored in a string map
func NewUserIDFromContext(ctx context.Context) (string, error) {
sessionDetails, ok := contextx.From[*Session[any]](ctx)
if !ok {
return "", ErrInvalidSession
}

sessionID := sessionDetails.GetKey()

sessionData, ok := sessionDetails.GetOk(sessionID)
if !ok {
return "", ErrInvalidSession
}

sd, ok := sessionData.(map[string]string)
if !ok {
return "", ErrInvalidSession
}

userID, ok := sd["userID"]
if !ok {
return "", ErrInvalidSession
}

return userID, nil
}

type UserID string

// NewContextWithUserID returns a copy of ctx that stores the user ID
func NewContextWithUserID(ctx context.Context, userID UserID) context.Context {
if userID == "" {
return ctx
}

return contextx.With(ctx, userID)
}

// NewSessionToken returns the session token from the context
func NewSessionToken(ctx context.Context) (string, error) {
sd, err := newGetSessionDataFromContext(ctx)
if err != nil {
return "", err
}

sd.mu.Lock()
defer sd.mu.Unlock()

return sd.store.EncodeCookie(sd)
}

// NewAddSessionDataToContext adds session data to the context
func (s *Session[P]) NewAddSessionDataToContext(ctx context.Context) context.Context {
return contextx.With(ctx, s)
}

// newGetSessionDataFromContext retrieves session data from the context
func newGetSessionDataFromContext(ctx context.Context) (*Session[map[string]any], error) {
sessionData, ok := contextx.From[*Session[map[string]any]](ctx)
if !ok {
return nil, errors.New("context missing session data")
}

return sessionData, nil
}
74 changes: 74 additions & 0 deletions sessions/newcontext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package sessions_test

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/theopenlane/utils/contextx"
"golang.org/x/oauth2"

"github.com/theopenlane/iam/sessions"
)

func TestNewContextWithToken(t *testing.T) {
ctx := context.Background()
token := &oauth2.Token{AccessToken: "test_token"}

ctx = sessions.NewContextWithToken(ctx, token)
retrievedToken, ok := contextx.From[*oauth2.Token](ctx)

assert.True(t, ok)
assert.Equal(t, token, retrievedToken)
}

func TestNewOhAuthTokenFromContext(t *testing.T) {
ctx := context.Background()
token := &oauth2.Token{AccessToken: "test_token"}

ctx = sessions.NewContextWithToken(ctx, token)
retrievedToken, err := sessions.NewOhAuthTokenFromContext(ctx)

assert.NoError(t, err)
assert.Equal(t, token, retrievedToken)
}

func TestNewOhAuthTokenFromContext_MissingToken(t *testing.T) {
ctx := context.Background()

_, err := sessions.NewOhAuthTokenFromContext(ctx)

assert.Error(t, err)
assert.Equal(t, "context missing Token", err.Error())
}

func TestNewUserIDFromContext_MissingSession(t *testing.T) {
ctx := context.Background()

_, err := sessions.NewUserIDFromContext(ctx)

assert.Error(t, err)
assert.Equal(t, sessions.ErrInvalidSession, err)
}

func TestNewContextWithUserID(t *testing.T) {
ctx := context.Background()
userID := sessions.UserID("test_user")

ctx = sessions.NewContextWithUserID(ctx, userID)
retrievedUserID, ok := contextx.From[sessions.UserID](ctx)

assert.True(t, ok)
assert.Equal(t, userID, retrievedUserID)
}

func TestNewContextWithUserID_EmptyUserID(t *testing.T) {
ctx := context.Background()
userID := sessions.UserID("")

ctx = sessions.NewContextWithUserID(ctx, userID)
retrievedUserID, ok := contextx.From[sessions.UserID](ctx)

assert.False(t, ok)
assert.Equal(t, sessions.UserID(""), retrievedUserID)
}

0 comments on commit d5d4653

Please sign in to comment.