-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* draft new context functions * add TODO and context management functions
- Loading branch information
Showing
3 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
- Update Session data to use defined types instead of string map | ||
- Create type assertions for Session data | ||
- Refactor functions to use newly defined types + assertions | ||
- Phase in new_context.go functions over context.go |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package sessions | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/pkg/errors" | ||
|
||
"github.com/theopenlane/utils/contextx" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
// NewContextWithToken returns a copy of ctx that stores the Token | ||
func NewContextWithToken(ctx context.Context, token *oauth2.Token) context.Context { | ||
return contextx.With(ctx, token) | ||
} | ||
|
||
// NewOhAuthTokenFromContext returns the Token from the ctx | ||
func NewOhAuthTokenFromContext(ctx context.Context) (*oauth2.Token, error) { | ||
token, ok := contextx.From[*oauth2.Token](ctx) | ||
if !ok { | ||
return nil, errors.New("context missing Token") | ||
} | ||
|
||
return token, nil | ||
} | ||
|
||
// NewUserIDFromContext returns the user ID from the ctx | ||
// this function assumes the session data is stored in a string map | ||
func NewUserIDFromContext(ctx context.Context) (string, error) { | ||
sessionDetails, ok := contextx.From[*Session[any]](ctx) | ||
if !ok { | ||
return "", ErrInvalidSession | ||
} | ||
|
||
sessionID := sessionDetails.GetKey() | ||
|
||
sessionData, ok := sessionDetails.GetOk(sessionID) | ||
if !ok { | ||
return "", ErrInvalidSession | ||
} | ||
|
||
sd, ok := sessionData.(map[string]string) | ||
if !ok { | ||
return "", ErrInvalidSession | ||
} | ||
|
||
userID, ok := sd["userID"] | ||
if !ok { | ||
return "", ErrInvalidSession | ||
} | ||
|
||
return userID, nil | ||
} | ||
|
||
type UserID string | ||
|
||
// NewContextWithUserID returns a copy of ctx that stores the user ID | ||
func NewContextWithUserID(ctx context.Context, userID UserID) context.Context { | ||
if userID == "" { | ||
return ctx | ||
} | ||
|
||
return contextx.With(ctx, userID) | ||
} | ||
|
||
// NewSessionToken returns the session token from the context | ||
func NewSessionToken(ctx context.Context) (string, error) { | ||
sd, err := newGetSessionDataFromContext(ctx) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
sd.mu.Lock() | ||
defer sd.mu.Unlock() | ||
|
||
return sd.store.EncodeCookie(sd) | ||
} | ||
|
||
// NewAddSessionDataToContext adds session data to the context | ||
func (s *Session[P]) NewAddSessionDataToContext(ctx context.Context) context.Context { | ||
return contextx.With(ctx, s) | ||
} | ||
|
||
// newGetSessionDataFromContext retrieves session data from the context | ||
func newGetSessionDataFromContext(ctx context.Context) (*Session[map[string]any], error) { | ||
sessionData, ok := contextx.From[*Session[map[string]any]](ctx) | ||
if !ok { | ||
return nil, errors.New("context missing session data") | ||
} | ||
|
||
return sessionData, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package sessions_test | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/theopenlane/utils/contextx" | ||
"golang.org/x/oauth2" | ||
|
||
"github.com/theopenlane/iam/sessions" | ||
) | ||
|
||
func TestNewContextWithToken(t *testing.T) { | ||
ctx := context.Background() | ||
token := &oauth2.Token{AccessToken: "test_token"} | ||
|
||
ctx = sessions.NewContextWithToken(ctx, token) | ||
retrievedToken, ok := contextx.From[*oauth2.Token](ctx) | ||
|
||
assert.True(t, ok) | ||
assert.Equal(t, token, retrievedToken) | ||
} | ||
|
||
func TestNewOhAuthTokenFromContext(t *testing.T) { | ||
ctx := context.Background() | ||
token := &oauth2.Token{AccessToken: "test_token"} | ||
|
||
ctx = sessions.NewContextWithToken(ctx, token) | ||
retrievedToken, err := sessions.NewOhAuthTokenFromContext(ctx) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, token, retrievedToken) | ||
} | ||
|
||
func TestNewOhAuthTokenFromContext_MissingToken(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
_, err := sessions.NewOhAuthTokenFromContext(ctx) | ||
|
||
assert.Error(t, err) | ||
assert.Equal(t, "context missing Token", err.Error()) | ||
} | ||
|
||
func TestNewUserIDFromContext_MissingSession(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
_, err := sessions.NewUserIDFromContext(ctx) | ||
|
||
assert.Error(t, err) | ||
assert.Equal(t, sessions.ErrInvalidSession, err) | ||
} | ||
|
||
func TestNewContextWithUserID(t *testing.T) { | ||
ctx := context.Background() | ||
userID := sessions.UserID("test_user") | ||
|
||
ctx = sessions.NewContextWithUserID(ctx, userID) | ||
retrievedUserID, ok := contextx.From[sessions.UserID](ctx) | ||
|
||
assert.True(t, ok) | ||
assert.Equal(t, userID, retrievedUserID) | ||
} | ||
|
||
func TestNewContextWithUserID_EmptyUserID(t *testing.T) { | ||
ctx := context.Background() | ||
userID := sessions.UserID("") | ||
|
||
ctx = sessions.NewContextWithUserID(ctx, userID) | ||
retrievedUserID, ok := contextx.From[sessions.UserID](ctx) | ||
|
||
assert.False(t, ok) | ||
assert.Equal(t, sessions.UserID(""), retrievedUserID) | ||
} |