Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDP-946] stellar-multitenant: Get tenants API #93

Merged
merged 14 commits into from
Nov 16, 2023
46 changes: 46 additions & 0 deletions stellar-multitenant/pkg/internal/httphandler/tenants_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package httphandler

import (
"errors"
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
"github.com/stellar/go/support/render/httpjson"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)

type TenantsHandler struct {
Manager *tenant.Manager
}

func (t TenantsHandler) GetAll(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

tnts, err := t.Manager.GetAllTenants(ctx)
if err != nil {
httperror.InternalError(ctx, "Cannot get tenants", err, nil).Render(w)
return
}

httpjson.RenderStatus(w, http.StatusOK, tnts, httpjson.JSON)
}

func (t TenantsHandler) GetByIDOrName(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
arg := chi.URLParam(r, "arg")

tnt, err := t.Manager.GetTenantByIDOrName(ctx, arg)
if err != nil {
if errors.Is(tenant.ErrTenantDoesNotExist, err) {
errorMsg := fmt.Sprintf("tenant %s does not exist", arg)
httperror.NotFound(errorMsg, err, nil).Render(w)
return
}
httperror.InternalError(ctx, "Cannot get tenant by ID or name", err, nil).Render(w)
return
}

httpjson.RenderStatus(w, http.StatusOK, tnt, httpjson.JSON)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package httphandler

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
"github.com/stellar/stellar-disbursement-platform-backend/db"
"github.com/stellar/stellar-disbursement-platform-backend/db/dbtest"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Get(t *testing.T) {
dbt := dbtest.OpenWithTenantMigrationsOnly(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()

handler := TenantsHandler{
Manager: tenant.NewManager(tenant.WithDatabase(dbConnectionPool)),
}

r := chi.NewRouter()
r.Get("/tenants/{arg}", handler.GetByIDOrName)

tenant.DeleteAllTenantsFixture(t, ctx, dbConnectionPool)
tnt1 := tenant.CreateTenantFixture(t, ctx, dbConnectionPool, "myorg1")
tnt2 := tenant.CreateTenantFixture(t, ctx, dbConnectionPool, "myorg2")

tnt1JSON, err := json.Marshal(tnt1)
require.NoError(t, err)
tnt2JSON, err := json.Marshal(tnt2)
require.NoError(t, err)

t.Run("GetAll successfully returns a list of all tenants", func(t *testing.T) {
expectedJSON := fmt.Sprintf("[%s, %s]", tnt1JSON, tnt2JSON)

rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/tenants", nil)
http.HandlerFunc(handler.GetAll).ServeHTTP(rr, req)

resp := rr.Result()

respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.JSONEq(t, expectedJSON, string(respBody))
})

t.Run("successfully returns a tenant by ID", func(t *testing.T) {
url := fmt.Sprintf("/tenants/%s", tnt1.ID)
rr := httptest.NewRecorder()
req, err := http.NewRequest("GET", url, nil)
require.NoError(t, err)
r.ServeHTTP(rr, req)

resp := rr.Result()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.JSONEq(t, string(tnt1JSON), string(respBody))
})

t.Run("successfully returns a tenant by name", func(t *testing.T) {
url := fmt.Sprintf("/tenants/%s", tnt2.Name)
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", url, nil)
r.ServeHTTP(rr, req)

resp := rr.Result()

respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.JSONEq(t, string(tnt2JSON), string(respBody))
})
}
8 changes: 8 additions & 0 deletions stellar-multitenant/pkg/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,13 @@ func handleHTTP(opts *ServeOptions) *chi.Mux {
Version: opts.Version,
}.ServeHTTP)

mux.Group(func(r chi.Router) {
r.Route("/tenants", func(r chi.Router) {
tenantsHandler := httphandler.TenantsHandler{Manager: opts.tenantManager}
r.Get("/", tenantsHandler.GetAll)
r.Get("/{arg}", tenantsHandler.GetByIDOrName)
})
})

return mux
}
42 changes: 42 additions & 0 deletions stellar-multitenant/pkg/tenant/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/db"
authmigrations "github.com/stellar/stellar-disbursement-platform-backend/db/migrations/auth-migrations"
sdpmigrations "github.com/stellar/stellar-disbursement-platform-backend/db/migrations/sdp-migrations"
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -92,6 +93,47 @@ func AssertRegisteredUser(t *testing.T, ctx context.Context, dbConnectionPool db
assert.True(t, user.IsOwner)
}

func CreateTenantFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, name string) *Tenant {
tenantName := name
if name == "" {
name, err := utils.RandomString(56)
require.NoError(t, err)
tenantName = name
}

const query = `
WITH create_tenant AS (
INSERT INTO tenants
(name)
VALUES
($1)
ON CONFLICT DO NOTHING
RETURNING *
)
SELECT
ct.id,
ct.name,
ct.status,
ct.email_sender_type,
ct.sms_sender_type,
ct.enable_mfa,
ct.enable_recaptcha,
ct.created_at,
ct.updated_at
FROM
create_tenant ct
`

tnt := &Tenant{
Name: tenantName,
}

err := sqlExec.QueryRowxContext(ctx, query, tnt.Name).Scan(&tnt.ID, &tnt.Name, &tnt.Status, &tnt.EmailSenderType, &tnt.SMSSenderType, &tnt.EnableMFA, &tnt.EnableReCAPTCHA, &tnt.CreatedAt, &tnt.UpdatedAt)
require.NoError(t, err)

return tnt
}

func CheckSchemaExistsFixture(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, schemaName string) bool {
t.Helper()

Expand Down
43 changes: 37 additions & 6 deletions stellar-multitenant/pkg/tenant/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ type tenantContextKey struct{}

type ManagerInterface interface {
GetDSNForTenant(ctx context.Context, tenantName string) (string, error)
GetAllTenants(ctx context.Context) ([]Tenant, error)
GetTenantByID(ctx context.Context, id string) (*Tenant, error)
GetTenantByName(ctx context.Context, name string) (*Tenant, error)
GetTenantByIDOrName(ctx context.Context, arg string) (*Tenant, error)
AddTenant(ctx context.Context, name string) (*Tenant, error)
UpdateTenantConfig(ctx context.Context, tu *TenantUpdate) (*Tenant, error)
}
Expand All @@ -48,17 +50,30 @@ func (m *Manager) GetDSNForTenant(ctx context.Context, tenantName string) (strin
return u.String(), nil
}

var selectQuery string = `
SELECT
*
FROM
tenants t
%s
`

// GetAllTenants returns all tenants in the database.
func (m *Manager) GetAllTenants(ctx context.Context) ([]Tenant, error) {
const q = `SELECT * FROM tenants`
var tenants []Tenant
if err := m.db.SelectContext(ctx, &tenants, q); err != nil {
var tnts []Tenant

query := fmt.Sprintf(selectQuery, "ORDER BY t.name ASC")

err := m.db.SelectContext(ctx, &tnts, query)
if err != nil {
return nil, fmt.Errorf("getting all tenants: %w", err)
}
return tenants, nil

return tnts, nil
}

func (m *Manager) GetTenantByID(ctx context.Context, id string) (*Tenant, error) {
const q = "SELECT * FROM tenants WHERE id = $1"
q := fmt.Sprintf(selectQuery, "WHERE t.id = $1")
var t Tenant
if err := m.db.GetContext(ctx, &t, q, id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -70,7 +85,7 @@ func (m *Manager) GetTenantByID(ctx context.Context, id string) (*Tenant, error)
}

func (m *Manager) GetTenantByName(ctx context.Context, name string) (*Tenant, error) {
const q = "SELECT * FROM tenants WHERE name = $1"
q := fmt.Sprintf(selectQuery, "WHERE t.name = $1")
var t Tenant
if err := m.db.GetContext(ctx, &t, q, name); err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -81,6 +96,22 @@ func (m *Manager) GetTenantByName(ctx context.Context, name string) (*Tenant, er
return &t, nil
}

// GetTenantByIDOrName returns the tenant with a given id or name.
func (m *Manager) GetTenantByIDOrName(ctx context.Context, arg string) (*Tenant, error) {
var tnt Tenant
query := fmt.Sprintf(selectQuery, "WHERE t.id = $1 OR t.name = $1")

err := m.db.GetContext(ctx, &tnt, query, arg)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrTenantDoesNotExist
}
return nil, fmt.Errorf("getting tenant %s: %w", arg, err)
}

return &tnt, nil
}

func (m *Manager) AddTenant(ctx context.Context, name string) (*Tenant, error) {
if name == "" {
return nil, ErrEmptyTenantName
Expand Down
64 changes: 64 additions & 0 deletions stellar-multitenant/pkg/tenant/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,35 @@ func Test_Manager_GetAllTenants(t *testing.T) {
assert.ElementsMatch(t, tenants, []Tenant{*tnt1, *tnt2})
}

func Test_Manager_GetTenantByID(t *testing.T) {
dbt := dbtest.OpenWithTenantMigrationsOnly(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()

m := NewManager(WithDatabase(dbConnectionPool))
_, err = m.AddTenant(ctx, "myorg1")
require.NoError(t, err)
tnt2, err := m.AddTenant(ctx, "myorg2")
require.NoError(t, err)

t.Run("gets tenant successfully", func(t *testing.T) {
tntDB, err := m.GetTenantByID(ctx, tnt2.ID)
require.NoError(t, err)
assert.Equal(t, tnt2, tntDB)
})

t.Run("returns error when tenant is not found", func(t *testing.T) {
tntDB, err := m.GetTenantByID(ctx, "unknown")
assert.ErrorIs(t, err, ErrTenantDoesNotExist)
assert.Nil(t, tntDB)
})
}

func Test_Manager_GetTenantByName(t *testing.T) {
dbt := dbtest.OpenWithTenantMigrationsOnly(t)
defer dbt.Close()
Expand Down Expand Up @@ -188,3 +217,38 @@ func Test_Manager_GetTenantByName(t *testing.T) {
assert.Nil(t, tntDB)
})
}

func Test_Manager_GetTenantByIDOrName(t *testing.T) {
dbt := dbtest.OpenWithTenantMigrationsOnly(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()

m := NewManager(WithDatabase(dbConnectionPool))
tnt1, err := m.AddTenant(ctx, "myorg1")
require.NoError(t, err)
tnt2, err := m.AddTenant(ctx, "myorg2")
require.NoError(t, err)

t.Run("gets tenant by ID successfully", func(t *testing.T) {
tntDB, err := m.GetTenantByIDOrName(ctx, tnt1.ID)
require.NoError(t, err)
assert.Equal(t, tnt1, tntDB)
})

t.Run("gets tenant by name successfully", func(t *testing.T) {
tntDB, err := m.GetTenantByIDOrName(ctx, tnt2.Name)
require.NoError(t, err)
assert.Equal(t, tnt2, tntDB)
})

t.Run("returns error when tenant is not found", func(t *testing.T) {
tntDB, err := m.GetTenantByIDOrName(ctx, "unknown")
assert.ErrorIs(t, err, ErrTenantDoesNotExist)
assert.Nil(t, tntDB)
})
}
12 changes: 10 additions & 2 deletions stellar-multitenant/pkg/tenant/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ func (m *TenantManagerMock) GetDSNForTenant(ctx context.Context, tenantName stri
return args.String(0), args.Error(1)
}

func (m *TenantManagerMock) GetTenant(ctx context.Context) (*Tenant, error) {
func (m *TenantManagerMock) GetAllTenants(ctx context.Context) ([]Tenant, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*Tenant), args.Error(1)
return args.Get(0).([]Tenant), args.Error(1)
}

func (m *TenantManagerMock) GetTenantByName(ctx context.Context, name string) (*Tenant, error) {
Expand All @@ -39,6 +39,14 @@ func (m *TenantManagerMock) GetTenantByID(ctx context.Context, id string) (*Tena
return args.Get(0).(*Tenant), args.Error(1)
}

func (m *TenantManagerMock) GetTenantByIDOrName(ctx context.Context, arg string) (*Tenant, error) {
args := m.Called(ctx, arg)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*Tenant), args.Error(1)
}

func (m *TenantManagerMock) AddTenant(ctx context.Context, name string) (*Tenant, error) {
args := m.Called(ctx, name)
if args.Get(0) == nil {
Expand Down
Loading