Skip to content

Commit

Permalink
[SDP-1245] Invalidate Circle Distribution Account Status upon receivi…
Browse files Browse the repository at this point in the history
…ng auth error (#350)

We should mark Circle distribution accounts whose API keys are not either expired (401) or lack transfer request permissions through Circle (403) as "deactivated". Requests to SDP to process payments will not trigger any transfer request to Circle that will determinately fail
  • Loading branch information
ziyliu authored Jul 17, 2024
1 parent c13c71d commit 7ba0af5
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 98 deletions.
2 changes: 2 additions & 0 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/signing"
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
serveadmin "github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/serve"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)

type ServeCommand struct{}
Expand Down Expand Up @@ -603,6 +604,7 @@ func (c *ServeCommand) Command(serverService ServerServiceInterface, monitorServ
ClientConfigModel: circle.NewClientConfigModel(serveOpts.MtnDBConnectionPool),
NetworkType: serveOpts.NetworkType,
EncryptionPassphrase: serveOpts.DistAccEncryptionPassphrase,
TenantManager: tenant.NewManager(tenant.WithDatabase(serveOpts.AdminDBConnectionPool)),
})
if err != nil {
log.Ctx(ctx).Fatalf("error creating Circle service: %v", err)
Expand Down
64 changes: 44 additions & 20 deletions internal/circle/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"io"
"net/http"
"net/url"
"slices"

"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient"
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)

const (
Expand All @@ -19,6 +21,8 @@ const (
walletPath = "/v1/wallets"
)

var authErrorStatusCodes = []int{http.StatusUnauthorized, http.StatusForbidden}

// ClientInterface defines the interface for interacting with the Circle API.
//
//go:generate mockery --name=ClientInterface --case=underscore --structname=MockClient --filename=client_mock.go --inpackage
Expand All @@ -31,27 +35,29 @@ type ClientInterface interface {

// Client provides methods to interact with the Circle API.
type Client struct {
BasePath string
APIKey string
httpClient httpclient.HttpClientInterface
BasePath string
APIKey string
httpClient httpclient.HttpClientInterface
tenantManager tenant.ManagerInterface
}

// ClientFactory is a function that creates a ClientInterface.
type ClientFactory func(networkType utils.NetworkType, apiKey string) ClientInterface
type ClientFactory func(networkType utils.NetworkType, apiKey string, tntManager tenant.ManagerInterface) ClientInterface

var _ ClientFactory = NewClient

// NewClient creates a new instance of Circle Client.
func NewClient(networkType utils.NetworkType, apiKey string) ClientInterface {
func NewClient(networkType utils.NetworkType, apiKey string, tntManager tenant.ManagerInterface) ClientInterface {
circleEnv := Sandbox
if networkType == utils.PubnetNetworkType {
circleEnv = Production
}

return &Client{
BasePath: string(circleEnv),
APIKey: apiKey,
httpClient: httpclient.DefaultClient(),
BasePath: string(circleEnv),
APIKey: apiKey,
httpClient: httpclient.DefaultClient(),
tenantManager: tntManager,
}
}

Expand Down Expand Up @@ -113,11 +119,10 @@ func (client *Client) PostTransfer(ctx context.Context, transferReq TransferRequ
}

if resp.StatusCode != http.StatusCreated {
apiError, parseErr := parseAPIError(resp)
if parseErr != nil {
return nil, fmt.Errorf("parsing API error: %w", parseErr)
handleErr := client.handleError(ctx, resp)
if handleErr != nil {
return nil, fmt.Errorf("handling API response error: %w", handleErr)
}
return nil, fmt.Errorf("API error: %w", apiError)
}

return parseTransferResponse(resp)
Expand All @@ -138,11 +143,10 @@ func (client *Client) GetTransferByID(ctx context.Context, id string) (*Transfer
}

if resp.StatusCode != http.StatusOK {
apiError, parseErr := parseAPIError(resp)
if parseErr != nil {
return nil, fmt.Errorf("parsing API error: %w", parseErr)
handleErr := client.handleError(ctx, resp)
if handleErr != nil {
return nil, fmt.Errorf("handling API response error: %w", handleErr)
}
return nil, fmt.Errorf("API error: %w", apiError)
}

return parseTransferResponse(resp)
Expand All @@ -164,11 +168,10 @@ func (client *Client) GetWalletByID(ctx context.Context, id string) (*Wallet, er
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
apiError, parseErr := parseAPIError(resp)
if parseErr != nil {
return nil, fmt.Errorf("parsing API error: %w", parseErr)
handleErr := client.handleError(ctx, resp)
if handleErr != nil {
return nil, fmt.Errorf("handling API response error: %w", handleErr)
}
return nil, fmt.Errorf("API error: %w", apiError)
}

return parseWalletResponse(resp)
Expand All @@ -192,4 +195,25 @@ func (client *Client) request(ctx context.Context, u string, method string, isAu
return client.httpClient.Do(req)
}

func (client *Client) handleError(ctx context.Context, resp *http.Response) error {
if slices.Contains(authErrorStatusCodes, resp.StatusCode) {
tnt, getCtxTntErr := tenant.GetTenantFromContext(ctx)
if getCtxTntErr != nil {
return fmt.Errorf("getting tenant from context: %w", getCtxTntErr)
}

deactivateTntErr := client.tenantManager.DeactivateTenantDistributionAccount(ctx, tnt.ID)
if deactivateTntErr != nil {
return fmt.Errorf("deactivating tenant distribution account: %w", deactivateTntErr)
}
}

apiError, err := parseAPIError(resp)
if err != nil {
return fmt.Errorf("parsing API error: %w", err)
}

return fmt.Errorf("Circle API error: %w", apiError) //nolint:golint,unused
}

var _ ClientInterface = (*Client)(nil)
121 changes: 99 additions & 22 deletions internal/circle/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ import (

httpclientMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient/mocks"
"github.com/stellar/stellar-disbursement-platform-backend/internal/utils"
"github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant"
)

func Test_NewClient(t *testing.T) {
mockTntManager := &tenant.TenantManagerMock{}
t.Run("production environment", func(t *testing.T) {
clientInterface := NewClient(utils.PubnetNetworkType, "test-key")
clientInterface := NewClient(utils.PubnetNetworkType, "test-key", mockTntManager)
cc, ok := clientInterface.(*Client)
assert.True(t, ok)
assert.Equal(t, string(Production), cc.BasePath)
assert.Equal(t, "test-key", cc.APIKey)
})

t.Run("sandbox environment", func(t *testing.T) {
clientInterface := NewClient(utils.TestnetNetworkType, "test-key")
clientInterface := NewClient(utils.TestnetNetworkType, "test-key", mockTntManager)
cc, ok := clientInterface.(*Client)
assert.True(t, ok)
assert.Equal(t, string(Sandbox), cc.BasePath)
Expand All @@ -39,7 +41,7 @@ func Test_Client_Ping(t *testing.T) {
ctx := context.Background()

t.Run("ping error", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
testError := errors.New("test error")
httpClientMock.
On("Do", mock.Anything).
Expand All @@ -52,7 +54,7 @@ func Test_Client_Ping(t *testing.T) {
})

t.Run("ping successful", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
Expand Down Expand Up @@ -85,7 +87,7 @@ func Test_Client_PostTransfer(t *testing.T) {
}

t.Run("post transfer error", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
testError := errors.New("test error")
httpClientMock.
On("Do", mock.Anything).
Expand All @@ -98,15 +100,17 @@ func Test_Client_PostTransfer(t *testing.T) {
})

t.Run("post transfer fails to validate request", func(t *testing.T) {
cc, _ := newClientWithMock(t)
cc, _, _ := newClientWithMocks(t)
transfer, err := cc.PostTransfer(ctx, TransferRequest{})
assert.EqualError(t, err, fmt.Errorf("validating transfer request: %w", errors.New("source type must be provided")).Error())
assert.Nil(t, transfer)
})

t.Run("post transfer fails auth", func(t *testing.T) {
unauthorizedResponse := `{"code": 401, "message": "Malformed key. Does it contain three parts?"}`
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, tntManagerMock := newClientWithMocks(t)
tnt := &tenant.Tenant{ID: "test-id"}
ctx = tenant.SaveTenantInContext(ctx, tnt)

httpClientMock.
On("Do", mock.Anything).
Expand All @@ -115,14 +119,17 @@ func Test_Client_PostTransfer(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(unauthorizedResponse)),
}, nil).
Once()
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(nil).Once()

transfer, err := cc.PostTransfer(ctx, validTransferReq)
assert.EqualError(t, err, "API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.EqualError(t, err, "handling API response error: Circle API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.Nil(t, transfer)
})

t.Run("post transfer successful", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
Expand All @@ -149,7 +156,7 @@ func Test_Client_PostTransfer(t *testing.T) {
func Test_Client_GetTransferByID(t *testing.T) {
ctx := context.Background()
t.Run("get transfer by id error", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
testError := errors.New("test error")
httpClientMock.
On("Do", mock.Anything).
Expand All @@ -163,22 +170,28 @@ func Test_Client_GetTransferByID(t *testing.T) {

t.Run("get transfer by id fails auth", func(t *testing.T) {
unauthorizedResponse := `{"code": 401, "message": "Malformed key. Does it contain three parts?"}`
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, tntManagerMock := newClientWithMocks(t)
tnt := &tenant.Tenant{ID: "test-id"}
ctx = tenant.SaveTenantInContext(ctx, tnt)

httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
StatusCode: http.StatusUnauthorized,
Body: io.NopCloser(bytes.NewBufferString(unauthorizedResponse)),
}, nil).
Once()
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(nil).Once()

transfer, err := cc.GetTransferByID(ctx, "test-id")
assert.EqualError(t, err, "API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.EqualError(t, err, "handling API response error: Circle API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.Nil(t, transfer)
})

t.Run("get transfer by id successful", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
Expand All @@ -204,7 +217,7 @@ func Test_Client_GetTransferByID(t *testing.T) {
func Test_Client_GetWalletByID(t *testing.T) {
ctx := context.Background()
t.Run("get wallet by id error", func(t *testing.T) {
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
testError := errors.New("test error")
httpClientMock.
On("Do", mock.Anything).
Expand All @@ -229,17 +242,23 @@ func Test_Client_GetWalletByID(t *testing.T) {
"code": 401,
"message": "Malformed key. Does it contain three parts?"
}`
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, tntManagerMock := newClientWithMocks(t)
tnt := &tenant.Tenant{ID: "test-id"}
ctx = tenant.SaveTenantInContext(ctx, tnt)

httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
StatusCode: http.StatusUnauthorized,
Body: io.NopCloser(bytes.NewBufferString(unauthorizedResponse)),
}, nil).
Once()
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(nil).Once()

transfer, err := cc.GetWalletByID(ctx, "test-id")
assert.EqualError(t, err, "API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.EqualError(t, err, "handling API response error: Circle API error: APIError: Code=401, Message=Malformed key. Does it contain three parts?, Errors=[], StatusCode=401")
assert.Nil(t, transfer)
})

Expand All @@ -258,7 +277,7 @@ func Test_Client_GetWalletByID(t *testing.T) {
]
}
}`
cc, httpClientMock := newClientWithMock(t)
cc, httpClientMock, _ := newClientWithMocks(t)
httpClientMock.
On("Do", mock.Anything).
Return(&http.Response{
Expand Down Expand Up @@ -290,12 +309,70 @@ func Test_Client_GetWalletByID(t *testing.T) {
})
}

func newClientWithMock(t *testing.T) (Client, *httpclientMocks.HttpClientMock) {
func Test_Client_handleError(t *testing.T) {
ctx := context.Background()
tnt := &tenant.Tenant{ID: "test-id"}
ctx = tenant.SaveTenantInContext(ctx, tnt)

cc, _, tntManagerMock := newClientWithMocks(t)

t.Run("deactivate tenant distribution account error", func(t *testing.T) {
testError := errors.New("foo")
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(testError).Once()

err := cc.handleError(ctx, &http.Response{StatusCode: http.StatusUnauthorized})
assert.EqualError(t, err, fmt.Errorf("deactivating tenant distribution account: %w", testError).Error())
})

t.Run("deactivates tenant distribution account if Circle error response is unauthorized", func(t *testing.T) {
unauthorizedResponse := &http.Response{
StatusCode: http.StatusUnauthorized,
Body: io.NopCloser(bytes.NewBufferString(`{"code": 401, "message": "Unauthorized"}`)),
}
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(nil).Once()

err := cc.handleError(ctx, unauthorizedResponse)
assert.EqualError(t, fmt.Errorf("Circle API error: %w", errors.New("APIError: Code=401, Message=Unauthorized, Errors=[], StatusCode=401")), err.Error())
})

t.Run("deactivates tenant distribution account if Circle error response is forbidden", func(t *testing.T) {
unauthorizedResponse := &http.Response{
StatusCode: http.StatusForbidden,
Body: io.NopCloser(bytes.NewBufferString(`{"code": 403, "message": "Forbidden"}`)),
}
tntManagerMock.
On("DeactivateTenantDistributionAccount", mock.Anything, tnt.ID).
Return(nil).Once()

err := cc.handleError(ctx, unauthorizedResponse)
assert.EqualError(t, fmt.Errorf("Circle API error: %w", errors.New("APIError: Code=403, Message=Forbidden, Errors=[], StatusCode=403")), err.Error())
})

t.Run("does not deactivate tenant distribution account if Circle error response is not unauthorized or forbidden", func(t *testing.T) {
unauthorizedResponse := &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(bytes.NewBufferString(`{"code": 400, "message": "Bad Request"}`)),
}

err := cc.handleError(ctx, unauthorizedResponse)
assert.EqualError(t, fmt.Errorf("Circle API error: %w", errors.New("APIError: Code=400, Message=Bad Request, Errors=[], StatusCode=400")), err.Error())
})

tntManagerMock.AssertExpectations(t)
}

func newClientWithMocks(t *testing.T) (Client, *httpclientMocks.HttpClientMock, *tenant.TenantManagerMock) {
httpClientMock := httpclientMocks.NewHttpClientMock(t)
tntManagerMock := tenant.NewTenantManagerMock(t)

return Client{
BasePath: "http://localhost:8080",
APIKey: "test-key",
httpClient: httpClientMock,
}, httpClientMock
BasePath: "http://localhost:8080",
APIKey: "test-key",
httpClient: httpClientMock,
tenantManager: tntManagerMock,
}, httpClientMock, tntManagerMock
}
Loading

0 comments on commit 7ba0af5

Please sign in to comment.