Skip to content

Commit

Permalink
[SDP-1259] Export Disbursements with Filters (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
marwen-abid authored Dec 10, 2024
1 parent ca37849 commit bf242be
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 47 deletions.
8 changes: 4 additions & 4 deletions internal/data/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import (
)

type Asset struct {
ID string `json:"id" db:"id"`
ID string `json:"id" csv:"-" db:"id"`
Code string `json:"code" db:"code"`
Issuer string `json:"issuer" db:"issuer"`
CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"`
CreatedAt *time.Time `json:"created_at,omitempty" csv:"-" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" csv:"-" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at" csv:"-" db:"deleted_at"`
}

// IsNative returns true if the asset is the native asset (XLM).
Expand Down
25 changes: 15 additions & 10 deletions internal/data/disbursements.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ type Disbursement struct {
Asset *Asset `json:"asset,omitempty" db:"asset"`
Status DisbursementStatus `json:"status" db:"status"`
VerificationField VerificationType `json:"verification_field,omitempty" db:"verification_field"`
StatusHistory DisbursementStatusHistory `json:"status_history,omitempty" db:"status_history"`
ReceiverRegistrationMessageTemplate string `json:"receiver_registration_message_template" db:"receiver_registration_message_template"`
FileName string `json:"file_name,omitempty" db:"file_name"`
FileContent []byte `json:"-" db:"file_content"`
StatusHistory DisbursementStatusHistory `json:"status_history,omitempty" csv:"-" db:"status_history"`
ReceiverRegistrationMessageTemplate string `json:"receiver_registration_message_template" csv:"-" db:"receiver_registration_message_template"`
FileName string `json:"file_name,omitempty" csv:"-" db:"file_name"`
FileContent []byte `json:"-" csv:"-" db:"file_content"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
RegistrationContactType RegistrationContactType `json:"registration_contact_type,omitempty" db:"registration_contact_type"`
Expand Down Expand Up @@ -255,7 +255,7 @@ func (d *DisbursementModel) Count(ctx context.Context, sqlExec db.SQLExecuter, q
JOIN assets a on d.asset_id = a.id
`

query, params := d.newDisbursementQuery(baseQuery, queryParams, false)
query, params := d.newDisbursementQuery(baseQuery, queryParams, QueryTypeCount)

err := sqlExec.GetContext(ctx, &count, query, params...)
if err != nil {
Expand All @@ -265,10 +265,10 @@ func (d *DisbursementModel) Count(ctx context.Context, sqlExec db.SQLExecuter, q
}

// GetAll returns all disbursements matching the given query parameters.
func (d *DisbursementModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) ([]*Disbursement, error) {
func (d *DisbursementModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams, queryType QueryType) ([]*Disbursement, error) {
disbursements := []*Disbursement{}

query, params := d.newDisbursementQuery(selectDisbursementQuery, queryParams, true)
query, params := d.newDisbursementQuery(selectDisbursementQuery, queryParams, queryType)
err := sqlExec.SelectContext(ctx, &disbursements, query, params...)
if err != nil {
return nil, fmt.Errorf("error querying disbursements: %w", err)
Expand Down Expand Up @@ -319,7 +319,7 @@ func (d *DisbursementModel) UpdateStatus(ctx context.Context, sqlExec db.SQLExec
}

// newDisbursementQuery generates the full query and parameters for a disbursement search query
func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *QueryParams, paginated bool) (string, []interface{}) {
func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *QueryParams, queryType QueryType) (string, []interface{}) {
qb := NewQueryBuilder(baseQuery)

if queryParams.Query != "" {
Expand All @@ -335,10 +335,15 @@ func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *
if queryParams.Filters[FilterKeyCreatedAtBefore] != nil {
qb.AddCondition("d.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore])
}
if paginated {
qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d")

if queryType == QueryTypeSelectPaginated {
qb.AddPagination(queryParams.Page, queryParams.PageLimit)
}

if queryType == QueryTypeSelectAll || queryType == QueryTypeSelectPaginated {
qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d")
}

query, params := qb.Build()
return d.dbConnectionPool.Rebind(query), params
}
Expand Down
20 changes: 12 additions & 8 deletions internal/data/disbursements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {

t.Run("returns empty list when no disbursements exist", func(t *testing.T) {
DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
disbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
disbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 0, len(disbursements))
})
Expand All @@ -296,7 +296,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Len(t, actualDisbursements, 2)
assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements)
Expand All @@ -311,7 +311,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 1, PageLimit: 1})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 1, PageLimit: 1}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected1}, actualDisbursements)
Expand All @@ -326,7 +326,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 2, PageLimit: 1})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 2, PageLimit: 1}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected2}, actualDisbursements)
Expand All @@ -341,7 +341,9 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{SortBy: SortFieldName, SortOrder: SortOrderDESC})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool,
&QueryParams{SortBy: SortFieldName, SortOrder: SortOrderDESC},
QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 2, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements)
Expand All @@ -361,7 +363,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
filters := map[FilterKey]interface{}{
FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus},
}
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected1}, actualDisbursements)
Expand All @@ -383,7 +385,9 @@ func Test_DisbursementModelGetAll(t *testing.T) {
filters := map[FilterKey]interface{}{
FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus, CompletedDisbursementStatus},
}
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters, SortBy: SortFieldCreatedAt, SortOrder: SortOrderDESC})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool,
&QueryParams{Filters: filters, SortBy: SortFieldCreatedAt, SortOrder: SortOrderDESC},
QueryTypeSelectPaginated)

require.NoError(t, err)
assert.Equal(t, 2, len(actualDisbursements))
Expand Down Expand Up @@ -437,7 +441,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {

expectedDisbursement.DisbursementStats = expectedStats

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expectedDisbursement}, actualDisbursements)
Expand Down
30 changes: 20 additions & 10 deletions internal/data/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,20 +536,30 @@ func CreateDisbursementFixture(t *testing.T, ctx context.Context, sqlExec db.SQL
Status: d.Status,
}}
}
id, err := model.Insert(ctx, d)
require.NoError(t, err)

// update created_at
const query = `
UPDATE disbursements
SET created_at = $1
WHERE id = $2
`
_, err = sqlExec.ExecContext(ctx, query, d.CreatedAt, id)
const q = `
INSERT INTO
disbursements (name, status, status_history, wallet_id, asset_id, verification_field, receiver_registration_message_template, registration_contact_type, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id
`
var newID string
err := sqlExec.GetContext(ctx, &newID, q,
d.Name,
d.Status,
d.StatusHistory,
d.Wallet.ID,
d.Asset.ID,
utils.SQLNullString(string(d.VerificationField)),
d.ReceiverRegistrationMessageTemplate,
d.RegistrationContactType,
d.CreatedAt,
)
require.NoError(t, err)

// get disbursement
disbursement, err := model.Get(ctx, model.dbConnectionPool, id)
disbursement, err := model.Get(ctx, sqlExec, newID)
require.NoError(t, err)
return disbursement
}
Expand Down
8 changes: 8 additions & 0 deletions internal/data/query_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ package data

import "fmt"

type QueryType string

const (
QueryTypeSelectPaginated QueryType = "SELECT_PAGINATED"
QueryTypeSelectAll QueryType = "SELECT_ALL"
QueryTypeCount QueryType = "COUNT"
)

type QueryParams struct {
Query string
Page int
Expand Down
15 changes: 11 additions & 4 deletions internal/data/registration_contact_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"slices"
"strings"

"github.com/gocarina/gocsv"
)

// RegistrationContactType represents the type of contact information to be used when creating and validating a disbursement.
Expand Down Expand Up @@ -89,9 +91,14 @@ func (rct *RegistrationContactType) Scan(value interface{}) error {
return rct.ParseFromString(strValue)
}

func (rct RegistrationContactType) MarshalCSV() (string, error) {
return rct.String(), nil
}

var (
_ json.Marshaler = (*RegistrationContactType)(nil)
_ json.Unmarshaler = (*RegistrationContactType)(nil)
_ driver.Valuer = (*RegistrationContactType)(nil)
_ sql.Scanner = (*RegistrationContactType)(nil)
_ gocsv.TypeMarshaller = (*RegistrationContactType)(nil)
_ json.Marshaler = (*RegistrationContactType)(nil)
_ json.Unmarshaler = (*RegistrationContactType)(nil)
_ driver.Valuer = (*RegistrationContactType)(nil)
_ sql.Scanner = (*RegistrationContactType)(nil)
)
20 changes: 10 additions & 10 deletions internal/data/wallets.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ var (
)

type Wallet struct {
ID string `json:"id" db:"id"`
ID string `json:"id" csv:"-" db:"id"`
Name string `json:"name" db:"name"`
Homepage string `json:"homepage,omitempty" db:"homepage"`
SEP10ClientDomain string `json:"sep_10_client_domain,omitempty" db:"sep_10_client_domain"`
DeepLinkSchema string `json:"deep_link_schema,omitempty" db:"deep_link_schema"`
Enabled bool `json:"enabled" db:"enabled"`
UserManaged bool `json:"user_managed,omitempty" db:"user_managed"`
Assets WalletAssets `json:"assets,omitempty" db:"assets"`
CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
DeletedAt *time.Time `json:"-" db:"deleted_at"`
Homepage string `json:"homepage,omitempty" csv:"-" db:"homepage"`
SEP10ClientDomain string `json:"sep_10_client_domain,omitempty" csv:"-" db:"sep_10_client_domain"`
DeepLinkSchema string `json:"deep_link_schema,omitempty" csv:"-" db:"deep_link_schema"`
Enabled bool `json:"enabled" csv:"-" db:"enabled"`
UserManaged bool `json:"user_managed,omitempty" csv:"-" db:"user_managed"`
Assets WalletAssets `json:"assets,omitempty" csv:"-" db:"assets"`
CreatedAt *time.Time `json:"created_at,omitempty" csv:"-" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" csv:"-" db:"updated_at"`
DeletedAt *time.Time `json:"-" csv:"-" db:"deleted_at"`
}

type WalletInsert struct {
Expand Down
50 changes: 50 additions & 0 deletions internal/serve/httphandler/export_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package httphandler

import (
"fmt"
"net/http"
"time"

"github.com/gocarina/gocsv"

"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
)

type ExportHandler struct {
Models *data.Models
}

func (e ExportHandler) ExportDisbursements(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

validator := validators.NewDisbursementQueryValidator()
queryParams := validator.ParseParametersFromRequest(r)

if validator.HasErrors() {
httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw)
return
}

queryParams.Filters = validator.ValidateAndGetDisbursementFilters(queryParams.Filters)
if validator.HasErrors() {
httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw)
return
}

disbursements, err := e.Models.Disbursements.GetAll(ctx, e.Models.DBConnectionPool, queryParams, data.QueryTypeSelectAll)
if err != nil {
httperror.InternalError(ctx, "Failed to get disbursements", err, nil).Render(rw)
return
}

fileName := fmt.Sprintf("disbursements_%s.csv", time.Now().Format("2006-01-02-15-04-05"))
rw.Header().Set("Content-Type", "text/csv")
rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))

if err := gocsv.Marshal(disbursements, rw); err != nil {
httperror.InternalError(ctx, "Failed to write CSV", err, nil).Render(rw)
return
}
}
Loading

0 comments on commit bf242be

Please sign in to comment.