diff --git a/internal/data/disbursements.go b/internal/data/disbursements.go index 4e6f3aeb..839725ef 100644 --- a/internal/data/disbursements.go +++ b/internal/data/disbursements.go @@ -336,12 +336,14 @@ func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams * qb.AddCondition("d.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) } - if queryType == QueryTypeSelectPaginated { + switch queryType { + case QueryTypeSelectPaginated: qb.AddPagination(queryParams.Page, queryParams.PageLimit) - } - - if queryType == QueryTypeSelectAll || queryType == QueryTypeSelectPaginated { qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d") + case QueryTypeSelectAll: + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d") + case QueryTypeCount: + // no need to sort or paginate. } query, params := qb.Build() diff --git a/internal/data/payments.go b/internal/data/payments.go index 754c63a0..14ff6328 100644 --- a/internal/data/payments.go +++ b/internal/data/payments.go @@ -640,13 +640,17 @@ func newPaymentQuery(baseQuery string, queryParams *QueryParams, sqlExec db.SQLE if queryParams.Filters[FilterKeyCreatedAtBefore] != nil { qb.AddCondition("p.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) } - if queryType == QueryTypeSelectPaginated { - qb.AddPagination(queryParams.Page, queryParams.PageLimit) - } - if queryType == QueryTypeSelectAll || queryType == QueryTypeSelectPaginated { + switch queryType { + case QueryTypeSelectPaginated: + qb.AddPagination(queryParams.Page, queryParams.PageLimit) qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "p") + case QueryTypeSelectAll: + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "p") + case QueryTypeCount: + // no need to sort or paginate. } + query, params := qb.Build() return sqlExec.Rebind(query), params } diff --git a/internal/data/receivers.go b/internal/data/receivers.go index 07e5dd27..bc3e64fd 100644 --- a/internal/data/receivers.go +++ b/internal/data/receivers.go @@ -17,11 +17,11 @@ import ( type Receiver struct { ID string `json:"id" db:"id"` + Email string `json:"email,omitempty" db:"email"` + PhoneNumber string `json:"phone_number,omitempty" db:"phone_number"` ExternalID string `json:"external_id,omitempty" db:"external_id"` CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"` UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"` - Email string `json:"email,omitempty" db:"email"` - PhoneNumber string `json:"phone_number,omitempty" db:"phone_number"` ReceiverStats } @@ -225,7 +225,7 @@ func (r *ReceiverModel) Count(ctx context.Context, sqlExec db.SQLExecuter, query FROM receivers r LEFT JOIN receiver_wallets rw ON rw.receiver_id = r.id ` - query, params := newReceiverQuery(baseQuery, queryParams, false, sqlExec) + query, params := newReceiverQuery(baseQuery, queryParams, sqlExec, QueryTypeCount) err := sqlExec.GetContext(ctx, &count, query, params...) if err != nil { @@ -236,7 +236,7 @@ func (r *ReceiverModel) Count(ctx context.Context, sqlExec db.SQLExecuter, query } // GetAll returns all RECEIVERS matching the given query parameters. -func (r *ReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) ([]Receiver, error) { +func (r *ReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams, queryType QueryType) ([]Receiver, error) { receivers := []Receiver{} baseQuery := ` @@ -309,7 +309,7 @@ func (r *ReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, quer ` query := fmt.Sprintf(baseQuery, receiverQuery) - query, params := newReceiverQuery(query, queryParams, true, sqlExec) + query, params := newReceiverQuery(query, queryParams, sqlExec, queryType) err := sqlExec.SelectContext(ctx, &receivers, query, params...) if err != nil { @@ -320,7 +320,7 @@ func (r *ReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, quer } // newReceiverQuery generates the full query and parameters for a receiver search query -func newReceiverQuery(baseQuery string, queryParams *QueryParams, paginated bool, sqlExec db.SQLExecuter) (string, []interface{}) { +func newReceiverQuery(baseQuery string, queryParams *QueryParams, sqlExec db.SQLExecuter, queryType QueryType) (string, []interface{}) { qb := NewQueryBuilder(baseQuery) if queryParams.Query != "" { q := "%" + queryParams.Query + "%" @@ -336,10 +336,17 @@ func newReceiverQuery(baseQuery string, queryParams *QueryParams, paginated bool if queryParams.Filters[FilterKeyCreatedAtBefore] != nil { qb.AddCondition("r.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) } - if paginated { - qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "r") + + switch queryType { + case QueryTypeSelectPaginated: qb.AddPagination(queryParams.Page, queryParams.PageLimit) + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "r") + case QueryTypeSelectAll: + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "r") + case QueryTypeCount: + // no need to sort or paginate. } + query, params := qb.Build() return sqlExec.Rebind(query), params } diff --git a/internal/data/receivers_test.go b/internal/data/receivers_test.go index 06dbe5ef..5cff2a1d 100644 --- a/internal/data/receivers_test.go +++ b/internal/data/receivers_test.go @@ -448,7 +448,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { require.Error(t, err, "not in transaction") }() - receivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{}) + receivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 0, len(receivers)) @@ -490,7 +490,9 @@ func Test_ReceiversModel_GetAll(t *testing.T) { require.Error(t, err, "not in transaction") }() - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, + &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}, + QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 2, len(actualReceivers)) @@ -548,7 +550,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { SortOrder: SortOrderASC, Page: 1, PageLimit: 1, - }) + }, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -591,7 +593,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { SortOrder: SortOrderASC, Page: 2, PageLimit: 1, - }) + }, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -634,7 +636,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { filters := map[FilterKey]interface{}{ FilterKeyStatus: DraftReceiversWalletStatus, } - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -672,7 +674,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { require.Error(t, err, "not in transaction") }() - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: receiver1Email}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: receiver1Email}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -710,7 +712,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { require.Error(t, err, "not in transaction") }() - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: "+99992222"}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: "+99992222"}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -752,7 +754,7 @@ func Test_ReceiversModel_GetAll(t *testing.T) { FilterKeyCreatedAtAfter: "2023-01-01", FilterKeyCreatedAtBefore: "2023-03-01", } - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 1, len(actualReceivers)) @@ -790,7 +792,9 @@ func Test_ReceiversModel_GetAll(t *testing.T) { require.Error(t, err, "not in transaction") }() - actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, + &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}, + QueryTypeSelectPaginated) require.NoError(t, err) assert.Equal(t, 2, len(actualReceivers)) @@ -861,7 +865,7 @@ func Test_ReceiversModel_GetAll_makeSureReceiversWithMultipleWalletsWillReturnAS CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, ReadyReceiversWalletStatus) CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, RegisteredReceiversWalletStatus) - receivers, err := receiverModel.GetAll(ctx, dbConnectionPool, &QueryParams{}) + receivers, err := receiverModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, QueryTypeSelectPaginated) require.NoError(t, err) assert.Len(t, receivers, 1) @@ -889,7 +893,9 @@ func Test_ReceiversModel_ParseReceiverIDs(t *testing.T) { err = dbTx.Rollback() require.Error(t, err, "not in transaction") }() - receivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + receivers, err := receiverModel.GetAll(ctx, dbTx, + &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}, + QueryTypeSelectPaginated) require.NoError(t, err) receiverIds := receiverModel.ParseReceiverIDs(receivers) diff --git a/internal/serve/httphandler/export_handler.go b/internal/serve/httphandler/export_handler.go index 38090de5..cebfbf86 100644 --- a/internal/serve/httphandler/export_handler.go +++ b/internal/serve/httphandler/export_handler.go @@ -120,3 +120,35 @@ func (e ExportHandler) ExportPayments(rw http.ResponseWriter, r *http.Request) { return } } + +func (e ExportHandler) ExportReceivers(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + validator := validators.NewReceiverQueryValidator() + queryParams := validator.ParseParametersFromRequest(r) + if validator.HasErrors() { + httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw) + return + } + + queryParams.Filters = validator.ValidateAndGetReceiverFilters(queryParams.Filters) + if validator.HasErrors() { + httperror.BadRequest("Request invalid", nil, validator.Errors).Render(rw) + return + } + + receivers, err := e.Models.Receiver.GetAll(ctx, e.Models.DBConnectionPool, queryParams, data.QueryTypeSelectAll) + if err != nil { + httperror.InternalError(ctx, "Failed to get receivers", err, nil).Render(rw) + return + } + + fileName := fmt.Sprintf("receivers_%s.csv", time.Now().Format("2006-01-02-15-04-05")) + rw.Header().Set("Content-Type", "text/csv") + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", fileName)) + + if err := gocsv.Marshal(receivers, rw); err != nil { + httperror.InternalError(ctx, "Failed to write CSV", err, nil).Render(rw) + return + } +} diff --git a/internal/serve/httphandler/export_handler_test.go b/internal/serve/httphandler/export_handler_test.go index f2274a96..d9a43986 100644 --- a/internal/serve/httphandler/export_handler_test.go +++ b/internal/serve/httphandler/export_handler_test.go @@ -275,3 +275,136 @@ func Test_ExportHandler_ExportPayments(t *testing.T) { }) } } + +func Test_ExportHandler_ExportReceivers(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + ctx := context.Background() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ExportHandler{ + Models: models, + } + + r := chi.NewRouter() + r.Get("/exports/receivers", handler.ExportReceivers) + + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + createdFirst := time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC) + createdLast := time.Date(2023, 3, 21, 23, 40, 20, 1431, time.UTC) + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + CreatedAt: &createdLast, + }) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + CreatedAt: &createdFirst, + }) + _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + tests := []struct { + name string + queryParams string + expectedStatusCode int + expectedReceivers []*data.Receiver + }{ + { + name: "success - returns CSV with no receivers", + queryParams: "status=draft", + expectedStatusCode: http.StatusOK, + expectedReceivers: []*data.Receiver{}, + }, + { + name: "success - returns CSV with all receivers", + queryParams: "sort=created_at&direction=desc", + expectedStatusCode: http.StatusOK, + expectedReceivers: []*data.Receiver{receiver, receiver2}, + }, + { + name: "success - return CSV with reverse order of receivers", + expectedStatusCode: http.StatusOK, + queryParams: "sort=created_at&direction=asc", + expectedReceivers: []*data.Receiver{receiver2, receiver}, + }, + { + name: "success - return CSV with only registered receivers", + expectedStatusCode: http.StatusOK, + queryParams: "status=registered", + expectedReceivers: []*data.Receiver{receiver}, + }, + { + name: "success - return CSV with only ready receivers", + expectedStatusCode: http.StatusOK, + queryParams: "status=ready", + expectedReceivers: []*data.Receiver{receiver2}, + }, + { + name: "error - invalid status", + queryParams: "status=invalid", + expectedStatusCode: http.StatusBadRequest, + expectedReceivers: nil, + }, + { + name: "error - invalid sort field", + queryParams: "sort=invalid", + expectedStatusCode: http.StatusBadRequest, + expectedReceivers: nil, + }, + { + name: "error - invalid direction", + queryParams: "direction=invalid", + expectedStatusCode: http.StatusBadRequest, + expectedReceivers: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + url := "/exports/receivers" + if tc.queryParams != "" { + url += "?" + tc.queryParams + } + req, err := http.NewRequest(http.MethodGet, url, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatusCode, rr.Code) + + if tc.expectedStatusCode == http.StatusOK { + csvReader := csv.NewReader(strings.NewReader(rr.Body.String())) + + header, err := csvReader.Read() + require.NoError(t, err) + + expectedHeaders := []string{ + "ID", "Email", "PhoneNumber", "ExternalID", "CreatedAt", "UpdatedAt", + "TotalPayments", "SuccessfulPayments", "FailedPayments", "CanceledPayments", + "RemainingPayments", "RegisteredWallets", "ReceivedAmounts", + } + assert.Equal(t, expectedHeaders, header) + + assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) + today := time.Now().Format("2006-01-02") + assert.Contains(t, rr.Header().Get("Content-Disposition"), fmt.Sprintf("attachment; filename=receivers_%s", today)) + + rows, err := csvReader.ReadAll() + require.NoError(t, err) + assert.Len(t, rows, len(tc.expectedReceivers)) + + for i, row := range rows { + assert.Equal(t, tc.expectedReceivers[i].ID, row[0]) + assert.Equal(t, tc.expectedReceivers[i].Email, row[1]) + assert.Equal(t, tc.expectedReceivers[i].PhoneNumber, row[2]) + assert.Equal(t, tc.expectedReceivers[i].ExternalID, row[3]) + } + } + }) + } +} diff --git a/internal/serve/httphandler/receiver_handler.go b/internal/serve/httphandler/receiver_handler.go index 847c5d61..e881f157 100644 --- a/internal/serve/httphandler/receiver_handler.go +++ b/internal/serve/httphandler/receiver_handler.go @@ -108,7 +108,7 @@ func (rh ReceiverHandler) GetReceivers(w http.ResponseWriter, r *http.Request) { return &httpResponse, nil } - receivers, err := rh.Models.Receiver.GetAll(ctx, dbTx, queryParams) + receivers, err := rh.Models.Receiver.GetAll(ctx, dbTx, queryParams, data.QueryTypeSelectPaginated) if err != nil { return nil, fmt.Errorf("error retrieving receivers: %w", err) } diff --git a/internal/serve/httphandler/receiver_handler_test.go b/internal/serve/httphandler/receiver_handler_test.go index c72405e3..819e96c5 100644 --- a/internal/serve/httphandler/receiver_handler_test.go +++ b/internal/serve/httphandler/receiver_handler_test.go @@ -1466,7 +1466,9 @@ func Test_ReceiverHandler_BuildReceiversResponse(t *testing.T) { require.Error(t, err, "not in transaction") }() - receivers, err := handler.Models.Receiver.GetAll(ctx, dbTx, &data.QueryParams{SortBy: data.SortFieldUpdatedAt, SortOrder: data.SortOrderDESC}) + receivers, err := handler.Models.Receiver.GetAll(ctx, dbTx, + &data.QueryParams{SortBy: data.SortFieldUpdatedAt, SortOrder: data.SortOrderDESC}, + data.QueryTypeSelectPaginated) require.NoError(t, err) receiversId := handler.Models.Receiver.ParseReceiverIDs(receivers) receiversWallets, err := handler.Models.ReceiverWallet.GetWithReceiverIds(ctx, dbTx, receiversId) diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 7f3bbcd6..ab896c7c 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -420,6 +420,7 @@ func handleHTTP(o ServeOptions) *chi.Mux { Route("/exports", func(r chi.Router) { r.Get("/disbursements", exportHandler.ExportDisbursements) r.Get("/payments", exportHandler.ExportPayments) + r.Get("/receivers", exportHandler.ExportReceivers) }) }) diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go index cb506801..58135b51 100644 --- a/internal/serve/serve_test.go +++ b/internal/serve/serve_test.go @@ -481,6 +481,7 @@ func Test_handleHTTP_authenticatedEndpoints(t *testing.T) { // Exports {http.MethodGet, "/exports/disbursements"}, {http.MethodGet, "/exports/payments"}, + {http.MethodGet, "/exports/receivers"}, } // Expect 401 as a response: