Skip to content

Commit

Permalink
SDP-1460 Fix single tenant mode SEP24 (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
marwen-abid authored Jan 16, 2025
1 parent 2bbd7f8 commit ad6936c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
38 changes: 16 additions & 22 deletions internal/anchorplatform/sep24_auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,7 @@ func SEP24QueryTokenAuthenticateMiddleware(jwtManager *JWTManager, networkPassph
return
}

tenantName, err := utils.ExtractTenantNameFromHostName(sep24Claims.HomeDomain())
if err != nil || tenantName == "" {
httperror.BadRequest("Tenant name not found in SEP24Claims or invalid", err, nil).Render(rw)
return
}

currentTenant, httpErr := getCurrentTenant(ctx, tenantManager, singleTenantMode, tenantName)
currentTenant, httpErr := getCurrentTenant(ctx, tenantManager, singleTenantMode, sep24Claims.HomeDomain())
if httpErr != nil {
httpErr.Render(rw)
return
Expand Down Expand Up @@ -160,13 +154,7 @@ func SEP24HeaderTokenAuthenticateMiddleware(jwtManager *JWTManager, networkPassp
return
}

tenantName, err := utils.ExtractTenantNameFromHostName(sep24Claims.HomeDomain())
if err != nil || tenantName == "" {
httperror.BadRequest("Tenant name not found in SEP24Claims or invalid", err, nil).Render(rw)
return
}

currentTenant, httpErr := getCurrentTenant(ctx, tenantManager, singleTenantMode, tenantName)
currentTenant, httpErr := getCurrentTenant(ctx, tenantManager, singleTenantMode, sep24Claims.HomeDomain())
if httpErr != nil {
httpErr.Render(rw)
return
Expand All @@ -182,21 +170,27 @@ func SEP24HeaderTokenAuthenticateMiddleware(jwtManager *JWTManager, networkPassp
}
}

func getCurrentTenant(ctx context.Context, tenantManager tenant.ManagerInterface, singleTenantMode bool, tenantName string) (currentTenant *tenant.Tenant, httpError *httperror.HTTPError) {
func getCurrentTenant(ctx context.Context, tenantManager tenant.ManagerInterface, singleTenantMode bool, homeDomain string) (currentTenant *tenant.Tenant, httpError *httperror.HTTPError) {
var err error
if singleTenantMode {
currentTenant, err = tenantManager.GetDefault(ctx)
if err != nil {
err = fmt.Errorf("failed to load default tenant: %w", err)
return nil, httperror.InternalError(ctx, "Failed to load default tenant", err, nil)
}
} else {
currentTenant, err = tenantManager.GetTenantByName(ctx, tenantName)
if err != nil {
err = fmt.Errorf("failed to load tenant by name for tenant name %s: %w", tenantName, err)
return nil, httperror.InternalError(ctx, "Failed to load tenant by name", err, nil)
}
return currentTenant, nil
}

tenantName, err := utils.ExtractTenantNameFromHostName(homeDomain)
if err != nil || tenantName == "" {
return nil, httperror.BadRequest("Tenant name not found in SEP24Claims or invalid", err, nil)
}

currentTenant, err = tenantManager.GetTenantByName(ctx, tenantName)
if err != nil {
err = fmt.Errorf("failed to load tenant by name for tenant name %s: %w", tenantName, err)
return nil, httperror.InternalError(ctx, "Failed to load tenant by name", err, nil)
}

return
return currentTenant, nil
}
4 changes: 2 additions & 2 deletions internal/anchorplatform/sep24_auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ func Test_getCurrentTenant(t *testing.T) {
Once()
defer tenantManagerMock.AssertExpectations(t)

currentTnt, httpErr := getCurrentTenant(ctx, tenantManagerMock, false, "tenant_name")
currentTnt, httpErr := getCurrentTenant(ctx, tenantManagerMock, false, "tenant_name.stellar.org")
assert.Equal(t,
httperror.InternalError(ctx, "Failed to load tenant by name", fmt.Errorf("failed to load tenant by name for tenant name tenant_name: %w", tenant.ErrTenantDoesNotExist), nil),
httpErr)
Expand All @@ -756,7 +756,7 @@ func Test_getCurrentTenant(t *testing.T) {
Once()
defer tenantManagerMock.AssertExpectations(t)

currentTnt, httpErr := getCurrentTenant(ctx, tenantManagerMock, false, "tenant_name")
currentTnt, httpErr := getCurrentTenant(ctx, tenantManagerMock, false, "tenant_name.stellar.org")
require.Nil(t, httpErr)
assert.Equal(t, &expectedTenant, currentTnt)
})
Expand Down

0 comments on commit ad6936c

Please sign in to comment.