diff --git a/internal/data/dibursements_state_machine.go b/internal/data/dibursements_state_machine.go index 7d078560..474b2a7a 100644 --- a/internal/data/dibursements_state_machine.go +++ b/internal/data/dibursements_state_machine.go @@ -15,6 +15,8 @@ const ( CompletedDisbursementStatus DisbursementStatus = "COMPLETED" ) +var NotStartedDisbursementStatuses = []DisbursementStatus{DraftDisbursementStatus, ReadyDisbursementStatus} + // TransitionTo transitions the disbursement status to the target state func (status DisbursementStatus) TransitionTo(targetState DisbursementStatus) error { return DisbursementStateMachineWithInitialState(status).TransitionTo(targetState.State()) diff --git a/internal/data/disbursement_instructions.go b/internal/data/disbursement_instructions.go index e529eeda..995aa5af 100644 --- a/internal/data/disbursement_instructions.go +++ b/internal/data/disbursement_instructions.go @@ -123,9 +123,9 @@ func (di DisbursementInstructionModel) ProcessAll(ctx context.Context, opts Disb } } - // Step 4: Delete all pre-existing payments tied to this disbursement for each receiver in one call - if err = di.paymentModel.DeleteAllForDisbursement(ctx, dbTx, opts.Disbursement.ID); err != nil { - return fmt.Errorf("deleting payments: %w", err) + // Step 4: Delete all pre-existing draft payments tied to this disbursement for each receiver in one call + if err = di.paymentModel.DeleteAllDraftForDisbursement(ctx, dbTx, opts.Disbursement.ID); err != nil { + return fmt.Errorf("deleting draft payments: %w", err) } // Step 5: Create payments for all receivers diff --git a/internal/data/disbursements.go b/internal/data/disbursements.go index e82fbefd..2b59144a 100644 --- a/internal/data/disbursements.go +++ b/internal/data/disbursements.go @@ -457,3 +457,25 @@ func (d *DisbursementModel) CompleteDisbursements(ctx context.Context, sqlExec d return nil } + +// Delete deletes a disbursement by ID +func (d *DisbursementModel) Delete(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) error { + disbursementQuery := `DELETE FROM disbursements WHERE id = $1 AND status = ANY($2)` + result, err := sqlExec.ExecContext(ctx, disbursementQuery, disbursementID, pq.Array(NotStartedDisbursementStatuses)) + if err != nil { + if strings.Contains(err.Error(), "violates foreign key constraint") { + return fmt.Errorf("deleting disbursement %s because it has associated payments: %w", disbursementID, err) + } + return fmt.Errorf("deleting disbursement %s: %w", disbursementID, err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("getting number of rows affected: %w", err) + } + if rowsAffected == 0 { + return ErrRecordNotFound + } + + return nil +} diff --git a/internal/data/disbursements_test.go b/internal/data/disbursements_test.go index 7d81fb36..170c9aa2 100644 --- a/internal/data/disbursements_test.go +++ b/internal/data/disbursements_test.go @@ -2,9 +2,11 @@ package data import ( "context" + "fmt" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -662,3 +664,107 @@ func Test_DisbursementModel_CompleteDisbursements(t *testing.T) { assert.Equal(t, CompletedDisbursementStatus, disbursement2.Status) }) } + +func Test_DisbursementModel_Delete(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + disbursementModel := &DisbursementModel{dbConnectionPool: dbConnectionPool} + ctx := context.Background() + + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVN") + + t.Run("successfully deletes draft disbursement", func(t *testing.T) { + disbursement := CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, Disbursement{ + Name: uuid.NewString(), + Asset: asset, + Wallet: wallet, + }) + + err := disbursementModel.Delete(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("successfully deletes ready disbursement", func(t *testing.T) { + disbursement := CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, Disbursement{ + Name: uuid.NewString(), + Status: ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + }) + + err := disbursementModel.Delete(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns error when disbursement not found", func(t *testing.T) { + err := disbursementModel.Delete(ctx, dbConnectionPool, "non-existent-id") + require.Error(t, err) + assert.EqualError(t, err, ErrRecordNotFound.Error()) + }) + + t.Run("returns error when disbursement is not in draft status", func(t *testing.T) { + disbursement := CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, Disbursement{ + Name: uuid.NewString(), + Status: StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + }) + + err := disbursementModel.Delete(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.EqualError(t, err, ErrRecordNotFound.Error()) + + // Verify disbursement still exists + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + }) + + t.Run("returns error when disbursement has associated payments", func(t *testing.T) { + disbursement := CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, Disbursement{ + Name: uuid.NewString(), + Asset: asset, + Wallet: wallet, + }) + + // Create a receiver and receiver wallet + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + // Create an associated payment + CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id", + StellarOperationID: "operation-id", + Status: SuccessPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + // Attempt to delete the disbursement + err := disbursementModel.Delete(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.ErrorContains(t, err, fmt.Sprintf("deleting disbursement %s because it has associated payments", disbursement.ID)) + + // Verify disbursement still exists + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + }) +} diff --git a/internal/data/payments.go b/internal/data/payments.go index 6d88cdaa..989dd802 100644 --- a/internal/data/payments.go +++ b/internal/data/payments.go @@ -302,14 +302,15 @@ func (p *PaymentModel) GetAll(ctx context.Context, queryParams *QueryParams, sql return payments, nil } -// DeleteAllForDisbursement deletes all payments for a given disbursement. -func (p *PaymentModel) DeleteAllForDisbursement(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) error { +// DeleteAllDraftForDisbursement deletes all payments for a given disbursement. +func (p *PaymentModel) DeleteAllDraftForDisbursement(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) error { query := ` DELETE FROM payments WHERE disbursement_id = $1 + AND status = $2 ` - result, err := sqlExec.ExecContext(ctx, query, disbursementID) + result, err := sqlExec.ExecContext(ctx, query, disbursementID, DraftPaymentStatus) if err != nil { return fmt.Errorf("error deleting payments for disbursement: %w", err) } diff --git a/internal/serve/httphandler/disbursement_handler.go b/internal/serve/httphandler/disbursement_handler.go index b9ff8dc7..dc58892d 100644 --- a/internal/serve/httphandler/disbursement_handler.go +++ b/internal/serve/httphandler/disbursement_handler.go @@ -19,6 +19,7 @@ import ( "github.com/stellar/go/support/log" "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/db" "github.com/stellar/stellar-disbursement-platform-backend/internal/data" "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" @@ -186,6 +187,53 @@ func (d DisbursementHandler) PostDisbursement(w http.ResponseWriter, r *http.Req httpjson.RenderStatus(w, http.StatusCreated, newDisbursement, httpjson.JSON) } +// DeleteDisbursement deletes a draft or ready disbursement and its associated payments +func (d DisbursementHandler) DeleteDisbursement(w http.ResponseWriter, r *http.Request) { + disbursementID := chi.URLParam(r, "id") + ctx := r.Context() + + ErrDisbursementStarted := errors.New("can't delete disbursement that has started") + + disbursement, err := db.RunInTransactionWithResult(ctx, d.Models.DBConnectionPool, nil, func(tx db.DBTransaction) (*data.Disbursement, error) { + // Check if disbursement exists and is in draft or ready status + disbursement, err := d.Models.Disbursements.Get(ctx, tx, disbursementID) + if err != nil { + return nil, fmt.Errorf("getting disbursement: %w", err) + } + + if !slices.Contains(data.NotStartedDisbursementStatuses, disbursement.Status) { + return nil, ErrDisbursementStarted + } + + // Delete associated payments + err = d.Models.Payment.DeleteAllDraftForDisbursement(ctx, tx, disbursementID) + if err != nil { + return nil, fmt.Errorf("deleting payments: %w", err) + } + + // Delete disbursement + err = d.Models.Disbursements.Delete(ctx, tx, disbursementID) + if err != nil { + return nil, fmt.Errorf("deleting draft or ready disbursement: %w", err) + } + + return disbursement, nil + }) + if err != nil { + switch { + case errors.Is(err, data.ErrRecordNotFound): + httperror.NotFound("Disbursement not found", err, nil).Render(w) + case errors.Is(err, ErrDisbursementStarted): + httperror.BadRequest("Cannot delete a disbursement that has started", err, nil).Render(w) + default: + httperror.InternalError(ctx, "Cannot delete disbursement", err, nil).Render(w) + } + return + } + + httpjson.RenderStatus(w, http.StatusOK, disbursement, httpjson.JSON) +} + // GetDisbursements returns a paginated list of disbursements func (d DisbursementHandler) GetDisbursements(w http.ResponseWriter, r *http.Request) { validator := validators.NewDisbursementQueryValidator() diff --git a/internal/serve/httphandler/disbursement_handler_test.go b/internal/serve/httphandler/disbursement_handler_test.go index 824d038d..ad1f87f1 100644 --- a/internal/serve/httphandler/disbursement_handler_test.go +++ b/internal/serve/httphandler/disbursement_handler_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -1932,6 +1933,147 @@ func Test_DisbursementHandler_GetDisbursementInstructions(t *testing.T) { } } +func Test_DisbursementHandler_DeleteDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + ctx := context.Background() + handler := &DisbursementHandler{ + Models: models, + } + + r := chi.NewRouter() + r.Delete("/disbursements/{id}", handler.DeleteDisbursement) + + // Create test fixtures + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + t.Run("successfully deletes draft disbursement", func(t *testing.T) { + disbursement := data.CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, data.Disbursement{ + Name: uuid.NewString(), + Asset: asset, + Wallet: wallet, + }) + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/disbursements/%s", disbursement.ID), nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var response data.Disbursement + require.NoError(t, json.NewDecoder(rr.Body).Decode(&response)) + assert.Equal(t, *disbursement, response) + + // Verify disbursement was deleted + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.Equal(t, data.ErrRecordNotFound, err) + }) + + t.Run("successfully deletes ready disbursement", func(t *testing.T) { + disbursement := data.CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, data.Disbursement{ + Name: uuid.NewString(), + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + }) + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/disbursements/%s", disbursement.ID), nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var response data.Disbursement + require.NoError(t, json.NewDecoder(rr.Body).Decode(&response)) + assert.Equal(t, *disbursement, response) + + // Verify disbursement was deleted + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.Error(t, err) + assert.Equal(t, data.ErrRecordNotFound, err) + }) + + t.Run("returns 404 when disbursement not found", func(t *testing.T) { + req, err := http.NewRequest(http.MethodDelete, "/disbursements/non-existent-id", nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("returns 400 when disbursement is not in draft status", func(t *testing.T) { + disbursement := data.CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, data.Disbursement{ + Name: uuid.NewString(), + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + }) + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/disbursements/%s", disbursement.ID), nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Cannot delete a disbursement that has started") + + // Verify disbursement still exists + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + }) + + t.Run("returns error when disbursement has associated payments", func(t *testing.T) { + disbursement := data.CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, data.Disbursement{ + Name: uuid.NewString(), + Asset: asset, + Wallet: wallet, + }) + + // Create a receiver and receiver wallet + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + // Create an associated payment + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id", + StellarOperationID: "operation-id", + Status: data.SuccessPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/disbursements/%s", disbursement.ID), nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Cannot delete disbursement") + + // Verify disbursement still exists + _, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + }) +} + func createCSVFile(t *testing.T, records [][]string) (io.Reader, error) { var buf bytes.Buffer writer := csv.NewWriter(&buf) diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 6f37cb53..3c85b6b5 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -266,6 +266,9 @@ func handleHTTP(o ServeOptions) *chi.Mux { r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). Post("/", handler.PostDisbursement) + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Delete("/{id}", handler.DeleteDisbursement) + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). Post("/{id}/instructions", handler.PostDisbursementInstructions) diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go index 84e60ff3..8276439a 100644 --- a/internal/serve/serve_test.go +++ b/internal/serve/serve_test.go @@ -443,6 +443,7 @@ func Test_handleHTTP_authenticatedEndpoints(t *testing.T) { {http.MethodGet, "/disbursements/1234"}, {http.MethodGet, "/disbursements/1234/receivers"}, {http.MethodPatch, "/disbursements/1234/status"}, + {http.MethodDelete, "/disbursements/1234"}, // Payments {http.MethodGet, "/payments"}, {http.MethodGet, "/payments/1234"},