Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDP-1259] Export Disbursements with Filters #490

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/data/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import (
)

type Asset struct {
ID string `json:"id" db:"id"`
ID string `json:"id" csv:"-" db:"id"`
Code string `json:"code" db:"code"`
Issuer string `json:"issuer" db:"issuer"`
CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"`
CreatedAt *time.Time `json:"created_at,omitempty" csv:"-" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" csv:"-" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at" csv:"-" db:"deleted_at"`
}

// IsNative returns true if the asset is the native asset (XLM).
Expand Down
25 changes: 15 additions & 10 deletions internal/data/disbursements.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ type Disbursement struct {
Asset *Asset `json:"asset,omitempty" db:"asset"`
Status DisbursementStatus `json:"status" db:"status"`
VerificationField VerificationType `json:"verification_field,omitempty" db:"verification_field"`
StatusHistory DisbursementStatusHistory `json:"status_history,omitempty" db:"status_history"`
ReceiverRegistrationMessageTemplate string `json:"receiver_registration_message_template" db:"receiver_registration_message_template"`
FileName string `json:"file_name,omitempty" db:"file_name"`
FileContent []byte `json:"-" db:"file_content"`
StatusHistory DisbursementStatusHistory `json:"status_history,omitempty" csv:"-" db:"status_history"`
ReceiverRegistrationMessageTemplate string `json:"receiver_registration_message_template" csv:"-" db:"receiver_registration_message_template"`
FileName string `json:"file_name,omitempty" csv:"-" db:"file_name"`
FileContent []byte `json:"-" csv:"-" db:"file_content"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
RegistrationContactType RegistrationContactType `json:"registration_contact_type,omitempty" db:"registration_contact_type"`
Expand Down Expand Up @@ -255,7 +255,7 @@ func (d *DisbursementModel) Count(ctx context.Context, sqlExec db.SQLExecuter, q
JOIN assets a on d.asset_id = a.id
`

query, params := d.newDisbursementQuery(baseQuery, queryParams, false)
query, params := d.newDisbursementQuery(baseQuery, queryParams, QueryTypeCount)

err := sqlExec.GetContext(ctx, &count, query, params...)
if err != nil {
Expand All @@ -265,10 +265,10 @@ func (d *DisbursementModel) Count(ctx context.Context, sqlExec db.SQLExecuter, q
}

// GetAll returns all disbursements matching the given query parameters.
func (d *DisbursementModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) ([]*Disbursement, error) {
func (d *DisbursementModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams, queryType QueryType) ([]*Disbursement, error) {
disbursements := []*Disbursement{}

query, params := d.newDisbursementQuery(selectDisbursementQuery, queryParams, true)
query, params := d.newDisbursementQuery(selectDisbursementQuery, queryParams, queryType)
err := sqlExec.SelectContext(ctx, &disbursements, query, params...)
if err != nil {
return nil, fmt.Errorf("error querying disbursements: %w", err)
Expand Down Expand Up @@ -319,7 +319,7 @@ func (d *DisbursementModel) UpdateStatus(ctx context.Context, sqlExec db.SQLExec
}

// newDisbursementQuery generates the full query and parameters for a disbursement search query
func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *QueryParams, paginated bool) (string, []interface{}) {
func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *QueryParams, queryType QueryType) (string, []interface{}) {
qb := NewQueryBuilder(baseQuery)

if queryParams.Query != "" {
Expand All @@ -335,10 +335,15 @@ func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *
if queryParams.Filters[FilterKeyCreatedAtBefore] != nil {
qb.AddCondition("d.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore])
}
if paginated {
qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d")

if queryType == QueryTypeSelectPaginated {
qb.AddPagination(queryParams.Page, queryParams.PageLimit)
}

if queryType == QueryTypeSelectAll || queryType == QueryTypeSelectPaginated {
qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d")
}

query, params := qb.Build()
return d.dbConnectionPool.Rebind(query), params
}
Expand Down
20 changes: 12 additions & 8 deletions internal/data/disbursements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {

t.Run("returns empty list when no disbursements exist", func(t *testing.T) {
DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool)
disbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
disbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 0, len(disbursements))
})
Expand All @@ -296,7 +296,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Len(t, actualDisbursements, 2)
assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements)
Expand All @@ -311,7 +311,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 1, PageLimit: 1})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 1, PageLimit: 1}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected1}, actualDisbursements)
Expand All @@ -326,7 +326,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 2, PageLimit: 1})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 2, PageLimit: 1}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected2}, actualDisbursements)
Expand All @@ -341,7 +341,9 @@ func Test_DisbursementModelGetAll(t *testing.T) {
disbursement.Name = "disbursement2"
expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement)

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{SortBy: SortFieldName, SortOrder: SortOrderDESC})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool,
&QueryParams{SortBy: SortFieldName, SortOrder: SortOrderDESC},
QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 2, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements)
Expand All @@ -361,7 +363,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {
filters := map[FilterKey]interface{}{
FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus},
}
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expected1}, actualDisbursements)
Expand All @@ -383,7 +385,9 @@ func Test_DisbursementModelGetAll(t *testing.T) {
filters := map[FilterKey]interface{}{
FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus, CompletedDisbursementStatus},
}
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters, SortBy: SortFieldCreatedAt, SortOrder: SortOrderDESC})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool,
&QueryParams{Filters: filters, SortBy: SortFieldCreatedAt, SortOrder: SortOrderDESC},
QueryTypeSelectPaginated)

require.NoError(t, err)
assert.Equal(t, 2, len(actualDisbursements))
Expand Down Expand Up @@ -437,7 +441,7 @@ func Test_DisbursementModelGetAll(t *testing.T) {

expectedDisbursement.DisbursementStats = expectedStats

actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{})
actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated)
require.NoError(t, err)
assert.Equal(t, 1, len(actualDisbursements))
assert.Equal(t, []*Disbursement{expectedDisbursement}, actualDisbursements)
Expand Down
30 changes: 20 additions & 10 deletions internal/data/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,20 +536,30 @@ func CreateDisbursementFixture(t *testing.T, ctx context.Context, sqlExec db.SQL
Status: d.Status,
}}
}
id, err := model.Insert(ctx, d)
require.NoError(t, err)

// update created_at
const query = `
UPDATE disbursements
SET created_at = $1
WHERE id = $2
`
_, err = sqlExec.ExecContext(ctx, query, d.CreatedAt, id)
const q = `
INSERT INTO
disbursements (name, status, status_history, wallet_id, asset_id, verification_field, receiver_registration_message_template, registration_contact_type, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id
`
var newID string
err := sqlExec.GetContext(ctx, &newID, q,
d.Name,
d.Status,
d.StatusHistory,
d.Wallet.ID,
d.Asset.ID,
utils.SQLNullString(string(d.VerificationField)),
d.ReceiverRegistrationMessageTemplate,
d.RegistrationContactType,
d.CreatedAt,
)
require.NoError(t, err)

// get disbursement
disbursement, err := model.Get(ctx, model.dbConnectionPool, id)
disbursement, err := model.Get(ctx, sqlExec, newID)
require.NoError(t, err)
return disbursement
}
Expand Down
8 changes: 8 additions & 0 deletions internal/data/query_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ package data

import "fmt"

type QueryType string

const (
QueryTypeSelectPaginated QueryType = "SELECT_PAGINATED"
QueryTypeSelectAll QueryType = "SELECT_ALL"
QueryTypeCount QueryType = "COUNT"
)

type QueryParams struct {
Query string
Page int
Expand Down
15 changes: 11 additions & 4 deletions internal/data/registration_contact_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"slices"
"strings"

"github.com/gocarina/gocsv"
)

// RegistrationContactType represents the type of contact information to be used when creating and validating a disbursement.
Expand Down Expand Up @@ -89,9 +91,14 @@ func (rct *RegistrationContactType) Scan(value interface{}) error {
return rct.ParseFromString(strValue)
}

func (rct RegistrationContactType) MarshalCSV() (string, error) {
return rct.String(), nil
}

var (
_ json.Marshaler = (*RegistrationContactType)(nil)
_ json.Unmarshaler = (*RegistrationContactType)(nil)
_ driver.Valuer = (*RegistrationContactType)(nil)
_ sql.Scanner = (*RegistrationContactType)(nil)
_ gocsv.TypeMarshaller = (*RegistrationContactType)(nil)
_ json.Marshaler = (*RegistrationContactType)(nil)
_ json.Unmarshaler = (*RegistrationContactType)(nil)
_ driver.Valuer = (*RegistrationContactType)(nil)
_ sql.Scanner = (*RegistrationContactType)(nil)
)
20 changes: 10 additions & 10 deletions internal/data/wallets.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ var (
)

type Wallet struct {
ID string `json:"id" db:"id"`
ID string `json:"id" csv:"-" db:"id"`
Name string `json:"name" db:"name"`
Homepage string `json:"homepage,omitempty" db:"homepage"`
SEP10ClientDomain string `json:"sep_10_client_domain,omitempty" db:"sep_10_client_domain"`
DeepLinkSchema string `json:"deep_link_schema,omitempty" db:"deep_link_schema"`
Enabled bool `json:"enabled" db:"enabled"`
UserManaged bool `json:"user_managed,omitempty" db:"user_managed"`
Assets WalletAssets `json:"assets,omitempty" db:"assets"`
CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"`
DeletedAt *time.Time `json:"-" db:"deleted_at"`
Homepage string `json:"homepage,omitempty" csv:"-" db:"homepage"`
SEP10ClientDomain string `json:"sep_10_client_domain,omitempty" csv:"-" db:"sep_10_client_domain"`
DeepLinkSchema string `json:"deep_link_schema,omitempty" csv:"-" db:"deep_link_schema"`
Enabled bool `json:"enabled" csv:"-" db:"enabled"`
UserManaged bool `json:"user_managed,omitempty" csv:"-" db:"user_managed"`
Assets WalletAssets `json:"assets,omitempty" csv:"-" db:"assets"`
CreatedAt *time.Time `json:"created_at,omitempty" csv:"-" db:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty" csv:"-" db:"updated_at"`
DeletedAt *time.Time `json:"-" csv:"-" db:"deleted_at"`
}

type WalletInsert struct {
Expand Down
50 changes: 50 additions & 0 deletions internal/serve/httphandler/export_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package httphandler

import (
"fmt"
"net/http"
"time"

"github.com/gocarina/gocsv"

"github.com/stellar/stellar-disbursement-platform-backend/internal/data"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror"
"github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators"
)

type ExportHandler struct {
Models *data.Models
}

func (e ExportHandler) ExportDisbursements(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

validator := validators.NewDisbursementQueryValidator()
queryParams := validator.ParseParametersFromRequest(r)

if validator.HasErrors() {
httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw)
return
}

queryParams.Filters = validator.ValidateAndGetDisbursementFilters(queryParams.Filters)
if validator.HasErrors() {
httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw)
return
}

disbursements, err := e.Models.Disbursements.GetAll(ctx, e.Models.DBConnectionPool, queryParams, data.QueryTypeSelectAll)
if err != nil {
httperror.InternalError(ctx, "Failed to get disbursements", err, nil).Render(rw)
return
}

fileName := fmt.Sprintf("disbursements_%s.csv", time.Now().Format("2006-01-02-15-04-05"))
rw.Header().Set("Content-Type", "text/csv")
rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName))

if err := gocsv.Marshal(disbursements, rw); err != nil {
httperror.InternalError(ctx, "Failed to write CSV", err, nil).Render(rw)
return
}
}
Loading
Loading